BAAI
/

BoyaWu10 commited on
Commit
5f26851
1 Parent(s): 540de87

Update modeling_bunny_phi3.py

Browse files
Files changed (1) hide show
  1. modeling_bunny_phi3.py +6 -0
modeling_bunny_phi3.py CHANGED
@@ -878,11 +878,17 @@ class BunnyMetaForCausalLM(ABC):
878
  if labels is None:
879
  labels = torch.full_like(input_ids, IGNORE_INDEX)
880
 
 
 
881
  # remove the padding using attention_mask -- TODO: double check
882
  input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in
883
  zip(input_ids, attention_mask)]
884
  labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
885
 
 
 
 
 
886
  new_input_embeds = []
887
  new_labels = []
888
  cur_image_idx = 0
 
878
  if labels is None:
879
  labels = torch.full_like(input_ids, IGNORE_INDEX)
880
 
881
+ input_ids_temp = input_ids # points to the actual input_ids tensor
882
+
883
  # remove the padding using attention_mask -- TODO: double check
884
  input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in
885
  zip(input_ids, attention_mask)]
886
  labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
887
 
888
+ # -- TODO: better implementation?
889
+ # replace IMAGE_TOKEN_INDEX(-200) with 0 to be compatible with repetition penalty
890
+ input_ids_temp[input_ids_temp == IMAGE_TOKEN_INDEX] = 0
891
+
892
  new_input_embeds = []
893
  new_labels = []
894
  cur_image_idx = 0