zR
commited on
Commit
•
1127073
1
Parent(s):
d907213
finetune
Browse files- modeling_chatglm.py +4 -2
modeling_chatglm.py
CHANGED
@@ -884,6 +884,9 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
884 |
|
885 |
batch_size, seq_length = input_ids.shape
|
886 |
|
|
|
|
|
|
|
887 |
if self.pre_seq_len is not None:
|
888 |
if past_key_values is None:
|
889 |
past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device,
|
@@ -912,9 +915,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
912 |
|
913 |
attention_mask = torch.stack(new_attention_mask, dim=0)
|
914 |
input_ids = torch.stack(new_input_ids, dim=0)
|
|
|
915 |
|
916 |
-
if inputs_embeds is None:
|
917 |
-
inputs_embeds = self.embedding(input_ids)
|
918 |
full_attention_mask = self.get_masks(inputs_embeds, past_key_values, padding_mask=attention_mask)
|
919 |
|
920 |
# Rotary positional embeddings
|
|
|
884 |
|
885 |
batch_size, seq_length = input_ids.shape
|
886 |
|
887 |
+
if inputs_embeds is None:
|
888 |
+
inputs_embeds = self.embedding(input_ids)
|
889 |
+
|
890 |
if self.pre_seq_len is not None:
|
891 |
if past_key_values is None:
|
892 |
past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device,
|
|
|
915 |
|
916 |
attention_mask = torch.stack(new_attention_mask, dim=0)
|
917 |
input_ids = torch.stack(new_input_ids, dim=0)
|
918 |
+
inputs_embeds = self.embedding(input_ids)
|
919 |
|
|
|
|
|
920 |
full_attention_mask = self.get_masks(inputs_embeds, past_key_values, padding_mask=attention_mask)
|
921 |
|
922 |
# Rotary positional embeddings
|