zhihan1996
commited on
Commit
•
25abaf0
1
Parent(s):
1d020b8
Update bert_layers.py
Browse files- bert_layers.py +1 -1
bert_layers.py
CHANGED
@@ -413,7 +413,7 @@ class BertEncoder(nn.Module):
|
|
413 |
|
414 |
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
415 |
extended_attention_mask = extended_attention_mask.to(
|
416 |
-
dtype=
|
417 |
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
418 |
|
419 |
attention_mask_bool = attention_mask.bool()
|
|
|
413 |
|
414 |
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
415 |
extended_attention_mask = extended_attention_mask.to(
|
416 |
+
dtype=torch.float32) # fp16 compatibility
|
417 |
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
418 |
|
419 |
attention_mask_bool = attention_mask.bool()
|