Skip to content

Commit ee0a539

Browse files
authored
Gemma 3 conversion improvements (#1991)
1 parent 6f2bd3e commit ee0a539

2 files changed

Lines changed: 10 additions & 2 deletions

File tree

python/ctranslate2/converters/transformers.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1859,8 +1859,12 @@ def get_model_spec(self, model):
18591859
"Quantization type '%s' is not yet implemented."
18601860
% quantization_config.quant_method
18611861
)
1862+
quant_group_size = quantization_config.group_size
1863+
quant_bits = quantization_config.bits
18621864
else:
18631865
quant_type = common_spec.Quantization.CT2
1866+
quant_group_size = None
1867+
quant_bits = None
18641868

18651869
# Create base spec using from_config
18661870
spec = transformer_spec.TransformerDecoderModelSpec.from_config(
@@ -1881,6 +1885,9 @@ def get_model_spec(self, model):
18811885
head_dim=head_dim,
18821886
sliding_window=sliding_window, # Default to local sliding window
18831887
pre_post_layer_norm=True,
1888+
quant_type=quant_type,
1889+
quant_group_size=quant_group_size,
1890+
quant_bits=quant_bits,
18841891
qk_norm=True,
18851892
)
18861893

@@ -1933,7 +1940,8 @@ def set_config(self, config, model, tokenizer):
19331940
config.eos_token = tokenizer.eos_token
19341941

19351942
def set_layer_norm(self, spec, layer_norm):
1936-
spec.gamma = layer_norm.weight + 1.0
1943+
spec.gamma = layer_norm.weight
1944+
spec.layer_norm_use_residual = True
19371945

19381946
def set_decoder(self, spec, module, quant_type=common_spec.Quantization.CT2):
19391947
spec.scale_embeddings = True

python/ctranslate2/specs/transformer_spec.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ def __init__(
275275
self.project_in = common_spec.LinearSpec()
276276
self.project_out = common_spec.LinearSpec()
277277

278-
if quant_type is not None:
278+
if quant_type:
279279
self._config["quantization_type"] = quant_type
280280
self._config["quantization_bits"] = quant_bits
281281
self._config["quantization_group_size"] = quant_group_size

0 commit comments

Comments
 (0)