Update modeling_deberta.py
Browse files- modeling_deberta.py +2 -3
modeling_deberta.py
CHANGED
@@ -455,9 +455,6 @@ class DebertaV2Encoder(nn.Module):
|
|
455 |
if attention_mask.dim() <= 2:
|
456 |
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
457 |
attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1)
|
458 |
-
attention_mask = attention_mask.triu(diagonal=-510).tril(diagonal=510)
|
459 |
-
attention_mask[:, :, :, 0] = 1
|
460 |
-
attention_mask[:, :, :, -1] = 1
|
461 |
elif attention_mask.dim() == 3:
|
462 |
attention_mask = attention_mask.unsqueeze(1)
|
463 |
|
@@ -868,6 +865,8 @@ class DebertaV2Embeddings(nn.Module):
|
|
868 |
], dim=1)
|
869 |
else:
|
870 |
position_ids = self.position_ids[:, :seq_length]
|
|
|
|
|
871 |
|
872 |
if token_type_ids is None:
|
873 |
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
|
|
455 |
if attention_mask.dim() <= 2:
|
456 |
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
457 |
attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1)
|
|
|
|
|
|
|
458 |
elif attention_mask.dim() == 3:
|
459 |
attention_mask = attention_mask.unsqueeze(1)
|
460 |
|
|
|
865 |
], dim=1)
|
866 |
else:
|
867 |
position_ids = self.position_ids[:, :seq_length]
|
868 |
+
elif position_ids.size(1) > self.position_ids.size(1):
|
869 |
+
position_ids = (position_ids + self.position_ids.size(1) - position_ids.size(1)).clamp(min=0)
|
870 |
|
871 |
if token_type_ids is None:
|
872 |
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|