zR commited on
Commit
30c2bc6
1 Parent(s): 72c9149
Files changed (1) hide show
  1. modeling_glm.py +10 -18
modeling_glm.py CHANGED
@@ -763,32 +763,24 @@ class GlmModel(GlmPreTrainedModel):
763
  assert input_ids is not None and inputs_embeds is None, f"{input_ids} {inputs_embeds}"
764
  inputs_embeds = self.embed_tokens(input_ids)
765
  new_input_embeds = []
766
- multi_flags = [True if self.config.boi_token_id in input_id.tolist() else False for input_id in input_ids]
767
- images_features = None
768
- if not is_empty(images):
769
- images_features = self.vision(images).to(inputs_embeds.dtype)
770
  image_count = 0
771
  for i in range(len(input_ids)):
772
  input_id = input_ids[i].tolist()
773
- if multi_flags[i]:
774
  boi_token_pos = input_id.index(self.config.boi_token_id)
775
  assert boi_token_pos >= 0, "begin_of_image not found!"
776
  num_image_padding_tokens = input_id.count(self.config.boi_token_id)
777
- assert (
778
- num_image_padding_tokens == images_features[image_count].shape[0]
779
- ), f"Wrong image padding token number: {num_image_padding_tokens}"
780
- new_input_embeds.append(
781
- torch.cat(
782
- (
783
- inputs_embeds[i, :boi_token_pos],
784
- images_features[image_count].to(inputs_embeds.device),
785
- inputs_embeds[i, boi_token_pos + num_image_padding_tokens :],
786
- )
787
- )
788
- )
789
  image_count += 1
790
  else:
791
- new_input_embeds.append(inputs_embeds[i])
792
  inputs_embeds = torch.stack(new_input_embeds, dim=0)
793
 
794
  if self.gradient_checkpointing and self.training and use_cache:
 
763
  assert input_ids is not None and inputs_embeds is None, f"{input_ids} {inputs_embeds}"
764
  inputs_embeds = self.embed_tokens(input_ids)
765
  new_input_embeds = []
766
+ boi_token_flags = [True if self.config.boi_token_id in input_id.tolist() else False for input_id in input_ids]
767
+ if is_empty(images):
768
+ images = torch.zeros([1, 3, 672, 672]).to(input_ids.device)
769
+ images_features = self.vision(images).to(device=inputs_embeds.device, dtype=inputs_embeds.dtype)
770
  image_count = 0
771
  for i in range(len(input_ids)):
772
  input_id = input_ids[i].tolist()
773
+ if boi_token_flags[i]:
774
  boi_token_pos = input_id.index(self.config.boi_token_id)
775
  assert boi_token_pos >= 0, "begin_of_image not found!"
776
  num_image_padding_tokens = input_id.count(self.config.boi_token_id)
777
+ assert num_image_padding_tokens == images_features[image_count].shape[0], f"Wrong image padding token number: {num_image_padding_tokens}"
778
+ new_input_embeds.append(torch.cat(
779
+ (inputs_embeds[i, :boi_token_pos], images_features[image_count],
780
+ inputs_embeds[i, boi_token_pos + num_image_padding_tokens:])))
 
 
 
 
 
 
 
 
781
  image_count += 1
782
  else:
783
+ new_input_embeds.append(inputs_embeds[i] + (0 * images_features[0].sum()))
784
  inputs_embeds = torch.stack(new_input_embeds, dim=0)
785
 
786
  if self.gradient_checkpointing and self.training and use_cache: