Update modeling_mpt.py

#55
by daking - opened
Files changed (1) hide show
  1. modeling_mpt.py +1 -1
modeling_mpt.py CHANGED
@@ -181,7 +181,7 @@ class MPTModel(MPTPreTrainedModel):
181
  x_shrunk = x * self.embedding_fraction + x.detach() * (1 - self.embedding_fraction)
182
  assert isinstance(self.emb_drop, nn.Module)
183
  x = self.emb_drop(x_shrunk)
184
- (attn_bias, attention_mask) = self._attn_bias(device=x.device, dtype=x.dtype, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id)
185
  if use_cache and past_key_values is None:
186
  past_key_values = [() for _ in range(self.config.n_layers)]
187
  all_hidden_states = () if output_hidden_states else None
 
181
  x_shrunk = x * self.embedding_fraction + x.detach() * (1 - self.embedding_fraction)
182
  assert isinstance(self.emb_drop, nn.Module)
183
  x = self.emb_drop(x_shrunk)
184
+ (attn_bias, attention_mask) = self._attn_bias(device=x.device, dtype=torch.float32, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id)
185
  if use_cache and past_key_values is None:
186
  past_key_values = [() for _ in range(self.config.n_layers)]
187
  all_hidden_states = () if output_hidden_states else None