Crystalcareai
commited on
Commit
•
b3900b9
1
Parent(s):
accf604
Update modeling_quiet.py
Browse files- modeling_quiet.py +4 -1
modeling_quiet.py
CHANGED
@@ -1233,7 +1233,10 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1233 |
# )
|
1234 |
# if labels is not None:
|
1235 |
# loss += self.router_aux_loss_coef * aux_loss.to(loss.device)
|
1236 |
-
|
|
|
|
|
|
|
1237 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1238 |
output_hidden_states = (
|
1239 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
|
1233 |
# )
|
1234 |
# if labels is not None:
|
1235 |
# loss += self.router_aux_loss_coef * aux_loss.to(loss.device)
|
1236 |
+
if input_ids.dim() == 1:
|
1237 |
+
input_ids = input_ids.unsqueeze(0)
|
1238 |
+
attention_mask = attention_mask.unsqueeze(0) if attention_mask is not None else None
|
1239 |
+
|
1240 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1241 |
output_hidden_states = (
|
1242 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|