HuanjinYao commited on
Commit
ce3e773
1 Parent(s): 622623f

Upload 10 files

Browse files
config.json ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "./ckpt/vicuna-7b-v1.5",
3
+ "architectures": [
4
+ "MGMLlamaForCausalLM"
5
+ ],
6
+ "attention_bias": false,
7
+ "attention_dropout": 0.0,
8
+ "bos_token_id": 1,
9
+ "eos_token_id": 2,
10
+ "freeze_mm_mlp_adapter": false,
11
+ "hidden_act": "silu",
12
+ "hidden_size": 4096,
13
+ "image_aspect_ratio": "pad",
14
+ "image_global": false,
15
+ "image_grid": 1,
16
+ "image_grid_pinpoints": null,
17
+ "image_size_aux": 768,
18
+ "initializer_range": 0.02,
19
+ "intermediate_size": 11008,
20
+ "max_position_embeddings": 4096,
21
+ "mm_hidden_size": 3072,
22
+ "mm_hidden_size_aux": 2880,
23
+ "mm_hidden_size_uni": 1024,
24
+ "mm_projector_lr": null,
25
+ "mm_projector_type": "mlp2x_gelu",
26
+ "mm_use_im_patch_token": false,
27
+ "mm_use_im_start_end": false,
28
+ "mm_vision_select_feature": "patch",
29
+ "mm_vision_select_layer": -2,
30
+ "mm_vision_tower": "./ckpt/clip-vit-large-patch14-336",
31
+ "mm_vision_tower_aux": "./cache_dir/openclip-convnext-large-d-320-laion2B-s29B-b131K-ft-soup",
32
+ "model_type": "mgm",
33
+ "num_attention_heads": 32,
34
+ "num_hidden_layers": 32,
35
+ "num_key_value_heads": 32,
36
+ "optimize_vision_tower": false,
37
+ "optimize_vision_tower_aux": false,
38
+ "pad_token_id": 0,
39
+ "pretraining_tp": 1,
40
+ "rms_norm_eps": 1e-05,
41
+ "rope_scaling": null,
42
+ "rope_theta": 10000.0,
43
+ "tie_word_embeddings": false,
44
+ "tokenizer_model_max_length": 2048,
45
+ "tokenizer_padding_side": "right",
46
+ "torch_dtype": "float16",
47
+ "transformers_version": "4.36.2",
48
+ "tune_mm_mlp_adapter": false,
49
+ "use_cache": true,
50
+ "use_mm_proj": true,
51
+ "vocab_size": 32000
52
+ }
mgm_arch.py ADDED
@@ -0,0 +1,517 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ------------------------------------------------------------------------
15
+ # Modified from LLaVA (https://github.com/haotian-liu/LLaVA)
16
+ # Copyright 2024 Yanwei Li
17
+ # ------------------------------------------------------------------------
18
+
19
+ from abc import ABC, abstractmethod
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ import json
25
+ import os
26
+ import transformers
27
+ import safetensors
28
+ from transformers.deepspeed import is_deepspeed_zero3_enabled
29
+ import deepspeed
30
+
31
+ from .multimodal_encoder.builder import build_vision_tower, build_vision_tower_aux
32
+ from .multimodal_projector.builder import build_vision_projector
33
+
34
+ from mgm.constants import (IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN,
35
+ DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN)
36
+
37
+ IS_NEW_TRANSFORMERS = transformers.__version__ >= "4.34.0"
38
+
39
+ class MGMMetaModel:
40
+
41
+ def __init__(self, config):
42
+ super(MGMMetaModel, self).__init__(config)
43
+
44
+ if hasattr(config, "mm_vision_tower"):
45
+ self.vision_tower = build_vision_tower(config, delay_load=True)
46
+ self.mm_projector = build_vision_projector(config)
47
+
48
+ if hasattr(config, "mm_vision_tower_aux"):
49
+ self.vision_tower_aux = build_vision_tower_aux(config, delay_load=True)
50
+
51
+ def get_vision_tower(self):
52
+ vision_tower = getattr(self, 'vision_tower', None)
53
+ if type(vision_tower) is list:
54
+ vision_tower = vision_tower[0]
55
+ return vision_tower
56
+
57
+ def get_vision_tower_aux(self):
58
+ vision_tower_aux = getattr(self, 'vision_tower_aux', None)
59
+ if type(vision_tower_aux) is list:
60
+ vision_tower_aux = vision_tower_aux[0]
61
+ return vision_tower_aux
62
+
63
+ def initialize_vision_modules(self, model_args, fsdp=None):
64
+ vision_tower = model_args.vision_tower
65
+ vision_tower_aux = model_args.vision_tower_aux
66
+ mm_vision_select_layer = model_args.mm_vision_select_layer
67
+ mm_vision_select_feature = model_args.mm_vision_select_feature
68
+ pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
69
+
70
+ self.config.mm_vision_tower = vision_tower
71
+ self.config.mm_vision_tower_aux = vision_tower_aux
72
+
73
+ if self.get_vision_tower() is None:
74
+ vision_tower = build_vision_tower(model_args)
75
+
76
+ if fsdp is not None and len(fsdp) > 0:
77
+ self.vision_tower = [vision_tower]
78
+ else:
79
+ self.vision_tower = vision_tower
80
+ else:
81
+ if fsdp is not None and len(fsdp) > 0:
82
+ vision_tower = self.vision_tower[0]
83
+ else:
84
+ vision_tower = self.vision_tower
85
+ vision_tower.load_model()
86
+
87
+ if vision_tower_aux is not None:
88
+ if self.get_vision_tower_aux() is None:
89
+ vision_tower_aux = build_vision_tower_aux(model_args)
90
+
91
+ if fsdp is not None and len(fsdp) > 0:
92
+ self.vision_tower_aux = [vision_tower_aux]
93
+ else:
94
+ self.vision_tower_aux = vision_tower_aux
95
+ else:
96
+ if fsdp is not None and len(fsdp) > 0:
97
+ vision_tower_aux = self.vision_tower_aux[0]
98
+ else:
99
+ vision_tower_aux = self.vision_tower_aux
100
+ vision_tower_aux.load_model()
101
+ self.config.mm_hidden_size_aux = vision_tower_aux.hidden_size
102
+
103
+ self.config.use_mm_proj = True
104
+ self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
105
+ # self.config.mm_hidden_size = vision_tower.hidden_size
106
+ self.config.mm_hidden_size = 3072
107
+ self.config.mm_hidden_size_uni = vision_tower.hidden_size
108
+ self.config.mm_vision_select_layer = mm_vision_select_layer
109
+ self.config.mm_vision_select_feature = mm_vision_select_feature
110
+
111
+ if getattr(self, 'mm_projector', None) is None:
112
+ self.mm_projector = build_vision_projector(self.config)
113
+ else:
114
+ # In case it is frozen by LoRA
115
+ for p in self.mm_projector.parameters():
116
+ p.requires_grad = True
117
+
118
+ if pretrain_mm_mlp_adapter is not None:
119
+ mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
120
+ def get_w(weights, keyword):
121
+ return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword + '.' in k}
122
+
123
+ if 'model' in mm_projector_weights.keys():
124
+ mm_projector_weights = mm_projector_weights['model']
125
+ if is_deepspeed_zero3_enabled():
126
+ if len(mm_projector_weights) > 0:
127
+ with deepspeed.zero.GatheredParameters(mm_projector_weights, modifier_rank=0):
128
+ if torch.distributed.get_rank() == 0:
129
+ self.mm_projector.load_state_dict(mm_projector_weights)
130
+ else:
131
+ status = self.mm_projector.load_state_dict(mm_projector_weights, strict=False)
132
+ print('missing_keys:', status.missing_keys)
133
+ else:
134
+ if is_deepspeed_zero3_enabled():
135
+ named_parameters = get_w(mm_projector_weights, 'mm_projector')
136
+ if len(named_parameters) > 0:
137
+ with deepspeed.zero.GatheredParameters(named_parameters, modifier_rank=0):
138
+ if torch.distributed.get_rank() == 0:
139
+ self.mm_projector.load_state_dict(named_parameters)
140
+ else:
141
+ status = self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'), strict=False)
142
+ print('missing_keys:', status.missing_keys)
143
+ self.mm_projector = self.mm_projector.to(device='cuda')
144
+
145
+ def initialize_uni_modules(self, model_args, for_eval=False):
146
+ pretrain_mm_mlp_adapter = getattr(model_args, "pretrain_mm_mlp_adapter", None)
147
+ self.config.image_size_aux = getattr(model_args, 'image_size_aux', 320)
148
+ self.config.optimize_vision_tower = getattr(model_args, 'optimize_vision_tower', False)
149
+ self.config.optimize_vision_tower_aux = getattr(model_args, 'optimize_vision_tower_aux', False)
150
+
151
+ self.vlm_uni_query_projector = nn.Sequential(nn.LayerNorm(self.config.mm_hidden_size_uni),
152
+ nn.Linear(self.config.mm_hidden_size_uni, self.config.mm_hidden_size_uni))
153
+ self.vlm_uni_aux_projector = nn.Sequential(nn.LayerNorm(self.config.mm_hidden_size_aux),
154
+ nn.Linear(self.config.mm_hidden_size_aux, self.config.mm_hidden_size_uni))
155
+ self.vlm_uni_val_projector = nn.Sequential(nn.LayerNorm(self.config.mm_hidden_size_aux),
156
+ nn.Linear(self.config.mm_hidden_size_aux, self.config.mm_hidden_size_uni))
157
+
158
+ if pretrain_mm_mlp_adapter is not None:
159
+ projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
160
+ else:
161
+ trainable_module = ['vlm_uni', 'vision_fpn', 'vision_stages']
162
+ if hasattr(model_args, 'model_name_or_path'):
163
+ model_save_path = model_args.model_name_or_path
164
+ else:
165
+ model_save_path = model_args.model_path
166
+ model_idx_path = getattr(model_args, 'model_path', model_save_path)
167
+ if IS_NEW_TRANSFORMERS:
168
+ try:
169
+ weight_file = json.load(open(os.path.join(model_idx_path, 'model.safetensors.index.json'), 'r'))['weight_map']
170
+ except:
171
+ weight_file = json.load(open(os.path.join(model_idx_path, 'pytorch_model.bin.index.json'), 'r'))['weight_map']
172
+ else:
173
+ weight_file = json.load(open(os.path.join(model_idx_path, 'pytorch_model.bin.index.json'), 'r'))['weight_map']
174
+ model_path = set([weight_file[_key] for _key in weight_file if any([_module in _key for _module in trainable_module])])
175
+ projector_weights = {}
176
+ for _model in model_path:
177
+ if not IS_NEW_TRANSFORMERS:
178
+ projector_weights.update(torch.load(os.path.join(model_idx_path, _model), map_location='cpu'))
179
+ else:
180
+ with safetensors.safe_open(os.path.join(model_idx_path, _model), framework="pt", device='cpu') as f:
181
+ for _key in f.keys():
182
+ projector_weights.update({_key: f.get_tensor(_key)})
183
+ if len(projector_weights) == 0:
184
+ return
185
+
186
+ def get_w(weights, keyword, main_module, sub_module):
187
+ if getattr(main_module, sub_module, None) is None:
188
+ return
189
+
190
+ pretrain_weight = {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword + '.' in k}
191
+ if len(pretrain_weight) == 0:
192
+ return
193
+ if is_deepspeed_zero3_enabled():
194
+ named_parameters = [v for k, v in getattr(main_module, sub_module).named_parameters()]
195
+ if len(named_parameters) > 0:
196
+ # because zero3 puts placeholders in model params, this context
197
+ # manager gathers (unpartitions) the params of the current layer, then loads from
198
+ # the state dict and then re-partitions them again
199
+ with deepspeed.zero.GatheredParameters(named_parameters, modifier_rank=0):
200
+ if torch.distributed.get_rank() == 0:
201
+ getattr(main_module, sub_module).load_state_dict(pretrain_weight)
202
+ with deepspeed.zero.GatheredParameters(self.mm_projector[0].weight, modifier_rank=None):
203
+ weight_type = self.mm_projector[0].weight.dtype
204
+ device_type = self.mm_projector[0].weight.device
205
+ else:
206
+ weight_type = self.mm_projector[0].weight.dtype
207
+ device_type = self.mm_projector[0].weight.device
208
+ getattr(main_module, sub_module).load_state_dict(pretrain_weight)
209
+ if weight_type == torch.uint8 or weight_type == torch.int8 or weight_type == torch.int16:
210
+ weight_type = torch.float16
211
+ getattr(main_module, sub_module).to(device=device_type, dtype=weight_type)
212
+ print(f"Loading {sub_module} weights...")
213
+
214
+ # load pretrained weights
215
+ get_w(projector_weights, 'vision_tower.vision_tower', self.vision_tower, 'vision_tower')
216
+
217
+ # load pretrained weights
218
+ if self.config.optimize_vision_tower_aux:
219
+ # not optimize vision stem, just used to check
220
+ get_w(projector_weights, 'vision_tower_aux.vision_stem', self.vision_tower_aux, 'vision_stem')
221
+ get_w(projector_weights, 'vision_tower_aux.vision_stages', self.vision_tower_aux, 'vision_stages')
222
+ get_w(projector_weights, 'vlm_uni_query_projector', self, 'vlm_uni_query_projector')
223
+ get_w(projector_weights, 'vlm_uni_aux_projector', self, 'vlm_uni_aux_projector')
224
+ get_w(projector_weights, 'vlm_uni_val_projector', self, 'vlm_uni_val_projector')
225
+
226
+ class MGMMetaForCausalLM(ABC):
227
+
228
+ @abstractmethod
229
+ def get_model(self):
230
+ pass
231
+
232
+ def get_vision_tower(self):
233
+ return self.get_model().get_vision_tower()
234
+
235
+ def get_vision_tower_aux(self):
236
+ return self.get_model().get_vision_tower_aux()
237
+
238
+ def encode_images(self, images, images_aux=None, is_video=False):
239
+ image_grid = getattr(self.config, 'image_grid', 1)
240
+ image_global = getattr(self.config, 'image_global', False)
241
+ if image_grid > 1:
242
+ batch_size = images.shape[0]
243
+ if image_global:
244
+ global_images = images[:, -1:].flatten(0,1).contiguous()
245
+ grid_images = images[:, :-1].flatten(0,1).contiguous()
246
+ images = torch.cat([grid_images, global_images], dim=0)
247
+ else:
248
+ images = images.flatten(0,1).contiguous()
249
+
250
+ image_features, image_forward_outs = self.get_model().get_vision_tower()(images)
251
+
252
+ if image_global: # false
253
+ image_feat_global = image_features[-len(global_images):]
254
+ image_features = image_features[:len(grid_images)]
255
+
256
+ if images_aux is not None:
257
+ image_aux_features_raw = self.get_model().get_vision_tower_aux()(images_aux).to(
258
+ dtype=image_features.dtype, device=image_features.device)
259
+
260
+ if image_global:
261
+ image_aux_features_global = F.interpolate(image_aux_features_raw.float(),
262
+ scale_factor=1/image_grid,
263
+ mode='bilinear',
264
+ align_corners=False).to(dtype=image_aux_features_raw.dtype)
265
+ image_feat_global, image_aux_feat_global = self.unified_resampler(image_feat_global, image_aux_features_global)
266
+
267
+ if image_grid > 1:
268
+ image_aux_features_raw = image_aux_features_raw.reshape(*image_aux_features_raw.shape[:2],
269
+ image_grid,
270
+ image_aux_features_raw.shape[-2]//image_grid,
271
+ image_grid,
272
+ image_aux_features_raw.shape[-1]//image_grid)
273
+ image_aux_features_raw = image_aux_features_raw.permute(0, 2, 4, 1, 3, 5).flatten(1,2).flatten(0,1).contiguous()
274
+ image_features, image_aux_features = self.unified_resampler(image_features, image_aux_features_raw)
275
+
276
+ if image_grid > 1:
277
+ image_features = image_features.reshape(batch_size, image_grid**2, *image_features.shape[1:])
278
+ image_features = image_features.flatten(1,2).contiguous()
279
+ image_aux_features = image_aux_features.reshape(batch_size, image_grid**2, *image_aux_features.shape[1:])
280
+ image_aux_features = image_aux_features.flatten(1,2).contiguous()
281
+
282
+ # add global features, [global, local]
283
+ if image_global:
284
+ image_features = torch.cat([image_feat_global, image_features], dim=1)
285
+ image_aux_features = torch.cat([image_aux_feat_global, image_aux_features], dim=1)
286
+
287
+ # token generation
288
+ image_features = image_features + image_aux_features
289
+
290
+ # dense connector
291
+ image_features_1 = []
292
+ image_features_2 = []
293
+ for i in range(0, 12):
294
+ image_features_1.append(image_forward_outs.hidden_states[i][:, 1:].to(image_features.dtype))
295
+ image_features_1 = torch.stack(image_features_1, dim=0)
296
+ image_features_1 = torch.sum(image_features_1, dim=0) / 12
297
+ for i in range(12, 24):
298
+ image_features_2.append(image_forward_outs.hidden_states[i][:, 1:].to(image_features.dtype))
299
+ image_features_2 = torch.stack(image_features_2, dim=0)
300
+ image_features_2 = torch.sum(image_features_2, dim=0) / 12
301
+
302
+ image_features = torch.cat([image_features, image_features_1, image_features_2], dim=-1)
303
+ ## dense connector end
304
+
305
+ # process image features after token generation
306
+ image_features = self.get_model().mm_projector(image_features)
307
+
308
+ return image_features
309
+
310
+ def unified_resampler(self, images, images_aux):
311
+ # patchwise with square images
312
+ patch_num = int(images.shape[1]**0.5)
313
+ patch_size = images_aux.shape[-1]//patch_num
314
+ # within patch attention
315
+ images_aux = images_aux.permute(0,2,3,1)
316
+ images_aux = images_aux.reshape(len(images_aux), patch_num, patch_size, patch_num, patch_size, images_aux.shape[-1])
317
+ images_aux = images_aux.permute(0,1,3,2,4,5)
318
+ images_aux = images_aux.reshape(len(images_aux), patch_num**2, patch_size**2, images_aux.shape[-1]).contiguous()
319
+
320
+ # token attention
321
+ embed_query = self.get_model().vlm_uni_query_projector(images)
322
+ embed_aux = self.get_model().vlm_uni_aux_projector(images_aux)
323
+ embed_value = self.get_model().vlm_uni_val_projector(images_aux)
324
+ embed_att = embed_query[:,:,None] @ (embed_aux.transpose(-1,-2) / (embed_aux.shape[-1]**0.5))
325
+ embed_att = embed_att.nan_to_num()
326
+ embed_feat = (embed_att.softmax(-1) @ embed_value).mean(2)
327
+
328
+ return images, embed_feat
329
+
330
+ def prepare_inputs_labels_for_multimodal(
331
+ self, input_ids, position_ids, attention_mask, past_key_values, labels, images=None, images_aux=None,
332
+ ):
333
+ vision_tower = self.get_vision_tower()
334
+ if vision_tower is None or images is None or input_ids.shape[1] == 1:
335
+ if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1:
336
+ target_shape = past_key_values[-1][-1].shape[-2] + 1
337
+ attention_mask = torch.cat((attention_mask, torch.ones(
338
+ (attention_mask.shape[0], target_shape - attention_mask.shape[1]),
339
+ dtype=attention_mask.dtype,
340
+ device=attention_mask.device
341
+ )), dim=1)
342
+ position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
343
+ return input_ids, position_ids, attention_mask, past_key_values, None, labels
344
+
345
+ if isinstance(images, list):
346
+ images = torch.stack(images, dim=0)
347
+ if isinstance(images_aux, list):
348
+ images_aux = torch.stack(images_aux, dim=0)
349
+
350
+ image_features = self.encode_images(images, images_aux)
351
+
352
+ # TODO: image start / end is not implemented here to support pretraining.
353
+ if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
354
+ raise NotImplementedError
355
+
356
+ # Let's just add dummy tensors if they do not exist,
357
+ # it is a headache to deal with None all the time.
358
+ # But it is not ideal, and if you have a better idea,
359
+ # please open an issue / submit a PR, thanks.
360
+ _labels = labels
361
+ _position_ids = position_ids
362
+ _attention_mask = attention_mask
363
+ if attention_mask is None:
364
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
365
+ else:
366
+ attention_mask = attention_mask.bool()
367
+ if position_ids is None:
368
+ position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
369
+ if labels is None:
370
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
371
+
372
+ # remove the padding using attention_mask -- TODO: double check
373
+ input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
374
+ labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
375
+
376
+ new_input_embeds = []
377
+ new_labels = []
378
+ cur_image_idx = 0
379
+ for batch_idx, cur_input_ids in enumerate(input_ids):
380
+ num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
381
+ if num_images == 0:
382
+ cur_image_features = image_features[cur_image_idx]
383
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
384
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
385
+ new_input_embeds.append(cur_input_embeds)
386
+ new_labels.append(labels[batch_idx])
387
+ cur_image_idx += 1
388
+ continue
389
+
390
+ image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
391
+ cur_input_ids_noim = []
392
+ cur_labels = labels[batch_idx]
393
+ cur_labels_noim = []
394
+ for i in range(len(image_token_indices) - 1):
395
+ cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
396
+ cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
397
+ split_sizes = [x.shape[0] for x in cur_labels_noim]
398
+ cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
399
+ cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
400
+ cur_new_input_embeds = []
401
+ cur_new_labels = []
402
+
403
+ max_pos_id = 0
404
+ for i in range(num_images + 1):
405
+ cur_new_input_embeds.append(cur_input_embeds_no_im[i])
406
+ cur_new_labels.append(cur_labels_noim[i])
407
+ max_pos_id += cur_input_embeds_no_im[i].shape[0]
408
+ if i < num_images:
409
+ cur_image_features = image_features[cur_image_idx]
410
+ cur_image_idx += 1
411
+ cur_new_input_embeds.append(cur_image_features)
412
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
413
+ max_pos_id += cur_image_features.shape[0]
414
+
415
+ cur_new_input_embeds = [x.to(device=cur_input_embeds.device) for x in cur_new_input_embeds]
416
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds)
417
+ cur_new_labels = torch.cat(cur_new_labels)
418
+
419
+ new_input_embeds.append(cur_new_input_embeds)
420
+ new_labels.append(cur_new_labels)
421
+
422
+ # Truncate sequences to max length as image embeddings can make the sequence longer
423
+ tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
424
+ if tokenizer_model_max_length is not None:
425
+ new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
426
+ new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
427
+
428
+ # Combine them
429
+ max_len = max(x.shape[0] for x in new_input_embeds)
430
+ batch_size = len(new_input_embeds)
431
+
432
+ new_input_embeds_padded = []
433
+ new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
434
+ attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
435
+ position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
436
+
437
+ for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
438
+ cur_len = cur_new_embed.shape[0]
439
+ if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
440
+ new_input_embeds_padded.append(torch.cat((
441
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
442
+ cur_new_embed
443
+ ), dim=0))
444
+ if cur_len > 0:
445
+ new_labels_padded[i, -cur_len:] = cur_new_labels
446
+ attention_mask[i, -cur_len:] = True
447
+ position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
448
+ else:
449
+ new_input_embeds_padded.append(torch.cat((
450
+ cur_new_embed,
451
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
452
+ ), dim=0))
453
+ if cur_len > 0:
454
+ new_labels_padded[i, :cur_len] = cur_new_labels
455
+ attention_mask[i, :cur_len] = True
456
+ position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
457
+
458
+ new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
459
+
460
+ if _labels is None:
461
+ new_labels = None
462
+ else:
463
+ new_labels = new_labels_padded
464
+
465
+ if _attention_mask is None:
466
+ attention_mask = None
467
+ else:
468
+ attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
469
+
470
+ if _position_ids is None:
471
+ position_ids = None
472
+
473
+ return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
474
+
475
+ def initialize_vision_tokenizer(self, model_args, tokenizer):
476
+ if model_args.mm_use_im_patch_token:
477
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
478
+ self.resize_token_embeddings(len(tokenizer))
479
+
480
+ if model_args.mm_use_im_start_end:
481
+ num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
482
+ self.resize_token_embeddings(len(tokenizer))
483
+
484
+ if num_new_tokens > 0:
485
+ input_embeddings = self.get_input_embeddings().weight.data
486
+ output_embeddings = self.get_output_embeddings().weight.data
487
+
488
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
489
+ dim=0, keepdim=True)
490
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
491
+ dim=0, keepdim=True)
492
+
493
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
494
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
495
+
496
+ if model_args.tune_mm_mlp_adapter:
497
+ for p in self.get_input_embeddings().parameters():
498
+ p.requires_grad = True
499
+ for p in self.get_output_embeddings().parameters():
500
+ p.requires_grad = False
501
+
502
+ if model_args.pretrain_mm_mlp_adapter:
503
+ mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
504
+ embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
505
+ assert num_new_tokens == 2
506
+ if input_embeddings.shape == embed_tokens_weight.shape:
507
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
508
+ elif embed_tokens_weight.shape[0] == num_new_tokens:
509
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight
510
+ else:
511
+ raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
512
+ elif model_args.mm_use_im_patch_token:
513
+ if model_args.tune_mm_mlp_adapter:
514
+ for p in self.get_input_embeddings().parameters():
515
+ p.requires_grad = False
516
+ for p in self.get_output_embeddings().parameters():
517
+ p.requires_grad = False
model-00001-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7c446d6deb83996f07846ea18e9b2035a7f54a163e2c606566b225cad343ff0a
3
+ size 4938985248
model-00002-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8e26d8ca4f66025d3852d0863f4d7ef8f81f53c364de780101f2b0a5a33b48e1
3
+ size 4947390768
model-00003-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c43ac375c790ae5bab23ba5b4093f22d326f57e86aae3b1608f89e95b8f9f5a3
3
+ size 4662729856
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "<unk>",
17
+ "unk_token": {
18
+ "content": "<unk>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
3
+ size 499723
tokenizer_config.json ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "added_tokens_decoder": {
5
+ "0": {
6
+ "content": "<unk>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "1": {
14
+ "content": "<s>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "2": {
22
+ "content": "</s>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ }
29
+ },
30
+ "bos_token": "<s>",
31
+ "clean_up_tokenization_spaces": false,
32
+ "eos_token": "</s>",
33
+ "legacy": false,
34
+ "model_max_length": 2048,
35
+ "pad_token": "<unk>",
36
+ "padding_side": "right",
37
+ "sp_model_kwargs": {},
38
+ "spaces_between_special_tokens": false,
39
+ "tokenizer_class": "LlamaTokenizer",
40
+ "unk_token": "<unk>",
41
+ "use_default_system_prompt": false
42
+ }
trainer_state.json ADDED
The diff for this file is too large to render. See raw diff