jupyterjazz commited on
Commit
bda5e8d
1 Parent(s): a709b51

fix mixed precision loading with recent transformers versions (#39)

Browse files

- fix mixed precision loading with recent transformers versions (553cf70617db7379e93ec5a92ba8d75ebb3cda66)

Files changed (1) hide show
  1. modeling_xlm_roberta.py +1 -0
modeling_xlm_roberta.py CHANGED
@@ -404,6 +404,7 @@ class XLMRobertaPreTrainedModel(PreTrainedModel):
404
  config_class = XLMRobertaFlashConfig
405
  base_model_prefix = "roberta"
406
  supports_gradient_checkpointing = True
 
407
 
408
  def _set_gradient_checkpointing(self, module, value=False):
409
  if isinstance(module, XLMRobertaEncoder):
 
404
  config_class = XLMRobertaFlashConfig
405
  base_model_prefix = "roberta"
406
  supports_gradient_checkpointing = True
407
+ _supports_param_buffer_assignment = False
408
 
409
  def _set_gradient_checkpointing(self, module, value=False):
410
  if isinstance(module, XLMRobertaEncoder):