Skip to content

Commit 2ee81da

Browse files
authored
Merge pull request #1768 from pps-lab/bert_layer_improvements_pr
Reduced compilation time of BERT and other improvements
2 parents bf7f8f4 + 7b72a46 commit 2ee81da

3 files changed

Lines changed: 154 additions & 183 deletions

File tree

Compiler/ml.py

Lines changed: 41 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -866,8 +866,8 @@ def __init__(self, N, d_in, d_out, d=1, activation='id', debug=False):
866866
self.f_input = self.Y
867867

868868
def __repr__(self):
869-
return '%s(%s, %s, %s, activation=%s)' % \
870-
(type(self).__name__, self.N, self.d_in,
869+
return '%s(%s, %s, %s, %s, activation=%s)' % \
870+
(type(self).__name__, self.N, self.d_in, self.d,
871871
self.d_out, repr(self.activation))
872872

873873
def reset(self):
@@ -1755,7 +1755,7 @@ class LayerNorm(Layer): # Changed class name
17551755
thetas = lambda self: (self.weights, self.bias)
17561756
nablas = lambda self: (self.nabla_weights, self.nabla_bias)
17571757

1758-
def __init__(self, shape, approx=False, layernorm_eps=None, args=None):
1758+
def __init__(self, shape, approx=True, layernorm_eps=None, args=None):
17591759
if len(shape) == 2:
17601760
shape = [shape[0], 1, shape[1]] # Not sure why this extra dimension is added
17611761
tensors = (Tensor(shape, sfix) for i in range(4))
@@ -1810,6 +1810,7 @@ def _(*arg):
18101810
tmp = self.weights[:] * (sel_X[:] - mu_sel) * fac_sel # Removed self.mu reference
18111811
sel_Y[:] = self.bias[:] + tmp
18121812

1813+
@_layer_method_call_tape
18131814
def forward(self, batch, training=False):
18141815
d = self.X.sizes[1]
18151816
d_in = self.X.sizes[2]
@@ -2670,7 +2671,7 @@ class BertBase(BaseLayer, FixBase):
26702671
class BertPooler(BertBase):
26712672

26722673
thetas = lambda self: self.dense.thetas()
2673-
nablas = lambda self: self.dense.nablas() # refer to downstream layers?
2674+
nablas = lambda self: self.dense.nablas()
26742675

26752676
def __init__(self, n_examples, seq_len, hidden_state):
26762677
input_shape = [n_examples, seq_len, hidden_state]
@@ -2679,28 +2680,19 @@ def __init__(self, n_examples, seq_len, hidden_state):
26792680
self.dense = Dense(n_examples, hidden_state, hidden_state)
26802681
self.activation = Tanh(output_shape)
26812682

2682-
self.d_out = hidden_state
2683-
2684-
2685-
def _forward(self, batch):
2686-
# self.dense.X.address = self.X.address
26872683
self.activation.X.address = self.dense.Y.address
26882684
self.activation.Y.address = self.Y.address
26892685

2690-
# grab the first repr?
2686+
self.d_out = hidden_state
2687+
2688+
def _forward(self, batch):
26912689
# batch contains [n_batch, n_heads, n_dim]
26922690
@for_range(len(batch))
26932691
def _(j):
26942692
self.dense.X[j][:] = self.X[batch[j]][0][:]
26952693

2696-
# if self.debug_output:
2697-
# print_ln("forward layer pooler.dense X %s", self.dense.X.reveal_nested())
2698-
26992694
self.dense.forward(batch)
2700-
# print_ln("LINEAR Layer weights after bertpooler.dense: %s", self.opt.layers[-2].W.reveal_nested())
2701-
2702-
self.activation._forward(batch)
2703-
# print_ln("LINEAR Layer weights after bertpooler.activation: %s", self.opt.layers[-2].W.reveal_nested())
2695+
self.activation.forward(batch)
27042696

27052697
def reset(self):
27062698
self.dense.reset()
@@ -2767,44 +2759,28 @@ def __init__(self, n_examples, seq_len, hidden_state, intermediate_size, num_att
27672759
self.intermediate = BertIntermediate(internal_shape, hidden_state, intermediate_size, seq_len)
27682760
self.output = BertOutput(internal_shape, intermediate_size, hidden_state, seq_len, dropout, layernorm_eps, rsqrt_approx)
27692761

2770-
self.hidden_state = sfix.Tensor(input_shape) # TODO: Could also make this smaller
2771-
# self.nabla_hidden_state = sfix.Tensor(input_shape)
2772-
# self.nabla_hidden_state.alloc()
2773-
2774-
# self.X.address = self.multi_head_attention.X.address
2775-
# self.Y.address = self.output.Y.address
2776-
27772762
self.d_out = hidden_state
27782763

2779-
print("Init BertLayer", input_shape, output_shape)
2780-
2764+
@_layer_method_call_tape
27812765
def forward(self, batch, training=False):
27822766
if batch is None:
27832767
batch = Array.create_from(regint(0))
27842768

27852769
self.multi_head_attention._X.address = self.X.address
27862770
self.output.Y.address = self.Y.address
2787-
self.hidden_state.address = self.X.address
2788-
# self.multi_head_attention.Y.address = self.Y.address
2789-
2790-
self.multi_head_attention.forward(batch, self.hidden_state, training)
2791-
# if self.debug_output:
2792-
# print_ln("our layer X %s %s", self.X[0][0][0].reveal(), self.output.X[0][0][0].reveal())
27932771

2772+
self.multi_head_attention.forward(batch, self.X, training)
27942773
if self.debug_output:
27952774
print_ln("forward layer multi_head_attention %s %s", self.multi_head_attention.Y[0][1][0].reveal(), sum(sum(self.multi_head_attention.Y[0].reveal())))
27962775
# print_ln("forward layer multi_head_attention full %s", self.multi_head_attention.Y.reveal())
27972776

2798-
print("Forward Attention")
2799-
28002777
batch_inc = regint.Array(len(batch))
28012778
batch_inc.assign(regint.inc(len(batch)))
28022779
self.intermediate.X.address = self.multi_head_attention.Y.address
28032780
self.intermediate.forward(batch_inc)
28042781

28052782
if self.debug_output:
28062783
print_ln("forward layer intermediate %s %s %s", self.intermediate.Y.shape, self.intermediate.Y[0][1][0:20].reveal(), sum(sum(self.intermediate.Y[0].reveal())))
2807-
28082784
print_ln(" ")
28092785

28102786
self.output.X.address = self.intermediate.Y.address
@@ -2813,14 +2789,7 @@ def forward(self, batch, training=False):
28132789

28142790
if self.debug_output:
28152791
print_ln("our output %s %s %s %s", self.Y.address, len(self.Y[0].reveal()), self.Y[0][0][0:20].reveal(), sum(sum(self.Y[0].reveal())))
2816-
# print_ln("our output %s %s %s %s", self.Y.address, len(self.Y[0].reveal()), self.Y[0][0][0:20].reveal(), sum(sum(self.Y[0].reveal())))
2817-
# print_ln("our output %s %s %s %s", self.Y.address, len(self.Y[0].reveal()), self.Y[0][0][0:20].reveal(), sum(sum(self.Y[0].reveal())))
2818-
28192792
print_ln("our layer output %s %s %s %s", self.output.Y.address, len(self.Y[0].reveal()), self.output.Y[0][0][0:20].reveal(), sum(sum(self.output.Y[0].reveal())))
2820-
# print_ln("shapes %s %s", self.Y.sizes, self.output.Y.sizes)
2821-
# print_ln("types %s %s %s %s %s %s", self.Y.value_type, self.output.Y.value_type, type(self.Y), type(self.output.Y), self, self.output)
2822-
2823-
print("Forward BertLayer")
28242793

28252794
def reset(self):
28262795
self.multi_head_attention.reset()
@@ -2864,10 +2833,8 @@ def backward(self, compute_nabla_X=True, batch=None):
28642833
self.output.nabla_Y.address = self.nabla_Y.address
28652834
self.intermediate.nabla_Y.address = self.output.nabla_X.address
28662835
self.multi_head_attention.nabla_Y.address = self.intermediate.nabla_X.address
2867-
# self.multi_head_attention.nabla_X.address = self.nabla_X.address
28682836

28692837
nabla_y_multi_head_attention_from_layernorm = self.output.backward(True, batch)
2870-
# print_ln("Backward BertLayer.output.nabla_X %s", self.output.nabla_X.reveal_nested()[:8])
28712838
self.intermediate.backward(True, batch)
28722839

28732840
# residual, add it to Y because it gave the output of multihadattention to output
@@ -2904,7 +2871,11 @@ def __init__(self, n_examples, hidden_size, intermediate_size, seq_len):
29042871
self.dense = Dense(n_examples, hidden_size, intermediate_size, seq_len)
29052872
self.activation = Gelu([n_examples, seq_len, intermediate_size])
29062873

2874+
self.dense.X.address = self.X.address
2875+
self.activation.X.address = self.dense.Y.address
2876+
self.activation.Y.address = self.Y.address
29072877

2878+
@_layer_method_call_tape
29082879
def forward(self, batch=None, training=None):
29092880
self.dense.X.address = self.X.address
29102881
self.activation.X.address = self.dense.Y.address
@@ -2914,16 +2885,14 @@ def forward(self, batch=None, training=None):
29142885
if self.debug_output:
29152886
print_ln("forward layer intermediate.dense %s", self.dense.Y[0][0][0:20].reveal())
29162887

2917-
self.activation._forward(batch)
2888+
self.activation.forward(batch)
29182889

29192890
def reset(self):
29202891
self.dense.reset()
29212892

29222893
def backward(self, compute_nabla_X=True, batch=None):
29232894
self.activation.nabla_X.alloc()
29242895

2925-
# print_ln("Backward BertIntermediate.nabla_X %s", self.nabla_X.reveal_nested()[:8])
2926-
29272896
self.activation.nabla_Y.address = self.nabla_Y.address
29282897
self.dense.nabla_Y.address = self.activation.nabla_X.address
29292898
self.dense.nabla_X.address = self.nabla_X.address
@@ -2941,13 +2910,13 @@ def __init__(self, n_examples, intermediate_size, hidden_size, seq_len, dropout=
29412910
input_shape = [n_examples, seq_len, intermediate_size]
29422911
output_shape = [n_examples, seq_len, hidden_size]
29432912
self.input_shape = input_shape
2944-
print("INSTANTIATING BERTOUTPUT with ", input_shape, output_shape, intermediate_size, hidden_size, rsqrt_approx)
29452913
super(BertOutput, self).__init__(input_shape, output_shape)
29462914
self.dense = Dense(n_examples, intermediate_size, hidden_size, seq_len)
29472915
self.layer_norm = LayerNorm(output_shape, layernorm_eps=layernorm_eps, approx=rsqrt_approx)
29482916
self.dropout = Dropout([n_examples, seq_len, hidden_size], alpha=dropout)
29492917

29502918

2919+
@_layer_method_call_tape
29512920
def forward(self, batch, input_tensor, training=False, input_tensor_batch=None):
29522921
# Because input_tensor might be the full training data shape
29532922
self.dense.X.address = self.X.address
@@ -2975,18 +2944,12 @@ def _(base, size):
29752944
self.layer_norm.X.assign_part_vector(
29762945
self.layer_norm.X.get_part_vector(base, size) +
29772946
input_tensor.get_part_vector(base, size), base)
2978-
# if self.debug_output:
2979-
# print_ln("input tensor %s", input_tensor.reveal())
2980-
2981-
# self.layer_norm.X[:] += input_tensor[:] # TODO: is it maybe this addition since we take the last value? would be strange
29822947

29832948
if self.debug_output:
29842949
print_ln("forward layer layer_norm_add %s", self.layer_norm.X[0][0][0:20].reveal())
29852950
print_ln("")
29862951
self.layer_norm.forward(batch)
29872952

2988-
2989-
29902953
def reset(self):
29912954
self.dense.reset()
29922955

@@ -3049,6 +3012,7 @@ def __init__(self, n_examples, seq_len, hidden_size, num_attention_heads, dropou
30493012
self.nabla_attention_scores = MultiArray([internal_shape, self.num_attention_heads, self.seq_len, self.seq_len], sfix)
30503013
self.nabla_preattention_scores = MultiArray([internal_shape, self.num_attention_heads, self.seq_len, self.seq_len], sfix)
30513014

3015+
@_layer_method_call_tape
30523016
def forward(self, batch=None, hidden_state=None, training=None):
30533017
N = len(batch)
30543018

@@ -3068,32 +3032,24 @@ def forward(self, batch=None, hidden_state=None, training=None):
30683032
inc_batch.assign(regint.inc(N))
30693033

30703034
if self.debug_output:
3071-
# print_ln('forward layer wq full %s', self.wq.X.reveal())
30723035
print_ln('forward layer wv %s %s', self.wv.Y[0][0][0:10].reveal(), sum(self.wv.Y[0][0].reveal()))
30733036
print_ln('forward layer hidden_state %s', hidden_state[0][1][0:10].reveal())
3074-
# print_ln('forward layer wv full %s', self.wv.Y.reveal())
30753037

3076-
# max_size = program.budget // self.attention_head_size
30773038
@for_range_opt_multithread(self.n_threads, [N, self.num_attention_heads])
30783039
def _(i, j):
3079-
# for j in range(self.num_attention_heads):
30803040
query_sub = sfix.Matrix(self.seq_len, self.attention_head_size) # this is mem inefficient?
30813041
key_sub = sfix.Matrix(self.seq_len, self.attention_head_size)
3082-
# print(self.wq.Y.shape, "wk Y shape", i, self.attention_head_size, j, self.wq.Y[i], self.wq.Y[i][:])
30833042

30843043
@for_range_opt(self.seq_len)
30853044
def _(k):
3086-
# for k in range(self.seq_len):
30873045
query_sub[k] = self.wq.Y[i][k].get_part_vector(j * self.attention_head_size, self.attention_head_size)
30883046
key_sub[k] = self.wk.Y[i][k].get_part_vector(j * self.attention_head_size, self.attention_head_size)
30893047

3090-
# print_ln("query_sub %s %s", i, j)
30913048
res = query_sub.direct_mul_trans(key_sub)
30923049
self.attention_scores[i].assign_part_vector(res, j)
30933050

30943051
if self.debug_output:
30953052
print_ln('forward layer attention_scores %s', self.attention_scores[0][0].reveal())
3096-
# print_ln('forward layer attention_scores full %s', self.attention_scores.reveal())
30973053

30983054
@for_range_opt_multithread(self.n_threads, [N, self.num_attention_heads, self.seq_len])
30993055
def _(i, j, k):
@@ -3113,7 +3069,6 @@ def _(i, j):
31133069
@for_range_opt([self.seq_len])
31143070
def _(k):
31153071
value_sub[k] = self.wv.Y[i][k].get_part_vector(j * self.attention_head_size, self.attention_head_size)
3116-
# value_sub[k] = self.wv.Y[i][k][j * self.attention_head_size:(j + 1) * self.attention_head_size]
31173072

31183073
res = sfix.Matrix(self.seq_len, self.attention_head_size)
31193074
res.assign_vector(self.dropout.Y[i][j].direct_mul(value_sub))
@@ -3124,13 +3079,7 @@ def _(k):
31243079
self.context[i][k].assign_part_vector(res[k],
31253080
j * self.attention_head_size
31263081
)
3127-
# for k in range(self.seq_len):
3128-
# self.context[i][k][j * self.attention_head_size:(j + 1) * self.attention_head_size] = res[k * self.attention_head_size:(k + 1) * self.attention_head_size]
3129-
3130-
# How to transfer to forward?
31313082

3132-
# missing half of the values ?
3133-
# print_ln('forward layer old_context %s', self.old_context[0].get_vector().reveal())
31343083
if self.debug_output:
31353084
print_ln('forward layer multiheadattention before internal output %s', self.context[0][0][0:20].get_vector().reveal())
31363085

@@ -3142,8 +3091,6 @@ def _(k):
31423091
print_ln('forward multiheadattention output %s', self.output.Y[0][0][0:20].reveal())
31433092
print_ln("")
31443093

3145-
# return context
3146-
31473094
def reset(self):
31483095
self.wq.reset()
31493096
self.wk.reset()
@@ -3199,8 +3146,6 @@ def _(k):
31993146
nabla_value_sub[k],
32003147
j * self.attention_head_size)
32013148

3202-
print("RES MULTI BACK", self.dropout.Y, res, self.num_attention_heads, self.attention_head_size)
3203-
32043149
self.dropout.nabla_X.alloc()
32053150
self.dropout.backward(True, batch)
32063151

@@ -4413,10 +4358,20 @@ def process(item, inputs, input_shape, args, kwargs={}):
44134358
raise CompilerError('multi-input layer %s not supported' % item)
44144359
name = type(item).__name__
44154360
if name == 'Linear':
4416-
assert mul(input_shape[1:]) == item.in_features
4361+
# Precondition: the item
44174362
assert item.bias is not None
4418-
layers.append(Dense(input_shape[0], item.in_features,
4419-
item.out_features))
4363+
if mul(input_shape[1:]) == item.in_features:
4364+
layers.append(Dense(input_shape[0], item.in_features,
4365+
item.out_features))
4366+
elif input_shape[-1] == item.in_features:
4367+
# we loop over all but last dimension
4368+
assert len(input_shape) == 3, "Dense only supports one extra dimension to loop over"
4369+
d = input_shape[1]
4370+
layers.append(Dense(input_shape[0], item.in_features,
4371+
item.out_features, d))
4372+
else:
4373+
assert False, f"input shape {input_shape} incompatible with in_features {item.in_features}"
4374+
44204375
if input_via is not None:
44214376
shapes = [x.shape for x in (layers[-1].W, layers[-1].b)]
44224377
import numpy
@@ -4487,6 +4442,15 @@ def process(item, inputs, input_shape, args, kwargs={}):
44874442
input_shape = layers[-1].shape
44884443
elif name == 'ReLU' or item == torch.nn.functional.relu:
44894444
layers.append(Relu(input_shape))
4445+
elif name == 'GeLU' or item == torch.nn.functional.gelu:
4446+
layers.append(Gelu(input_shape))
4447+
elif name == 'LayerNorm':
4448+
layers.append(LayerNorm(input_shape, True, item.eps))
4449+
if input_via is not None:
4450+
layers[-1].weights = sfix.input_tensor_via(
4451+
input_via, item.weight.detach())
4452+
layers[-1].beta = sfix.input_tensor_via(
4453+
input_via, item.bias.detach())
44904454
elif name == 'Flatten':
44914455
return
44924456
elif name == 'BatchNorm2d' or name == 'BatchNorm1d':
@@ -4539,7 +4503,7 @@ def process(item, inputs, input_shape, args, kwargs={}):
45394503
num_attention_heads = config.num_attention_heads
45404504
layernorm_eps = config.layer_norm_eps
45414505
seq_len = input_shape[1]
4542-
rsqrt_approx = False
4506+
rsqrt_approx = True
45434507
layer = BertLayer(input_shape[0], seq_len, hidden_state, intermediate_size, num_attention_heads,
45444508
layernorm_eps, 0.125, rsqrt_approx, batch_size=batch_size)
45454509
if input_via is not None:

Compiler/types.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7029,7 +7029,10 @@ def __add__(self, other):
70297029
:return: container of same shape and type as :py:obj:`self` """
70307030
if is_zero(other):
70317031
return self
7032-
assert self.sizes == other.sizes
7032+
if hasattr(other, 'sizes'):
7033+
assert self.sizes == other.sizes
7034+
if hasattr(other, 'size'):
7035+
assert self.total_size() == other.size
70337036
return self.from_vector(
70347037
self.sizes, self.get_vector() + other.get_vector())
70357038

@@ -7604,7 +7607,7 @@ def reveal_to_binary_output(self, player=None):
76047607
def __str__(self):
76057608
return '%s multi-array of lengths %s at %s' % (
76067609
self.value_type, self.sizes,
7607-
'<unallocated>' if self.array._address is None else self.address)
7610+
'<unallocated>' if self.address is None else self.address)
76087611
__repr__ = __str__
76097612

76107613
class MultiArray(SubMultiArray):

0 commit comments

Comments
 (0)