tsujuifu commited on
Commit
893b461
1 Parent(s): bdbb79e
Files changed (7) hide show
  1. README.md +1 -1
  2. app.py +30 -22
  3. conversation.py +370 -0
  4. llava.py → mgie_llava.py +22 -19
  5. pre-requirements.txt +4 -4
  6. requirements.txt +4 -4
  7. train.py +0 -831
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 👩‍🎨
4
  colorFrom: blue
5
  colorTo: gray
6
  sdk: gradio
7
- sdk_version: 3.37.0
8
  app_file: app.py
9
  license: other
10
  ---
 
4
  colorFrom: blue
5
  colorTo: gray
6
  sdk: gradio
7
+ sdk_version: 4.12.0
8
  app_file: app.py
9
  license: other
10
  ---
app.py CHANGED
@@ -1,9 +1,9 @@
1
 
2
  import os
3
- # os.system('cp -r ./_ckpt/LLaVA-7B-v1 /data/LLaVA-7B-v1'), os.system('cp -r ./_ckpt/mgie_7b /data/mgie_7b')
4
- os.system('ls /data'), os.system('df -h /data')
5
- [os.system('mv llava.py /home/user/.pyenv/versions/3.10.13/lib/python3.10/site-packages/llava/model/llava.py'),
6
- os.system('mv train.py /home/user/.pyenv/versions/3.10.13/lib/python3.10/site-packages/llava/train/train.py')]
7
 
8
  from PIL import Image
9
 
@@ -11,8 +11,8 @@ import numpy as np
11
  import torch as T
12
  import transformers, diffusers
13
 
14
- from llava.conversation import conv_templates
15
- from llava.model import *
16
 
17
  import gradio as gr
18
 
@@ -39,7 +39,7 @@ DEFAULT_IMAGE_TOKEN = '<image>'
39
  DEFAULT_IMAGE_PATCH_TOKEN = '<im_patch>'
40
  DEFAULT_IM_START_TOKEN = '<im_start>'
41
  DEFAULT_IM_END_TOKEN = '<im_end>'
42
- PATH_LLAVA = '/data/LLaVA-7B-v1'
43
 
44
  tokenizer = transformers.AutoTokenizer.from_pretrained(PATH_LLAVA)
45
  model = LlavaLlamaForCausalLM.from_pretrained(PATH_LLAVA, low_cpu_mem_usage=True, torch_dtype=T.float16, use_cache=True).cuda()
@@ -48,7 +48,7 @@ image_processor = transformers.CLIPImageProcessor.from_pretrained(model.config.m
48
  tokenizer.padding_side = 'left'
49
  tokenizer.add_tokens(['[IMG0]', '[IMG1]', '[IMG2]', '[IMG3]', '[IMG4]', '[IMG5]', '[IMG6]', '[IMG7]'], special_tokens=True)
50
  model.resize_token_embeddings(len(tokenizer))
51
- ckpt = T.load('/data/mgie_7b/mllm.pt', map_location='cpu')
52
  model.load_state_dict(ckpt, strict=False)
53
 
54
  mm_use_im_start_end = getattr(model.config, 'mm_use_im_start_end', False)
@@ -65,15 +65,17 @@ if mm_use_im_start_end: vision_config.im_start_token, vision_config.im_end_token
65
  image_token_len = (vision_config.image_size//vision_config.patch_size)**2
66
 
67
  _ = model.eval()
68
- EMB = ckpt['emb'].cuda()
69
- with T.inference_mode(): NULL = model.edit_head(T.zeros(1, 8, 4096).half().to('cuda'), EMB)
70
 
71
  pipe = diffusers.StableDiffusionInstructPix2PixPipeline.from_pretrained('timbrooks/instruct-pix2pix', torch_dtype=T.float16).to('cuda')
72
  pipe.set_progress_bar_config(disable=True)
73
- pipe.unet.load_state_dict(T.load('/data/mgie_7b/unet.pt', map_location='cpu'))
74
  print('--init MGIE--')
75
 
 
76
  def go_mgie(img, txt, seed, cfg_txt, cfg_img):
 
 
 
77
  img, seed = crop_resize(Image.fromarray(img).convert('RGB')), int(seed)
78
  inp = img
79
 
@@ -87,6 +89,7 @@ def go_mgie(img, txt, seed, cfg_txt, cfg_img):
87
  txt, mask = T.as_tensor(txt['input_ids']), T.as_tensor(txt['attention_mask'])
88
 
89
  with T.inference_mode():
 
90
  out = model.generate(txt.unsqueeze(dim=0).cuda(), images=img.half().unsqueeze(dim=0).cuda(), attention_mask=mask.unsqueeze(dim=0).cuda(),
91
  do_sample=False, max_new_tokens=96, num_beams=1, no_repeat_ngram_size=3,
92
  return_dict_in_generate=True, output_hidden_states=True)
@@ -98,6 +101,7 @@ def go_mgie(img, txt, seed, cfg_txt, cfg_img):
98
  hid = hid[p:p+8]
99
 
100
  out = remove_alter(tokenizer.decode(out))
 
101
  emb = model.edit_head(hid.unsqueeze(dim=0), EMB)
102
  res = pipe(image=inp, prompt_embeds=emb, negative_prompt_embeds=NULL,
103
  generator=T.Generator(device='cuda').manual_seed(seed), guidance_scale=cfg_txt, image_guidance_scale=cfg_img).images[0]
@@ -105,14 +109,14 @@ def go_mgie(img, txt, seed, cfg_txt, cfg_img):
105
  return res, out
106
 
107
  def go_example(seed, cfg_txt, cfg_img):
108
- txt = ['make the frame red', 'turn the day into night', 'give him a beard', 'make cottage a mansion',
109
  'remove yellow object from dogs paws', 'change the hair from red to blue', 'remove the text', 'increase the image contrast',
110
  'remove the people in the background', 'please make this photo professional looking', 'darken the image, sharpen it', 'photoshop the girl out',
111
  'make more brightness', 'take away the brown filter form the image', 'add more contrast to simulate more light', 'dark on rgb',
112
  'make the face happy', 'change view as ocean', 'replace basketball with soccer ball', 'let the floor be made of wood']
113
- i = T.randint(len(txt), (1, )).item()
114
 
115
- return './_input/%d.jpg'%(i), txt[i], seed, cfg_txt, cfg_img
116
 
117
  go_mgie(np.array(Image.open('./_input/0.jpg').convert('RGB')), 'make the frame red', 13331, 7.5, 1.5)
118
  print('--init GO--')
@@ -120,25 +124,29 @@ print('--init GO--')
120
  with gr.Blocks() as app:
121
  gr.Markdown(
122
  """
123
- 🔔 we will have a maintenance at 3 a.m. (PST)
124
  # [ICLR\'24] Guiding Instruction-based Image Editing via Multimodal Large Language Models<br>
125
  🔔 this demo is hosted by [Tsu-Jui Fu](https://github.com/tsujuifu/pytorch_mgie)<br>
126
  🔔 a black image means that the output did not pass the [safety checker](https://huggingface.co/CompVis/stable-diffusion-safety-checker)<br>
127
- 🔔 if the queue is full (*this app is too busy*), you can also try it [here](http://128.111.41.13:7122)<br>
128
  🔔 if the building process takes too long, please try refreshing the page
129
  """
130
  )
131
  with gr.Row(): inp, res = [gr.Image(height=384, width=384, label='Input Image', interactive=True),
132
- gr.Image(height=384, width=384, label='Goal Image', interactive=False)]
133
  with gr.Row(): txt, out = [gr.Textbox(label='Instruction', interactive=True),
134
  gr.Textbox(label='Expressive Instruction', interactive=False)]
135
  with gr.Row(): seed, cfg_txt, cfg_img = [gr.Number(value=13331, label='Seed', interactive=True),
136
  gr.Number(value=7.5, label='Text CFG', interactive=True),
137
  gr.Number(value=1.5, label='Image CFG', interactive=True)]
138
- with gr.Row(): btn_sub, btn_exp = [gr.Button('Submit'),
139
- gr.Button('Example')]
140
-
141
- btn_sub.click(fn=go_mgie, inputs=[inp, txt, seed, cfg_txt, cfg_img], outputs=[res, out])
142
  btn_exp.click(fn=go_example, inputs=[seed, cfg_txt, cfg_img], outputs=[inp, txt, seed, cfg_txt, cfg_img])
 
 
 
 
 
 
 
 
143
 
144
- app.queue(concurrency_count=1, max_size=75), app.launch()
 
1
 
2
  import os
3
+
4
+ import huggingface_hub, spaces
5
+ huggingface_hub.snapshot_download(repo_id='tsujuifu/ml-mgie', repo_type='model', local_dir='_ckpt', local_dir_use_symlinks=False)
6
+ os.system('ls _ckpt')
7
 
8
  from PIL import Image
9
 
 
11
  import torch as T
12
  import transformers, diffusers
13
 
14
+ from conversation import conv_templates
15
+ from mgie_llava import *
16
 
17
  import gradio as gr
18
 
 
39
  DEFAULT_IMAGE_PATCH_TOKEN = '<im_patch>'
40
  DEFAULT_IM_START_TOKEN = '<im_start>'
41
  DEFAULT_IM_END_TOKEN = '<im_end>'
42
+ PATH_LLAVA = '_ckpt/LLaVA-7B-v1'
43
 
44
  tokenizer = transformers.AutoTokenizer.from_pretrained(PATH_LLAVA)
45
  model = LlavaLlamaForCausalLM.from_pretrained(PATH_LLAVA, low_cpu_mem_usage=True, torch_dtype=T.float16, use_cache=True).cuda()
 
48
  tokenizer.padding_side = 'left'
49
  tokenizer.add_tokens(['[IMG0]', '[IMG1]', '[IMG2]', '[IMG3]', '[IMG4]', '[IMG5]', '[IMG6]', '[IMG7]'], special_tokens=True)
50
  model.resize_token_embeddings(len(tokenizer))
51
+ ckpt = T.load('_ckpt/mgie_7b/mllm.pt', map_location='cpu')
52
  model.load_state_dict(ckpt, strict=False)
53
 
54
  mm_use_im_start_end = getattr(model.config, 'mm_use_im_start_end', False)
 
65
  image_token_len = (vision_config.image_size//vision_config.patch_size)**2
66
 
67
  _ = model.eval()
 
 
68
 
69
  pipe = diffusers.StableDiffusionInstructPix2PixPipeline.from_pretrained('timbrooks/instruct-pix2pix', torch_dtype=T.float16).to('cuda')
70
  pipe.set_progress_bar_config(disable=True)
71
+ pipe.unet.load_state_dict(T.load('_ckpt/mgie_7b/unet.pt', map_location='cpu'))
72
  print('--init MGIE--')
73
 
74
+ @spaces.GPU(enable_queue=True)
75
  def go_mgie(img, txt, seed, cfg_txt, cfg_img):
76
+ EMB = ckpt['emb'].cuda()
77
+ with T.inference_mode(): NULL = model.edit_head(T.zeros(1, 8, 4096).half().to('cuda'), EMB)
78
+
79
  img, seed = crop_resize(Image.fromarray(img).convert('RGB')), int(seed)
80
  inp = img
81
 
 
89
  txt, mask = T.as_tensor(txt['input_ids']), T.as_tensor(txt['attention_mask'])
90
 
91
  with T.inference_mode():
92
+ _ = model.cuda()
93
  out = model.generate(txt.unsqueeze(dim=0).cuda(), images=img.half().unsqueeze(dim=0).cuda(), attention_mask=mask.unsqueeze(dim=0).cuda(),
94
  do_sample=False, max_new_tokens=96, num_beams=1, no_repeat_ngram_size=3,
95
  return_dict_in_generate=True, output_hidden_states=True)
 
101
  hid = hid[p:p+8]
102
 
103
  out = remove_alter(tokenizer.decode(out))
104
+ _ = model.cuda()
105
  emb = model.edit_head(hid.unsqueeze(dim=0), EMB)
106
  res = pipe(image=inp, prompt_embeds=emb, negative_prompt_embeds=NULL,
107
  generator=T.Generator(device='cuda').manual_seed(seed), guidance_scale=cfg_txt, image_guidance_scale=cfg_img).images[0]
 
109
  return res, out
110
 
111
  def go_example(seed, cfg_txt, cfg_img):
112
+ ins = ['make the frame red', 'turn the day into night', 'give him a beard', 'make cottage a mansion',
113
  'remove yellow object from dogs paws', 'change the hair from red to blue', 'remove the text', 'increase the image contrast',
114
  'remove the people in the background', 'please make this photo professional looking', 'darken the image, sharpen it', 'photoshop the girl out',
115
  'make more brightness', 'take away the brown filter form the image', 'add more contrast to simulate more light', 'dark on rgb',
116
  'make the face happy', 'change view as ocean', 'replace basketball with soccer ball', 'let the floor be made of wood']
117
+ i = T.randint(len(ins), (1, )).item()
118
 
119
+ return './_input/%d.jpg'%(i), ins[i], seed, cfg_txt, cfg_img
120
 
121
  go_mgie(np.array(Image.open('./_input/0.jpg').convert('RGB')), 'make the frame red', 13331, 7.5, 1.5)
122
  print('--init GO--')
 
124
  with gr.Blocks() as app:
125
  gr.Markdown(
126
  """
 
127
  # [ICLR\'24] Guiding Instruction-based Image Editing via Multimodal Large Language Models<br>
128
  🔔 this demo is hosted by [Tsu-Jui Fu](https://github.com/tsujuifu/pytorch_mgie)<br>
129
  🔔 a black image means that the output did not pass the [safety checker](https://huggingface.co/CompVis/stable-diffusion-safety-checker)<br>
130
+ 🔔 if the queue is full (*no GPU available*), you can also try it [here](http://128.111.41.13:7122)<br>
131
  🔔 if the building process takes too long, please try refreshing the page
132
  """
133
  )
134
  with gr.Row(): inp, res = [gr.Image(height=384, width=384, label='Input Image', interactive=True),
135
+ gr.Image(height=384, width=384, label='Goal Image', interactive=True)]
136
  with gr.Row(): txt, out = [gr.Textbox(label='Instruction', interactive=True),
137
  gr.Textbox(label='Expressive Instruction', interactive=False)]
138
  with gr.Row(): seed, cfg_txt, cfg_img = [gr.Number(value=13331, label='Seed', interactive=True),
139
  gr.Number(value=7.5, label='Text CFG', interactive=True),
140
  gr.Number(value=1.5, label='Image CFG', interactive=True)]
141
+ with gr.Row(): btn_exp, btn_sub = [gr.Button('More Example'), gr.Button('Submit')]
 
 
 
142
  btn_exp.click(fn=go_example, inputs=[seed, cfg_txt, cfg_img], outputs=[inp, txt, seed, cfg_txt, cfg_img])
143
+ btn_sub.click(fn=go_mgie, inputs=[inp, txt, seed, cfg_txt, cfg_img], outputs=[res, out])
144
+
145
+ ins = ['make the frame red', 'turn the day into night', 'give him a beard', 'make cottage a mansion',
146
+ 'remove yellow object from dogs paws', 'change the hair from red to blue', 'remove the text', 'increase the image contrast',
147
+ 'remove the people in the background', 'please make this photo professional looking', 'darken the image, sharpen it', 'photoshop the girl out',
148
+ 'make more brightness', 'take away the brown filter form the image', 'add more contrast to simulate more light', 'dark on rgb',
149
+ 'make the face happy', 'change view as ocean', 'replace basketball with soccer ball', 'let the floor be made of wood']
150
+ gr.Examples(examples=[['./_input/%d.jpg'%(i), ins[i]] for i in [1, 5, 8, 14, 16]], inputs=[inp, txt])
151
 
152
+ app.launch()
conversation.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # modified from https://github.com/haotian-liu/LLaVA/blob/7ace501183c4bdec6052ec1a30039cdc3242a67c/llava/conversation.py
3
+
4
+ import dataclasses
5
+ from enum import auto, Enum
6
+ from typing import List, Tuple
7
+
8
+
9
+ class SeparatorStyle(Enum):
10
+ """Different separator style."""
11
+ SINGLE = auto()
12
+ TWO = auto()
13
+ MPT = auto()
14
+
15
+
16
+ @dataclasses.dataclass
17
+ class Conversation:
18
+ """A class that keeps all conversation history."""
19
+ system: str
20
+ roles: List[str]
21
+ messages: List[List[str]]
22
+ offset: int
23
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
24
+ sep: str = "###"
25
+ sep2: str = None
26
+ version: str = "Unknown"
27
+
28
+ skip_next: bool = False
29
+
30
+ def get_prompt(self):
31
+ if self.sep_style == SeparatorStyle.SINGLE:
32
+ ret = self.system + self.sep
33
+ for role, message in self.messages:
34
+ if message:
35
+ if type(message) is tuple:
36
+ message, _, _ = message
37
+ ret += role + ": " + message + self.sep
38
+ else:
39
+ ret += role + ":"
40
+ return ret
41
+ elif self.sep_style == SeparatorStyle.TWO:
42
+ seps = [self.sep, self.sep2]
43
+ ret = self.system + seps[0]
44
+ for i, (role, message) in enumerate(self.messages):
45
+ if message:
46
+ if type(message) is tuple:
47
+ message, _, _ = message
48
+ ret += role + ": " + message + seps[i % 2]
49
+ else:
50
+ ret += role + ":"
51
+ return ret
52
+ if self.sep_style == SeparatorStyle.MPT:
53
+ ret = self.system + self.sep
54
+ for role, message in self.messages:
55
+ if message:
56
+ if type(message) is tuple:
57
+ message, _, _ = message
58
+ ret += role + message + self.sep
59
+ else:
60
+ ret += role
61
+ return ret
62
+ else:
63
+ raise ValueError(f"Invalid style: {self.sep_style}")
64
+
65
+ def append_message(self, role, message):
66
+ self.messages.append([role, message])
67
+
68
+ def get_images(self, return_pil=False):
69
+ images = []
70
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
71
+ if i % 2 == 0:
72
+ if type(msg) is tuple:
73
+ import base64
74
+ from io import BytesIO
75
+ from PIL import Image
76
+ msg, image, image_process_mode = msg
77
+ if image_process_mode == "Pad":
78
+ def expand2square(pil_img, background_color=(122, 116, 104)):
79
+ width, height = pil_img.size
80
+ if width == height:
81
+ return pil_img
82
+ elif width > height:
83
+ result = Image.new(pil_img.mode, (width, width), background_color)
84
+ result.paste(pil_img, (0, (width - height) // 2))
85
+ return result
86
+ else:
87
+ result = Image.new(pil_img.mode, (height, height), background_color)
88
+ result.paste(pil_img, ((height - width) // 2, 0))
89
+ return result
90
+ image = expand2square(image)
91
+ elif image_process_mode == "Crop":
92
+ pass
93
+ elif image_process_mode == "Resize":
94
+ image = image.resize((224, 224))
95
+ else:
96
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
97
+ max_hw, min_hw = max(image.size), min(image.size)
98
+ aspect_ratio = max_hw / min_hw
99
+ max_len, min_len = 800, 400
100
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
101
+ longest_edge = int(shortest_edge * aspect_ratio)
102
+ W, H = image.size
103
+ if H > W:
104
+ H, W = longest_edge, shortest_edge
105
+ else:
106
+ H, W = shortest_edge, longest_edge
107
+ image = image.resize((W, H))
108
+ if return_pil:
109
+ images.append(image)
110
+ else:
111
+ buffered = BytesIO()
112
+ image.save(buffered, format="JPEG")
113
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
114
+ images.append(img_b64_str)
115
+ return images
116
+
117
+ def to_gradio_chatbot(self):
118
+ ret = []
119
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
120
+ if i % 2 == 0:
121
+ if type(msg) is tuple:
122
+ import base64
123
+ from io import BytesIO
124
+ msg, image, image_process_mode = msg
125
+ max_hw, min_hw = max(image.size), min(image.size)
126
+ aspect_ratio = max_hw / min_hw
127
+ max_len, min_len = 800, 400
128
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
129
+ longest_edge = int(shortest_edge * aspect_ratio)
130
+ W, H = image.size
131
+ if H > W:
132
+ H, W = longest_edge, shortest_edge
133
+ else:
134
+ H, W = shortest_edge, longest_edge
135
+ image = image.resize((W, H))
136
+ # image = image.resize((224, 224))
137
+ buffered = BytesIO()
138
+ image.save(buffered, format="JPEG")
139
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
140
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
141
+ msg = msg.replace('<image>', img_str)
142
+ ret.append([msg, None])
143
+ else:
144
+ ret[-1][-1] = msg
145
+ return ret
146
+
147
+ def copy(self):
148
+ return Conversation(
149
+ system=self.system,
150
+ roles=self.roles,
151
+ messages=[[x, y] for x, y in self.messages],
152
+ offset=self.offset,
153
+ sep_style=self.sep_style,
154
+ sep=self.sep,
155
+ sep2=self.sep2)
156
+
157
+ def dict(self):
158
+ if len(self.get_images()) > 0:
159
+ return {
160
+ "system": self.system,
161
+ "roles": self.roles,
162
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
163
+ "offset": self.offset,
164
+ "sep": self.sep,
165
+ "sep2": self.sep2,
166
+ }
167
+ return {
168
+ "system": self.system,
169
+ "roles": self.roles,
170
+ "messages": self.messages,
171
+ "offset": self.offset,
172
+ "sep": self.sep,
173
+ "sep2": self.sep2,
174
+ }
175
+
176
+
177
+ conv_v1 = Conversation(
178
+ system="A chat between a curious human and an artificial intelligence assistant. "
179
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
180
+ roles=("Human", "Assistant"),
181
+ messages=(
182
+ ("Human", "Give three tips for staying healthy."),
183
+ ("Assistant",
184
+ "Sure, here are three tips for staying healthy:\n"
185
+ "1. Exercise regularly: Regular physical activity can help improve your overall health and wellbeing. "
186
+ "It can also help reduce your risk of chronic conditions such as obesity, diabetes, heart disease, "
187
+ "and certain cancers. Aim for at least 150 minutes of moderate-intensity aerobic exercise or "
188
+ "75 minutes of vigorous-intensity aerobic exercise per week, along with muscle-strengthening "
189
+ "activities at least two days per week.\n"
190
+ "2. Eat a balanced diet: Eating a balanced diet that is rich in fruits, "
191
+ "vegetables, whole grains, lean proteins, and healthy fats can help support "
192
+ "your overall health. Try to limit your intake of processed and high-sugar foods, "
193
+ "and aim to drink plenty of water throughout the day.\n"
194
+ "3. Get enough sleep: Getting enough quality sleep is essential for your physical "
195
+ "and mental health. Adults should aim for seven to nine hours of sleep per night. "
196
+ "Establish a regular sleep schedule and try to create a relaxing bedtime routine to "
197
+ "help improve the quality of your sleep.")
198
+ ),
199
+ offset=2,
200
+ sep_style=SeparatorStyle.SINGLE,
201
+ sep="###",
202
+ )
203
+
204
+ conv_v1_2 = Conversation(
205
+ system="A chat between a curious human and an artificial intelligence assistant. "
206
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
207
+ roles=("Human", "Assistant"),
208
+ messages=(
209
+ ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
210
+ ("Assistant",
211
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
212
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
213
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
214
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
215
+ "renewable and non-renewable energy sources:\n"
216
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
217
+ "energy sources are finite and will eventually run out.\n"
218
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
219
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
220
+ "and other negative effects.\n"
221
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
222
+ "have lower operational costs than non-renewable sources.\n"
223
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
224
+ "locations than non-renewable sources.\n"
225
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
226
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
227
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
228
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
229
+ ),
230
+ offset=2,
231
+ sep_style=SeparatorStyle.SINGLE,
232
+ sep="###",
233
+ )
234
+
235
+ conv_vicuna_v1_1 = Conversation(
236
+ system="A chat between a curious user and an artificial intelligence assistant. "
237
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
238
+ roles=("USER", "ASSISTANT"),
239
+ version="v1",
240
+ messages=(),
241
+ offset=0,
242
+ sep_style=SeparatorStyle.TWO,
243
+ sep=" ",
244
+ sep2="</s>",
245
+ )
246
+
247
+ conv_mpt = Conversation(
248
+ system="""<|im_start|>system
249
+ - You are a helpful language and vision assistant.
250
+ - You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.
251
+ - You should follow the instructions carefully and explain your answers in detail.""",
252
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
253
+ version="mpt",
254
+ messages=(),
255
+ offset=0,
256
+ sep_style=SeparatorStyle.MPT,
257
+ sep="<|im_end|>",
258
+ )
259
+
260
+ conv_mpt_text = Conversation(
261
+ system="""<|im_start|>system
262
+ - You are a helpful assistant chatbot trained by MosaicML.
263
+ - You answer questions.
264
+ - You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
265
+ - You are more than just an information source, you are also able to write poetry, short stories, and make jokes.""",
266
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
267
+ version="mpt",
268
+ messages=(),
269
+ offset=0,
270
+ sep_style=SeparatorStyle.MPT,
271
+ sep="<|im_end|>",
272
+ )
273
+
274
+ conv_bair_v1 = Conversation(
275
+ system="BEGINNING OF CONVERSATION:",
276
+ roles=("USER", "GPT"),
277
+ messages=(),
278
+ offset=0,
279
+ sep_style=SeparatorStyle.TWO,
280
+ sep=" ",
281
+ sep2="</s>",
282
+ )
283
+
284
+ simple_conv = Conversation(
285
+ system="A chat between a curious human and an artificial intelligence assistant. "
286
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
287
+ roles=("Human", "Assistant"),
288
+ messages=(
289
+ ("Human", "Hi!"),
290
+ ("Assistant", "Hi there! How can I help you today?")
291
+ ),
292
+ offset=2,
293
+ sep_style=SeparatorStyle.SINGLE,
294
+ sep="###",
295
+ )
296
+
297
+ simple_conv_multimodal = Conversation(
298
+ system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab."
299
+ "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
300
+ "Follow the instructions carefully and explain your answers in detail.",
301
+ roles=("Human", "Assistant"),
302
+ messages=(
303
+ ("Human", "Hi!"),
304
+ ("Assistant", "Hi there! How can I help you today?\n")
305
+ ),
306
+ offset=2,
307
+ sep_style=SeparatorStyle.SINGLE,
308
+ sep="###",
309
+ )
310
+
311
+ simple_conv_mpt_multimodal = Conversation(
312
+ system="""<|im_start|>system
313
+ - You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab.
314
+ - You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.
315
+ - You should follow the instructions carefully and explain your answers in detail.""",
316
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
317
+ version="mpt",
318
+ messages=(),
319
+ offset=0,
320
+ sep_style=SeparatorStyle.MPT,
321
+ sep="<|im_end|>",
322
+ )
323
+
324
+ simple_conv_legacy = Conversation(
325
+ system="You are LLaVA, a large language model trained by UW Madison WAIV Lab."
326
+ "You are designed to assist human with a variety of tasks using natural language."
327
+ "Follow the instructions carefully.",
328
+ roles=("Human", "Assistant"),
329
+ messages=(
330
+ ("Human", "Hi!\n\n### Response:"),
331
+ ("Assistant", "Hi there! How can I help you today?\n")
332
+ ),
333
+ offset=2,
334
+ sep_style=SeparatorStyle.SINGLE,
335
+ sep="###",
336
+ )
337
+
338
+ conv_llava_v1 = Conversation(
339
+ system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab."
340
+ "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
341
+ "Follow the instructions carefully and explain your answers in detail.",
342
+ roles=("USER", "ASSISTANT"),
343
+ version="v1",
344
+ messages=(),
345
+ offset=0,
346
+ sep_style=SeparatorStyle.TWO,
347
+ sep=" ",
348
+ sep2="</s>",
349
+ )
350
+
351
+ default_conversation = conv_v1_2
352
+ conv_templates = {
353
+ "default": conv_v1_2,
354
+ "simple": simple_conv,
355
+ "simple_legacy": simple_conv_legacy,
356
+ "multimodal": simple_conv_multimodal,
357
+ "mpt_multimodal": simple_conv_mpt_multimodal,
358
+ "llava_v1": conv_llava_v1,
359
+
360
+ # fastchat
361
+ "v1": conv_v1_2,
362
+ "bair_v1": conv_bair_v1,
363
+ "vicuna_v1_1": conv_vicuna_v1_1,
364
+ "mpt": conv_mpt,
365
+ "mpt_text": conv_mpt_text,
366
+ }
367
+
368
+
369
+ if __name__ == "__main__":
370
+ print(default_conversation.get_prompt())
llava.py → mgie_llava.py RENAMED
@@ -1,4 +1,7 @@
1
-
 
 
 
2
  # modified from https://github.com/haotian-liu/LLaVA/blob/7ace501183c4bdec6052ec1a30039cdc3242a67c/llava/model/llava.py
3
 
4
  from typing import List, Optional, Tuple, Union
@@ -184,19 +187,19 @@ class LlavaLlamaModel(LlamaModel):
184
  class EditMapper(nn.Module):
185
  def __init__(self):
186
  super().__init__()
187
-
188
  self.llm2hid = nn.Linear(4096, 512)
189
  self.query = nn.Parameter(torch.randn(1, 77, 512))
190
- self.mapper = nn.Transformer(batch_first=True, norm_first=True,
191
- d_model=512, nhead=4, num_encoder_layers=4, num_decoder_layers=4,
192
  dim_feedforward=2048, dropout=0.0)
193
  self.hid2feat = nn.Linear(512, 768)
194
-
195
  def forward(self, llm, emb):
196
  hid = self.llm2hid(llm+emb)
197
  hid = self.mapper(hid, self.query.repeat(llm.shape[0], 1, 1))
198
  feat = self.hid2feat(hid)
199
-
200
  return feat
201
 
202
  class LlavaLlamaForCausalLM(LlamaForCausalLM):
@@ -209,9 +212,9 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
209
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
210
 
211
  self.edit_head = EditMapper()
212
-
213
- '''self.scheduler, self.vae, self.unet = [diffusers.DDPMScheduler.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder='scheduler'),
214
- diffusers.AutoencoderKL.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder='vae'),
215
  diffusers.UNet2DConditionModel.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder='unet')]
216
  self.vae.requires_grad_(False)
217
  self.unet.register_to_config(in_channels=8)
@@ -220,7 +223,7 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
220
  conv.weight.zero_()
221
  conv.weight[:, :4, :, :].copy_(self.unet.conv_in.weight)
222
  self.unet.conv_in = conv'''
223
-
224
  # Initialize weights and apply final processing
225
  self.post_init()
226
 
@@ -236,7 +239,7 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
236
  if type(vision_tower) is list:
237
  vision_tower = vision_tower[0]
238
  return vision_tower
239
-
240
  def forward(
241
  self,
242
  input_ids: torch.LongTensor = None,
@@ -248,7 +251,7 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
248
  output_attentions: Optional[bool] = None,
249
  output_hidden_states: Optional[bool] = None,
250
  images: Optional[torch.FloatTensor] = None,
251
- return_dict: Optional[bool] = None,
252
  p2p_inp=None, p2p_ans=None
253
  ) -> Union[Tuple, CausalLMOutputWithPast]:
254
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -297,13 +300,13 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
297
  hid_edit = self.edit_head(llm, self.model.embed_tokens.weight[-8:].unsqueeze(dim=0).repeat(labels.shape[0], 1, 1))
298
 
299
  B, DROP = labels.shape[0], 0.05
300
-
301
- hid_null = self.edit_head(torch.zeros(B, 8, 4096, device=labels.device),
302
  self.model.embed_tokens.weight[-8:].unsqueeze(dim=0).repeat(labels.shape[0], 1, 1))
303
 
304
  with torch.no_grad():
305
  lat_ans, lat_inp = self.vae.encode(p2p_ans).latent_dist.sample()*self.vae.config.scaling_factor, self.vae.encode(p2p_inp).latent_dist.mode()
306
- lat_ans, lat_inp = [torch.from_numpy(lat_ans.data.cpu().float().numpy()).to(lat_ans.device),
307
  torch.from_numpy(lat_inp.data.cpu().float().numpy()).to(lat_inp.device)]
308
 
309
  noise = torch.randn_like(lat_ans)
@@ -317,15 +320,15 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
317
  lat_inp *= mask
318
 
319
  out = self.unet(torch.cat([lat_noise, lat_inp], dim=1), ts, hid_edit).sample
320
-
321
  loss_ce, loss_edit = loss, nn.functional.mse_loss(out, noise, reduction='mean')
322
  if int(os.environ['LOCAL_RANK'])==0: print('loss_ce:', loss_ce, '/', 'loss_edit:', loss_edit)
323
  loss = loss_ce+loss_edit*0.5
324
-
325
  if not return_dict:
326
  output = (logits,) + outputs[1:]
327
  return (loss,) + output if loss is not None else output
328
-
329
  return CausalLMOutputWithPast(
330
  loss=loss,
331
  logits=logits,
@@ -371,7 +374,7 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM):
371
  if num_new_tokens > 0:
372
  input_embeddings = self.get_input_embeddings().weight.data
373
  output_embeddings = self.get_output_embeddings().weight.data
374
-
375
  input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
376
  dim=0, keepdim=True)
377
  output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
 
1
+ #
2
+ # For licensing see accompanying LICENSE file.
3
+ # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4
+ #
5
  # modified from https://github.com/haotian-liu/LLaVA/blob/7ace501183c4bdec6052ec1a30039cdc3242a67c/llava/model/llava.py
6
 
7
  from typing import List, Optional, Tuple, Union
 
187
  class EditMapper(nn.Module):
188
  def __init__(self):
189
  super().__init__()
190
+
191
  self.llm2hid = nn.Linear(4096, 512)
192
  self.query = nn.Parameter(torch.randn(1, 77, 512))
193
+ self.mapper = nn.Transformer(batch_first=True, norm_first=True,
194
+ d_model=512, nhead=4, num_encoder_layers=4, num_decoder_layers=4,
195
  dim_feedforward=2048, dropout=0.0)
196
  self.hid2feat = nn.Linear(512, 768)
197
+
198
  def forward(self, llm, emb):
199
  hid = self.llm2hid(llm+emb)
200
  hid = self.mapper(hid, self.query.repeat(llm.shape[0], 1, 1))
201
  feat = self.hid2feat(hid)
202
+
203
  return feat
204
 
205
  class LlavaLlamaForCausalLM(LlamaForCausalLM):
 
212
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
213
 
214
  self.edit_head = EditMapper()
215
+
216
+ '''self.scheduler, self.vae, self.unet = [diffusers.DDPMScheduler.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder='scheduler'),
217
+ diffusers.AutoencoderKL.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder='vae'),
218
  diffusers.UNet2DConditionModel.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder='unet')]
219
  self.vae.requires_grad_(False)
220
  self.unet.register_to_config(in_channels=8)
 
223
  conv.weight.zero_()
224
  conv.weight[:, :4, :, :].copy_(self.unet.conv_in.weight)
225
  self.unet.conv_in = conv'''
226
+
227
  # Initialize weights and apply final processing
228
  self.post_init()
229
 
 
239
  if type(vision_tower) is list:
240
  vision_tower = vision_tower[0]
241
  return vision_tower
242
+
243
  def forward(
244
  self,
245
  input_ids: torch.LongTensor = None,
 
251
  output_attentions: Optional[bool] = None,
252
  output_hidden_states: Optional[bool] = None,
253
  images: Optional[torch.FloatTensor] = None,
254
+ return_dict: Optional[bool] = None,
255
  p2p_inp=None, p2p_ans=None
256
  ) -> Union[Tuple, CausalLMOutputWithPast]:
257
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
300
  hid_edit = self.edit_head(llm, self.model.embed_tokens.weight[-8:].unsqueeze(dim=0).repeat(labels.shape[0], 1, 1))
301
 
302
  B, DROP = labels.shape[0], 0.05
303
+
304
+ hid_null = self.edit_head(torch.zeros(B, 8, 4096, device=labels.device),
305
  self.model.embed_tokens.weight[-8:].unsqueeze(dim=0).repeat(labels.shape[0], 1, 1))
306
 
307
  with torch.no_grad():
308
  lat_ans, lat_inp = self.vae.encode(p2p_ans).latent_dist.sample()*self.vae.config.scaling_factor, self.vae.encode(p2p_inp).latent_dist.mode()
309
+ lat_ans, lat_inp = [torch.from_numpy(lat_ans.data.cpu().float().numpy()).to(lat_ans.device),
310
  torch.from_numpy(lat_inp.data.cpu().float().numpy()).to(lat_inp.device)]
311
 
312
  noise = torch.randn_like(lat_ans)
 
320
  lat_inp *= mask
321
 
322
  out = self.unet(torch.cat([lat_noise, lat_inp], dim=1), ts, hid_edit).sample
323
+
324
  loss_ce, loss_edit = loss, nn.functional.mse_loss(out, noise, reduction='mean')
325
  if int(os.environ['LOCAL_RANK'])==0: print('loss_ce:', loss_ce, '/', 'loss_edit:', loss_edit)
326
  loss = loss_ce+loss_edit*0.5
327
+
328
  if not return_dict:
329
  output = (logits,) + outputs[1:]
330
  return (loss,) + output if loss is not None else output
331
+
332
  return CausalLMOutputWithPast(
333
  loss=loss,
334
  logits=logits,
 
374
  if num_new_tokens > 0:
375
  input_embeddings = self.get_input_embeddings().weight.data
376
  output_embeddings = self.get_output_embeddings().weight.data
377
+
378
  input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
379
  dim=0, keepdim=True)
380
  output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
pre-requirements.txt CHANGED
@@ -1,9 +1,9 @@
1
  sentencepiece
2
- transformers
3
  diffusers
4
- tokenizers
5
  datasets
6
  accelerate
7
  evaluate
8
- gradio
9
- git+https://github.com/haotian-liu/LLaVA@7ace501
 
1
  sentencepiece
2
+ git+https://github.com/huggingface/transformers.git@cae78c46
3
  diffusers
4
+ tokenizers==0.12.1
5
  datasets
6
  accelerate
7
  evaluate
8
+ gradio==4.12.0
9
+ gradio_client==0.8.0
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- -i https://download.pytorch.org/whl/cu113
2
- torch==1.12.0
3
- torchvision==0.13.0
4
- torchaudio==0.12.0
 
1
+ -i https://download.pytorch.org/whl/cu118
2
+ torch==2.0
3
+ torchvision==0.15
4
+ torchaudio==2.0
train.py DELETED
@@ -1,831 +0,0 @@
1
-
2
- # modified from https://github.com/haotian-liu/LLaVA/blob/7ace501183c4bdec6052ec1a30039cdc3242a67c/llava/train/train.py
3
-
4
- import os
5
- import copy
6
- from dataclasses import dataclass, field
7
- import json
8
- import logging
9
- import pathlib
10
- from typing import Dict, Optional, Sequence, List
11
-
12
- import torch
13
-
14
- import transformers
15
- from torch.utils.data import Dataset
16
- from llava.train.llava_trainer import LLaVATrainer
17
-
18
- from llava import conversation as conversation_lib
19
- from llava.model import *
20
-
21
- from PIL import Image
22
- import torch.nn as nn
23
-
24
- # TODO: import and use code from ../data/dataset.py
25
-
26
- IGNORE_INDEX = -100
27
- DEFAULT_PAD_TOKEN = "[PAD]"
28
- DEFAULT_EOS_TOKEN = "</s>"
29
- DEFAULT_BOS_TOKEN = "<s>"
30
- DEFAULT_UNK_TOKEN = "<unk>"
31
- DEFAULT_IMAGE_TOKEN = "<image>"
32
- DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
33
- DEFAULT_IM_START_TOKEN = "<im_start>"
34
- DEFAULT_IM_END_TOKEN = "<im_end>"
35
-
36
- import io, base64, pickle, random
37
- from tqdm import tqdm
38
- import numpy as np
39
-
40
- def b2f(b): return Image.open(io.BytesIO(base64.b64decode(b))).convert('RGB')
41
- def resize(f):
42
- w, h = f.size
43
- if w>h:
44
- p = (w-h)//2
45
- f = f.crop([p, 0, p+h, h])
46
- elif h>w:
47
- p = (h-w)//2
48
- f = f.crop([0, p, w, p+w])
49
- f = f.resize([512, 512])
50
- return f
51
- def img2npy(f): return (2.0*np.array(f)/255.0-1.0).transpose((2, 0, 1)).astype(np.float32)
52
-
53
- @dataclass
54
- class ModelArguments:
55
- model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
56
- version: Optional[str] = field(default="v0")
57
- freeze_backbone: bool = field(default=False)
58
- tune_mm_mlp_adapter: bool = field(default=False)
59
- vision_tower: Optional[str] = field(default=None)
60
- mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer
61
- pretrain_mm_mlp_adapter: Optional[str] = field(default=None)
62
- mm_use_im_start_end: bool = field(default=False)
63
-
64
-
65
- @dataclass
66
- class DataArguments:
67
- data_path: str = field(default=None,
68
- metadata={"help": "Path to the training data."})
69
- lazy_preprocess: bool = False
70
- is_multimodal: bool = False
71
- sep_image_conv_front: bool = False
72
- image_token_len: int = 0
73
- image_folder: Optional[str] = field(default=None)
74
- image_aspect_ratio: str = 'square'
75
-
76
-
77
- @dataclass
78
- class TrainingArguments(transformers.TrainingArguments):
79
- cache_dir: Optional[str] = field(default=None)
80
- optim: str = field(default="adamw_torch")
81
- remove_unused_columns: bool = field(default=False)
82
- freeze_mm_mlp_adapter: bool = field(default=False)
83
- force_fsdp: bool = field(default=False)
84
- model_max_length: int = field(
85
- default=512,
86
- metadata={
87
- "help":
88
- "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
89
- },
90
- )
91
- double_quant: bool = field(
92
- default=True,
93
- metadata={"help": "Compress the quantization statistics through double quantization."}
94
- )
95
- quant_type: str = field(
96
- default="nf4",
97
- metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
98
- )
99
- bits: int = field(
100
- default=16,
101
- metadata={"help": "How many bits to use."}
102
- )
103
- lora_enable: bool = False
104
- lora_r: int = 64
105
- lora_alpha: int = 16
106
- lora_dropout: float = 0.05
107
- lora_weight_path: str = ""
108
- lora_bias: str = "none"
109
-
110
-
111
- def maybe_zero_3(param, ignore_status=False, name=None):
112
- from deepspeed import zero
113
- from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
114
- if hasattr(param, "ds_id"):
115
- if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
116
- if not ignore_status:
117
- logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
118
- with zero.GatheredParameters([param]):
119
- param = param.data.detach().cpu().clone()
120
- else:
121
- param = param.detach().cpu().clone()
122
- return param
123
-
124
-
125
- # Borrowed from peft.utils.get_peft_model_state_dict
126
- def get_peft_state_maybe_zero_3(named_params, bias):
127
- if bias == "none":
128
- to_return = {k: t for k, t in named_params if "lora_" in k}
129
- elif bias == "all":
130
- to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
131
- elif bias == "lora_only":
132
- to_return = {}
133
- maybe_lora_bias = {}
134
- lora_bias_names = set()
135
- for k, t in named_params:
136
- if "lora_" in k:
137
- to_return[k] = t
138
- bias_name = k.split("lora_")[0] + "bias"
139
- lora_bias_names.add(bias_name)
140
- elif "bias" in k:
141
- maybe_lora_bias[k] = t
142
- for k, t in maybe_lora_bias:
143
- if bias_name in lora_bias_names:
144
- to_return[bias_name] = t
145
- else:
146
- raise NotImplementedError
147
- to_return = {k: maybe_zero_3(v, name=k) for k, v in to_return.items()}
148
- return to_return
149
-
150
-
151
- def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
152
- to_return = {k: t for k, t in named_params if "lora_" not in k}
153
- if require_grad_only:
154
- to_return = {k: t for k, t in to_return.items() if t.requires_grad}
155
- to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
156
- return to_return
157
-
158
-
159
- def find_all_linear_names(model):
160
- cls = torch.nn.Linear
161
- lora_module_names = set()
162
- for name, module in model.named_modules():
163
- if isinstance(module, cls):
164
- names = name.split('.')
165
- lora_module_names.add(names[0] if len(names) == 1 else names[-1])
166
-
167
-
168
- if 'lm_head' in lora_module_names: # needed for 16-bit
169
- lora_module_names.remove('lm_head')
170
- return list(lora_module_names)
171
-
172
-
173
- def safe_save_model_for_hf_trainer(trainer: transformers.Trainer,
174
- output_dir: str):
175
- """Collects the state dict and dump to disk."""
176
- if trainer.deepspeed:
177
- torch.cuda.synchronize()
178
- trainer.save_model(output_dir)
179
- return
180
-
181
- state_dict = trainer.model.state_dict()
182
- if trainer.args.should_save:
183
- cpu_state_dict = {
184
- key: value.cpu()
185
- for key, value in state_dict.items()
186
- }
187
- del state_dict
188
- trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
189
-
190
-
191
- def smart_tokenizer_and_embedding_resize(
192
- special_tokens_dict: Dict,
193
- tokenizer: transformers.PreTrainedTokenizer,
194
- model: transformers.PreTrainedModel,
195
- ):
196
- """Resize tokenizer and embedding.
197
-
198
- Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
199
- """
200
- num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
201
- model.resize_token_embeddings(len(tokenizer))
202
-
203
- if num_new_tokens > 0:
204
- input_embeddings = model.get_input_embeddings().weight.data
205
- output_embeddings = model.get_output_embeddings().weight.data
206
-
207
- input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
208
- dim=0, keepdim=True)
209
- output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
210
- dim=0, keepdim=True)
211
-
212
- input_embeddings[-num_new_tokens:] = input_embeddings_avg
213
- output_embeddings[-num_new_tokens:] = output_embeddings_avg
214
-
215
-
216
- def _tokenize_fn(strings: Sequence[str],
217
- tokenizer: transformers.PreTrainedTokenizer) -> Dict:
218
- """Tokenize a list of strings."""
219
- tokenized_list = [
220
- tokenizer(
221
- text,
222
- return_tensors="pt",
223
- padding="longest",
224
- max_length=tokenizer.model_max_length,
225
- truncation=True,
226
- ) for text in strings
227
- ]
228
- input_ids = labels = [
229
- tokenized.input_ids[0] for tokenized in tokenized_list
230
- ]
231
- input_ids_lens = labels_lens = [
232
- tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
233
- for tokenized in tokenized_list
234
- ]
235
- return dict(
236
- input_ids=input_ids,
237
- labels=labels,
238
- input_ids_lens=input_ids_lens,
239
- labels_lens=labels_lens,
240
- )
241
-
242
-
243
- def _mask_targets(target, tokenized_lens, speakers):
244
- # cur_idx = 0
245
- cur_idx = tokenized_lens[0]
246
- tokenized_lens = tokenized_lens[1:]
247
- target[:cur_idx] = IGNORE_INDEX
248
- for tokenized_len, speaker in zip(tokenized_lens, speakers):
249
- if speaker == "human":
250
- target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX
251
- cur_idx += tokenized_len
252
-
253
-
254
- def _add_speaker_and_signal(header, source, get_conversation=True):
255
- """Add speaker and start/end signal on each round."""
256
- BEGIN_SIGNAL = "### "
257
- END_SIGNAL = "\n"
258
- conversation = header
259
- for sentence in source:
260
- from_str = sentence["from"]
261
- if from_str.lower() == "human":
262
- from_str = conversation_lib.default_conversation.roles[0]
263
- elif from_str.lower() == "gpt":
264
- from_str = conversation_lib.default_conversation.roles[1]
265
- else:
266
- from_str = 'unknown'
267
- sentence["value"] = (BEGIN_SIGNAL + from_str + ": " +
268
- sentence["value"] + END_SIGNAL)
269
- if get_conversation:
270
- conversation += sentence["value"]
271
- conversation += BEGIN_SIGNAL
272
- return conversation
273
-
274
-
275
- def preprocess_multimodal(
276
- sources: Sequence[str],
277
- multimodal_cfg: dict,
278
- cur_token_len: int,
279
- ) -> Dict:
280
- is_multimodal = multimodal_cfg['is_multimodal']
281
- # image_token_len = multimodal_cfg['image_token_len']
282
- image_token_len = cur_token_len
283
- if not is_multimodal:
284
- return sources
285
-
286
- for source in sources:
287
- if multimodal_cfg['sep_image_conv_front']:
288
- assert DEFAULT_IMAGE_TOKEN in source[0]['value']
289
- source[0]['value'] = source[0]['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip()
290
- source[0]['value'] = DEFAULT_IMAGE_TOKEN + conversation_lib.default_conversation.sep + conversation_lib.default_conversation.roles[0] + ": " + source[0]['value']
291
- for sentence in source:
292
- replace_token = DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
293
- if multimodal_cfg['use_im_start_end']:
294
- replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
295
- sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)
296
-
297
- return sources
298
-
299
-
300
- def preprocess_v1(
301
- sources,
302
- tokenizer: transformers.PreTrainedTokenizer,
303
- ) -> Dict:
304
- conv = conversation_lib.default_conversation.copy()
305
- roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
306
-
307
- # Apply prompt templates
308
- conversations = []
309
- for i, source in enumerate(sources):
310
- if roles[source[0]["from"]] != conv.roles[0]:
311
- # Skip the first one if it is not from human
312
- source = source[1:]
313
-
314
- conv.messages = []
315
- for j, sentence in enumerate(source):
316
- role = roles[sentence["from"]]
317
- assert role == conv.roles[j % 2], f"{i}"
318
- conv.append_message(role, sentence["value"])
319
- conversations.append(conv.get_prompt())
320
-
321
- # Tokenize conversations
322
- input_ids = tokenizer(
323
- conversations,
324
- return_tensors="pt",
325
- padding="longest",
326
- max_length=tokenizer.model_max_length,
327
- truncation=True,
328
- ).input_ids
329
- targets = input_ids.clone()
330
-
331
- assert conv.sep_style == conversation_lib.SeparatorStyle.TWO
332
-
333
- # Mask targets
334
- sep = conv.sep + conv.roles[1] + ": "
335
- for conversation, target in zip(conversations, targets):
336
- total_len = int(target.ne(tokenizer.pad_token_id).sum())
337
-
338
- rounds = conversation.split(conv.sep2)
339
- cur_len = 1
340
- target[:cur_len] = IGNORE_INDEX
341
- for i, rou in enumerate(rounds):
342
- if rou == "":
343
- break
344
-
345
- parts = rou.split(sep)
346
- if len(parts) != 2:
347
- break
348
- parts[0] += sep
349
- round_len = len(tokenizer(rou).input_ids)
350
- instruction_len = len(tokenizer(parts[0]).input_ids) - 2
351
-
352
- target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
353
-
354
- cur_len += round_len
355
- target[cur_len:] = IGNORE_INDEX
356
-
357
- if cur_len < tokenizer.model_max_length:
358
- if cur_len != total_len:
359
- target[:] = IGNORE_INDEX
360
- print(
361
- f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
362
- f" (ignored)"
363
- )
364
-
365
- return dict(
366
- input_ids=input_ids,
367
- labels=targets,
368
- )
369
-
370
- def preprocess_mpt(
371
- sources,
372
- tokenizer: transformers.PreTrainedTokenizer,
373
- ) -> Dict:
374
- conv = conversation_lib.default_conversation.copy()
375
- roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
376
-
377
- # Apply prompt templates
378
- conversations = []
379
- for i, source in enumerate(sources):
380
- if roles[source[0]["from"]] != conv.roles[0]:
381
- # Skip the first one if it is not from human
382
- source = source[1:]
383
-
384
- conv.messages = []
385
- for j, sentence in enumerate(source):
386
- role = roles[sentence["from"]]
387
- assert role == conv.roles[j % 2], f"{i}"
388
- conv.append_message(role, sentence["value"])
389
- conversations.append(conv.get_prompt())
390
-
391
- # Tokenize conversations
392
- input_ids = tokenizer(
393
- conversations,
394
- return_tensors="pt",
395
- padding="longest",
396
- max_length=tokenizer.model_max_length,
397
- truncation=True,
398
- ).input_ids
399
- targets = input_ids.clone()
400
- assert conv.sep_style == conversation_lib.SeparatorStyle.MPT
401
-
402
- # Mask targets
403
- sep = conv.sep + conv.roles[1]
404
- for conversation, target in zip(conversations, targets):
405
- total_len = int(target.ne(tokenizer.pad_token_id).sum())
406
-
407
- rounds = conversation.split(conv.sep)
408
- re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt
409
- for conv_idx in range(3, len(rounds), 2):
410
- re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2])) # user + gpt
411
- cur_len = 0
412
- target[:cur_len] = IGNORE_INDEX
413
- for i, rou in enumerate(re_rounds):
414
- if rou == "":
415
- break
416
-
417
- parts = rou.split(sep)
418
- if len(parts) != 2:
419
- break
420
- parts[0] += sep
421
- round_len = len(tokenizer(rou).input_ids) + len(tokenizer(conv.sep).input_ids)
422
- instruction_len = len(tokenizer(parts[0]).input_ids)
423
- target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
424
-
425
- cur_len += round_len
426
- target[cur_len:] = IGNORE_INDEX
427
-
428
- if cur_len < tokenizer.model_max_length:
429
- if cur_len != total_len:
430
- target[:] = IGNORE_INDEX
431
- print(
432
- f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
433
- f" (ignored)"
434
- )
435
-
436
- return dict(
437
- input_ids=input_ids,
438
- labels=targets,
439
- )
440
-
441
-
442
- def preprocess(
443
- sources: Sequence[str],
444
- tokenizer: transformers.PreTrainedTokenizer,
445
- ) -> Dict:
446
- """
447
- Given a list of sources, each is a conversation list. This transform:
448
- 1. Add signal '### ' at the beginning each sentence, with end signal '\n';
449
- 2. Concatenate conversations together;
450
- 3. Tokenize the concatenated conversation;
451
- 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
452
- """
453
- if conversation_lib.default_conversation.version == "v1":
454
- return preprocess_v1(sources, tokenizer)
455
- if conversation_lib.default_conversation.version == "mpt":
456
- return preprocess_mpt(sources, tokenizer)
457
- # add end signal and concatenate together
458
- conversations = []
459
- for source in sources:
460
- header = f"{conversation_lib.default_conversation.system}\n\n"
461
- conversation = _add_speaker_and_signal(header, source)
462
- conversations.append(conversation)
463
- # tokenize conversations
464
- conversations_tokenized = _tokenize_fn(conversations, tokenizer)
465
- input_ids = conversations_tokenized["input_ids"]
466
- targets = copy.deepcopy(input_ids)
467
- for target, source in zip(targets, sources):
468
- tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source],
469
- tokenizer)["input_ids_lens"]
470
- speakers = [sentence["from"] for sentence in source]
471
- _mask_targets(target, tokenized_lens, speakers)
472
-
473
- return dict(input_ids=input_ids, labels=targets)
474
-
475
-
476
- class SupervisedDataset(Dataset):
477
- """Dataset for supervised fine-tuning."""
478
-
479
- def __init__(self, data_path: str,
480
- tokenizer: transformers.PreTrainedTokenizer):
481
- super(SupervisedDataset, self).__init__()
482
- logging.warning("Loading data...")
483
- list_data_dict = json.load(open(data_path, "r"))
484
-
485
- logging.warning("Formatting inputs...")
486
- sources = [example["conversations"] for example in list_data_dict]
487
- data_dict = preprocess(sources, tokenizer)
488
-
489
- self.input_ids = data_dict["input_ids"]
490
- self.labels = data_dict["labels"]
491
-
492
- def __len__(self):
493
- return len(self.input_ids)
494
-
495
- def __getitem__(self, i) -> Dict[str, torch.Tensor]:
496
- return dict(input_ids=self.input_ids[i], labels=self.labels[i])
497
-
498
-
499
- class LazySupervisedDataset(Dataset):
500
-
501
- def __init__(self, data_path: str,
502
- tokenizer: transformers.PreTrainedTokenizer,
503
- multimodal_cfg: dict):
504
- super(LazySupervisedDataset, self).__init__()
505
-
506
- self.tokenizer, self.multimodal_cfg = tokenizer, multimodal_cfg
507
-
508
- self.pkl, self.prompt = pickle.load(open('./_data/ipr2pr.pkl', 'rb'))['task'], json.load(open('./_data/ipr2pr_expressive.json', 'r'))
509
- random.shuffle(self.pkl)
510
- print('--pkl: %d--'%(len(self.pkl)))
511
-
512
- def __len__(self):
513
- return len(self.pkl)
514
-
515
- def __getitem__(self, i) -> Dict[str, torch.Tensor]:
516
- item = self.pkl[i][0]
517
-
518
- tsv = open('./_data/ipr2pr.tsv', 'r')
519
- tsv.seek(item['lineidx'])
520
- b = tsv.readline().strip().split('\t')
521
- image = resize(b2f(b[0]))
522
-
523
- processor = self.multimodal_cfg['image_processor']
524
- image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
525
-
526
- cur_token_len = (image.shape[1]//14)*(image.shape[2]//14)
527
- query = "what will this image be like if '%s'\n%s"%(item['instruction'], DEFAULT_IMAGE_TOKEN)
528
- ans = '%s [IMG0] [IMG1] [IMG2] [IMG3] [IMG4] [IMG5] [IMG6] [IMG7]'%(self.prompt[item['input']]['expressive'])
529
- sources = preprocess_multimodal(copy.deepcopy([[{'from': 'human', 'value': query}, {'from': 'gpt', 'value': ans}]]),
530
- self.multimodal_cfg, cur_token_len)
531
-
532
- data_dict = preprocess(sources, self.tokenizer)
533
- if isinstance(i, int): data_dict = dict(input_ids=data_dict['input_ids'][0],
534
- labels=data_dict['labels'][0])
535
- data_dict['image'] = image
536
-
537
- p2p_inp, p2p_ans = img2npy(resize(b2f(b[0])).resize([256, 256])), img2npy(resize(b2f(b[1])).resize([256, 256]))
538
- data_dict['p2p_inp'], data_dict['p2p_ans'] = p2p_inp, p2p_ans
539
-
540
- return data_dict
541
-
542
-
543
- @dataclass
544
- class DataCollatorForSupervisedDataset(object):
545
- """Collate examples for supervised fine-tuning."""
546
-
547
- tokenizer: transformers.PreTrainedTokenizer
548
-
549
- def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
550
- input_ids, labels = tuple([instance[key] for instance in instances]
551
- for key in ("input_ids", "labels"))
552
- input_ids = torch.nn.utils.rnn.pad_sequence(
553
- input_ids,
554
- batch_first=True,
555
- padding_value=self.tokenizer.pad_token_id)
556
- labels = torch.nn.utils.rnn.pad_sequence(labels,
557
- batch_first=True,
558
- padding_value=IGNORE_INDEX)
559
- batch = dict(
560
- input_ids=input_ids,
561
- labels=labels,
562
- attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
563
- )
564
-
565
- if 'image' in instances[0]:
566
- images = [instance['image'] for instance in instances]
567
- if all(x is not None and x.shape == images[0].shape for x in images):
568
- batch['images'] = torch.stack(images)
569
- else:
570
- batch['images'] = images
571
-
572
- batch['p2p_inp'], batch['p2p_ans'] = [torch.cat([torch.from_numpy(d['p2p_inp']).unsqueeze(dim=0) for d in instances], dim=0),
573
- torch.cat([torch.from_numpy(d['p2p_ans']).unsqueeze(dim=0) for d in instances], dim=0)]
574
-
575
- return batch
576
-
577
-
578
- def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
579
- data_args) -> Dict:
580
- """Make dataset and collator for supervised fine-tuning."""
581
- dataset_cls = (LazySupervisedDataset
582
- if data_args.lazy_preprocess else SupervisedDataset)
583
- train_dataset = dataset_cls(tokenizer=tokenizer,
584
- data_path=data_args.data_path,
585
- multimodal_cfg=dict(
586
- is_multimodal=data_args.is_multimodal,
587
- sep_image_conv_front=data_args.sep_image_conv_front,
588
- image_token_len=data_args.image_token_len,
589
- image_folder=data_args.image_folder,
590
- image_aspect_ratio=data_args.image_aspect_ratio,
591
- use_im_start_end=getattr(data_args, 'mm_use_im_start_end', False),
592
- image_processor=getattr(data_args, 'image_processor', None)))
593
- data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
594
- return dict(train_dataset=train_dataset,
595
- eval_dataset=None,
596
- data_collator=data_collator)
597
-
598
-
599
- def train():
600
- parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
601
- model_args, data_args, training_args = parser.parse_args_into_dataclasses()
602
- compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
603
-
604
- bnb_model_from_pretrained_args = {}
605
- if training_args.bits in [4, 8]:
606
- from transformers import BitsAndBytesConfig
607
- from peft import prepare_model_for_int8_training
608
- bnb_model_from_pretrained_args.update(dict(
609
- device_map={"": training_args.device},
610
- load_in_4bit=training_args.bits == 4,
611
- load_in_8bit=training_args.bits == 8,
612
- quantization_config=BitsAndBytesConfig(
613
- load_in_4bit=training_args.bits == 4,
614
- load_in_8bit=training_args.bits == 8,
615
- llm_int8_threshold=6.0,
616
- llm_int8_has_fp16_weight=False,
617
- bnb_4bit_compute_dtype=compute_dtype,
618
- bnb_4bit_use_double_quant=training_args.double_quant,
619
- bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'}
620
- )
621
- ))
622
-
623
- if model_args.vision_tower is not None:
624
- if 'mpt' in model_args.model_name_or_path:
625
- model = LlavaMPTForCausalLM.from_pretrained(
626
- model_args.model_name_or_path,
627
- cache_dir=training_args.cache_dir,
628
- **bnb_model_from_pretrained_args
629
- )
630
- else:
631
- model = LlavaLlamaForCausalLM.from_pretrained(
632
- model_args.model_name_or_path,
633
- cache_dir=training_args.cache_dir,
634
- **bnb_model_from_pretrained_args
635
- )
636
- else:
637
- model = transformers.LlamaForCausalLM.from_pretrained(
638
- model_args.model_name_or_path,
639
- cache_dir=training_args.cache_dir,
640
- **bnb_model_from_pretrained_args
641
- )
642
- model.config.use_cache = False
643
-
644
- if model_args.freeze_backbone:
645
- model.model.requires_grad_(False)
646
-
647
- if training_args.bits in [4, 8]:
648
- model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
649
- model = prepare_model_for_int8_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)
650
-
651
- if training_args.gradient_checkpointing and model_args.vision_tower is None:
652
- if hasattr(model, "enable_input_require_grads"):
653
- model.enable_input_require_grads()
654
- else:
655
- def make_inputs_require_grad(module, input, output):
656
- output.requires_grad_(True)
657
- model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
658
-
659
- if training_args.lora_enable:
660
- from peft import LoraConfig, get_peft_model
661
- lora_config = LoraConfig(
662
- r=training_args.lora_r,
663
- lora_alpha=training_args.lora_alpha,
664
- target_modules=find_all_linear_names(model),
665
- lora_dropout=training_args.lora_dropout,
666
- bias=training_args.lora_bias,
667
- task_type="CAUSAL_LM",
668
- )
669
- if training_args.bits == 16:
670
- if training_args.bf16:
671
- model.to(torch.bfloat16)
672
- if training_args.fp16:
673
- model.to(torch.float16)
674
- logging.warning("Adding LoRA adapters...")
675
- model = get_peft_model(model, lora_config)
676
-
677
- if 'mpt' in model_args.model_name_or_path:
678
- tokenizer = transformers.AutoTokenizer.from_pretrained(
679
- model_args.model_name_or_path,
680
- cache_dir=training_args.cache_dir,
681
- model_max_length=training_args.model_max_length,
682
- padding_side="right"
683
- )
684
- else:
685
- tokenizer = transformers.AutoTokenizer.from_pretrained(
686
- model_args.model_name_or_path,
687
- cache_dir=training_args.cache_dir,
688
- model_max_length=training_args.model_max_length,
689
- padding_side="right",
690
- use_fast=False,
691
- )
692
-
693
- if model_args.version == "v0":
694
- if tokenizer.pad_token is None:
695
- smart_tokenizer_and_embedding_resize(
696
- special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
697
- tokenizer=tokenizer,
698
- model=model,
699
- )
700
- if "llama" in model_args.model_name_or_path:
701
- tokenizer.add_special_tokens({
702
- "eos_token": DEFAULT_EOS_TOKEN,
703
- "bos_token": DEFAULT_BOS_TOKEN,
704
- "unk_token": DEFAULT_UNK_TOKEN,
705
- })
706
- else:
707
- tokenizer.pad_token = tokenizer.unk_token
708
- if "mpt" in model_args.model_name_or_path:
709
- conversation_lib.default_conversation = conversation_lib.conv_templates["mpt"]
710
- else:
711
- conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1_1"]
712
-
713
- if model_args.vision_tower is not None:
714
- model_vision_dict = model.get_model().initialize_vision_modules(
715
- vision_tower=model_args.vision_tower,
716
- mm_vision_select_layer=model_args.mm_vision_select_layer,
717
- pretrain_mm_mlp_adapter=model_args.pretrain_mm_mlp_adapter,
718
- fsdp=training_args.fsdp
719
- )
720
- model.get_vision_tower().to(dtype=torch.float16, device=training_args.device)
721
- vision_config = model_vision_dict['vision_config']
722
-
723
- data_args.image_token_len = model_vision_dict['image_token_len']
724
- data_args.image_processor = model_vision_dict['image_processor']
725
- data_args.is_multimodal = True
726
-
727
- model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter
728
- if model_args.tune_mm_mlp_adapter:
729
- model.requires_grad_(False)
730
- for p in model.get_model().mm_projector.parameters():
731
- p.requires_grad = True
732
-
733
- model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter
734
- if training_args.freeze_mm_mlp_adapter:
735
- for p in model.get_model().mm_projector.parameters():
736
- p.requires_grad = False
737
-
738
- if training_args.bits in [4, 8]:
739
- model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device)
740
-
741
- model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end
742
- vision_config.use_im_start_end = training_args.use_im_start_end = model_args.mm_use_im_start_end
743
- model.config.sep_image_conv_front = data_args.sep_image_conv_front
744
- model.initialize_vision_tokenizer(mm_use_im_start_end=model_args.mm_use_im_start_end, tokenizer=tokenizer, device=training_args.device,
745
- tune_mm_mlp_adapter=model_args.tune_mm_mlp_adapter, pretrain_mm_mlp_adapter=model_args.pretrain_mm_mlp_adapter)
746
-
747
- params_no_grad = [n for n, p in model.named_parameters() if not p.requires_grad]
748
- if len(params_no_grad) > 0:
749
- if training_args.fsdp is not None and len(training_args.fsdp) > 0:
750
- if len(params_no_grad) < 10:
751
- print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}'. format(len(params_no_grad), params_no_grad))
752
- else:
753
- print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}...(omitted)'. format(len(params_no_grad), ', '.join(params_no_grad[:10])))
754
- print("[WARNING] Attempting to use FSDP with partially frozen paramters, this is experimental.")
755
- print("[WARNING] As of 4/30/23, this feature requires PyTorch-nightly build. See here for details: https://github.com/haotian-liu/LLaVA#experimental-use-fsdp-to-save-memory-in-pretraining")
756
-
757
- from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
758
- def patch_FSDP_use_orig_params(func):
759
- def wrap_func(*args, **kwargs):
760
- use_orig_params = kwargs.pop('use_orig_params', True)
761
- return func(*args, **kwargs, use_orig_params=use_orig_params)
762
- return wrap_func
763
-
764
- FSDP.__init__ = patch_FSDP_use_orig_params(FSDP.__init__)
765
-
766
- if training_args.bits in [4, 8]:
767
- from peft.tuners.lora import LoraLayer
768
- for name, module in model.named_modules():
769
- if isinstance(module, LoraLayer):
770
- if training_args.bf16:
771
- module = module.to(torch.bfloat16)
772
- if 'norm' in name:
773
- module = module.to(torch.float32)
774
- if 'lm_head' in name or 'embed_tokens' in name:
775
- if hasattr(module, 'weight'):
776
- if training_args.bf16 and module.weight.dtype == torch.float32:
777
- module = module.to(torch.bfloat16)
778
-
779
- # start for MGIE
780
- os.makedirs('_log', exist_ok=True)
781
-
782
- pt = {}
783
- for i in tqdm(range(2)): pt.update(torch.load('./_ckpt/LLaVA-7B-v1/pytorch_model-0000%d-of-00002.bin'%(i+1), map_location='cpu'))
784
- miss, unexp = model.load_state_dict(pt, strict=False)
785
- print('miss:', miss), print('unexp:', unexp)
786
-
787
- tokenizer.add_tokens(['[IMG0]', '[IMG1]', '[IMG2]', '[IMG3]', '[IMG4]', '[IMG5]', '[IMG6]', '[IMG7]'], special_tokens=True)
788
- model.resize_token_embeddings(len(tokenizer))
789
- print(tokenizer), json.dump(tokenizer.get_vocab(), open('_log/vocabs.json', 'w'), indent=2)
790
-
791
- for n, p in model.named_parameters():
792
- if 'embed_tokens' in n or 'lm_head' in n or 'edit_head' in n or 'unet' in n: p.requires_grad = True
793
- else: p.requires_grad = False
794
- with open('_log/parameters.txt', 'w') as F:
795
- for n, p in model.named_parameters(): F.write('%s %s %s\n'%(n, str(p.shape), str(p.requires_grad)))
796
-
797
- with open('_log/args_train.txt', 'w') as F:
798
- for key in vars(training_args): F.write('%s: %s\n'%(str(key), str(vars(training_args)[key])))
799
- # end for MGIE
800
-
801
- data_module = make_supervised_data_module(tokenizer=tokenizer,
802
- data_args=data_args)
803
- trainer = LLaVATrainer(model=model,
804
- tokenizer=tokenizer,
805
- args=training_args,
806
- **data_module)
807
-
808
- if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
809
- trainer.train(resume_from_checkpoint=True)
810
- else:
811
- trainer.train()
812
- trainer.save_state()
813
-
814
- if training_args.lora_enable:
815
- state_dict = get_peft_state_maybe_zero_3(
816
- model.named_parameters(), training_args.lora_bias
817
- )
818
- non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(
819
- model.named_parameters()
820
- )
821
- if training_args.local_rank == 0 or training_args.local_rank == -1:
822
- model.config.save_pretrained(training_args.output_dir)
823
- model.save_pretrained(training_args.output_dir, state_dict=state_dict)
824
- torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin'))
825
- else:
826
- safe_save_model_for_hf_trainer(trainer=trainer,
827
- output_dir=training_args.output_dir)
828
-
829
-
830
- if __name__ == "__main__":
831
- train()