AzizBelaweid
commited on
Commit
•
d86f6e9
1
Parent(s):
6f39d18
Update modeling_pharia.py
Browse files- modeling_pharia.py +20 -1
modeling_pharia.py
CHANGED
@@ -764,9 +764,28 @@ class PhariaForCausalLM(PhariaPreTrainedModel):
|
|
764 |
|
765 |
hidden_states = outputs[0]
|
766 |
logits = self.lm_head(hidden_states)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
767 |
|
768 |
return CausalLMOutputWithPast(
|
769 |
-
loss=
|
770 |
logits=logits,
|
771 |
past_key_values=outputs.past_key_values,
|
772 |
hidden_states=outputs.hidden_states,
|
|
|
764 |
|
765 |
hidden_states = outputs[0]
|
766 |
logits = self.lm_head(hidden_states)
|
767 |
+
loss = 0.0
|
768 |
+
|
769 |
+
if self.training and labels is None:
|
770 |
+
raise ValueError(
|
771 |
+
"You have to specify the `labels` tensor when training the model."
|
772 |
+
)
|
773 |
+
|
774 |
+
if self.training and labels is not None:
|
775 |
+
# Shift logits and labels for causal language modeling
|
776 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
777 |
+
shift_labels = outputs['labels'][..., 1:].contiguous()
|
778 |
+
|
779 |
+
# Flatten the tokens
|
780 |
+
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
|
781 |
+
shift_labels = shift_labels.view(-1)
|
782 |
+
|
783 |
+
# Compute loss
|
784 |
+
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=1) # Pad token ID for Pharia is 1
|
785 |
+
loss = loss_fct(shift_logits, shift_labels)
|
786 |
|
787 |
return CausalLMOutputWithPast(
|
788 |
+
loss=loss,
|
789 |
logits=logits,
|
790 |
past_key_values=outputs.past_key_values,
|
791 |
hidden_states=outputs.hidden_states,
|