update return dict.
Browse files- modeling_mpt.py +12 -2
modeling_mpt.py
CHANGED
@@ -134,8 +134,8 @@ class MPTModel(MPTPreTrainedModel):
|
|
134 |
attention_mask = attention_mask.bool()
|
135 |
if prefix_mask is not None:
|
136 |
prefix_mask = prefix_mask.bool()
|
137 |
-
if not return_dict:
|
138 |
-
|
139 |
if output_attentions:
|
140 |
raise NotImplementedError('output_attentions is not implemented yet for MPT')
|
141 |
if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training:
|
@@ -184,6 +184,9 @@ class MPTModel(MPTPreTrainedModel):
|
|
184 |
if past_key_values is not None:
|
185 |
past_key_values[b_idx] = past_key_value
|
186 |
x = self.norm_f(x)
|
|
|
|
|
|
|
187 |
return BaseModelOutputWithPast(last_hidden_state=x, past_key_values=past_key_values, hidden_states=all_hidden_states)
|
188 |
|
189 |
def param_init_fn(self, module):
|
@@ -234,6 +237,9 @@ class MPTForCausalLM(MPTPreTrainedModel):
|
|
234 |
def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None):
|
235 |
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
236 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
|
|
|
|
|
237 |
outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
|
238 |
logits = F.linear(outputs.last_hidden_state, self.transformer.wte.weight)
|
239 |
if self.logit_scale is not None:
|
@@ -245,6 +251,10 @@ class MPTForCausalLM(MPTPreTrainedModel):
|
|
245 |
labels = torch.roll(labels, shifts=-1)
|
246 |
labels[:, -1] = -100
|
247 |
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1))
|
|
|
|
|
|
|
|
|
248 |
return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states)
|
249 |
|
250 |
def param_init_fn(self, module):
|
|
|
134 |
attention_mask = attention_mask.bool()
|
135 |
if prefix_mask is not None:
|
136 |
prefix_mask = prefix_mask.bool()
|
137 |
+
# if not return_dict:
|
138 |
+
# raise NotImplementedError('return_dict False is not implemented yet for MPT')
|
139 |
if output_attentions:
|
140 |
raise NotImplementedError('output_attentions is not implemented yet for MPT')
|
141 |
if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training:
|
|
|
184 |
if past_key_values is not None:
|
185 |
past_key_values[b_idx] = past_key_value
|
186 |
x = self.norm_f(x)
|
187 |
+
if not return_dict:
|
188 |
+
output = (x,) + (tuple(past_key_values),)
|
189 |
+
return output
|
190 |
return BaseModelOutputWithPast(last_hidden_state=x, past_key_values=past_key_values, hidden_states=all_hidden_states)
|
191 |
|
192 |
def param_init_fn(self, module):
|
|
|
237 |
def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None):
|
238 |
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
239 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
240 |
+
|
241 |
+
past_key_values = list(past_key_values) if past_key_values is not None else None
|
242 |
+
|
243 |
outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
|
244 |
logits = F.linear(outputs.last_hidden_state, self.transformer.wte.weight)
|
245 |
if self.logit_scale is not None:
|
|
|
251 |
labels = torch.roll(labels, shifts=-1)
|
252 |
labels[:, -1] = -100
|
253 |
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1))
|
254 |
+
|
255 |
+
if not return_dict:
|
256 |
+
output = (logits,) + (tuple(outputs[1]),)
|
257 |
+
return (loss,) + output if loss is not None else output
|
258 |
return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states)
|
259 |
|
260 |
def param_init_fn(self, module):
|