zR
commited on
Commit
•
30c2bc6
1
Parent(s):
72c9149
new
Browse files- modeling_glm.py +10 -18
modeling_glm.py
CHANGED
@@ -763,32 +763,24 @@ class GlmModel(GlmPreTrainedModel):
|
|
763 |
assert input_ids is not None and inputs_embeds is None, f"{input_ids} {inputs_embeds}"
|
764 |
inputs_embeds = self.embed_tokens(input_ids)
|
765 |
new_input_embeds = []
|
766 |
-
|
767 |
-
|
768 |
-
|
769 |
-
|
770 |
image_count = 0
|
771 |
for i in range(len(input_ids)):
|
772 |
input_id = input_ids[i].tolist()
|
773 |
-
if
|
774 |
boi_token_pos = input_id.index(self.config.boi_token_id)
|
775 |
assert boi_token_pos >= 0, "begin_of_image not found!"
|
776 |
num_image_padding_tokens = input_id.count(self.config.boi_token_id)
|
777 |
-
assert
|
778 |
-
|
779 |
-
|
780 |
-
|
781 |
-
torch.cat(
|
782 |
-
(
|
783 |
-
inputs_embeds[i, :boi_token_pos],
|
784 |
-
images_features[image_count].to(inputs_embeds.device),
|
785 |
-
inputs_embeds[i, boi_token_pos + num_image_padding_tokens :],
|
786 |
-
)
|
787 |
-
)
|
788 |
-
)
|
789 |
image_count += 1
|
790 |
else:
|
791 |
-
new_input_embeds.append(inputs_embeds[i])
|
792 |
inputs_embeds = torch.stack(new_input_embeds, dim=0)
|
793 |
|
794 |
if self.gradient_checkpointing and self.training and use_cache:
|
|
|
763 |
assert input_ids is not None and inputs_embeds is None, f"{input_ids} {inputs_embeds}"
|
764 |
inputs_embeds = self.embed_tokens(input_ids)
|
765 |
new_input_embeds = []
|
766 |
+
boi_token_flags = [True if self.config.boi_token_id in input_id.tolist() else False for input_id in input_ids]
|
767 |
+
if is_empty(images):
|
768 |
+
images = torch.zeros([1, 3, 672, 672]).to(input_ids.device)
|
769 |
+
images_features = self.vision(images).to(device=inputs_embeds.device, dtype=inputs_embeds.dtype)
|
770 |
image_count = 0
|
771 |
for i in range(len(input_ids)):
|
772 |
input_id = input_ids[i].tolist()
|
773 |
+
if boi_token_flags[i]:
|
774 |
boi_token_pos = input_id.index(self.config.boi_token_id)
|
775 |
assert boi_token_pos >= 0, "begin_of_image not found!"
|
776 |
num_image_padding_tokens = input_id.count(self.config.boi_token_id)
|
777 |
+
assert num_image_padding_tokens == images_features[image_count].shape[0], f"Wrong image padding token number: {num_image_padding_tokens}"
|
778 |
+
new_input_embeds.append(torch.cat(
|
779 |
+
(inputs_embeds[i, :boi_token_pos], images_features[image_count],
|
780 |
+
inputs_embeds[i, boi_token_pos + num_image_padding_tokens:])))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
781 |
image_count += 1
|
782 |
else:
|
783 |
+
new_input_embeds.append(inputs_embeds[i] + (0 * images_features[0].sum()))
|
784 |
inputs_embeds = torch.stack(new_input_embeds, dim=0)
|
785 |
|
786 |
if self.gradient_checkpointing and self.training and use_cache:
|