TCMVince commited on
Commit
20587cc
1 Parent(s): 93645f7

Update flaubert2_model.py

Browse files
Files changed (1) hide show
  1. flaubert2_model.py +4 -0
flaubert2_model.py CHANGED
@@ -390,6 +390,10 @@ class Flaubert2Model(RobertaModel):
390
 
391
  pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
392
 
 
 
 
 
393
  if not return_dict:
394
  return (sequence_output, pooled_output) + encoder_outputs[1:]
395
 
 
390
 
391
  pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
392
 
393
+ # Fairseq Linformer implementation works with transposed hidden states -> we transpose them back for HF implementation.
394
+ if output_hidden_states:
395
+ encoder_outputs.hidden_states = [h.transpose(0,1) for h in encoder_outputs.hidden_states]
396
+
397
  if not return_dict:
398
  return (sequence_output, pooled_output) + encoder_outputs[1:]
399