AnwenHu commited on
Commit
b733be1
1 Parent(s): 4e6b800

Delete mplug_docowl/model/convert_mplug_docowl_weight_to_hf.py

Browse files
mplug_docowl/model/convert_mplug_docowl_weight_to_hf.py DELETED
@@ -1,319 +0,0 @@
1
- # Copyright 2023 DAMO Academy and The HuggingFace Inc. team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- import argparse
15
- import gc
16
- import json
17
- import math
18
- import os
19
- import shutil
20
- import warnings
21
-
22
- import torch
23
-
24
- from transformers import LlamaTokenizer
25
- from .configuration_mplug_docowl import MPLUGDocOwlConfig
26
- from icecream import ic
27
-
28
- try:
29
- from transformers import LlamaTokenizerFast
30
- except ImportError as e:
31
- warnings.warn(e)
32
- warnings.warn(
33
- "The converted tokenizer will be the `slow` tokenizer. To use the fast, update your `tokenizers` library and re-run the tokenizer conversion"
34
- )
35
- LlamaTokenizerFast = None
36
-
37
- """
38
- Sample usage:
39
-
40
- ```
41
- python3 /pure-mlo-scratch/sfan/model-parallel-trainer/llama2megatron/convert_llama2hf.py \
42
- --input_dir /pure-mlo-scratch/llama/ --model_size 7 --output_dir /pure-mlo-scratch/llama/converted_HF_7B
43
- ```
44
-
45
- Thereafter, models can be loaded via:
46
-
47
- ```py
48
- from transformers import LlamaForCausalLM, LlamaTokenizer
49
-
50
- model = LlamaForCausalLM.from_pretrained("/output/path")
51
- tokenizer = LlamaTokenizer.from_pretrained("/output/path")
52
- ```
53
-
54
- Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions
55
- come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM).
56
- """
57
-
58
- llama_s2layer = {7: 32, 13: 40, 30: 60, 65: 80, 70: 80}
59
- llama_s2heads = {7: 32, 13: 40, 30: 52, 65: 64, 70: 64}
60
- llama_s2dense = {7: 11008, 13: 13824, 30: 17920, 65: 22016,
61
- 70: 28672} # should be (2/3)*4*d, but it isn't exaclty that
62
- llama_s2hidden = {7: 4096, 13: 5120, 32: 6656, 65: 8192, 70: 8192}
63
-
64
-
65
- def compute_intermediate_size(n):
66
- return int(math.ceil(n * 8 / 3) + 255) // 256 * 256
67
-
68
-
69
- def read_json(path):
70
- with open(path, "r") as f:
71
- return json.load(f)
72
-
73
-
74
- def write_json(text, path):
75
- with open(path, "w") as f:
76
- json.dump(text, f)
77
-
78
-
79
- def write_model(model_path,
80
- input_base_path,
81
- model_size,
82
- num_input_shards=1,
83
- num_output_shards=2,
84
- skip_permute=True,
85
- norm_eps=1e-05):
86
- # if os.path.exists(model_path):
87
- # shutil.rmtree(model_path)
88
- os.makedirs(model_path, exist_ok=True)
89
- # tmp_model_path = os.path.join(model_path, "tmp")
90
- tmp_model_path = model_path
91
- os.makedirs(tmp_model_path, exist_ok=True)
92
-
93
- num_shards = num_input_shards
94
- n_layers = llama_s2layer[model_size]
95
- n_heads = llama_s2heads[model_size]
96
- n_heads_per_shard = n_heads // num_shards
97
- n_dense = llama_s2dense[model_size]
98
- n_hidden = llama_s2hidden[model_size]
99
- hidden_per_head = n_hidden // n_heads
100
- base = 10000.0
101
- inv_freq = 1.0 / (base ** (torch.arange(0, hidden_per_head, 2).float() / hidden_per_head))
102
-
103
- # permute for sliced rotary
104
- def permute(w, skip_permute=skip_permute):
105
- if skip_permute:
106
- return w
107
- return w.view(n_heads, n_hidden // n_heads // 2, 2, n_hidden).transpose(1, 2).reshape(n_hidden, n_hidden)
108
-
109
- print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
110
- # Load weights
111
- if num_shards==1:
112
- # Not sharded
113
- # (The sharded implementation would also work, but this is simpler.)
114
- # /pure-mlo-scratch/alhernan/megatron-data/checkpoints/llama2-7b-tp4-pp1-optim/release/mp_rank_00/model_optim_rng.pt
115
- if os.path.exists(os.path.join(input_base_path, 'release')):
116
- filename = os.path.join(input_base_path, 'release', 'mp_rank_00', 'model_optim_rng.pt')
117
- elif input_base_path.split('/')[-1].startswith('iter_'):
118
- iteration = eval(input_base_path.split('/')[-1].replace('iter_', '').lstrip('0'))
119
- load_dir = '/'.join(input_base_path.split('/')[:-1])
120
- filename = os.path.join(input_base_path, 'mp_rank_00', 'model_optim_rng.pt')
121
- if not os.path.exists(filename):
122
- filename = filename.replace('model_optim_rng.pt', 'model_rng.pt')
123
- else:
124
- tracker_filename = os.path.join(input_base_path, 'latest_checkpointed_iteration.txt')
125
- with open(tracker_filename, 'r') as f:
126
- metastring = f.read().strip()
127
- iteration = 'iter_{:07d}'.format(int(metastring))
128
- filename = os.path.join(input_base_path, iteration, 'mp_rank_00', 'model_optim_rng.pt')
129
- if not os.path.exists(filename):
130
- filename = filename.replace('model_optim_rng.pt', 'model_rng.pt')
131
- original_filename = filename
132
- loaded = torch.load(filename, map_location="cpu")['model']['language_model']
133
-
134
- else:
135
- # Sharded
136
- filenames = []
137
- for i in range(num_shards):
138
- if os.path.exists(os.path.join(input_base_path, 'release')):
139
- filename = os.path.join(input_base_path, 'release', f'mp_rank_{i:02d}', 'model_optim_rng.pt')
140
- else:
141
- tracker_filename = os.path.join(input_base_path, 'latest_checkpointed_iteration.txt')
142
- with open(tracker_filename, 'r') as f:
143
- metastring = f.read().strip()
144
- iteration = 'iter_{:07d}'.format(int(metastring))
145
- filename = os.path.join(input_base_path, iteration, f'mp_rank_{i:02d}', 'model_optim_rng.pt')
146
- if not os.path.exists(filename):
147
- filename = filename.replace('model_optim_rng.pt', 'model_rng.pt')
148
- filenames.append(filename)
149
- loaded = [
150
- torch.load(filenames[i], map_location="cpu")['model']['language_model']
151
- for i in range(num_shards)
152
- ]
153
-
154
- print('Llama-Megatron Loaded!')
155
- param_count = 0
156
- index_dict = {"weight_map": {}}
157
-
158
- print(f'Weighted Converting for {n_layers} layers...')
159
- for layer_i in range(n_layers):
160
- print(layer_i)
161
- filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin"
162
- if num_shards == 1:
163
- # Unsharded
164
- state_dict = {
165
- f"model.layers.{layer_i}.self_attn.q_proj.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.q_proj.weight"],
166
- f"model.layers.{layer_i}.self_attn.k_proj.multiway.0.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.k_proj.multiway.0.weight"],
167
- f"model.layers.{layer_i}.self_attn.v_proj.multiway.0.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.v_proj.multiway.0.weight"],
168
- f"model.layers.{layer_i}.self_attn.k_proj.multiway.1.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.k_proj.multiway.1.weight"],
169
- f"model.layers.{layer_i}.self_attn.v_proj.multiway.1.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.v_proj.multiway.1.weight"],
170
- f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.o_proj.weight"],
171
- f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded['encoder'][f"layers.{layer_i}.mlp.gate_proj.weight"],
172
- f"model.layers.{layer_i}.mlp.down_proj.weight": loaded['encoder'][f"layers.{layer_i}.mlp.down_proj.weight"],
173
- f"model.layers.{layer_i}.mlp.up_proj.weight": loaded['encoder'][f"layers.{layer_i}.mlp.up_proj.weight"],
174
- f"model.layers.{layer_i}.input_layernorm.multiway.0.weight": loaded['encoder'][f"layers.{layer_i}.input_layernorm.multiway.0.weight"],
175
- f"model.layers.{layer_i}.post_attention_layernorm.multiway.0.weight": loaded['encoder'][f"layers.{layer_i}.post_attention_layernorm.multiway.0.weight"],
176
- f"model.layers.{layer_i}.input_layernorm.multiway.1.weight": loaded['encoder'][f"layers.{layer_i}.input_layernorm.multiway.1.weight"],
177
- f"model.layers.{layer_i}.post_attention_layernorm.multiway.1.weight": loaded['encoder'][f"layers.{layer_i}.post_attention_layernorm.multiway.1.weight"],
178
- }
179
- else:
180
- raise NotImplemented
181
-
182
- state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
183
- for k, v in state_dict.items():
184
- index_dict["weight_map"][k] = filename
185
- param_count += v.numel()
186
- torch.save(state_dict, os.path.join(tmp_model_path, filename))
187
- print(f'Sharded file saved to {filename}')
188
-
189
- filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin"
190
- if num_shards==1:
191
- # Unsharded
192
- state_dict = {
193
- "model.embed_tokens.weight": loaded['embedding']['word_embeddings']['weight'],
194
- "model.norm.weight": loaded['encoder']['norm.weight'],
195
- "lm_head.weight": loaded['encoder']['lm_head.weight'],
196
- }
197
- else:
198
- state_dict = {
199
- "model.embed_tokens.weight": loaded[0]['embedding']['word_embeddings']['weight'],
200
- "model.norm.weight": loaded[0]['encoder']['norm.weight'],
201
- "lm_head.weight": loaded[0]['encoder']['lm_head.weight'],
202
- }
203
-
204
-
205
- loaded_all = torch.load(original_filename, map_location="cpu")['model']
206
- # Vision Part
207
- state_dict.update({
208
- "model.vision_model.embeddings.cls_token": loaded_all['vision_model']['cls_token'],
209
- "model.vision_model.embeddings.patch_embed.weight": loaded_all['vision_model']['patch_embed']['weight'],
210
- "model.vision_model.embeddings.position_embedding": loaded_all['vision_model']['position_embeddings'],
211
- "model.vision_model.embeddings.pre_layernorm.bias": loaded_all['vision_model']['pre_layernorm']['bias'],
212
- "model.vision_model.embeddings.pre_layernorm.weight": loaded_all['vision_model']['pre_layernorm']['weight'],
213
- "model.vision_model.post_layernorm.bias": loaded_all['vision_model']['transformer']['final_layernorm.bias'],
214
- "model.vision_model.post_layernorm.weight": loaded_all['vision_model']['transformer']['final_layernorm.weight'],
215
- })
216
- for v_layer_idx in range(24):
217
- state_dict.update({
218
- f"model.vision_model.encoder.layers.{v_layer_idx}.input_layernorm.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.input_layernorm.bias'],
219
- f"model.vision_model.encoder.layers.{v_layer_idx}.input_layernorm.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.input_layernorm.weight'],
220
- f"model.vision_model.encoder.layers.{v_layer_idx}.mlp.fc1.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.mlp.dense_h_to_4h.bias'],
221
- f"model.vision_model.encoder.layers.{v_layer_idx}.mlp.fc1.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.mlp.dense_h_to_4h.weight'],
222
- f"model.vision_model.encoder.layers.{v_layer_idx}.mlp.fc2.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.mlp.dense_4h_to_h.bias'],
223
- f"model.vision_model.encoder.layers.{v_layer_idx}.mlp.fc2.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.mlp.dense_4h_to_h.weight'],
224
- f"model.vision_model.encoder.layers.{v_layer_idx}.post_attention_layernorm.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.post_attention_layernorm.bias'],
225
- f"model.vision_model.encoder.layers.{v_layer_idx}.post_attention_layernorm.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.post_attention_layernorm.weight'],
226
- f"model.vision_model.encoder.layers.{v_layer_idx}.self_attn.dense.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.self_attention.dense.bias'],
227
- f"model.vision_model.encoder.layers.{v_layer_idx}.self_attn.dense.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.self_attention.dense.weight'],
228
- f"model.vision_model.encoder.layers.{v_layer_idx}.self_attn.query_key_value.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.self_attention.query_key_value.bias'],
229
- f"model.vision_model.encoder.layers.{v_layer_idx}.self_attn.query_key_value.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.self_attention.query_key_value.weight'],
230
- })
231
-
232
- # Vision2Text Part: HReducer
233
- state_dict.update({
234
- "model.vision2text.ln_q.weight": loaded_all['hreducer3']['ln_q']['weight'],
235
- "model.vision2text.ln_q.bias": loaded_all['hreducer3']['ln_q']['bias'],
236
- "model.vision2text.visual_fc.bias": loaded_all['hreducer3']['visual_fc']['bias'],
237
- "model.vision2text.visual_fc.weight": loaded_all['hreducer3']['visual_fc']['weight'],
238
- "model.vision2text.vit_eos": loaded_all['hreducer3']['vit_eos'],
239
- })
240
- # reducer_before conv (layer 0) + gleu (layer 1)
241
- state_dict.update({
242
- f"model.vision2text.reducer_before.0.weight": loaded_all['hreducer3']['reducer_before']["0.weight"],
243
- f"model.vision2text.reducer_before.0.bias": loaded_all['hreducer3']['reducer_before']["0.bias"],
244
- })
245
- # reducer conv
246
- state_dict.update({
247
- f"model.vision2text.reducer.weight": loaded_all['hreducer3']['reducer']["weight"],
248
- f"model.vision2text.reducer.bias": loaded_all['hreducer3']['reducer']["bias"],
249
- })
250
-
251
- for k, v in state_dict.items():
252
- # ic(k, v)
253
- index_dict["weight_map"][k] = filename
254
- param_count += v.numel()
255
- torch.save(state_dict, os.path.join(tmp_model_path, filename))
256
-
257
- # Write configs
258
- index_dict["metadata"] = {"total_size": param_count * 2}
259
- write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
260
-
261
- config = MPLUGDocOwlConfig()
262
- config.save_pretrained(tmp_model_path)
263
-
264
- # Make space so we can load the model properly now.
265
- del state_dict
266
- del loaded
267
- del loaded_all
268
- gc.collect()
269
-
270
- def write_tokenizer(tokenizer_path, input_tokenizer_path):
271
- # Initialize the tokenizer based on the `spm` model
272
- tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast
273
- print(f"Saving a {tokenizer_class.__name__} to {tokenizer_path}.")
274
- tokenizer = tokenizer_class(input_tokenizer_path)
275
- tokenizer.save_pretrained(tokenizer_path)
276
-
277
-
278
- def main():
279
- parser = argparse.ArgumentParser()
280
- parser.add_argument(
281
- "--input_dir",
282
- help="Location of LLaMA_Megatron weights",
283
- )
284
- parser.add_argument(
285
- "--model_size",
286
- type=int,
287
- default=7,
288
- choices=[7, 13, 30, 65, 70],
289
- )
290
- parser.add_argument(
291
- "--num_input_shards",
292
- type=int,
293
- default=1,
294
- )
295
- parser.add_argument(
296
- "--num_output_shards",
297
- type=int,
298
- default=1,
299
- )
300
- parser.add_argument('--skip_permute', action='store_true')
301
-
302
- parser.add_argument(
303
- "--output_dir",
304
- help="Location to write HF model and tokenizer",
305
- )
306
-
307
- args = parser.parse_args()
308
- write_model(
309
- model_path=args.output_dir,
310
- input_base_path=args.input_dir,
311
- model_size=args.model_size,
312
- num_input_shards=args.num_input_shards,
313
- num_output_shards=args.num_output_shards,
314
- skip_permute=args.skip_permute
315
- )
316
-
317
-
318
- if __name__ == "__main__":
319
- main()