@@ -866,8 +866,8 @@ def __init__(self, N, d_in, d_out, d=1, activation='id', debug=False):
866866 self .f_input = self .Y
867867
868868 def __repr__ (self ):
869- return '%s(%s, %s, %s, activation=%s)' % \
870- (type (self ).__name__ , self .N , self .d_in ,
869+ return '%s(%s, %s, %s, %s, activation=%s)' % \
870+ (type (self ).__name__ , self .N , self .d_in , self . d ,
871871 self .d_out , repr (self .activation ))
872872
873873 def reset (self ):
@@ -1755,7 +1755,7 @@ class LayerNorm(Layer): # Changed class name
17551755 thetas = lambda self : (self .weights , self .bias )
17561756 nablas = lambda self : (self .nabla_weights , self .nabla_bias )
17571757
1758- def __init__ (self , shape , approx = False , layernorm_eps = None , args = None ):
1758+ def __init__ (self , shape , approx = True , layernorm_eps = None , args = None ):
17591759 if len (shape ) == 2 :
17601760 shape = [shape [0 ], 1 , shape [1 ]] # Not sure why this extra dimension is added
17611761 tensors = (Tensor (shape , sfix ) for i in range (4 ))
@@ -1810,6 +1810,7 @@ def _(*arg):
18101810 tmp = self .weights [:] * (sel_X [:] - mu_sel ) * fac_sel # Removed self.mu reference
18111811 sel_Y [:] = self .bias [:] + tmp
18121812
1813+ @_layer_method_call_tape
18131814 def forward (self , batch , training = False ):
18141815 d = self .X .sizes [1 ]
18151816 d_in = self .X .sizes [2 ]
@@ -2670,7 +2671,7 @@ class BertBase(BaseLayer, FixBase):
26702671class BertPooler (BertBase ):
26712672
26722673 thetas = lambda self : self .dense .thetas ()
2673- nablas = lambda self : self .dense .nablas () # refer to downstream layers?
2674+ nablas = lambda self : self .dense .nablas ()
26742675
26752676 def __init__ (self , n_examples , seq_len , hidden_state ):
26762677 input_shape = [n_examples , seq_len , hidden_state ]
@@ -2679,28 +2680,19 @@ def __init__(self, n_examples, seq_len, hidden_state):
26792680 self .dense = Dense (n_examples , hidden_state , hidden_state )
26802681 self .activation = Tanh (output_shape )
26812682
2682- self .d_out = hidden_state
2683-
2684-
2685- def _forward (self , batch ):
2686- # self.dense.X.address = self.X.address
26872683 self .activation .X .address = self .dense .Y .address
26882684 self .activation .Y .address = self .Y .address
26892685
2690- # grab the first repr?
2686+ self .d_out = hidden_state
2687+
2688+ def _forward (self , batch ):
26912689 # batch contains [n_batch, n_heads, n_dim]
26922690 @for_range (len (batch ))
26932691 def _ (j ):
26942692 self .dense .X [j ][:] = self .X [batch [j ]][0 ][:]
26952693
2696- # if self.debug_output:
2697- # print_ln("forward layer pooler.dense X %s", self.dense.X.reveal_nested())
2698-
26992694 self .dense .forward (batch )
2700- # print_ln("LINEAR Layer weights after bertpooler.dense: %s", self.opt.layers[-2].W.reveal_nested())
2701-
2702- self .activation ._forward (batch )
2703- # print_ln("LINEAR Layer weights after bertpooler.activation: %s", self.opt.layers[-2].W.reveal_nested())
2695+ self .activation .forward (batch )
27042696
27052697 def reset (self ):
27062698 self .dense .reset ()
@@ -2767,44 +2759,28 @@ def __init__(self, n_examples, seq_len, hidden_state, intermediate_size, num_att
27672759 self .intermediate = BertIntermediate (internal_shape , hidden_state , intermediate_size , seq_len )
27682760 self .output = BertOutput (internal_shape , intermediate_size , hidden_state , seq_len , dropout , layernorm_eps , rsqrt_approx )
27692761
2770- self .hidden_state = sfix .Tensor (input_shape ) # TODO: Could also make this smaller
2771- # self.nabla_hidden_state = sfix.Tensor(input_shape)
2772- # self.nabla_hidden_state.alloc()
2773-
2774- # self.X.address = self.multi_head_attention.X.address
2775- # self.Y.address = self.output.Y.address
2776-
27772762 self .d_out = hidden_state
27782763
2779- print ("Init BertLayer" , input_shape , output_shape )
2780-
2764+ @_layer_method_call_tape
27812765 def forward (self , batch , training = False ):
27822766 if batch is None :
27832767 batch = Array .create_from (regint (0 ))
27842768
27852769 self .multi_head_attention ._X .address = self .X .address
27862770 self .output .Y .address = self .Y .address
2787- self .hidden_state .address = self .X .address
2788- # self.multi_head_attention.Y.address = self.Y.address
2789-
2790- self .multi_head_attention .forward (batch , self .hidden_state , training )
2791- # if self.debug_output:
2792- # print_ln("our layer X %s %s", self.X[0][0][0].reveal(), self.output.X[0][0][0].reveal())
27932771
2772+ self .multi_head_attention .forward (batch , self .X , training )
27942773 if self .debug_output :
27952774 print_ln ("forward layer multi_head_attention %s %s" , self .multi_head_attention .Y [0 ][1 ][0 ].reveal (), sum (sum (self .multi_head_attention .Y [0 ].reveal ())))
27962775 # print_ln("forward layer multi_head_attention full %s", self.multi_head_attention.Y.reveal())
27972776
2798- print ("Forward Attention" )
2799-
28002777 batch_inc = regint .Array (len (batch ))
28012778 batch_inc .assign (regint .inc (len (batch )))
28022779 self .intermediate .X .address = self .multi_head_attention .Y .address
28032780 self .intermediate .forward (batch_inc )
28042781
28052782 if self .debug_output :
28062783 print_ln ("forward layer intermediate %s %s %s" , self .intermediate .Y .shape , self .intermediate .Y [0 ][1 ][0 :20 ].reveal (), sum (sum (self .intermediate .Y [0 ].reveal ())))
2807-
28082784 print_ln (" " )
28092785
28102786 self .output .X .address = self .intermediate .Y .address
@@ -2813,14 +2789,7 @@ def forward(self, batch, training=False):
28132789
28142790 if self .debug_output :
28152791 print_ln ("our output %s %s %s %s" , self .Y .address , len (self .Y [0 ].reveal ()), self .Y [0 ][0 ][0 :20 ].reveal (), sum (sum (self .Y [0 ].reveal ())))
2816- # print_ln("our output %s %s %s %s", self.Y.address, len(self.Y[0].reveal()), self.Y[0][0][0:20].reveal(), sum(sum(self.Y[0].reveal())))
2817- # print_ln("our output %s %s %s %s", self.Y.address, len(self.Y[0].reveal()), self.Y[0][0][0:20].reveal(), sum(sum(self.Y[0].reveal())))
2818-
28192792 print_ln ("our layer output %s %s %s %s" , self .output .Y .address , len (self .Y [0 ].reveal ()), self .output .Y [0 ][0 ][0 :20 ].reveal (), sum (sum (self .output .Y [0 ].reveal ())))
2820- # print_ln("shapes %s %s", self.Y.sizes, self.output.Y.sizes)
2821- # print_ln("types %s %s %s %s %s %s", self.Y.value_type, self.output.Y.value_type, type(self.Y), type(self.output.Y), self, self.output)
2822-
2823- print ("Forward BertLayer" )
28242793
28252794 def reset (self ):
28262795 self .multi_head_attention .reset ()
@@ -2864,10 +2833,8 @@ def backward(self, compute_nabla_X=True, batch=None):
28642833 self .output .nabla_Y .address = self .nabla_Y .address
28652834 self .intermediate .nabla_Y .address = self .output .nabla_X .address
28662835 self .multi_head_attention .nabla_Y .address = self .intermediate .nabla_X .address
2867- # self.multi_head_attention.nabla_X.address = self.nabla_X.address
28682836
28692837 nabla_y_multi_head_attention_from_layernorm = self .output .backward (True , batch )
2870- # print_ln("Backward BertLayer.output.nabla_X %s", self.output.nabla_X.reveal_nested()[:8])
28712838 self .intermediate .backward (True , batch )
28722839
28732840 # residual, add it to Y because it gave the output of multihadattention to output
@@ -2904,7 +2871,11 @@ def __init__(self, n_examples, hidden_size, intermediate_size, seq_len):
29042871 self .dense = Dense (n_examples , hidden_size , intermediate_size , seq_len )
29052872 self .activation = Gelu ([n_examples , seq_len , intermediate_size ])
29062873
2874+ self .dense .X .address = self .X .address
2875+ self .activation .X .address = self .dense .Y .address
2876+ self .activation .Y .address = self .Y .address
29072877
2878+ @_layer_method_call_tape
29082879 def forward (self , batch = None , training = None ):
29092880 self .dense .X .address = self .X .address
29102881 self .activation .X .address = self .dense .Y .address
@@ -2914,16 +2885,14 @@ def forward(self, batch=None, training=None):
29142885 if self .debug_output :
29152886 print_ln ("forward layer intermediate.dense %s" , self .dense .Y [0 ][0 ][0 :20 ].reveal ())
29162887
2917- self .activation ._forward (batch )
2888+ self .activation .forward (batch )
29182889
29192890 def reset (self ):
29202891 self .dense .reset ()
29212892
29222893 def backward (self , compute_nabla_X = True , batch = None ):
29232894 self .activation .nabla_X .alloc ()
29242895
2925- # print_ln("Backward BertIntermediate.nabla_X %s", self.nabla_X.reveal_nested()[:8])
2926-
29272896 self .activation .nabla_Y .address = self .nabla_Y .address
29282897 self .dense .nabla_Y .address = self .activation .nabla_X .address
29292898 self .dense .nabla_X .address = self .nabla_X .address
@@ -2941,13 +2910,13 @@ def __init__(self, n_examples, intermediate_size, hidden_size, seq_len, dropout=
29412910 input_shape = [n_examples , seq_len , intermediate_size ]
29422911 output_shape = [n_examples , seq_len , hidden_size ]
29432912 self .input_shape = input_shape
2944- print ("INSTANTIATING BERTOUTPUT with " , input_shape , output_shape , intermediate_size , hidden_size , rsqrt_approx )
29452913 super (BertOutput , self ).__init__ (input_shape , output_shape )
29462914 self .dense = Dense (n_examples , intermediate_size , hidden_size , seq_len )
29472915 self .layer_norm = LayerNorm (output_shape , layernorm_eps = layernorm_eps , approx = rsqrt_approx )
29482916 self .dropout = Dropout ([n_examples , seq_len , hidden_size ], alpha = dropout )
29492917
29502918
2919+ @_layer_method_call_tape
29512920 def forward (self , batch , input_tensor , training = False , input_tensor_batch = None ):
29522921 # Because input_tensor might be the full training data shape
29532922 self .dense .X .address = self .X .address
@@ -2975,18 +2944,12 @@ def _(base, size):
29752944 self .layer_norm .X .assign_part_vector (
29762945 self .layer_norm .X .get_part_vector (base , size ) +
29772946 input_tensor .get_part_vector (base , size ), base )
2978- # if self.debug_output:
2979- # print_ln("input tensor %s", input_tensor.reveal())
2980-
2981- # self.layer_norm.X[:] += input_tensor[:] # TODO: is it maybe this addition since we take the last value? would be strange
29822947
29832948 if self .debug_output :
29842949 print_ln ("forward layer layer_norm_add %s" , self .layer_norm .X [0 ][0 ][0 :20 ].reveal ())
29852950 print_ln ("" )
29862951 self .layer_norm .forward (batch )
29872952
2988-
2989-
29902953 def reset (self ):
29912954 self .dense .reset ()
29922955
@@ -3049,6 +3012,7 @@ def __init__(self, n_examples, seq_len, hidden_size, num_attention_heads, dropou
30493012 self .nabla_attention_scores = MultiArray ([internal_shape , self .num_attention_heads , self .seq_len , self .seq_len ], sfix )
30503013 self .nabla_preattention_scores = MultiArray ([internal_shape , self .num_attention_heads , self .seq_len , self .seq_len ], sfix )
30513014
3015+ @_layer_method_call_tape
30523016 def forward (self , batch = None , hidden_state = None , training = None ):
30533017 N = len (batch )
30543018
@@ -3068,32 +3032,24 @@ def forward(self, batch=None, hidden_state=None, training=None):
30683032 inc_batch .assign (regint .inc (N ))
30693033
30703034 if self .debug_output :
3071- # print_ln('forward layer wq full %s', self.wq.X.reveal())
30723035 print_ln ('forward layer wv %s %s' , self .wv .Y [0 ][0 ][0 :10 ].reveal (), sum (self .wv .Y [0 ][0 ].reveal ()))
30733036 print_ln ('forward layer hidden_state %s' , hidden_state [0 ][1 ][0 :10 ].reveal ())
3074- # print_ln('forward layer wv full %s', self.wv.Y.reveal())
30753037
3076- # max_size = program.budget // self.attention_head_size
30773038 @for_range_opt_multithread (self .n_threads , [N , self .num_attention_heads ])
30783039 def _ (i , j ):
3079- # for j in range(self.num_attention_heads):
30803040 query_sub = sfix .Matrix (self .seq_len , self .attention_head_size ) # this is mem inefficient?
30813041 key_sub = sfix .Matrix (self .seq_len , self .attention_head_size )
3082- # print(self.wq.Y.shape, "wk Y shape", i, self.attention_head_size, j, self.wq.Y[i], self.wq.Y[i][:])
30833042
30843043 @for_range_opt (self .seq_len )
30853044 def _ (k ):
3086- # for k in range(self.seq_len):
30873045 query_sub [k ] = self .wq .Y [i ][k ].get_part_vector (j * self .attention_head_size , self .attention_head_size )
30883046 key_sub [k ] = self .wk .Y [i ][k ].get_part_vector (j * self .attention_head_size , self .attention_head_size )
30893047
3090- # print_ln("query_sub %s %s", i, j)
30913048 res = query_sub .direct_mul_trans (key_sub )
30923049 self .attention_scores [i ].assign_part_vector (res , j )
30933050
30943051 if self .debug_output :
30953052 print_ln ('forward layer attention_scores %s' , self .attention_scores [0 ][0 ].reveal ())
3096- # print_ln('forward layer attention_scores full %s', self.attention_scores.reveal())
30973053
30983054 @for_range_opt_multithread (self .n_threads , [N , self .num_attention_heads , self .seq_len ])
30993055 def _ (i , j , k ):
@@ -3113,7 +3069,6 @@ def _(i, j):
31133069 @for_range_opt ([self .seq_len ])
31143070 def _ (k ):
31153071 value_sub [k ] = self .wv .Y [i ][k ].get_part_vector (j * self .attention_head_size , self .attention_head_size )
3116- # value_sub[k] = self.wv.Y[i][k][j * self.attention_head_size:(j + 1) * self.attention_head_size]
31173072
31183073 res = sfix .Matrix (self .seq_len , self .attention_head_size )
31193074 res .assign_vector (self .dropout .Y [i ][j ].direct_mul (value_sub ))
@@ -3124,13 +3079,7 @@ def _(k):
31243079 self .context [i ][k ].assign_part_vector (res [k ],
31253080 j * self .attention_head_size
31263081 )
3127- # for k in range(self.seq_len):
3128- # self.context[i][k][j * self.attention_head_size:(j + 1) * self.attention_head_size] = res[k * self.attention_head_size:(k + 1) * self.attention_head_size]
3129-
3130- # How to transfer to forward?
31313082
3132- # missing half of the values ?
3133- # print_ln('forward layer old_context %s', self.old_context[0].get_vector().reveal())
31343083 if self .debug_output :
31353084 print_ln ('forward layer multiheadattention before internal output %s' , self .context [0 ][0 ][0 :20 ].get_vector ().reveal ())
31363085
@@ -3142,8 +3091,6 @@ def _(k):
31423091 print_ln ('forward multiheadattention output %s' , self .output .Y [0 ][0 ][0 :20 ].reveal ())
31433092 print_ln ("" )
31443093
3145- # return context
3146-
31473094 def reset (self ):
31483095 self .wq .reset ()
31493096 self .wk .reset ()
@@ -3199,8 +3146,6 @@ def _(k):
31993146 nabla_value_sub [k ],
32003147 j * self .attention_head_size )
32013148
3202- print ("RES MULTI BACK" , self .dropout .Y , res , self .num_attention_heads , self .attention_head_size )
3203-
32043149 self .dropout .nabla_X .alloc ()
32053150 self .dropout .backward (True , batch )
32063151
@@ -4413,10 +4358,20 @@ def process(item, inputs, input_shape, args, kwargs={}):
44134358 raise CompilerError ('multi-input layer %s not supported' % item )
44144359 name = type (item ).__name__
44154360 if name == 'Linear' :
4416- assert mul ( input_shape [ 1 :]) == item . in_features
4361+ # Precondition: the item
44174362 assert item .bias is not None
4418- layers .append (Dense (input_shape [0 ], item .in_features ,
4419- item .out_features ))
4363+ if mul (input_shape [1 :]) == item .in_features :
4364+ layers .append (Dense (input_shape [0 ], item .in_features ,
4365+ item .out_features ))
4366+ elif input_shape [- 1 ] == item .in_features :
4367+ # we loop over all but last dimension
4368+ assert len (input_shape ) == 3 , "Dense only supports one extra dimension to loop over"
4369+ d = input_shape [1 ]
4370+ layers .append (Dense (input_shape [0 ], item .in_features ,
4371+ item .out_features , d ))
4372+ else :
4373+ assert False , f"input shape { input_shape } incompatible with in_features { item .in_features } "
4374+
44204375 if input_via is not None :
44214376 shapes = [x .shape for x in (layers [- 1 ].W , layers [- 1 ].b )]
44224377 import numpy
@@ -4487,6 +4442,15 @@ def process(item, inputs, input_shape, args, kwargs={}):
44874442 input_shape = layers [- 1 ].shape
44884443 elif name == 'ReLU' or item == torch .nn .functional .relu :
44894444 layers .append (Relu (input_shape ))
4445+ elif name == 'GeLU' or item == torch .nn .functional .gelu :
4446+ layers .append (Gelu (input_shape ))
4447+ elif name == 'LayerNorm' :
4448+ layers .append (LayerNorm (input_shape , True , item .eps ))
4449+ if input_via is not None :
4450+ layers [- 1 ].weights = sfix .input_tensor_via (
4451+ input_via , item .weight .detach ())
4452+ layers [- 1 ].beta = sfix .input_tensor_via (
4453+ input_via , item .bias .detach ())
44904454 elif name == 'Flatten' :
44914455 return
44924456 elif name == 'BatchNorm2d' or name == 'BatchNorm1d' :
@@ -4539,7 +4503,7 @@ def process(item, inputs, input_shape, args, kwargs={}):
45394503 num_attention_heads = config .num_attention_heads
45404504 layernorm_eps = config .layer_norm_eps
45414505 seq_len = input_shape [1 ]
4542- rsqrt_approx = False
4506+ rsqrt_approx = True
45434507 layer = BertLayer (input_shape [0 ], seq_len , hidden_state , intermediate_size , num_attention_heads ,
45444508 layernorm_eps , 0.125 , rsqrt_approx , batch_size = batch_size )
45454509 if input_via is not None :
0 commit comments