jupyterjazz commited on
Commit
3eceb33
1 Parent(s): e860caa

Update modeling_xlm_roberta.py

Browse files
Files changed (1) hide show
  1. modeling_xlm_roberta.py +4 -6
modeling_xlm_roberta.py CHANGED
@@ -210,12 +210,10 @@ class XLMRobertaEncoder(nn.Module):
210
  subset_mask: (batch, seqlen), dtype=torch.bool
211
  """
212
  if key_padding_mask is None or not self.use_flash_attn:
213
- mixer_kwargs = (
214
- {"key_padding_mask": key_padding_mask.bool()}
215
- if key_padding_mask is not None
216
- else None
217
- )
218
- mixer_kwargs['task_type'] = task_type
219
  for layer in self.layers:
220
  if self._grad_checkpointing:
221
  hidden_states = torch.utils.checkpoint.checkpoint(
 
210
  subset_mask: (batch, seqlen), dtype=torch.bool
211
  """
212
  if key_padding_mask is None or not self.use_flash_attn:
213
+ mixer_kwargs = {'task_type': task_type}
214
+ if key_padding_mask is not None:
215
+ mixer_kwargs['key_padding_mask'] = key_padding_mask.bool()
216
+
 
 
217
  for layer in self.layers:
218
  if self._grad_checkpointing:
219
  hidden_states = torch.utils.checkpoint.checkpoint(