zhihan1996 commited on
Commit
25abaf0
1 Parent(s): 1d020b8

Update bert_layers.py

Browse files
Files changed (1) hide show
  1. bert_layers.py +1 -1
bert_layers.py CHANGED
@@ -413,7 +413,7 @@ class BertEncoder(nn.Module):
413
 
414
  extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
415
  extended_attention_mask = extended_attention_mask.to(
416
- dtype=next(self.parameters()).dtype) # fp16 compatibility
417
  extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
418
 
419
  attention_mask_bool = attention_mask.bool()
 
413
 
414
  extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
415
  extended_attention_mask = extended_attention_mask.to(
416
+ dtype=torch.float32) # fp16 compatibility
417
  extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
418
 
419
  attention_mask_bool = attention_mask.bool()