goldenteethCN
commited on
Commit
•
34baeaa
1
Parent(s):
1d240ba
fix: truncate finished output in stream_generate
Browse files【修复】修复在有多个回答长度不一致的inputs时,模型没有对finished_sequences进行截断的问题
- modeling_chatglm.py +1 -0
modeling_chatglm.py
CHANGED
@@ -1404,6 +1404,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1404 |
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
1405 |
else:
|
1406 |
next_tokens = torch.argmax(probs, dim=-1)
|
|
|
1407 |
|
1408 |
# update generated ids, model inputs, and length for next step
|
1409 |
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
|
|
1404 |
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
1405 |
else:
|
1406 |
next_tokens = torch.argmax(probs, dim=-1)
|
1407 |
+
next_tokens = torch.where(unfinished_sequences.bool(), next_tokens, eos_token_id[0])
|
1408 |
|
1409 |
# update generated ids, model inputs, and length for next step
|
1410 |
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|