Fix batch beam search
Browse files- modeling_glm.py +10 -0
modeling_glm.py
CHANGED
@@ -873,6 +873,16 @@ class GLMForConditionalGeneration(GLMPreTrainedModel):
|
|
873 |
position_ids = position_ids[:, :, :seq_length]
|
874 |
if attention_mask is not None:
|
875 |
attention_mask = attention_mask[:, :, :seq_length, :seq_length]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
876 |
return {
|
877 |
"input_ids": input_ids,
|
878 |
"position_ids": position_ids,
|
|
|
873 |
position_ids = position_ids[:, :, :seq_length]
|
874 |
if attention_mask is not None:
|
875 |
attention_mask = attention_mask[:, :, :seq_length, :seq_length]
|
876 |
+
if position_ids is not None and input_ids.size(0) > position_ids.size(0):
|
877 |
+
batch_size = position_ids.size(0)
|
878 |
+
num_beams = input_ids.size(0) // batch_size
|
879 |
+
position_ids = position_ids.unsqueeze(1).expand(-1, num_beams, -1, -1)
|
880 |
+
position_ids = position_ids.reshape(batch_size * num_beams, *position_ids.shape[-2:])
|
881 |
+
if attention_mask is not None and input_ids.size(0) > attention_mask.size(0):
|
882 |
+
batch_size = attention_mask.size(0)
|
883 |
+
num_beams = input_ids.size(0) // batch_size
|
884 |
+
attention_mask = attention_mask.unsqueeze(1).expand(-1, num_beams, -1, -1, -1)
|
885 |
+
attention_mask = attention_mask.reshape(batch_size * num_beams, *attention_mask.shape[-3:])
|
886 |
return {
|
887 |
"input_ids": input_ids,
|
888 |
"position_ids": position_ids,
|