Update bert_layers.py
Browse files- bert_layers.py +2 -2
bert_layers.py
CHANGED
@@ -199,9 +199,9 @@ class BertUnpadSelfAttention(nn.Module):
|
|
199 |
|
200 |
# print(f'PROBLEM HERE: UNDERSTAND IT!!')
|
201 |
rearranged_attention = rearrange(attention, 'nnz h d -> nnz (h d)')
|
202 |
-
try:
|
203 |
# print(f'REARRANGED ATTENTION: {rearranged_attention.shape}')
|
204 |
-
except:
|
205 |
# print(f'REARRANGED ATTENTION: {rearranged_attention[0].shape}')
|
206 |
return rearrange(attention, 'nnz h d -> nnz (h d)'), attention_probs
|
207 |
|
|
|
199 |
|
200 |
# print(f'PROBLEM HERE: UNDERSTAND IT!!')
|
201 |
rearranged_attention = rearrange(attention, 'nnz h d -> nnz (h d)')
|
202 |
+
# try:
|
203 |
# print(f'REARRANGED ATTENTION: {rearranged_attention.shape}')
|
204 |
+
# except:
|
205 |
# print(f'REARRANGED ATTENTION: {rearranged_attention[0].shape}')
|
206 |
return rearrange(attention, 'nnz h d -> nnz (h d)'), attention_probs
|
207 |
|