@@ -1859,8 +1859,12 @@ def get_model_spec(self, model):
18591859 "Quantization type '%s' is not yet implemented."
18601860 % quantization_config .quant_method
18611861 )
1862+ quant_group_size = quantization_config .group_size
1863+ quant_bits = quantization_config .bits
18621864 else :
18631865 quant_type = common_spec .Quantization .CT2
1866+ quant_group_size = None
1867+ quant_bits = None
18641868
18651869 # Create base spec using from_config
18661870 spec = transformer_spec .TransformerDecoderModelSpec .from_config (
@@ -1881,6 +1885,9 @@ def get_model_spec(self, model):
18811885 head_dim = head_dim ,
18821886 sliding_window = sliding_window , # Default to local sliding window
18831887 pre_post_layer_norm = True ,
1888+ quant_type = quant_type ,
1889+ quant_group_size = quant_group_size ,
1890+ quant_bits = quant_bits ,
18841891 qk_norm = True ,
18851892 )
18861893
@@ -1933,7 +1940,8 @@ def set_config(self, config, model, tokenizer):
19331940 config .eos_token = tokenizer .eos_token
19341941
19351942 def set_layer_norm (self , spec , layer_norm ):
1936- spec .gamma = layer_norm .weight + 1.0
1943+ spec .gamma = layer_norm .weight
1944+ spec .layer_norm_use_residual = True
19371945
19381946 def set_decoder (self , spec , module , quant_type = common_spec .Quantization .CT2 ):
19391947 spec .scale_embeddings = True
0 commit comments