@@ -15,6 +15,15 @@ namespace at {
1515 const T* beta,
1616 T* Y);
1717
18+ template <typename T, typename SizeT>
19+ __global__ void LayerNormAxisForwardCUDAKernel (SizeT N,
20+ SizeT inner,
21+ float eps,
22+ const T* X,
23+ const T* gamma,
24+ const T* beta,
25+ T* Y);
26+
1827 }
1928}
2029
@@ -30,19 +39,29 @@ namespace ctranslate2 {
3039 const dim_t axis,
3140 const dim_t outer_size,
3241 const dim_t axis_size,
33- const dim_t ,
42+ const dim_t inner_size ,
3443 StorageView& output) const {
35- if (axis != input.rank () - 1 || !beta || !gamma)
36- throw std::invalid_argument (" Generalized LayerNorm is currently not implemented on GPU" );
37-
38- at::native::LayerNormForwardCUDAKernel<cuda::device_type<T>, cuda::index_t >
39- <<<outer_size, CUDA_NUM_THREADS, 0 , cuda::get_cuda_stream()>>> (
40- axis_size,
41- _epsilon,
42- cuda::device_cast (input.data <T>()),
43- cuda::device_cast (gamma->data <T>()),
44- cuda::device_cast (beta->data <T>()),
45- cuda::device_cast (output.data <T>()));
44+ if (axis == input.rank () - 1 ) {
45+ at::native::LayerNormForwardCUDAKernel<cuda::device_type<T>, cuda::index_t >
46+ <<<outer_size, CUDA_NUM_THREADS, 0 , cuda::get_cuda_stream()>>> (
47+ axis_size,
48+ _epsilon,
49+ cuda::device_cast (input.data <T>()),
50+ gamma ? cuda::device_cast (gamma->data <T>()) : nullptr ,
51+ beta ? cuda::device_cast (beta->data <T>()) : nullptr ,
52+ cuda::device_cast (output.data <T>()));
53+ } else {
54+ const dim_t blocks = std::min (outer_size * inner_size, cuda::max_blocks);
55+ at::native::LayerNormAxisForwardCUDAKernel<cuda::device_type<T>, cuda::index_t >
56+ <<<blocks, CUDA_NUM_THREADS, 0 , cuda::get_cuda_stream()>>> (
57+ axis_size,
58+ inner_size,
59+ _epsilon,
60+ cuda::device_cast (input.data <T>()),
61+ gamma ? cuda::device_cast (gamma->data <T>()) : nullptr ,
62+ beta ? cuda::device_cast (beta->data <T>()) : nullptr ,
63+ cuda::device_cast (output.data <T>()));
64+ }
4665 }
4766
4867#define DECLARE_IMPL (T ) \
@@ -181,7 +200,53 @@ namespace at {
181200
182201 for (SizeT j = threadIdx .x ; j < N; j += blockDim .x ) {
183202 const SizeT index = i * N + j;
184- Y[index] = (float (X[index]) - s_mean) * s_variance * float (gamma[j]) + float (beta[j]);
203+ const float gamma_v = gamma == nullptr ? float (1 ) : float (gamma[j]);
204+ const float beta_v = beta == nullptr ? float (0 ) : float (beta[j]);
205+ Y[index] = T ((float (X[index]) - s_mean) * s_variance * gamma_v + beta_v);
206+ }
207+ }
208+
209+ template <typename T, typename SizeT>
210+ __global__ void LayerNormAxisForwardCUDAKernel (SizeT N,
211+ SizeT inner_size,
212+ float eps,
213+ const T* X,
214+ const T* gamma,
215+ const T* beta,
216+ T* Y) {
217+ typedef cub::BlockReduce<float , CUDA_NUM_THREADS> BlockReduce;
218+ __shared__ typename BlockReduce::TempStorage m_temp_storage;
219+ __shared__ typename BlockReduce::TempStorage v_temp_storage;
220+ __shared__ float s_mean;
221+ __shared__ float s_variance;
222+
223+ const SizeT i = blockIdx .x / inner_size;
224+ const SizeT j = blockIdx .x % inner_size;
225+
226+ float sum1 = 0 ;
227+ float sum2 = 0 ;
228+ for (SizeT k = threadIdx .x ; k < N; k += blockDim .x ) {
229+ const SizeT index = (i * N + k) * inner_size + j;
230+ sum1 += float (X[index]);
231+ sum2 += float (X[index]) * float (X[index]);
232+ }
233+ sum1 = BlockReduce (m_temp_storage).Sum (sum1);
234+ sum2 = BlockReduce (v_temp_storage).Sum (sum2);
235+ if (threadIdx .x == 0 ) {
236+ const float scale = float (1 ) / float (N);
237+ sum1 *= scale;
238+ sum2 = fmaxf (sum2 * scale - sum1 * sum1, float (0 ));
239+ s_mean = sum1;
240+ s_variance = rsqrtf (sum2 + eps);
241+ }
242+
243+ __syncthreads ();
244+
245+ for (SizeT k = threadIdx .x ; k < N; k += blockDim .x ) {
246+ const SizeT index = (i * N + k) * inner_size + j;
247+ const float gamma_v = gamma == nullptr ? float (1 ) : float (gamma[j]);
248+ const float beta_v = beta == nullptr ? float (0 ) : float (beta[j]);
249+ Y[index] = T ((float (X[index]) - s_mean) * s_variance * gamma_v + beta_v);
185250 }
186251 }
187252
0 commit comments