Skip to content

Commit 4ed5c2c

Browse files
committed
Remove unnecessary check from WavLM (refs #1977)
1 parent e786a90 commit 4ed5c2c

3 files changed

Lines changed: 270 additions & 0 deletions

File tree

include/ctranslate2/layers/wavlm.h

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
#pragma once
2+
3+
#include <optional>
4+
#include "ctranslate2/layers/transformer.h"
5+
6+
namespace ctranslate2 {
7+
namespace layers {
8+
9+
class WavLMLayerNormConvLayer : public Layer {
10+
public:
11+
WavLMLayerNormConvLayer(const models::Model& model,
12+
const std::string& scope,
13+
dim_t stride,
14+
dim_t padding);
15+
16+
void operator()(const StorageView& input, StorageView& output) const;
17+
18+
DataType output_type() const override {
19+
return _conv.output_type();
20+
}
21+
22+
dim_t output_size() const override {
23+
return _conv.output_size();
24+
}
25+
26+
private:
27+
dim_t _stride;
28+
dim_t _padding;
29+
const Conv1D _conv;
30+
const LayerNorm _output_norm;
31+
const ops::Transpose _transpose;
32+
const ops::GELU _gelu;
33+
};
34+
35+
class WavLMPosConvLayer : public Layer {
36+
public:
37+
WavLMPosConvLayer(const models::Model& model, const std::string& scope);
38+
39+
void operator()(const StorageView& input, StorageView& output) const;
40+
41+
DataType output_type() const override {
42+
return _conv.output_type();
43+
}
44+
45+
dim_t output_size() const override {
46+
return _conv.output_size();
47+
}
48+
49+
private:
50+
const Conv1D _conv;
51+
const ops::Transpose _transpose;
52+
const ops::GELU _gelu;
53+
};
54+
55+
class WavLMEncoder : public Layer {
56+
public:
57+
WavLMEncoder(const models::Model& model, const std::string& scope);
58+
59+
void operator()(const StorageView& features, StorageView& output);
60+
61+
DataType output_type() const override {
62+
if (_lm_head) {
63+
return (*_lm_head).output_type();
64+
}
65+
else {
66+
return _output_norm.output_type();
67+
}
68+
}
69+
70+
dim_t output_size() const override {
71+
if (_lm_head) {
72+
return (*_lm_head).output_size();
73+
}
74+
else {
75+
return _output_norm.output_size();
76+
}
77+
}
78+
79+
dim_t input_size() const {
80+
return 1024;
81+
}
82+
83+
const StorageView* _upgraded_model;
84+
85+
private:
86+
const StorageView* _return_logits;
87+
std::optional<WavLMLayerNormConvLayer> _feat_layer0;
88+
std::optional<std::vector<std::unique_ptr<const WavLMLayerNormConvLayer>>> _feat_layers;
89+
std::optional<LayerNorm> _fp_norm;
90+
std::optional<Dense> _fp_ff;
91+
std::optional<WavLMPosConvLayer> _pos_conv_embed;
92+
const ops::Transpose _transpose;
93+
const ops::GELU _gelu;
94+
const dim_t _num_heads;
95+
const std::vector<std::unique_ptr<const TransformerEncoderLayer>> _layers;
96+
const LayerNorm _output_norm;
97+
std::optional<Dense> _lm_head;
98+
};
99+
100+
}
101+
}

include/ctranslate2/models/wavlm.h

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
#pragma once
2+
3+
//#include "ctranslate2/generation.h"
4+
#include "ctranslate2/layers/wavlm.h"
5+
#include "ctranslate2/models/model.h"
6+
#include "ctranslate2/replica_pool.h"
7+
8+
namespace ctranslate2 {
9+
namespace models {
10+
11+
struct WavLMOptions {
12+
// Maximum generation length.
13+
size_t max_length = 448;
14+
15+
// Randomly sample from the top K candidates (set 0 to sample from the full distribution).
16+
size_t sampling_topk = 1;
17+
18+
// Maximum index of the first predicted timestamp.
19+
size_t max_initial_timestamp_index = 50;
20+
21+
// Suppress blank outputs at the beginning of the sampling.
22+
bool suppress_blank = true;
23+
24+
// List of token IDs to suppress.
25+
// -1 will suppress a default set of symbols as defined in the model config.json file.
26+
std::vector<int> suppress_tokens = {-1};
27+
};
28+
29+
30+
class WavLMModel : public Model {
31+
public:
32+
const Vocabulary& get_vocabulary() const;
33+
size_t current_spec_revision() const override;
34+
bool is_quantizable(const std::string& variable_name) const override;
35+
bool is_linear_weight(const std::string& variable_name) const override;
36+
std::unique_ptr<Model> clone() const override;
37+
38+
bool use_global_int16_scale() const override {
39+
return false;
40+
}
41+
42+
protected:
43+
void initialize(ModelReader& model_reader) override;
44+
private:
45+
std::shared_ptr<const Vocabulary> _vocabulary;
46+
};
47+
48+
class WavLMReplica : public ModelReplica {
49+
public:
50+
static std::unique_ptr<WavLMReplica> create_from_model(const Model& model);
51+
52+
WavLMReplica(const std::shared_ptr<const WavLMModel>& model);
53+
StorageView encode(StorageView features, const bool to_cpu);
54+
private:
55+
const std::shared_ptr<const WavLMModel> _model;
56+
const std::unique_ptr<layers::WavLMEncoder> _encoder;
57+
};
58+
59+
class WavLM : public ReplicaPool<WavLMReplica> {
60+
public:
61+
using ReplicaPool::ReplicaPool;
62+
std::future<StorageView> encode(const StorageView& features, const bool to_cpu);
63+
};
64+
65+
}
66+
}

src/models/wavlm.cc

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
#include "ctranslate2/models/wavlm.h"
2+
3+
#include <algorithm>
4+
5+
#include "ctranslate2/decoding.h"
6+
7+
#include "dispatch.h"
8+
#include "dtw.h"
9+
10+
#ifdef CT2_WITH_CUDA
11+
# include "cuda/utils.h"
12+
#endif
13+
14+
15+
namespace ctranslate2 {
16+
namespace models {
17+
18+
const Vocabulary& WavLMModel::get_vocabulary() const {
19+
return *_vocabulary;
20+
}
21+
22+
size_t WavLMModel::current_spec_revision() const {
23+
return 3;
24+
}
25+
26+
void WavLMModel::initialize(ModelReader& model_reader) {
27+
VocabularyInfo vocab_info;
28+
vocab_info.unk_token = "[UNK]";
29+
vocab_info.bos_token = "<s>";
30+
vocab_info.eos_token = "</s>";
31+
32+
_vocabulary = load_vocabulary(model_reader, "vocabulary", std::move(vocab_info));
33+
if (!_vocabulary)
34+
throw std::runtime_error("Cannot load the vocabulary from the model directory");
35+
}
36+
37+
bool WavLMModel::is_quantizable(const std::string& variable_name) const {
38+
return Model::is_quantizable(variable_name);
39+
}
40+
41+
bool WavLMModel::is_linear_weight(const std::string& variable_name) const {
42+
return is_quantizable(variable_name) && variable_name.find("embeddings") == std::string::npos;
43+
}
44+
45+
std::unique_ptr<Model> WavLMModel::clone() const {
46+
return std::make_unique<WavLMModel>(*this);
47+
}
48+
49+
50+
std::unique_ptr<WavLMReplica> WavLMReplica::create_from_model(const Model& model) {
51+
if (!dynamic_cast<const WavLMModel*>(&model))
52+
throw std::invalid_argument("The model is not a WavLM model");
53+
54+
const auto scoped_device_setter = model.get_scoped_device_setter();
55+
const auto model_ptr = model.shared_from_this();
56+
const auto concrete_model = std::static_pointer_cast<const WavLMModel>(model_ptr);
57+
return std::make_unique<WavLMReplica>(concrete_model);
58+
}
59+
60+
WavLMReplica::WavLMReplica(const std::shared_ptr<const WavLMModel>& model)
61+
: ModelReplica(model)
62+
, _model(model)
63+
, _encoder(std::make_unique<layers::WavLMEncoder>(*model, "encoder"))
64+
{
65+
}
66+
67+
StorageView WavLMReplica::encode(StorageView features, const bool to_cpu) {
68+
PROFILE("WavLMReplica::encode");
69+
70+
#ifdef CT2_WITH_CUDA
71+
const cuda::UseTrueFp16GemmInScope use_true_fp16_gemm(false);
72+
#endif
73+
74+
const auto scoped_device_setter = _model->get_scoped_device_setter();
75+
const Device device = _model->device();
76+
const DataType dtype = _encoder->output_type();
77+
features.move_to(device, dtype);
78+
79+
StorageView encoder_output(dtype, device);
80+
(*_encoder)(features, encoder_output);
81+
82+
if (to_cpu) {
83+
if (device != Device::CPU)
84+
encoder_output = encoder_output.to(Device::CPU);
85+
86+
return encoder_output;
87+
}
88+
89+
// Ensure all operations are finished before returning the output.
90+
synchronize_stream(device);
91+
92+
return encoder_output;
93+
}
94+
95+
std::future<StorageView> WavLM::encode(const StorageView& features, const bool to_cpu) {
96+
return post<StorageView>(
97+
[features = features.sync_copy(), to_cpu](WavLMReplica& replica) mutable {
98+
return replica.encode(std::move(features), to_cpu);
99+
});
100+
}
101+
102+
}
103+
}

0 commit comments

Comments
 (0)