ltg
/

davda54 commited on
Commit
f8e6164
1 Parent(s): ca3c03c

Update modeling_deberta.py

Browse files
Files changed (1) hide show
  1. modeling_deberta.py +2 -3
modeling_deberta.py CHANGED
@@ -455,9 +455,6 @@ class DebertaV2Encoder(nn.Module):
455
  if attention_mask.dim() <= 2:
456
  extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
457
  attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1)
458
- attention_mask = attention_mask.triu(diagonal=-510).tril(diagonal=510)
459
- attention_mask[:, :, :, 0] = 1
460
- attention_mask[:, :, :, -1] = 1
461
  elif attention_mask.dim() == 3:
462
  attention_mask = attention_mask.unsqueeze(1)
463
 
@@ -868,6 +865,8 @@ class DebertaV2Embeddings(nn.Module):
868
  ], dim=1)
869
  else:
870
  position_ids = self.position_ids[:, :seq_length]
 
 
871
 
872
  if token_type_ids is None:
873
  token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
 
455
  if attention_mask.dim() <= 2:
456
  extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
457
  attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1)
 
 
 
458
  elif attention_mask.dim() == 3:
459
  attention_mask = attention_mask.unsqueeze(1)
460
 
 
865
  ], dim=1)
866
  else:
867
  position_ids = self.position_ids[:, :seq_length]
868
+ elif position_ids.size(1) > self.position_ids.size(1):
869
+ position_ids = (position_ids + self.position_ids.size(1) - position_ids.size(1)).clamp(min=0)
870
 
871
  if token_type_ids is None:
872
  token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)