Skip to content

Commit e786a90

Browse files
authored
Merge branch 'OpenNMT:master' into master
2 parents 520c8e5 + 6e9b3ac commit e786a90

6 files changed

Lines changed: 93 additions & 27 deletions

File tree

include/ctranslate2/ops/flash_attention.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ namespace ctranslate2 {
66
namespace ops {
77
class FlashAttention : public Op {
88
public:
9-
FlashAttention(float queries_scale, dim_t sliding_window);
9+
FlashAttention(float queries_scale, dim_t sliding_window, bool is_causal = true);
1010

1111
void operator()(StorageView& queries,
1212
StorageView& keys,
@@ -25,6 +25,7 @@ namespace ctranslate2 {
2525
private:
2626
const float _queries_scale;
2727
const dim_t _sliding_window;
28+
const bool _is_causal;
2829
template <Device D>
2930
void compute(StorageView& queries,
3031
StorageView& keys,

src/models/whisper.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -389,17 +389,18 @@ namespace ctranslate2 {
389389
const ops::MedianFilter median_filter_op(median_filter_width);
390390
const dim_t batch_size = attention_probs.dim(0);
391391

392-
// The remaining operations are not implemented on GPU, so move back to CPU.
393-
attention_probs.move_to(Device::CPU, DataType::FLOAT32);
394-
395392
ops::LayerNorm(-2, 0)(attention_probs);
396393

397-
StorageView median_filter;
394+
StorageView median_filter(attention_probs.dtype(), attention_probs.device());
398395
median_filter_op(attention_probs, median_filter);
399396

400-
StorageView weights;
397+
StorageView weights(median_filter.dtype(), median_filter.device());
401398
ops::Mean(1)(median_filter, weights);
402399

400+
// The remaining operations are not implemented on GPU, so move back to CPU.
401+
synchronize_stream(weights.device());
402+
weights.move_to(Device::CPU, DataType::FLOAT32);
403+
403404
std::vector<std::vector<std::pair<dim_t, dim_t>>> alignments;
404405
alignments.reserve(batch_size);
405406

src/ops/flash_attention.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44

55
namespace ctranslate2 {
66
namespace ops {
7-
FlashAttention::FlashAttention(float queries_scale, dim_t sliding_window)
7+
FlashAttention::FlashAttention(float queries_scale, dim_t sliding_window, bool is_causal)
88
: _queries_scale(queries_scale)
9-
,_sliding_window(sliding_window)
9+
, _sliding_window(sliding_window)
10+
, _is_causal(is_causal)
1011
{
1112
}
1213

src/ops/flash_attention_gpu.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,8 @@ namespace ctranslate2 {
232232
num_heads_k = cached_keys->dim(2);
233233
}
234234

235+
bool is_causal = _is_causal;
235236
// causal=true is the same as causal=false in this case
236-
bool is_causal = true;
237237
if (seqlen_q == 1 && !alibi) { is_causal = false; }
238238
if (is_causal) { window_size_right = 0; }
239239

src/ops/layer_norm_gpu.cu

Lines changed: 78 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,15 @@ namespace at {
1515
const T* beta,
1616
T* Y);
1717

18+
template <typename T, typename SizeT>
19+
__global__ void LayerNormAxisForwardCUDAKernel(SizeT N,
20+
SizeT inner,
21+
float eps,
22+
const T* X,
23+
const T* gamma,
24+
const T* beta,
25+
T* Y);
26+
1827
}
1928
}
2029

@@ -30,19 +39,29 @@ namespace ctranslate2 {
3039
const dim_t axis,
3140
const dim_t outer_size,
3241
const dim_t axis_size,
33-
const dim_t,
42+
const dim_t inner_size,
3443
StorageView& output) const {
35-
if (axis != input.rank() - 1 || !beta || !gamma)
36-
throw std::invalid_argument("Generalized LayerNorm is currently not implemented on GPU");
37-
38-
at::native::LayerNormForwardCUDAKernel<cuda::device_type<T>, cuda::index_t>
39-
<<<outer_size, CUDA_NUM_THREADS, 0, cuda::get_cuda_stream()>>>(
40-
axis_size,
41-
_epsilon,
42-
cuda::device_cast(input.data<T>()),
43-
cuda::device_cast(gamma->data<T>()),
44-
cuda::device_cast(beta->data<T>()),
45-
cuda::device_cast(output.data<T>()));
44+
if (axis == input.rank() - 1) {
45+
at::native::LayerNormForwardCUDAKernel<cuda::device_type<T>, cuda::index_t>
46+
<<<outer_size, CUDA_NUM_THREADS, 0, cuda::get_cuda_stream()>>>(
47+
axis_size,
48+
_epsilon,
49+
cuda::device_cast(input.data<T>()),
50+
gamma ? cuda::device_cast(gamma->data<T>()) : nullptr,
51+
beta ? cuda::device_cast(beta->data<T>()) : nullptr,
52+
cuda::device_cast(output.data<T>()));
53+
} else {
54+
const dim_t blocks = std::min(outer_size * inner_size, cuda::max_blocks);
55+
at::native::LayerNormAxisForwardCUDAKernel<cuda::device_type<T>, cuda::index_t>
56+
<<<blocks, CUDA_NUM_THREADS, 0, cuda::get_cuda_stream()>>>(
57+
axis_size,
58+
inner_size,
59+
_epsilon,
60+
cuda::device_cast(input.data<T>()),
61+
gamma ? cuda::device_cast(gamma->data<T>()) : nullptr,
62+
beta ? cuda::device_cast(beta->data<T>()) : nullptr,
63+
cuda::device_cast(output.data<T>()));
64+
}
4665
}
4766

4867
#define DECLARE_IMPL(T) \
@@ -181,7 +200,53 @@ namespace at {
181200

182201
for (SizeT j = threadIdx.x; j < N; j += blockDim.x) {
183202
const SizeT index = i * N + j;
184-
Y[index] = (float(X[index]) - s_mean) * s_variance * float(gamma[j]) + float(beta[j]);
203+
const float gamma_v = gamma == nullptr ? float(1) : float(gamma[j]);
204+
const float beta_v = beta == nullptr ? float(0) : float(beta[j]);
205+
Y[index] = T((float(X[index]) - s_mean) * s_variance * gamma_v + beta_v);
206+
}
207+
}
208+
209+
template <typename T, typename SizeT>
210+
__global__ void LayerNormAxisForwardCUDAKernel(SizeT N,
211+
SizeT inner_size,
212+
float eps,
213+
const T* X,
214+
const T* gamma,
215+
const T* beta,
216+
T* Y) {
217+
typedef cub::BlockReduce<float, CUDA_NUM_THREADS> BlockReduce;
218+
__shared__ typename BlockReduce::TempStorage m_temp_storage;
219+
__shared__ typename BlockReduce::TempStorage v_temp_storage;
220+
__shared__ float s_mean;
221+
__shared__ float s_variance;
222+
223+
const SizeT i = blockIdx.x / inner_size;
224+
const SizeT j = blockIdx.x % inner_size;
225+
226+
float sum1 = 0;
227+
float sum2 = 0;
228+
for (SizeT k = threadIdx.x; k < N; k += blockDim.x) {
229+
const SizeT index = (i * N + k) * inner_size + j;
230+
sum1 += float(X[index]);
231+
sum2 += float(X[index]) * float(X[index]);
232+
}
233+
sum1 = BlockReduce(m_temp_storage).Sum(sum1);
234+
sum2 = BlockReduce(v_temp_storage).Sum(sum2);
235+
if (threadIdx.x == 0) {
236+
const float scale = float(1) / float(N);
237+
sum1 *= scale;
238+
sum2 = fmaxf(sum2 * scale - sum1 * sum1, float(0));
239+
s_mean = sum1;
240+
s_variance = rsqrtf(sum2 + eps);
241+
}
242+
243+
__syncthreads();
244+
245+
for (SizeT k = threadIdx.x; k < N; k += blockDim.x) {
246+
const SizeT index = (i * N + k) * inner_size + j;
247+
const float gamma_v = gamma == nullptr ? float(1) : float(gamma[j]);
248+
const float beta_v = beta == nullptr ? float(0) : float(beta[j]);
249+
Y[index] = T((float(X[index]) - s_mean) * s_variance * gamma_v + beta_v);
185250
}
186251
}
187252

tests/ops_test.cc

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -724,9 +724,6 @@ TEST_P(OpDeviceFPTest, LayerNorm) {
724724

725725
TEST_P(OpDeviceFPTest, LayerNormAxis) {
726726
const Device device = GetParam().device;
727-
if (device == Device::CUDA) {
728-
GTEST_SKIP() << "Generalized LayerNorm is not implemented on GPU";
729-
}
730727
const DataType dtype = GetParam().dtype;
731728
const float error = GetParam().error;
732729
StorageView x({2, 3, 2}, std::vector<float>{
@@ -745,7 +742,7 @@ TEST_P(OpDeviceFPTest, LayerNormAxis) {
745742
1.4136513471603394, -1.3856042623519897}, device);
746743
StorageView y(dtype, device);
747744
ops::LayerNorm(1, 0)(x.to(dtype), y);
748-
expect_storage_eq(y.to_float32(), expected, error);
745+
expect_storage_eq(y.to_float32(), expected, error * 10);
749746
}
750747

751748
TEST_P(OpDeviceFPTest, RMSNorm) {
@@ -780,7 +777,8 @@ TEST_P(OpDeviceTest, QuantizeINT8) {
780777
}
781778

782779
// With rounding before cast and shift to uint8.
783-
{
780+
// Shift to uin8_t is not defined on CUDA
781+
if (device != Device::CUDA) {
784782
StorageView expected_qa(a.shape(), std::vector<int8_t>{1, 90, -64, -103, -98, -1, 110, -128});
785783
ops::Quantize(ops::Quantize::ScaleType::GLOBAL, true, true)(a, qa, scale);
786784
expect_storage_eq(scale, expected_scale);

0 commit comments

Comments
 (0)