Update modeling_bunny_phi3.py
Browse files- modeling_bunny_phi3.py +6 -0
modeling_bunny_phi3.py
CHANGED
@@ -878,11 +878,17 @@ class BunnyMetaForCausalLM(ABC):
|
|
878 |
if labels is None:
|
879 |
labels = torch.full_like(input_ids, IGNORE_INDEX)
|
880 |
|
|
|
|
|
881 |
# remove the padding using attention_mask -- TODO: double check
|
882 |
input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in
|
883 |
zip(input_ids, attention_mask)]
|
884 |
labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
|
885 |
|
|
|
|
|
|
|
|
|
886 |
new_input_embeds = []
|
887 |
new_labels = []
|
888 |
cur_image_idx = 0
|
|
|
878 |
if labels is None:
|
879 |
labels = torch.full_like(input_ids, IGNORE_INDEX)
|
880 |
|
881 |
+
input_ids_temp = input_ids # points to the actual input_ids tensor
|
882 |
+
|
883 |
# remove the padding using attention_mask -- TODO: double check
|
884 |
input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in
|
885 |
zip(input_ids, attention_mask)]
|
886 |
labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
|
887 |
|
888 |
+
# -- TODO: better implementation?
|
889 |
+
# replace IMAGE_TOKEN_INDEX(-200) with 0 to be compatible with repetition penalty
|
890 |
+
input_ids_temp[input_ids_temp == IMAGE_TOKEN_INDEX] = 0
|
891 |
+
|
892 |
new_input_embeds = []
|
893 |
new_labels = []
|
894 |
cur_image_idx = 0
|