Spaces:
Running
on
Zero
Running
on
Zero
update space
Browse files- oryx/__init__.py +1 -0
- oryx/constants.py +12 -0
- oryx/conversation.py +559 -0
- oryx/mm_utils.py +235 -0
- oryx/model/__init__.py +14 -0
- oryx/model/builder.py +89 -0
- oryx/model/language_model/oryx_llama.py +202 -0
- oryx/model/language_model/oryx_qwen.py +182 -0
- oryx/model/multimodal_encoder/builder.py +12 -0
- oryx/model/multimodal_encoder/oryx_vit.py +844 -0
- oryx/model/multimodal_projector/builder.py +155 -0
- oryx/model/multimodal_resampler/builder.py +36 -0
- oryx/model/multimodal_resampler/masked_drop.py +82 -0
- oryx/model/multimodal_resampler/perceiver.py +70 -0
- oryx/model/multimodal_resampler/qformer.py +1287 -0
- oryx/model/multimodal_resampler/spatial_pool.py +42 -0
- oryx/model/multimodal_resampler/vlm_attention.py +337 -0
- oryx/model/oryx_arch.py +338 -0
- oryx/train/llama_flash_attn_monkey_patch.py +116 -0
- oryx/train/oryx_trainer.py +435 -0
- oryx/train/train.py +1686 -0
- oryx/train/train_mem.py +5 -0
- oryx/utils.py +134 -0
oryx/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .model import OryxLlamaForCausalLM
|
oryx/constants.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CONTROLLER_HEART_BEAT_EXPIRATION = 30
|
2 |
+
WORKER_HEART_BEAT_INTERVAL = 15
|
3 |
+
|
4 |
+
LOGDIR = "."
|
5 |
+
|
6 |
+
# Model Constants
|
7 |
+
IGNORE_INDEX = -100
|
8 |
+
IMAGE_TOKEN_INDEX = -200
|
9 |
+
DEFAULT_IMAGE_TOKEN = "<image>"
|
10 |
+
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
11 |
+
DEFAULT_IM_START_TOKEN = "<im_start>"
|
12 |
+
DEFAULT_IM_END_TOKEN = "<im_end>"
|
oryx/conversation.py
ADDED
@@ -0,0 +1,559 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
from enum import auto, Enum
|
3 |
+
from typing import List, Any, Dict, Union, Tuple
|
4 |
+
import re
|
5 |
+
import base64
|
6 |
+
from io import BytesIO
|
7 |
+
from PIL import Image
|
8 |
+
from transformers import AutoTokenizer
|
9 |
+
|
10 |
+
import os
|
11 |
+
if 'EVALUATION' in os.environ:
|
12 |
+
# highresxpatch
|
13 |
+
EVALUATION = True
|
14 |
+
print(f"EVALUATION is set")
|
15 |
+
else:
|
16 |
+
EVALUATION = False
|
17 |
+
|
18 |
+
class SeparatorStyle(Enum):
|
19 |
+
"""Different separator style."""
|
20 |
+
|
21 |
+
SINGLE = auto()
|
22 |
+
TWO = auto()
|
23 |
+
MPT = auto()
|
24 |
+
PLAIN = auto()
|
25 |
+
CHATML = auto()
|
26 |
+
LLAMA_2 = auto()
|
27 |
+
LLAMA_3 = auto()
|
28 |
+
QWEN2 = auto()
|
29 |
+
QWEN = auto()
|
30 |
+
|
31 |
+
|
32 |
+
@dataclasses.dataclass
|
33 |
+
class Conversation:
|
34 |
+
"""A class that keeps all conversation history."""
|
35 |
+
|
36 |
+
system: str
|
37 |
+
roles: List[str]
|
38 |
+
messages: List[List[str]]
|
39 |
+
offset: int
|
40 |
+
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
41 |
+
sep: str = "###"
|
42 |
+
sep2: str = None
|
43 |
+
version: str = "Unknown"
|
44 |
+
|
45 |
+
tokenizer_id: str = ""
|
46 |
+
tokenizer: Any = None
|
47 |
+
# Stop criteria (the default one is EOS token)
|
48 |
+
stop_str: Union[str, List[str]] = None
|
49 |
+
# Stops generation if meeting any token in this list
|
50 |
+
stop_token_ids: List[int] = None
|
51 |
+
|
52 |
+
skip_next: bool = False
|
53 |
+
|
54 |
+
def get_prompt(self):
|
55 |
+
messages = self.messages
|
56 |
+
if len(messages) > 0 and type(messages[0][1]) is tuple:
|
57 |
+
messages = self.messages.copy()
|
58 |
+
init_role, init_msg = messages[0].copy()
|
59 |
+
init_msg = init_msg[0]
|
60 |
+
if "mmtag" in self.version:
|
61 |
+
init_msg = init_msg.replace("<image>", "").strip()
|
62 |
+
messages[0] = (init_role, init_msg)
|
63 |
+
messages.insert(0, (self.roles[0], "<Image><image></Image>"))
|
64 |
+
messages.insert(1, (self.roles[1], "Received."))
|
65 |
+
elif not init_msg.startswith("<image>"):
|
66 |
+
init_msg = init_msg.replace("<image>", "").strip()
|
67 |
+
messages[0] = (init_role, "<image>\n" + init_msg)
|
68 |
+
else:
|
69 |
+
messages[0] = (init_role, init_msg)
|
70 |
+
|
71 |
+
if self.sep_style == SeparatorStyle.SINGLE:
|
72 |
+
ret = self.system + self.sep
|
73 |
+
for role, message in messages:
|
74 |
+
if message:
|
75 |
+
if type(message) is tuple:
|
76 |
+
message, _, _ = message
|
77 |
+
ret += role + ": " + message + self.sep
|
78 |
+
else:
|
79 |
+
ret += role + ":"
|
80 |
+
|
81 |
+
elif self.sep_style == SeparatorStyle.TWO:
|
82 |
+
seps = [self.sep, self.sep2]
|
83 |
+
ret = self.system + seps[0]
|
84 |
+
for i, (role, message) in enumerate(messages):
|
85 |
+
if message:
|
86 |
+
if type(message) is tuple:
|
87 |
+
message, _, _ = message
|
88 |
+
ret += role + ": " + message + seps[i % 2]
|
89 |
+
else:
|
90 |
+
ret += role + ":"
|
91 |
+
|
92 |
+
elif self.sep_style == SeparatorStyle.CHATML:
|
93 |
+
ret = "" if self.system == "" else self.system + self.sep + "\n"
|
94 |
+
for role, message in messages:
|
95 |
+
if message:
|
96 |
+
if type(message) is tuple:
|
97 |
+
message, images = message
|
98 |
+
message = "<image>" * len(images) + message
|
99 |
+
ret += role + "\n" + message + self.sep + "\n"
|
100 |
+
else:
|
101 |
+
ret += role + "\n"
|
102 |
+
return ret
|
103 |
+
|
104 |
+
elif self.sep_style == SeparatorStyle.LLAMA_3:
|
105 |
+
chat_template_messages = [{"role": "system", "content": self.system}]
|
106 |
+
for role, message in messages:
|
107 |
+
if message:
|
108 |
+
if type(message) is tuple:
|
109 |
+
message, images = message
|
110 |
+
message = "<image>" * len(images) + message
|
111 |
+
chat_template_messages.append({"role": role, "content": message})
|
112 |
+
if EVALUATION:
|
113 |
+
return self.tokenizer.apply_chat_template(chat_template_messages, tokenize=False, add_generation_prompt=True)
|
114 |
+
else:
|
115 |
+
return self.tokenizer.apply_chat_template(chat_template_messages, tokenize=False, add_generation_prompt=False)
|
116 |
+
|
117 |
+
elif self.sep_style == SeparatorStyle.MPT:
|
118 |
+
ret = self.system + self.sep
|
119 |
+
for role, message in messages:
|
120 |
+
if message:
|
121 |
+
if type(message) is tuple:
|
122 |
+
message, _, _ = message
|
123 |
+
ret += role + message + self.sep
|
124 |
+
else:
|
125 |
+
ret += role
|
126 |
+
|
127 |
+
elif self.sep_style == SeparatorStyle.LLAMA_2:
|
128 |
+
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
|
129 |
+
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
|
130 |
+
ret = ""
|
131 |
+
|
132 |
+
for i, (role, message) in enumerate(messages):
|
133 |
+
if i == 0:
|
134 |
+
assert message, "first message should not be none"
|
135 |
+
assert role == self.roles[0], "first message should come from user"
|
136 |
+
if message:
|
137 |
+
if type(message) is tuple:
|
138 |
+
message, _, _ = message
|
139 |
+
if i == 0:
|
140 |
+
message = wrap_sys(self.system) + message
|
141 |
+
if i % 2 == 0:
|
142 |
+
message = wrap_inst(message)
|
143 |
+
ret += self.sep + message
|
144 |
+
else:
|
145 |
+
ret += " " + message + " " + self.sep2
|
146 |
+
else:
|
147 |
+
ret += ""
|
148 |
+
ret = ret.lstrip(self.sep)
|
149 |
+
|
150 |
+
elif self.sep_style == SeparatorStyle.PLAIN:
|
151 |
+
seps = [self.sep, self.sep2]
|
152 |
+
ret = self.system
|
153 |
+
for i, (role, message) in enumerate(messages):
|
154 |
+
if message:
|
155 |
+
if type(message) is tuple:
|
156 |
+
message, _, _ = message
|
157 |
+
ret += message + seps[i % 2]
|
158 |
+
else:
|
159 |
+
ret += ""
|
160 |
+
elif self.sep_style == SeparatorStyle.QWEN2:
|
161 |
+
start = '<|im_start|>'
|
162 |
+
end = '<|im_end|>\n'
|
163 |
+
ret = start + 'system\n' + self.system + end
|
164 |
+
for i, (role, message) in enumerate(messages):
|
165 |
+
if message:
|
166 |
+
if type(message) is tuple:
|
167 |
+
message, _, _ = message
|
168 |
+
|
169 |
+
if message.endswith('<|endoftext|>'):
|
170 |
+
message = message.replace('<|endoftext|>', '')
|
171 |
+
ret += start + role + "\n" + message + end + '<|endoftext|>'
|
172 |
+
else:
|
173 |
+
assert not '<|endoftext|>' in message, f"Invalid message: {message}"
|
174 |
+
ret += start + role + "\n" + message + end
|
175 |
+
else:
|
176 |
+
ret += start + role + "\n"
|
177 |
+
else:
|
178 |
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
179 |
+
|
180 |
+
return ret
|
181 |
+
|
182 |
+
def append_message(self, role, message):
|
183 |
+
self.messages.append([role, message])
|
184 |
+
|
185 |
+
def process_image(self, image, image_process_mode, return_pil=False, image_format="PNG"):
|
186 |
+
if image_process_mode == "Pad":
|
187 |
+
|
188 |
+
def expand2square(pil_img, background_color=(122, 116, 104)):
|
189 |
+
width, height = pil_img.size
|
190 |
+
if width == height:
|
191 |
+
return pil_img
|
192 |
+
elif width > height:
|
193 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
194 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
195 |
+
return result
|
196 |
+
else:
|
197 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
198 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
199 |
+
return result
|
200 |
+
|
201 |
+
image = expand2square(image)
|
202 |
+
elif image_process_mode in ["Default", "Crop"]:
|
203 |
+
pass
|
204 |
+
elif image_process_mode == "Resize":
|
205 |
+
image = image.resize((336, 336))
|
206 |
+
else:
|
207 |
+
raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
|
208 |
+
|
209 |
+
if type(image) is not Image.Image:
|
210 |
+
image = Image.open(image).convert("RGB")
|
211 |
+
|
212 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
213 |
+
aspect_ratio = max_hw / min_hw
|
214 |
+
max_len, min_len = 672, 448
|
215 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
216 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
217 |
+
W, H = image.size
|
218 |
+
if H > W:
|
219 |
+
H, W = longest_edge, shortest_edge
|
220 |
+
else:
|
221 |
+
H, W = shortest_edge, longest_edge
|
222 |
+
image = image.resize((W, H))
|
223 |
+
if return_pil:
|
224 |
+
return image
|
225 |
+
else:
|
226 |
+
buffered = BytesIO()
|
227 |
+
image.save(buffered, format=image_format)
|
228 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
229 |
+
return img_b64_str
|
230 |
+
|
231 |
+
def get_images(self, return_pil=False, return_path=False):
|
232 |
+
images = []
|
233 |
+
for i, (role, msg) in enumerate(self.messages[self.offset :]):
|
234 |
+
if i % 2 == 0:
|
235 |
+
if type(msg) is tuple:
|
236 |
+
msg, image, image_process_mode = msg
|
237 |
+
if type(image) != list:
|
238 |
+
image = [image]
|
239 |
+
for img in image:
|
240 |
+
if not return_path:
|
241 |
+
img = self.process_image(img, image_process_mode, return_pil=return_pil)
|
242 |
+
else:
|
243 |
+
images.append(img)
|
244 |
+
return images
|
245 |
+
|
246 |
+
def to_gradio_chatbot(self):
|
247 |
+
ret = []
|
248 |
+
for i, (role, msg) in enumerate(self.messages[self.offset :]):
|
249 |
+
if i % 2 == 0:
|
250 |
+
if type(msg) is tuple:
|
251 |
+
msg, image, image_process_mode = msg
|
252 |
+
if type(image) != list:
|
253 |
+
image = [image]
|
254 |
+
if len(image) == 1:
|
255 |
+
msg = "<image>\n" + msg.replace("<image>", "").strip()
|
256 |
+
else:
|
257 |
+
msg = re.sub(r"(<image>)\n(?=<image>)", r"\1 ", msg)
|
258 |
+
for img in image:
|
259 |
+
img_b64_str = self.process_image(img, "Default", return_pil=False, image_format="JPEG")
|
260 |
+
img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}"/>'
|
261 |
+
msg = msg.replace("<image>", img_str, 1).strip()
|
262 |
+
if len(msg) > 0:
|
263 |
+
ret.append([msg, None])
|
264 |
+
else:
|
265 |
+
ret.append([msg, None])
|
266 |
+
else:
|
267 |
+
ret[-1][-1] = msg
|
268 |
+
return ret
|
269 |
+
|
270 |
+
def copy(self):
|
271 |
+
return Conversation(system=self.system, roles=self.roles, messages=[[x, y] for x, y in self.messages], offset=self.offset, sep_style=self.sep_style, sep=self.sep, sep2=self.sep2, version=self.version)
|
272 |
+
|
273 |
+
def dict(self):
|
274 |
+
if len(self.get_images()) > 0:
|
275 |
+
return {
|
276 |
+
"system": self.system,
|
277 |
+
"roles": self.roles,
|
278 |
+
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
|
279 |
+
"offset": self.offset,
|
280 |
+
"sep": self.sep,
|
281 |
+
"sep2": self.sep2,
|
282 |
+
}
|
283 |
+
return {
|
284 |
+
"system": self.system,
|
285 |
+
"roles": self.roles,
|
286 |
+
"messages": self.messages,
|
287 |
+
"offset": self.offset,
|
288 |
+
"sep": self.sep,
|
289 |
+
"sep2": self.sep2,
|
290 |
+
}
|
291 |
+
|
292 |
+
|
293 |
+
conv_vicuna_v0 = Conversation(
|
294 |
+
system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
295 |
+
roles=("Human", "Assistant"),
|
296 |
+
messages=[
|
297 |
+
["Human", "What are the key differences between renewable and non-renewable energy sources?"],
|
298 |
+
[
|
299 |
+
"Assistant",
|
300 |
+
"Renewable energy sources are those that can be replenished naturally in a relatively "
|
301 |
+
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
|
302 |
+
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
|
303 |
+
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
|
304 |
+
"renewable and non-renewable energy sources:\n"
|
305 |
+
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
|
306 |
+
"energy sources are finite and will eventually run out.\n"
|
307 |
+
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
|
308 |
+
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
|
309 |
+
"and other negative effects.\n"
|
310 |
+
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
|
311 |
+
"have lower operational costs than non-renewable sources.\n"
|
312 |
+
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
|
313 |
+
"locations than non-renewable sources.\n"
|
314 |
+
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
|
315 |
+
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
|
316 |
+
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
|
317 |
+
"non-renewable sources are not, and their depletion can lead to economic and social instability.\n",
|
318 |
+
],
|
319 |
+
],
|
320 |
+
offset=2,
|
321 |
+
sep_style=SeparatorStyle.SINGLE,
|
322 |
+
sep="###",
|
323 |
+
)
|
324 |
+
|
325 |
+
conv_vicuna_v1 = Conversation(
|
326 |
+
system="A chat between a curious user and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
327 |
+
roles=("USER", "ASSISTANT"),
|
328 |
+
version="v1",
|
329 |
+
messages=[],
|
330 |
+
offset=0,
|
331 |
+
sep_style=SeparatorStyle.TWO,
|
332 |
+
sep=" ",
|
333 |
+
sep2="</s>",
|
334 |
+
)
|
335 |
+
|
336 |
+
conv_qwen_v1 = Conversation(
|
337 |
+
system="You are a helpful assistant.",
|
338 |
+
roles=("user", "assistant"),
|
339 |
+
version="v1",
|
340 |
+
messages=(),
|
341 |
+
offset=0,
|
342 |
+
sep_style=SeparatorStyle.QWEN2,
|
343 |
+
)
|
344 |
+
|
345 |
+
conv_llama_2 = Conversation(
|
346 |
+
system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
|
347 |
+
|
348 |
+
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
|
349 |
+
roles=("USER", "ASSISTANT"),
|
350 |
+
version="llama_v2",
|
351 |
+
messages=[],
|
352 |
+
offset=0,
|
353 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
354 |
+
sep="<s>",
|
355 |
+
sep2="</s>",
|
356 |
+
)
|
357 |
+
|
358 |
+
conv_llava_llama_2 = Conversation(
|
359 |
+
system="You are a helpful language and vision assistant. " "You are able to understand the visual content that the user provides, " "and assist the user with a variety of tasks using natural language.",
|
360 |
+
roles=("USER", "ASSISTANT"),
|
361 |
+
version="llama_v2",
|
362 |
+
messages=[],
|
363 |
+
offset=0,
|
364 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
365 |
+
sep="<s>",
|
366 |
+
sep2="</s>",
|
367 |
+
)
|
368 |
+
|
369 |
+
conv_llava_llama_3 = Conversation(
|
370 |
+
system="You are a helpful language and vision assistant. " "You are able to understand the visual content that the user provides, " "and assist the user with a variety of tasks using natural language.",
|
371 |
+
roles=("user", "assistant"),
|
372 |
+
version="llama_v3",
|
373 |
+
messages=[],
|
374 |
+
offset=0,
|
375 |
+
sep_style=SeparatorStyle.LLAMA_3,
|
376 |
+
tokenizer=AutoTokenizer.from_pretrained("/apdcephfs_jn/share_302244400/peterrao/nj3/models/Llama-3-8B-Instruct"),
|
377 |
+
stop_token_ids=[128009],
|
378 |
+
)
|
379 |
+
|
380 |
+
conv_mistral_instruct = Conversation(
|
381 |
+
system="",
|
382 |
+
roles=("USER", "ASSISTANT"),
|
383 |
+
version="llama_v2",
|
384 |
+
messages=[],
|
385 |
+
offset=0,
|
386 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
387 |
+
sep="",
|
388 |
+
sep2="</s>",
|
389 |
+
)
|
390 |
+
|
391 |
+
conv_llava_llama_2_simple = Conversation(
|
392 |
+
system="Answer the questions about the visual content that the user provides.",
|
393 |
+
roles=("USER", "ASSISTANT"),
|
394 |
+
version="llama_v2",
|
395 |
+
messages=[],
|
396 |
+
offset=0,
|
397 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
398 |
+
sep="<s>",
|
399 |
+
sep2="</s>",
|
400 |
+
)
|
401 |
+
|
402 |
+
conv_llava_llama_2_mmtag = Conversation(
|
403 |
+
system="Answer the questions about the visual content that the user provides." "The visual content will be provided with the following format: <Image>visual content</Image>.",
|
404 |
+
roles=("USER", "ASSISTANT"),
|
405 |
+
version="llama_v2_mmtag",
|
406 |
+
messages=[],
|
407 |
+
offset=0,
|
408 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
409 |
+
sep="<s>",
|
410 |
+
sep2="</s>",
|
411 |
+
)
|
412 |
+
|
413 |
+
conv_mpt = Conversation(
|
414 |
+
system="""<|im_start|>system
|
415 |
+
A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
|
416 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
417 |
+
version="mpt",
|
418 |
+
messages=[],
|
419 |
+
offset=0,
|
420 |
+
sep_style=SeparatorStyle.MPT,
|
421 |
+
sep="<|im_end|>",
|
422 |
+
)
|
423 |
+
|
424 |
+
conv_qwen = Conversation(
|
425 |
+
system="""<|im_start|>system
|
426 |
+
You are a helpful assistant.""",
|
427 |
+
roles=("<|im_start|>user", "<|im_start|>assistant"),
|
428 |
+
version="qwen",
|
429 |
+
messages=[],
|
430 |
+
offset=0,
|
431 |
+
sep_style=SeparatorStyle.CHATML,
|
432 |
+
sep="<|im_end|>",
|
433 |
+
)
|
434 |
+
|
435 |
+
conv_llava_plain = Conversation(
|
436 |
+
system="",
|
437 |
+
roles=("", ""),
|
438 |
+
messages=[],
|
439 |
+
offset=0,
|
440 |
+
sep_style=SeparatorStyle.PLAIN,
|
441 |
+
sep="\n",
|
442 |
+
)
|
443 |
+
|
444 |
+
conv_llava_v0 = Conversation(
|
445 |
+
system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
446 |
+
roles=("Human", "Assistant"),
|
447 |
+
messages=[],
|
448 |
+
offset=0,
|
449 |
+
sep_style=SeparatorStyle.SINGLE,
|
450 |
+
sep="###",
|
451 |
+
)
|
452 |
+
|
453 |
+
conv_llava_v0_mmtag = Conversation(
|
454 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
455 |
+
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
456 |
+
"The visual content will be provided with the following format: <Image>visual content</Image>.",
|
457 |
+
roles=("Human", "Assistant"),
|
458 |
+
messages=[],
|
459 |
+
offset=0,
|
460 |
+
sep_style=SeparatorStyle.SINGLE,
|
461 |
+
sep="###",
|
462 |
+
version="v0_mmtag",
|
463 |
+
)
|
464 |
+
|
465 |
+
conv_llava_v1 = Conversation(
|
466 |
+
system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
467 |
+
roles=("USER", "ASSISTANT"),
|
468 |
+
version="v1",
|
469 |
+
messages=[],
|
470 |
+
offset=0,
|
471 |
+
sep_style=SeparatorStyle.TWO,
|
472 |
+
sep=" ",
|
473 |
+
sep2="</s>",
|
474 |
+
)
|
475 |
+
|
476 |
+
conv_llava_v1_mmtag = Conversation(
|
477 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
478 |
+
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
479 |
+
"The visual content will be provided with the following format: <Image>visual content</Image>.",
|
480 |
+
roles=("USER", "ASSISTANT"),
|
481 |
+
messages=[],
|
482 |
+
offset=0,
|
483 |
+
sep_style=SeparatorStyle.TWO,
|
484 |
+
sep=" ",
|
485 |
+
sep2="</s>",
|
486 |
+
version="v1_mmtag",
|
487 |
+
)
|
488 |
+
|
489 |
+
conv_mistral_orca = Conversation(
|
490 |
+
system="""<|im_start|>system
|
491 |
+
You are MistralOrca, a large language model trained by Alignment Lab AI. Write out your reasoning step-by-step to be sure you get the right answers!""",
|
492 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
493 |
+
version="mpt",
|
494 |
+
messages=[],
|
495 |
+
offset=0,
|
496 |
+
sep_style=SeparatorStyle.MPT,
|
497 |
+
sep="<|im_end|>",
|
498 |
+
)
|
499 |
+
|
500 |
+
conv_mistral_zephyr = Conversation(
|
501 |
+
system="""<|system|>
|
502 |
+
You are a helpful AI assistant.""",
|
503 |
+
roles=("<|user|>\n", "<|assistant|>\n"),
|
504 |
+
version="mpt",
|
505 |
+
messages=[],
|
506 |
+
offset=0,
|
507 |
+
sep_style=SeparatorStyle.MPT,
|
508 |
+
sep="</s>",
|
509 |
+
)
|
510 |
+
|
511 |
+
conv_mistral_direct = Conversation(
|
512 |
+
system="""<|im_start|>system
|
513 |
+
Answer the questions.""",
|
514 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
515 |
+
version="mpt",
|
516 |
+
messages=[],
|
517 |
+
offset=0,
|
518 |
+
sep_style=SeparatorStyle.MPT,
|
519 |
+
sep="<|im_end|>",
|
520 |
+
)
|
521 |
+
|
522 |
+
conv_chatml_direct = Conversation(
|
523 |
+
system="""<|im_start|>system
|
524 |
+
Answer the questions.""",
|
525 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
526 |
+
version="mpt",
|
527 |
+
messages=[],
|
528 |
+
offset=0,
|
529 |
+
sep_style=SeparatorStyle.MPT,
|
530 |
+
sep="<|im_end|>",
|
531 |
+
)
|
532 |
+
|
533 |
+
default_conversation = conv_vicuna_v0
|
534 |
+
conv_templates = {
|
535 |
+
"default": conv_vicuna_v0,
|
536 |
+
"v0": conv_vicuna_v0,
|
537 |
+
"v1": conv_vicuna_v1,
|
538 |
+
"vicuna_v1": conv_vicuna_v1,
|
539 |
+
'v1_qwen2': conv_qwen_v1,
|
540 |
+
"llama_2": conv_llama_2,
|
541 |
+
"mistral_instruct": conv_mistral_instruct,
|
542 |
+
"mistral_orca": conv_mistral_orca,
|
543 |
+
"mistral_zephyr": conv_mistral_zephyr,
|
544 |
+
"mistral_direct": conv_mistral_direct,
|
545 |
+
"plain": conv_llava_plain,
|
546 |
+
"v0_plain": conv_llava_plain,
|
547 |
+
"chatml_direct": conv_chatml_direct,
|
548 |
+
"llava_v0": conv_llava_v0,
|
549 |
+
"llava_v0_mmtag": conv_llava_v0_mmtag,
|
550 |
+
"llava_v1": conv_llava_v1,
|
551 |
+
"llava_v1_mmtag": conv_llava_v1_mmtag,
|
552 |
+
"llava_llama_2": conv_llava_llama_2,
|
553 |
+
"llava_llama_3": conv_llava_llama_3,
|
554 |
+
"llava_llama_2_simple": conv_llava_llama_2_simple,
|
555 |
+
"llava_llama_2_mmtag": conv_llava_llama_2_mmtag,
|
556 |
+
"llava_mistral_instruct": conv_mistral_instruct,
|
557 |
+
"mpt": conv_mpt,
|
558 |
+
"qwen_1_5": conv_qwen,
|
559 |
+
}
|
oryx/mm_utils.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
from io import BytesIO
|
3 |
+
import base64
|
4 |
+
import math
|
5 |
+
import ast
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from transformers import StoppingCriteria
|
9 |
+
from oryx.constants import IMAGE_TOKEN_INDEX
|
10 |
+
import os
|
11 |
+
|
12 |
+
video_base = 0
|
13 |
+
video_ps = 64
|
14 |
+
highres_base = 0
|
15 |
+
highres_ps = 32
|
16 |
+
MAXRES = 1536
|
17 |
+
MINRES = 0
|
18 |
+
VIDEO_MAXRES = 480
|
19 |
+
VIDEO_MINRES = 288
|
20 |
+
LOWRES_RESIZE = (384,32)
|
21 |
+
PAD2STRIDE=False
|
22 |
+
|
23 |
+
def pad_image(image, target_resolution, value=0):
|
24 |
+
"""
|
25 |
+
Resize and pad an image to a target resolution while maintaining aspect ratio.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
image (PIL.Image.Image): The input image.
|
29 |
+
target_resolution (tuple): The target resolution (width, height) of the image.
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
PIL.Image.Image: The resized and padded image.
|
33 |
+
"""
|
34 |
+
original_width, original_height = image.size
|
35 |
+
target_width, target_height = target_resolution
|
36 |
+
# Create a new image with the target size and paste the resized image onto it
|
37 |
+
new_image = Image.new('RGB', (target_width, target_height), (value, value, value))
|
38 |
+
paste_x = (target_width - original_width) // 2
|
39 |
+
paste_y = (target_height - original_height) // 2
|
40 |
+
new_image.paste(image, (paste_x, paste_y))
|
41 |
+
return new_image
|
42 |
+
|
43 |
+
def resize_images(image, patch_size=14, base_size=896):
|
44 |
+
h, w = image.size
|
45 |
+
if base_size == 0:
|
46 |
+
if h * w > MAXRES * MAXRES:
|
47 |
+
# print(f'{h}x{w} larger than max size {MAXRES}, resize to {MAXRES}')
|
48 |
+
scale = MAXRES * MAXRES / (h * w)
|
49 |
+
scale = math.sqrt(scale)
|
50 |
+
elif h * w < MINRES * MINRES:
|
51 |
+
# print(f'{h}x{w} smaller than max size {MINRES}, resize to {MINRES}')
|
52 |
+
scale = MINRES * MINRES / (h * w)
|
53 |
+
scale = math.sqrt(scale)
|
54 |
+
else:
|
55 |
+
scale = None
|
56 |
+
else:
|
57 |
+
scale = base_size * base_size / (h * w)
|
58 |
+
scale = math.sqrt(scale)
|
59 |
+
|
60 |
+
|
61 |
+
if scale is not None:
|
62 |
+
new_h = int(h * scale / patch_size) * patch_size
|
63 |
+
new_w = int(w * scale / patch_size) * patch_size
|
64 |
+
image = image.resize((new_h, new_w))
|
65 |
+
elif PAD2STRIDE:
|
66 |
+
if h % patch_size == 0:
|
67 |
+
new_h = h
|
68 |
+
else:
|
69 |
+
new_h = (h // patch_size + 1) * patch_size
|
70 |
+
|
71 |
+
if w % patch_size == 0:
|
72 |
+
new_w = w
|
73 |
+
else:
|
74 |
+
new_w = (w // patch_size + 1) * patch_size
|
75 |
+
image = pad_image(image, (new_h, new_w), value=127)
|
76 |
+
else:
|
77 |
+
scale = 1.0
|
78 |
+
new_h = int(h * scale / patch_size) * patch_size
|
79 |
+
new_w = int(w * scale / patch_size) * patch_size
|
80 |
+
image = image.resize((new_h, new_w))
|
81 |
+
|
82 |
+
return image
|
83 |
+
|
84 |
+
def resize_video(image, patch_size=14, base_size=896):
|
85 |
+
h, w = image.size
|
86 |
+
if base_size == 0:
|
87 |
+
if h * w > VIDEO_MAXRES * VIDEO_MAXRES:
|
88 |
+
# print(f'{h}x{w} larger than max size {MAXRES}, resize to {MAXRES}')
|
89 |
+
scale = VIDEO_MAXRES * VIDEO_MAXRES / (h * w)
|
90 |
+
scale = math.sqrt(scale)
|
91 |
+
elif h * w < VIDEO_MINRES * VIDEO_MINRES:
|
92 |
+
# print(f'{h}x{w} smaller than max size {MINRES}, resize to {MINRES}')
|
93 |
+
scale = VIDEO_MINRES * VIDEO_MINRES / (h * w)
|
94 |
+
scale = math.sqrt(scale)
|
95 |
+
else:
|
96 |
+
scale = None
|
97 |
+
else:
|
98 |
+
scale = base_size * base_size / (h * w)
|
99 |
+
scale = math.sqrt(scale)
|
100 |
+
|
101 |
+
if scale is not None:
|
102 |
+
new_h = int(h * scale / patch_size) * patch_size
|
103 |
+
new_w = int(w * scale / patch_size) * patch_size
|
104 |
+
image = image.resize((new_h, new_w))
|
105 |
+
elif PAD2STRIDE:
|
106 |
+
if h % patch_size == 0:
|
107 |
+
new_h = h
|
108 |
+
else:
|
109 |
+
new_h = (h // patch_size + 1) * patch_size
|
110 |
+
|
111 |
+
if w % patch_size == 0:
|
112 |
+
new_w = w
|
113 |
+
else:
|
114 |
+
new_w = (w // patch_size + 1) * patch_size
|
115 |
+
image = pad_image(image, (new_h, new_w), value=127)
|
116 |
+
else:
|
117 |
+
scale = 1.0
|
118 |
+
new_h = int(h * scale / patch_size) * patch_size
|
119 |
+
new_w = int(w * scale / patch_size) * patch_size
|
120 |
+
image = image.resize((new_h, new_w))
|
121 |
+
|
122 |
+
return image
|
123 |
+
|
124 |
+
def process_anyres_video_genli(image, processor):
|
125 |
+
image = resize_video(image, patch_size=video_ps, base_size=video_base)
|
126 |
+
image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
127 |
+
return image.unsqueeze(0)
|
128 |
+
|
129 |
+
def process_anyres_video_genli_long(image, processor):
|
130 |
+
image = resize_video(image, patch_size=video_ps * 2, base_size=video_base)
|
131 |
+
image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
132 |
+
return image.unsqueeze(0)
|
133 |
+
|
134 |
+
def load_image_from_base64(image):
|
135 |
+
return Image.open(BytesIO(base64.b64decode(image)))
|
136 |
+
|
137 |
+
def process_anyres_highres_image_genli(image, processor):
|
138 |
+
h, w = image.size
|
139 |
+
if h < 32 and w < 32:
|
140 |
+
min_size = min(h, w)
|
141 |
+
ratio = 64 / min_size
|
142 |
+
image = image.resize((int(h * ratio), int(w * ratio)))
|
143 |
+
elif h < 32:
|
144 |
+
ratio = 64 / h
|
145 |
+
image = image.resize((int(h * ratio), int(w * ratio)))
|
146 |
+
elif w < 32:
|
147 |
+
ratio = 64 / w
|
148 |
+
image = image.resize((int(h * ratio), int(w * ratio)))
|
149 |
+
|
150 |
+
image = resize_images(image, patch_size=highres_ps, base_size=highres_base)
|
151 |
+
|
152 |
+
image_original_resize = resize_images(image, patch_size=LOWRES_RESIZE[1], base_size=LOWRES_RESIZE[0])
|
153 |
+
|
154 |
+
# image_patches = [image_original_resize] + [image_original_resize]
|
155 |
+
# image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0]
|
156 |
+
# for image_patch in image_patches]
|
157 |
+
image_patches = processor.preprocess(image_original_resize, return_tensors='pt')['pixel_values'][0]
|
158 |
+
image_padded = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
159 |
+
# return torch.stack(image_patches, dim=0), image_padded.unsqueeze(0)
|
160 |
+
return image_patches.unsqueeze(0), image_padded.unsqueeze(0)
|
161 |
+
|
162 |
+
|
163 |
+
def read_image_patch(patch_info):
|
164 |
+
if 'img_path' in patch_info.keys():
|
165 |
+
image = Image.open(patch_info['img_path']).convert('RGB')
|
166 |
+
else:
|
167 |
+
if 'image_encoing' in patch_info.keys():
|
168 |
+
patch_info['image_encoding'] = patch_info['image_encoing']
|
169 |
+
image_file_name = patch_info['patch']
|
170 |
+
start_bytes = int(patch_info['start_num'])
|
171 |
+
file_size = int(patch_info['size'])
|
172 |
+
|
173 |
+
with open(image_file_name, 'rb') as f:
|
174 |
+
f.seek(start_bytes)
|
175 |
+
if 'image_encoding' in patch_info.keys() and patch_info['image_encoding'] == 'base64':
|
176 |
+
image = Image.open(io.BytesIO(base64.b64decode(f.read(file_size).decode()))).convert("RGB")
|
177 |
+
else:
|
178 |
+
image = Image.open(io.BytesIO(f.read(file_size))).convert("RGB")
|
179 |
+
return image
|
180 |
+
|
181 |
+
def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
|
182 |
+
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
|
183 |
+
|
184 |
+
def insert_separator(X, sep):
|
185 |
+
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
|
186 |
+
|
187 |
+
input_ids = []
|
188 |
+
offset = 0
|
189 |
+
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
|
190 |
+
offset = 1
|
191 |
+
input_ids.append(prompt_chunks[0][0])
|
192 |
+
|
193 |
+
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
|
194 |
+
input_ids.extend(x[offset:])
|
195 |
+
|
196 |
+
if return_tensors is not None:
|
197 |
+
if return_tensors == 'pt':
|
198 |
+
return torch.tensor(input_ids, dtype=torch.long)
|
199 |
+
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
200 |
+
return input_ids
|
201 |
+
|
202 |
+
|
203 |
+
def get_model_name_from_path(model_path):
|
204 |
+
model_path = model_path.strip("/")
|
205 |
+
model_paths = model_path.split("/")
|
206 |
+
if model_paths[-1].startswith('checkpoint-'):
|
207 |
+
return model_paths[-2] + "_" + model_paths[-1]
|
208 |
+
else:
|
209 |
+
return model_paths[-1]
|
210 |
+
|
211 |
+
|
212 |
+
class KeywordsStoppingCriteria(StoppingCriteria):
|
213 |
+
def __init__(self, keywords, tokenizer, input_ids):
|
214 |
+
self.keywords = keywords
|
215 |
+
self.keyword_ids = []
|
216 |
+
for keyword in keywords:
|
217 |
+
cur_keyword_ids = tokenizer(keyword).input_ids
|
218 |
+
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
|
219 |
+
cur_keyword_ids = cur_keyword_ids[1:]
|
220 |
+
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
|
221 |
+
self.tokenizer = tokenizer
|
222 |
+
self.start_len = input_ids.shape[1]
|
223 |
+
|
224 |
+
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
225 |
+
assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
|
226 |
+
offset = min(output_ids.shape[1] - self.start_len, 3)
|
227 |
+
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
|
228 |
+
for keyword_id in self.keyword_ids:
|
229 |
+
if output_ids[0, -keyword_id.shape[0]:] == keyword_id:
|
230 |
+
return True
|
231 |
+
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
|
232 |
+
for keyword in self.keywords:
|
233 |
+
if keyword in outputs:
|
234 |
+
return True
|
235 |
+
return False
|
oryx/model/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
AVAILABLE_MODELS = {
|
4 |
+
"oryx_llama": "OryxLlamaForCausalLM, OryxConfig",
|
5 |
+
"oryx_qwen": "OryxQwenForCausalLM, OryxQwenConfig",
|
6 |
+
# Add other models as needed
|
7 |
+
}
|
8 |
+
|
9 |
+
for model_name, model_classes in AVAILABLE_MODELS.items():
|
10 |
+
try:
|
11 |
+
exec(f"from .language_model.{model_name} import {model_classes}")
|
12 |
+
except Exception as e:
|
13 |
+
raise e
|
14 |
+
print(f"Failed to import {model_name} from llava.language_model.{model_name}")
|
oryx/model/builder.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import warnings
|
3 |
+
import shutil
|
4 |
+
|
5 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
|
6 |
+
import torch
|
7 |
+
from oryx.model import *
|
8 |
+
from oryx.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
9 |
+
|
10 |
+
|
11 |
+
def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", overwrite_config=None):
|
12 |
+
kwargs = {"device_map": device_map}
|
13 |
+
|
14 |
+
if load_8bit:
|
15 |
+
kwargs["load_in_8bit"] = True
|
16 |
+
elif load_4bit:
|
17 |
+
kwargs["load_in_4bit"] = True
|
18 |
+
kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4")
|
19 |
+
else:
|
20 |
+
kwargs["torch_dtype"] = torch.bfloat16
|
21 |
+
|
22 |
+
if "oryx" in model_name.lower():
|
23 |
+
# Load Oryx model
|
24 |
+
if "7b" in model_name.lower():
|
25 |
+
from oryx.model.language_model.oryx_qwen import OryxQwenConfig
|
26 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
27 |
+
if overwrite_config is not None:
|
28 |
+
cfg_pretrained = OryxQwenConfig.from_pretrained(model_path)
|
29 |
+
print(f"Overwriting config with {overwrite_config}")
|
30 |
+
for k, v in overwrite_config.items():
|
31 |
+
setattr(cfg_pretrained, k, v)
|
32 |
+
model = OryxQwenForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
|
33 |
+
else:
|
34 |
+
model = OryxQwenForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
|
35 |
+
else:
|
36 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
37 |
+
cfg_pretrained = AutoConfig.from_pretrained(model_path)
|
38 |
+
if overwrite_config is not None:
|
39 |
+
print(f"Overwriting config with {overwrite_config}")
|
40 |
+
for k, v in overwrite_config.items():
|
41 |
+
setattr(cfg_pretrained, k, v)
|
42 |
+
model = OryxLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
|
43 |
+
|
44 |
+
else:
|
45 |
+
# Load language model
|
46 |
+
if model_base is not None:
|
47 |
+
# PEFT model
|
48 |
+
from peft import PeftModel
|
49 |
+
|
50 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
51 |
+
model = AutoModelForCausalLM.from_pretrained(model_base, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, device_map="auto")
|
52 |
+
print(f"Loading LoRA weights from {model_path}")
|
53 |
+
model = PeftModel.from_pretrained(model, model_path)
|
54 |
+
print(f"Merging weights")
|
55 |
+
model = model.merge_and_unload()
|
56 |
+
print("Convert to FP16...")
|
57 |
+
model.to(torch.bfloat16)
|
58 |
+
else:
|
59 |
+
use_fast = False
|
60 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
61 |
+
model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
|
62 |
+
|
63 |
+
image_processor = None
|
64 |
+
|
65 |
+
assert "oryx" in model_name.lower(), "Only Oryx models are supported for video chatbot."
|
66 |
+
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
|
67 |
+
mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
|
68 |
+
if mm_use_im_patch_token:
|
69 |
+
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
70 |
+
if mm_use_im_start_end:
|
71 |
+
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
|
72 |
+
model.resize_token_embeddings(len(tokenizer))
|
73 |
+
|
74 |
+
vision_tower = model.get_vision_tower()
|
75 |
+
print("Loading vision tower...")
|
76 |
+
if not vision_tower.is_loaded:
|
77 |
+
vision_tower.load_model(device_map=device_map)
|
78 |
+
if device_map != "auto":
|
79 |
+
vision_tower.to(device="cuda", dtype=torch.bfloat16)
|
80 |
+
else:
|
81 |
+
vision_tower.to(device="cuda:0", dtype=torch.bfloat16)
|
82 |
+
image_processor = vision_tower.image_processor
|
83 |
+
print("Loading vision tower succeeded.")
|
84 |
+
if hasattr(model.config, "max_sequence_length"):
|
85 |
+
context_len = model.config.max_sequence_length
|
86 |
+
else:
|
87 |
+
context_len = 2048
|
88 |
+
|
89 |
+
return tokenizer, model, image_processor, context_len
|
oryx/model/language_model/oryx_llama.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional, Tuple, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from transformers import AutoConfig, AutoModelForCausalLM, \
|
7 |
+
LlamaConfig, LlamaModel, LlamaForCausalLM
|
8 |
+
|
9 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
10 |
+
from transformers.generation.utils import GenerateOutput
|
11 |
+
|
12 |
+
from oryx.model.oryx_arch import OryxMetaModel, OryxMetaForCausalLM
|
13 |
+
|
14 |
+
|
15 |
+
class OryxConfig(LlamaConfig):
|
16 |
+
model_type = "oryx_llama"
|
17 |
+
|
18 |
+
|
19 |
+
class OryxLlamaModel(OryxMetaModel, LlamaModel):
|
20 |
+
config_class = OryxConfig
|
21 |
+
|
22 |
+
def __init__(self, config: LlamaConfig):
|
23 |
+
super(OryxLlamaModel, self).__init__(config)
|
24 |
+
|
25 |
+
|
26 |
+
class OryxLlamaForCausalLM(LlamaForCausalLM, OryxMetaForCausalLM):
|
27 |
+
config_class = OryxConfig
|
28 |
+
|
29 |
+
def __init__(self, config):
|
30 |
+
LlamaForCausalLM.__init__(self, config)
|
31 |
+
self.model = OryxLlamaModel(config)
|
32 |
+
|
33 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
34 |
+
|
35 |
+
# Initialize weights and apply final processing
|
36 |
+
self.post_init()
|
37 |
+
|
38 |
+
def get_model(self):
|
39 |
+
return self.model
|
40 |
+
|
41 |
+
def forward(
|
42 |
+
self,
|
43 |
+
input_ids: torch.LongTensor = None,
|
44 |
+
attention_mask: Optional[torch.Tensor] = None,
|
45 |
+
position_ids: Optional[torch.LongTensor] = None,
|
46 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
47 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
48 |
+
labels: Optional[torch.LongTensor] = None,
|
49 |
+
use_cache: Optional[bool] = None,
|
50 |
+
output_attentions: Optional[bool] = None,
|
51 |
+
output_hidden_states: Optional[bool] = None,
|
52 |
+
images: Optional[torch.FloatTensor] = None,
|
53 |
+
images_highres: Optional[List[torch.FloatTensor]] = None,
|
54 |
+
image_sizes: Optional[List[List[int]]] = None,
|
55 |
+
return_dict: Optional[bool] = None,
|
56 |
+
modalities: Optional[List[str]] = ["image"],
|
57 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
58 |
+
|
59 |
+
|
60 |
+
if inputs_embeds is None:
|
61 |
+
(input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images,
|
62 |
+
modalities, image_sizes, images_highres)
|
63 |
+
|
64 |
+
if labels is None:
|
65 |
+
return super().forward(
|
66 |
+
input_ids=input_ids,
|
67 |
+
attention_mask=attention_mask,
|
68 |
+
position_ids=position_ids,
|
69 |
+
past_key_values=past_key_values,
|
70 |
+
inputs_embeds=inputs_embeds,
|
71 |
+
use_cache=use_cache,
|
72 |
+
output_attentions=output_attentions,
|
73 |
+
output_hidden_states=output_hidden_states,
|
74 |
+
return_dict=return_dict
|
75 |
+
)
|
76 |
+
else:
|
77 |
+
return self.forward_llm_efficient(
|
78 |
+
input_ids=input_ids,
|
79 |
+
attention_mask=attention_mask,
|
80 |
+
position_ids=position_ids,
|
81 |
+
past_key_values=past_key_values,
|
82 |
+
inputs_embeds=inputs_embeds,
|
83 |
+
labels=labels,
|
84 |
+
use_cache=use_cache,
|
85 |
+
output_attentions=output_attentions,
|
86 |
+
output_hidden_states=output_hidden_states,
|
87 |
+
return_dict=return_dict
|
88 |
+
)
|
89 |
+
|
90 |
+
def forward_llm_efficient(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict):
|
91 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
92 |
+
output_hidden_states = (
|
93 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
94 |
+
)
|
95 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
96 |
+
|
97 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
98 |
+
outputs = self.model(
|
99 |
+
input_ids=input_ids,
|
100 |
+
attention_mask=attention_mask,
|
101 |
+
position_ids=position_ids,
|
102 |
+
past_key_values=past_key_values,
|
103 |
+
inputs_embeds=inputs_embeds,
|
104 |
+
use_cache=use_cache,
|
105 |
+
output_attentions=output_attentions,
|
106 |
+
output_hidden_states=output_hidden_states,
|
107 |
+
return_dict=return_dict,
|
108 |
+
)
|
109 |
+
|
110 |
+
hidden_states = outputs[0]
|
111 |
+
hidden_dim = hidden_states.size(-1)
|
112 |
+
shift_labels = labels[..., 1:].contiguous().reshape(-1)
|
113 |
+
shift_hidden_states = hidden_states[..., :-1, :].contiguous().reshape(-1, hidden_dim)
|
114 |
+
assert shift_labels.size(0) == shift_hidden_states.size(0)
|
115 |
+
mask = shift_labels > -1
|
116 |
+
seen_tokens = mask.float().sum().item()
|
117 |
+
if not seen_tokens > 0:
|
118 |
+
logits = self.lm_head(shift_hidden_states[0:2])
|
119 |
+
loss = logits.sum() * 0
|
120 |
+
print("No tokens seen")
|
121 |
+
print(shift_labels)
|
122 |
+
else:
|
123 |
+
shift_labels = shift_labels[mask]
|
124 |
+
shift_hidden_states = shift_hidden_states[mask, :]
|
125 |
+
logits = self.lm_head(shift_hidden_states)
|
126 |
+
logits = logits.float()
|
127 |
+
loss_fct = nn.CrossEntropyLoss()
|
128 |
+
loss = loss_fct(logits, shift_labels)
|
129 |
+
|
130 |
+
|
131 |
+
if not return_dict:
|
132 |
+
output = (logits,) + outputs[1:]
|
133 |
+
return (loss,) + output if loss is not None else output
|
134 |
+
|
135 |
+
return CausalLMOutputWithPast(
|
136 |
+
loss=loss,
|
137 |
+
logits=logits,
|
138 |
+
past_key_values=outputs.past_key_values,
|
139 |
+
hidden_states=outputs.hidden_states,
|
140 |
+
attentions=outputs.attentions,
|
141 |
+
)
|
142 |
+
|
143 |
+
@torch.no_grad()
|
144 |
+
def generate(
|
145 |
+
self,
|
146 |
+
inputs: Optional[torch.Tensor] = None,
|
147 |
+
images: Optional[torch.Tensor] = None,
|
148 |
+
image_sizes: Optional[torch.Tensor] = None,
|
149 |
+
**kwargs,
|
150 |
+
) -> Union[GenerateOutput, torch.LongTensor]:
|
151 |
+
modalities = kwargs.pop("modalities", None)
|
152 |
+
position_ids = kwargs.pop("position_ids", None)
|
153 |
+
attention_mask = kwargs.pop("attention_mask", None)
|
154 |
+
if "inputs_embeds" in kwargs:
|
155 |
+
raise NotImplementedError("`inputs_embeds` is not supported")
|
156 |
+
|
157 |
+
if images is not None:
|
158 |
+
(
|
159 |
+
inputs,
|
160 |
+
position_ids,
|
161 |
+
attention_mask,
|
162 |
+
_,
|
163 |
+
inputs_embeds,
|
164 |
+
_
|
165 |
+
) = self.prepare_inputs_labels_for_multimodal(
|
166 |
+
inputs,
|
167 |
+
position_ids,
|
168 |
+
attention_mask,
|
169 |
+
None,
|
170 |
+
None,
|
171 |
+
images,
|
172 |
+
modalities,
|
173 |
+
image_sizes=image_sizes
|
174 |
+
)
|
175 |
+
else:
|
176 |
+
inputs_embeds = self.get_model().embed_tokens(inputs)
|
177 |
+
|
178 |
+
return super().generate(
|
179 |
+
position_ids=position_ids,
|
180 |
+
attention_mask=attention_mask,
|
181 |
+
inputs_embeds=inputs_embeds,
|
182 |
+
**kwargs
|
183 |
+
)
|
184 |
+
|
185 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
|
186 |
+
inputs_embeds=None, **kwargs):
|
187 |
+
images = kwargs.pop("images", None)
|
188 |
+
image_sizes = kwargs.pop("image_sizes", None)
|
189 |
+
inputs = super().prepare_inputs_for_generation(
|
190 |
+
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
|
191 |
+
)
|
192 |
+
if images is not None:
|
193 |
+
inputs['images'] = images
|
194 |
+
if image_sizes is not None:
|
195 |
+
inputs['image_sizes'] = image_sizes
|
196 |
+
return inputs
|
197 |
+
|
198 |
+
if OryxConfig.model_type == "oryx":
|
199 |
+
OryxConfig.model_type = "oryx_llama" # directly set to Oryx_dev to avoid conflict with HF's Oryx
|
200 |
+
|
201 |
+
AutoConfig.register("oryx_llama", OryxConfig)
|
202 |
+
AutoModelForCausalLM.register(OryxConfig, OryxLlamaForCausalLM)
|
oryx/model/language_model/oryx_qwen.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from typing import List, Optional, Tuple, Union, Dict
|
3 |
+
import torch
|
4 |
+
import os
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
import transformers
|
8 |
+
from transformers import AutoConfig, AutoModelForCausalLM
|
9 |
+
|
10 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
11 |
+
from transformers.generation.utils import GenerateOutput
|
12 |
+
|
13 |
+
from oryx.model.oryx_arch import OryxMetaModel, OryxMetaForCausalLM
|
14 |
+
from transformers import Qwen2Config, Qwen2Model, Qwen2ForCausalLM
|
15 |
+
|
16 |
+
class OryxQwenConfig(Qwen2Config):
|
17 |
+
model_type = "oryx_qwen"
|
18 |
+
|
19 |
+
|
20 |
+
class OryxQwenModel(OryxMetaModel, Qwen2Model):
|
21 |
+
config_class = OryxQwenConfig
|
22 |
+
|
23 |
+
def __init__(self, config: Qwen2Config):
|
24 |
+
super(OryxQwenModel, self).__init__(config)
|
25 |
+
|
26 |
+
|
27 |
+
class OryxQwenForCausalLM(Qwen2ForCausalLM, OryxMetaForCausalLM):
|
28 |
+
config_class = OryxQwenConfig
|
29 |
+
|
30 |
+
def __init__(self, config):
|
31 |
+
# super(Qwen2ForCausalLM, self).__init__(config)
|
32 |
+
Qwen2ForCausalLM.__init__(self, config)
|
33 |
+
config.model_type = "oryx_qwen"
|
34 |
+
config.rope_scaling = None
|
35 |
+
|
36 |
+
self.model = OryxQwenModel(config)
|
37 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
38 |
+
# Initialize weights and apply final processing
|
39 |
+
self.post_init()
|
40 |
+
|
41 |
+
def get_model(self):
|
42 |
+
return self.model
|
43 |
+
|
44 |
+
def forward(
|
45 |
+
self,
|
46 |
+
input_ids: torch.LongTensor = None,
|
47 |
+
attention_mask: Optional[torch.Tensor] = None,
|
48 |
+
position_ids: Optional[torch.LongTensor] = None,
|
49 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
50 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
51 |
+
labels: Optional[torch.LongTensor] = None,
|
52 |
+
use_cache: Optional[bool] = None,
|
53 |
+
output_attentions: Optional[bool] = None,
|
54 |
+
output_hidden_states: Optional[bool] = None,
|
55 |
+
images: Optional[torch.FloatTensor] = None,
|
56 |
+
images_highres: Optional[List[torch.FloatTensor]] = None,
|
57 |
+
image_sizes: Optional[List[List[int]]] = None,
|
58 |
+
return_dict: Optional[bool] = None,
|
59 |
+
modalities: Optional[List[str]] = ["image"],
|
60 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
61 |
+
|
62 |
+
if inputs_embeds is None:
|
63 |
+
(input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images,
|
64 |
+
modalities, image_sizes, images_highres)
|
65 |
+
if labels is None:
|
66 |
+
return super().forward(
|
67 |
+
input_ids=input_ids,
|
68 |
+
attention_mask=attention_mask,
|
69 |
+
position_ids=position_ids,
|
70 |
+
past_key_values=past_key_values,
|
71 |
+
inputs_embeds=inputs_embeds,
|
72 |
+
use_cache=use_cache,
|
73 |
+
output_attentions=output_attentions,
|
74 |
+
output_hidden_states=output_hidden_states,
|
75 |
+
return_dict=return_dict
|
76 |
+
)
|
77 |
+
else:
|
78 |
+
return self.forward_llm_efficient(
|
79 |
+
input_ids=input_ids,
|
80 |
+
attention_mask=attention_mask,
|
81 |
+
position_ids=position_ids,
|
82 |
+
past_key_values=past_key_values,
|
83 |
+
inputs_embeds=inputs_embeds,
|
84 |
+
labels=labels,
|
85 |
+
use_cache=use_cache,
|
86 |
+
output_attentions=output_attentions,
|
87 |
+
output_hidden_states=output_hidden_states,
|
88 |
+
return_dict=return_dict
|
89 |
+
)
|
90 |
+
|
91 |
+
def forward_llm_efficient(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict):
|
92 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
93 |
+
output_hidden_states = (
|
94 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
95 |
+
)
|
96 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
97 |
+
|
98 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
99 |
+
outputs = self.model(
|
100 |
+
input_ids=input_ids,
|
101 |
+
attention_mask=attention_mask,
|
102 |
+
position_ids=position_ids,
|
103 |
+
past_key_values=past_key_values,
|
104 |
+
inputs_embeds=inputs_embeds,
|
105 |
+
use_cache=use_cache,
|
106 |
+
output_attentions=output_attentions,
|
107 |
+
output_hidden_states=output_hidden_states,
|
108 |
+
return_dict=return_dict,
|
109 |
+
)
|
110 |
+
|
111 |
+
hidden_states = outputs[0]
|
112 |
+
hidden_dim = hidden_states.size(-1)
|
113 |
+
shift_labels = labels[..., 1:].contiguous().reshape(-1)
|
114 |
+
shift_hidden_states = hidden_states[..., :-1, :].contiguous().reshape(-1, hidden_dim)
|
115 |
+
assert shift_labels.size(0) == shift_hidden_states.size(0)
|
116 |
+
mask = shift_labels > -1
|
117 |
+
assert mask.float().sum() > 0
|
118 |
+
shift_labels = shift_labels[mask]
|
119 |
+
shift_hidden_states = shift_hidden_states[mask, :]
|
120 |
+
logits = self.lm_head(shift_hidden_states)
|
121 |
+
logits = logits.float()
|
122 |
+
loss_fct = nn.CrossEntropyLoss()
|
123 |
+
loss = loss_fct(logits, shift_labels)
|
124 |
+
|
125 |
+
|
126 |
+
if not return_dict:
|
127 |
+
output = (logits,) + outputs[1:]
|
128 |
+
return (loss,) + output if loss is not None else output
|
129 |
+
|
130 |
+
|
131 |
+
return CausalLMOutputWithPast(
|
132 |
+
loss=loss,
|
133 |
+
logits=logits,
|
134 |
+
past_key_values=outputs.past_key_values,
|
135 |
+
hidden_states=outputs.hidden_states,
|
136 |
+
attentions=outputs.attentions,
|
137 |
+
)
|
138 |
+
|
139 |
+
@torch.no_grad()
|
140 |
+
def generate(
|
141 |
+
self,
|
142 |
+
inputs: Optional[torch.Tensor] = None,
|
143 |
+
images: Optional[torch.Tensor] = None,
|
144 |
+
images_highres: Optional[List[torch.FloatTensor]] = None,
|
145 |
+
image_sizes: Optional[torch.Tensor] = None,
|
146 |
+
modalities: Optional[List[str]] = ["image"],
|
147 |
+
**kwargs,
|
148 |
+
) -> Union[GenerateOutput, torch.LongTensor]:
|
149 |
+
position_ids = kwargs.pop("position_ids", None)
|
150 |
+
attention_mask = kwargs.pop("attention_mask", None)
|
151 |
+
if "inputs_embeds" in kwargs:
|
152 |
+
raise NotImplementedError("`inputs_embeds` is not supported")
|
153 |
+
|
154 |
+
if images is not None:
|
155 |
+
(inputs,
|
156 |
+
position_ids,
|
157 |
+
attention_mask,
|
158 |
+
_,
|
159 |
+
inputs_embeds,
|
160 |
+
_) = self.prepare_inputs_labels_for_multimodal(inputs,
|
161 |
+
position_ids,
|
162 |
+
attention_mask,
|
163 |
+
None, None,
|
164 |
+
images, modalities, image_sizes=image_sizes, images_highres=images_highres)
|
165 |
+
else:
|
166 |
+
inputs_embeds = self.get_model().embed_tokens(inputs)
|
167 |
+
|
168 |
+
return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
|
169 |
+
|
170 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
171 |
+
images = kwargs.pop("images", None)
|
172 |
+
image_sizes = kwargs.pop("image_sizes", None)
|
173 |
+
inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
|
174 |
+
if images is not None:
|
175 |
+
inputs["images"] = images
|
176 |
+
if image_sizes is not None:
|
177 |
+
inputs["image_sizes"] = image_sizes
|
178 |
+
return inputs
|
179 |
+
|
180 |
+
|
181 |
+
AutoConfig.register("oryx_qwen", OryxQwenConfig)
|
182 |
+
AutoModelForCausalLM.register(OryxQwenConfig, OryxQwenForCausalLM)
|
oryx/model/multimodal_encoder/builder.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from .oryx_vit import OryxViTWrapper
|
3 |
+
|
4 |
+
def build_vision_tower(vision_tower_cfg, **kwargs):
|
5 |
+
vision_tower = getattr(vision_tower_cfg, 'vision_tower', getattr(vision_tower_cfg, 'mm_vision_tower', None))
|
6 |
+
is_absolute_path_exists = os.path.exists(vision_tower)
|
7 |
+
if "oryx_vit" in vision_tower:
|
8 |
+
print(f"Buiding OryxViTWrapper from {vision_tower}...")
|
9 |
+
path = vision_tower.split(":")[1]
|
10 |
+
return OryxViTWrapper(vision_tower, path=path, args=vision_tower_cfg, **kwargs)
|
11 |
+
else:
|
12 |
+
raise ValueError(f'Unknown vision tower: {vision_tower}')
|
oryx/model/multimodal_encoder/oryx_vit.py
ADDED
@@ -0,0 +1,844 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import warnings
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from functools import partial
|
5 |
+
from typing import (
|
6 |
+
Callable,
|
7 |
+
Dict,
|
8 |
+
Final,
|
9 |
+
List,
|
10 |
+
Literal,
|
11 |
+
Optional,
|
12 |
+
Sequence,
|
13 |
+
Set,
|
14 |
+
Tuple,
|
15 |
+
Type,
|
16 |
+
Union,
|
17 |
+
)
|
18 |
+
|
19 |
+
from torch.utils.checkpoint import checkpoint
|
20 |
+
import torch
|
21 |
+
import torch.nn as nn
|
22 |
+
import torch.nn.functional as F
|
23 |
+
try:
|
24 |
+
from timm.layers import (
|
25 |
+
AttentionPoolLatent,
|
26 |
+
DropPath,
|
27 |
+
LayerType,
|
28 |
+
Mlp,
|
29 |
+
PatchDropout,
|
30 |
+
PatchEmbed,
|
31 |
+
resample_abs_pos_embed,
|
32 |
+
)
|
33 |
+
from timm.models._manipulate import checkpoint_seq, named_apply
|
34 |
+
except:
|
35 |
+
print('Wrong timm version')
|
36 |
+
|
37 |
+
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
38 |
+
|
39 |
+
from typing import Optional
|
40 |
+
|
41 |
+
import logging
|
42 |
+
import torch
|
43 |
+
import torch.nn as nn
|
44 |
+
import torch.nn.functional as F
|
45 |
+
|
46 |
+
import os
|
47 |
+
|
48 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
49 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
50 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
51 |
+
def norm_cdf(x):
|
52 |
+
# Computes standard normal cumulative distribution function
|
53 |
+
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
|
54 |
+
|
55 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
56 |
+
warnings.warn(
|
57 |
+
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
58 |
+
"The distribution of values may be incorrect.",
|
59 |
+
stacklevel=2,
|
60 |
+
)
|
61 |
+
|
62 |
+
with torch.no_grad():
|
63 |
+
# Values are generated by using a truncated uniform distribution and
|
64 |
+
# then using the inverse CDF for the normal distribution.
|
65 |
+
# Get upper and lower cdf values
|
66 |
+
l = norm_cdf((a - mean) / std) # noqa: E741
|
67 |
+
u = norm_cdf((b - mean) / std)
|
68 |
+
|
69 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
70 |
+
# [2l-1, 2u-1].
|
71 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
72 |
+
|
73 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
74 |
+
# standard normal
|
75 |
+
tensor.erfinv_()
|
76 |
+
|
77 |
+
# Transform to proper mean, std
|
78 |
+
tensor.mul_(std * math.sqrt(2.0))
|
79 |
+
tensor.add_(mean)
|
80 |
+
|
81 |
+
# Clamp to ensure it's in the proper range
|
82 |
+
tensor.clamp_(min=a, max=b)
|
83 |
+
return tensor
|
84 |
+
|
85 |
+
|
86 |
+
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
|
87 |
+
# type: (torch.Tensor, float, float, float, float) -> torch.Tensor
|
88 |
+
r"""The original timm.models.layers.weight_init.trunc_normal_ can not handle bfloat16 yet, here we first
|
89 |
+
convert the tensor to float32, apply the trunc_normal_() in float32, and then convert it back to its orignal dtype.
|
90 |
+
Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn
|
91 |
+
from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
92 |
+
with values outside :math:`[a, b]` redrawn until they are within
|
93 |
+
the bounds. The method used for generating the random values works
|
94 |
+
best when :math:`a \leq \text{mean} \leq b`.
|
95 |
+
Args:
|
96 |
+
tensor: an n-dimensional `torch.Tensor`
|
97 |
+
mean: the mean of the normal distribution
|
98 |
+
std: the standard deviation of the normal distribution
|
99 |
+
a: the minimum cutoff value
|
100 |
+
b: the maximum cutoff value
|
101 |
+
Examples:
|
102 |
+
>>> w = torch.empty(3, 5)
|
103 |
+
>>> nn.init.trunc_normal_(w)
|
104 |
+
"""
|
105 |
+
|
106 |
+
with torch.no_grad():
|
107 |
+
dtype = tensor.dtype
|
108 |
+
tensor_fp32 = tensor.float()
|
109 |
+
tensor_fp32 = _no_grad_trunc_normal_(tensor_fp32, mean, std, a, b)
|
110 |
+
tensor_dtype = tensor_fp32.to(dtype=dtype)
|
111 |
+
tensor.copy_(tensor_dtype)
|
112 |
+
|
113 |
+
|
114 |
+
def init_weights(self):
|
115 |
+
if self.pos_embed is not None:
|
116 |
+
trunc_normal_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5)
|
117 |
+
trunc_normal_(self.latent, std=self.latent_dim**-0.5)
|
118 |
+
|
119 |
+
|
120 |
+
def init_weights_vit_timm(module: nn.Module, name: str = "") -> None:
|
121 |
+
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
122 |
+
if isinstance(module, nn.Linear):
|
123 |
+
trunc_normal_(module.weight, std=0.02)
|
124 |
+
if module.bias is not None:
|
125 |
+
nn.init.zeros_(module.bias)
|
126 |
+
elif hasattr(module, "init_weights"):
|
127 |
+
module.init_weights()
|
128 |
+
|
129 |
+
|
130 |
+
class Attention(nn.Module):
|
131 |
+
fused_attn: Final[bool]
|
132 |
+
|
133 |
+
def __init__(
|
134 |
+
self,
|
135 |
+
dim: int,
|
136 |
+
num_heads: int = 8,
|
137 |
+
qkv_bias: bool = False,
|
138 |
+
qk_norm: bool = False,
|
139 |
+
attn_drop: float = 0.0,
|
140 |
+
proj_drop: float = 0.0,
|
141 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
142 |
+
) -> None:
|
143 |
+
super().__init__()
|
144 |
+
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
145 |
+
self.num_heads = num_heads
|
146 |
+
self.head_dim = dim // num_heads
|
147 |
+
self.scale = self.head_dim**-0.5
|
148 |
+
# self.fused_attn = use_fused_attn()
|
149 |
+
self.fused_attn = True
|
150 |
+
|
151 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
152 |
+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
153 |
+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
154 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
155 |
+
self.proj = nn.Linear(dim, dim)
|
156 |
+
self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0.0 else nn.Identity()
|
157 |
+
|
158 |
+
def forward(self, x: torch.Tensor, cu_slens=None) -> torch.Tensor:
|
159 |
+
B, N, C = x.shape
|
160 |
+
qkv = (
|
161 |
+
self.qkv(x)
|
162 |
+
.reshape(B, N, 3, self.num_heads, self.head_dim)
|
163 |
+
.permute(2, 0, 3, 1, 4)
|
164 |
+
)
|
165 |
+
q, k, v = qkv.unbind(0)
|
166 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
167 |
+
|
168 |
+
if cu_slens is not None:
|
169 |
+
q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
|
170 |
+
k = k.permute(0, 2, 1, 3)
|
171 |
+
v = v.permute(0, 2, 1, 3)
|
172 |
+
max_seqlen = torch.max(cu_slens[1:] - cu_slens[:-1]).item()
|
173 |
+
x = flash_attn_varlen_func(
|
174 |
+
q.squeeze(0),
|
175 |
+
k.squeeze(0),
|
176 |
+
v.squeeze(0),
|
177 |
+
cu_seqlens_q=cu_slens,
|
178 |
+
cu_seqlens_k=cu_slens,
|
179 |
+
max_seqlen_q=max_seqlen,
|
180 |
+
max_seqlen_k=max_seqlen,
|
181 |
+
softmax_scale=self.scale,
|
182 |
+
causal=False,
|
183 |
+
)
|
184 |
+
|
185 |
+
x = x.reshape(B, N, -1)
|
186 |
+
x = self.proj(x)
|
187 |
+
x = self.proj_drop(x)
|
188 |
+
|
189 |
+
else:
|
190 |
+
q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
|
191 |
+
k = k.permute(0, 2, 1, 3)
|
192 |
+
v = v.permute(0, 2, 1, 3)
|
193 |
+
x = flash_attn_func(q, k, v, softmax_scale=self.scale) # -> b, n, h, c
|
194 |
+
|
195 |
+
x = x.reshape(B, N, -1)
|
196 |
+
x = self.proj(x)
|
197 |
+
x = self.proj_drop(x)
|
198 |
+
return x
|
199 |
+
|
200 |
+
|
201 |
+
class LayerScale(nn.Module):
|
202 |
+
def __init__(
|
203 |
+
self,
|
204 |
+
dim: int,
|
205 |
+
init_values: float = 1e-5,
|
206 |
+
inplace: bool = False,
|
207 |
+
) -> None:
|
208 |
+
super().__init__()
|
209 |
+
self.inplace = inplace
|
210 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
211 |
+
|
212 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
213 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
214 |
+
|
215 |
+
|
216 |
+
class Block(nn.Module):
|
217 |
+
def __init__(
|
218 |
+
self,
|
219 |
+
dim: int,
|
220 |
+
num_heads: int,
|
221 |
+
mlp_ratio: float = 4.0,
|
222 |
+
qkv_bias: bool = False,
|
223 |
+
qk_norm: bool = False,
|
224 |
+
proj_drop: float = 0.0,
|
225 |
+
attn_drop: float = 0.0,
|
226 |
+
init_values: Optional[float] = None,
|
227 |
+
drop_path: float = 0.0,
|
228 |
+
act_layer: nn.Module = nn.GELU,
|
229 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
230 |
+
mlp_layer: nn.Module = Mlp,
|
231 |
+
) -> None:
|
232 |
+
super().__init__()
|
233 |
+
self.norm1 = norm_layer(dim)
|
234 |
+
self.attn = Attention(
|
235 |
+
dim,
|
236 |
+
num_heads=num_heads,
|
237 |
+
qkv_bias=qkv_bias,
|
238 |
+
qk_norm=qk_norm,
|
239 |
+
attn_drop=attn_drop,
|
240 |
+
proj_drop=proj_drop,
|
241 |
+
norm_layer=norm_layer,
|
242 |
+
)
|
243 |
+
self.ls1 = (
|
244 |
+
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
245 |
+
)
|
246 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
247 |
+
|
248 |
+
self.norm2 = norm_layer(dim)
|
249 |
+
self.mlp = mlp_layer(
|
250 |
+
in_features=dim,
|
251 |
+
hidden_features=int(dim * mlp_ratio),
|
252 |
+
act_layer=act_layer,
|
253 |
+
drop=proj_drop,
|
254 |
+
)
|
255 |
+
self.ls2 = (
|
256 |
+
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
257 |
+
)
|
258 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
259 |
+
|
260 |
+
def forward(self, x: torch.Tensor, cu_slens=None) -> torch.Tensor:
|
261 |
+
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), cu_slens=cu_slens)))
|
262 |
+
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
|
263 |
+
return x
|
264 |
+
|
265 |
+
|
266 |
+
class VisionTransformer(nn.Module):
|
267 |
+
"""Vision Transformer
|
268 |
+
|
269 |
+
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
|
270 |
+
- https://arxiv.org/abs/2010.11929
|
271 |
+
"""
|
272 |
+
|
273 |
+
dynamic_img_size: Final[bool]
|
274 |
+
|
275 |
+
def __init__(
|
276 |
+
self,
|
277 |
+
img_size: Union[int, Tuple[int, int]] = 224,
|
278 |
+
patch_size: Union[int, Tuple[int, int]] = 16,
|
279 |
+
in_chans: int = 3,
|
280 |
+
num_classes: int = 1000,
|
281 |
+
global_pool: Literal["", "avg", "token", "map"] = "token",
|
282 |
+
embed_dim: int = 768,
|
283 |
+
depth: int = 12,
|
284 |
+
num_heads: int = 12,
|
285 |
+
mlp_ratio: float = 4.0,
|
286 |
+
qkv_bias: bool = True,
|
287 |
+
qk_norm: bool = False,
|
288 |
+
init_values: Optional[float] = None,
|
289 |
+
class_token: bool = True,
|
290 |
+
no_embed_class: bool = False,
|
291 |
+
reg_tokens: int = 0,
|
292 |
+
pre_norm: bool = False,
|
293 |
+
fc_norm: Optional[bool] = None,
|
294 |
+
dynamic_img_size: bool = False,
|
295 |
+
dynamic_img_pad: bool = False,
|
296 |
+
drop_rate: float = 0.0,
|
297 |
+
pos_drop_rate: float = 0.0,
|
298 |
+
patch_drop_rate: float = 0.0,
|
299 |
+
proj_drop_rate: float = 0.0,
|
300 |
+
attn_drop_rate: float = 0.0,
|
301 |
+
drop_path_rate: float = 0.0,
|
302 |
+
weight_init: Literal["skip", "jax", "jax_nlhb", "moco", ""] = "",
|
303 |
+
embed_layer: Callable = PatchEmbed,
|
304 |
+
norm_layer: Optional[LayerType] = None,
|
305 |
+
act_layer: Optional[LayerType] = None,
|
306 |
+
strict_img_size: bool = False,
|
307 |
+
block_fn: Type[nn.Module] = Block,
|
308 |
+
mlp_layer: Type[nn.Module] = Mlp,
|
309 |
+
ignore_head: bool = False,
|
310 |
+
) -> None:
|
311 |
+
"""
|
312 |
+
Args:
|
313 |
+
img_size: Input image size.
|
314 |
+
patch_size: Patch size.
|
315 |
+
in_chans: Number of image input channels.
|
316 |
+
num_classes: Mumber of classes for classification head.
|
317 |
+
global_pool: Type of global pooling for final sequence (default: 'token').
|
318 |
+
embed_dim: Transformer embedding dimension.
|
319 |
+
depth: Depth of transformer.
|
320 |
+
num_heads: Number of attention heads.
|
321 |
+
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
|
322 |
+
qkv_bias: Enable bias for qkv projections if True.
|
323 |
+
init_values: Layer-scale init values (layer-scale enabled if not None).
|
324 |
+
class_token: Use class token.
|
325 |
+
no_embed_class: Don't include position embeddings for class (or reg) tokens.
|
326 |
+
reg_tokens: Number of register tokens.
|
327 |
+
fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
|
328 |
+
drop_rate: Head dropout rate.
|
329 |
+
pos_drop_rate: Position embedding dropout rate.
|
330 |
+
attn_drop_rate: Attention dropout rate.
|
331 |
+
drop_path_rate: Stochastic depth rate.
|
332 |
+
weight_init: Weight initialization scheme.
|
333 |
+
embed_layer: Patch embedding layer.
|
334 |
+
norm_layer: Normalization layer.
|
335 |
+
act_layer: MLP activation layer.
|
336 |
+
block_fn: Transformer block layer.
|
337 |
+
"""
|
338 |
+
super().__init__()
|
339 |
+
assert global_pool in ("", "avg", "token", "map")
|
340 |
+
assert class_token or global_pool != "token"
|
341 |
+
use_fc_norm = global_pool == "avg" if fc_norm is None else fc_norm
|
342 |
+
# norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
|
343 |
+
# act_layer = get_act_layer(act_layer) or nn.GELU
|
344 |
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
345 |
+
act_layer = nn.GELU
|
346 |
+
|
347 |
+
self.num_classes = num_classes
|
348 |
+
self.global_pool = global_pool
|
349 |
+
self.num_features = self.embed_dim = (
|
350 |
+
embed_dim # num_features for consistency with other models
|
351 |
+
)
|
352 |
+
self.num_prefix_tokens = 1 if class_token else 0
|
353 |
+
self.num_prefix_tokens += reg_tokens
|
354 |
+
self.num_reg_tokens = reg_tokens
|
355 |
+
self.has_class_token = class_token
|
356 |
+
self.no_embed_class = (
|
357 |
+
no_embed_class # don't embed prefix positions (includes reg)
|
358 |
+
)
|
359 |
+
self.dynamic_img_size = dynamic_img_size
|
360 |
+
self.grad_checkpointing = False
|
361 |
+
self.ignore_head = ignore_head
|
362 |
+
|
363 |
+
embed_args = {}
|
364 |
+
if dynamic_img_size:
|
365 |
+
# flatten deferred until after pos embed
|
366 |
+
embed_args.update(dict(strict_img_size=False, output_fmt="NHWC"))
|
367 |
+
self.patch_embed = embed_layer(
|
368 |
+
img_size=img_size,
|
369 |
+
patch_size=patch_size,
|
370 |
+
in_chans=in_chans,
|
371 |
+
embed_dim=embed_dim,
|
372 |
+
bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
|
373 |
+
dynamic_img_pad=dynamic_img_pad,
|
374 |
+
strict_img_size=strict_img_size,
|
375 |
+
**embed_args,
|
376 |
+
)
|
377 |
+
num_patches = self.patch_embed.num_patches
|
378 |
+
|
379 |
+
self.cls_token = (
|
380 |
+
nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
|
381 |
+
)
|
382 |
+
self.reg_token = (
|
383 |
+
nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None
|
384 |
+
)
|
385 |
+
embed_len = (
|
386 |
+
num_patches if no_embed_class else num_patches + self.num_prefix_tokens
|
387 |
+
)
|
388 |
+
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02)
|
389 |
+
self.pos_drop = nn.Dropout(p=pos_drop_rate)
|
390 |
+
if patch_drop_rate > 0:
|
391 |
+
self.patch_drop = PatchDropout(
|
392 |
+
patch_drop_rate,
|
393 |
+
num_prefix_tokens=self.num_prefix_tokens,
|
394 |
+
)
|
395 |
+
else:
|
396 |
+
self.patch_drop = nn.Identity()
|
397 |
+
self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
|
398 |
+
|
399 |
+
dpr = [
|
400 |
+
x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
401 |
+
] # stochastic depth decay rule
|
402 |
+
self.blocks = nn.Sequential(
|
403 |
+
*[
|
404 |
+
block_fn(
|
405 |
+
dim=embed_dim,
|
406 |
+
num_heads=num_heads,
|
407 |
+
mlp_ratio=mlp_ratio,
|
408 |
+
qkv_bias=qkv_bias,
|
409 |
+
qk_norm=qk_norm,
|
410 |
+
init_values=init_values,
|
411 |
+
proj_drop=proj_drop_rate,
|
412 |
+
attn_drop=attn_drop_rate,
|
413 |
+
drop_path=dpr[i],
|
414 |
+
norm_layer=norm_layer,
|
415 |
+
act_layer=act_layer,
|
416 |
+
mlp_layer=mlp_layer,
|
417 |
+
)
|
418 |
+
for i in range(depth)
|
419 |
+
]
|
420 |
+
)
|
421 |
+
|
422 |
+
def init_weights(self, mode: Literal["jax", "jax_nlhb", "moco", ""] = "") -> None:
|
423 |
+
assert mode in ("jax", "jax_nlhb", "moco", "")
|
424 |
+
# head_bias = -math.log(self.num_classes) if "nlhb" in mode else 0.0
|
425 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
426 |
+
if self.cls_token is not None:
|
427 |
+
nn.init.normal_(self.cls_token, std=1e-6)
|
428 |
+
named_apply(init_weights_vit_timm, self)
|
429 |
+
|
430 |
+
@torch.jit.ignore
|
431 |
+
def no_weight_decay(self) -> Set:
|
432 |
+
return {"pos_embed", "cls_token", "dist_token"}
|
433 |
+
|
434 |
+
@torch.jit.ignore
|
435 |
+
def group_matcher(self, coarse: bool = False) -> Dict:
|
436 |
+
return dict(
|
437 |
+
stem=r"^cls_token|pos_embed|patch_embed", # stem and embed
|
438 |
+
blocks=[(r"^blocks\.(\d+)", None), (r"^norm", (99999,))],
|
439 |
+
)
|
440 |
+
|
441 |
+
@torch.jit.ignore
|
442 |
+
def set_grad_checkpointing(self, enable: bool = True) -> None:
|
443 |
+
self.grad_checkpointing = enable
|
444 |
+
|
445 |
+
@torch.jit.ignore
|
446 |
+
def get_classifier(self) -> nn.Module:
|
447 |
+
return self.head
|
448 |
+
|
449 |
+
def reset_classifier(self, num_classes: int, global_pool=None) -> None:
|
450 |
+
self.num_classes = num_classes
|
451 |
+
if global_pool is not None:
|
452 |
+
assert global_pool in ("", "avg", "token", "map")
|
453 |
+
if global_pool == "map" and self.attn_pool is None:
|
454 |
+
assert (
|
455 |
+
False
|
456 |
+
), "Cannot currently add attention pooling in reset_classifier()."
|
457 |
+
elif global_pool != "map " and self.attn_pool is not None:
|
458 |
+
self.attn_pool = None # remove attention pooling
|
459 |
+
self.global_pool = global_pool
|
460 |
+
self.head = (
|
461 |
+
nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
462 |
+
)
|
463 |
+
|
464 |
+
def rescale_positional_embedding(self, out_size):
|
465 |
+
h, w = out_size
|
466 |
+
pos_embed_shape = int((self.pos_embed.shape[1]) ** 0.5)
|
467 |
+
if (h, w) == (pos_embed_shape, pos_embed_shape):
|
468 |
+
return self.pos_embed
|
469 |
+
rescaled_positional_embedding = \
|
470 |
+
self.pos_embed.new_zeros(1, h*w, self.pos_embed.shape[2])
|
471 |
+
pe_2d = self.pos_embed[0].T.contiguous().view(1, -1, pos_embed_shape, pos_embed_shape)
|
472 |
+
pe_2d = F.interpolate(pe_2d, out_size, mode='bilinear', align_corners=False).view(-1, h*w)
|
473 |
+
rescaled_positional_embedding[0] = pe_2d.T.contiguous()
|
474 |
+
return rescaled_positional_embedding
|
475 |
+
|
476 |
+
def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
|
477 |
+
if self.dynamic_img_size:
|
478 |
+
B, H, W, C = x.shape
|
479 |
+
pos_embed = resample_abs_pos_embed(
|
480 |
+
self.pos_embed,
|
481 |
+
(H, W),
|
482 |
+
num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
|
483 |
+
)
|
484 |
+
x = x.view(B, -1, C)
|
485 |
+
else:
|
486 |
+
pos_embed = self.pos_embed
|
487 |
+
|
488 |
+
to_cat = []
|
489 |
+
if self.cls_token is not None:
|
490 |
+
to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
|
491 |
+
if self.reg_token is not None:
|
492 |
+
to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
|
493 |
+
|
494 |
+
if self.no_embed_class:
|
495 |
+
# deit-3, updated JAX (big vision)
|
496 |
+
# position embedding does not overlap with class token, add then concat
|
497 |
+
x = x + pos_embed
|
498 |
+
if to_cat:
|
499 |
+
x = torch.cat(to_cat + [x], dim=1)
|
500 |
+
else:
|
501 |
+
# original timm, JAX, and deit vit impl
|
502 |
+
# pos_embed has entry for class token, concat then add
|
503 |
+
if to_cat:
|
504 |
+
x = torch.cat(to_cat + [x], dim=1)
|
505 |
+
x = x + pos_embed
|
506 |
+
|
507 |
+
return self.pos_drop(x)
|
508 |
+
|
509 |
+
def _intermediate_layers(
|
510 |
+
self,
|
511 |
+
x: torch.Tensor,
|
512 |
+
n: Union[int, Sequence] = 1,
|
513 |
+
) -> List[torch.Tensor]:
|
514 |
+
outputs, num_blocks = [], len(self.blocks)
|
515 |
+
take_indices = set(
|
516 |
+
range(num_blocks - n, num_blocks) if isinstance(n, int) else n
|
517 |
+
)
|
518 |
+
|
519 |
+
# forward pass
|
520 |
+
x = self.patch_embed(x)
|
521 |
+
x = self._pos_embed(x)
|
522 |
+
x = self.patch_drop(x)
|
523 |
+
x = self.norm_pre(x)
|
524 |
+
for i, blk in enumerate(self.blocks):
|
525 |
+
x = blk(x)
|
526 |
+
if i in take_indices:
|
527 |
+
outputs.append(x)
|
528 |
+
|
529 |
+
return outputs
|
530 |
+
|
531 |
+
def get_intermediate_layers(
|
532 |
+
self,
|
533 |
+
x: torch.Tensor,
|
534 |
+
n: Union[int, Sequence] = 1,
|
535 |
+
reshape: bool = False,
|
536 |
+
return_prefix_tokens: bool = False,
|
537 |
+
norm: bool = False,
|
538 |
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
539 |
+
"""Intermediate layer accessor (NOTE: This is a WIP experiment).
|
540 |
+
Inspired by DINO / DINOv2 interface
|
541 |
+
"""
|
542 |
+
# take last n blocks if n is an int, if in is a sequence, select by matching indices
|
543 |
+
outputs = self._intermediate_layers(x, n)
|
544 |
+
if norm:
|
545 |
+
outputs = [self.norm(out) for out in outputs]
|
546 |
+
prefix_tokens = [out[:, 0 : self.num_prefix_tokens] for out in outputs]
|
547 |
+
outputs = [out[:, self.num_prefix_tokens :] for out in outputs]
|
548 |
+
|
549 |
+
if reshape:
|
550 |
+
grid_size = self.patch_embed.grid_size
|
551 |
+
outputs = [
|
552 |
+
out.reshape(x.shape[0], grid_size[0], grid_size[1], -1)
|
553 |
+
.permute(0, 3, 1, 2)
|
554 |
+
.contiguous()
|
555 |
+
for out in outputs
|
556 |
+
]
|
557 |
+
|
558 |
+
if return_prefix_tokens:
|
559 |
+
return tuple(zip(outputs, prefix_tokens))
|
560 |
+
return tuple(outputs)
|
561 |
+
|
562 |
+
def forward_features_list(self, x_list):
|
563 |
+
x_all = []
|
564 |
+
image_sizes = []
|
565 |
+
for x in x_list:
|
566 |
+
bs, _, h, w = x.shape
|
567 |
+
|
568 |
+
# fix patch size=14 in datasets
|
569 |
+
pad_h = (self.patch_embed.patch_size[0] - h % self.patch_embed.patch_size[0]) % self.patch_embed.patch_size[0]
|
570 |
+
pad_w = (self.patch_embed.patch_size[1] - w % self.patch_embed.patch_size[1]) % self.patch_embed.patch_size[1]
|
571 |
+
x = F.pad(x, (0, pad_w, 0, pad_h))
|
572 |
+
|
573 |
+
bs, _, h, w = x.shape
|
574 |
+
|
575 |
+
h = h // self.patch_embed.patch_size[0]
|
576 |
+
w = w // self.patch_embed.patch_size[1]
|
577 |
+
|
578 |
+
x = self.patch_embed(x)
|
579 |
+
x = x + self.rescale_positional_embedding(out_size=(h, w))
|
580 |
+
x = self.patch_drop(x)
|
581 |
+
x = self.norm_pre(x)
|
582 |
+
x_all.append(x)
|
583 |
+
image_sizes.append((h, w))
|
584 |
+
|
585 |
+
slen = [xi.size(1) for xi in x_all]
|
586 |
+
x = torch.cat(x_all, dim=1)
|
587 |
+
|
588 |
+
cu_indices = [0, ]
|
589 |
+
for i in slen:
|
590 |
+
cu_indices.append(cu_indices[-1] + i)
|
591 |
+
|
592 |
+
cu_slens = torch.tensor(cu_indices, dtype=torch.int32).to(x.device)
|
593 |
+
for idx, blk in enumerate(self.blocks):
|
594 |
+
if self.grad_checkpointing and not torch.jit.is_scripting():
|
595 |
+
x = checkpoint(blk, x, cu_slens, use_reentrant=True)
|
596 |
+
else:
|
597 |
+
x = blk(x, cu_slens=cu_slens)
|
598 |
+
feats = x.split(slen, dim=1) #[(1, slen, c)]
|
599 |
+
return feats, image_sizes
|
600 |
+
|
601 |
+
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
602 |
+
bs, _, h, w = x.shape
|
603 |
+
h = h // self.patch_embed.patch_size[0]
|
604 |
+
w = w // self.patch_embed.patch_size[1]
|
605 |
+
|
606 |
+
x = self.patch_embed(x)
|
607 |
+
# x = self._pos_embed(x)
|
608 |
+
x = x + self.rescale_positional_embedding(out_size=(h, w))
|
609 |
+
x = self.patch_drop(x)
|
610 |
+
x = self.norm_pre(x)
|
611 |
+
if self.grad_checkpointing and not torch.jit.is_scripting():
|
612 |
+
x = checkpoint_seq(self.blocks, x)
|
613 |
+
else:
|
614 |
+
x = self.blocks(x)
|
615 |
+
return x, (h, w)
|
616 |
+
|
617 |
+
def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
|
618 |
+
x = self.norm(x)
|
619 |
+
if self.attn_pool is not None:
|
620 |
+
x = self.attn_pool(x)
|
621 |
+
elif self.global_pool == "avg":
|
622 |
+
x = x[:, self.num_prefix_tokens :].mean(dim=1)
|
623 |
+
elif self.global_pool:
|
624 |
+
x = x[:, 0] # class token
|
625 |
+
x = self.fc_norm(x)
|
626 |
+
x = self.head_drop(x)
|
627 |
+
return x if pre_logits else self.head(x)
|
628 |
+
|
629 |
+
def forward(self, x, cal_attn_pool=False):
|
630 |
+
if type(x) is list:
|
631 |
+
x, image_sizes = self.forward_features_list(x)
|
632 |
+
return x, image_sizes, None
|
633 |
+
else:
|
634 |
+
x, image_sizes = self.forward_features(x)
|
635 |
+
return x, image_sizes, None
|
636 |
+
|
637 |
+
@dataclass
|
638 |
+
class SigLIPVisionCfg:
|
639 |
+
width: int = 1152
|
640 |
+
layers: Union[Tuple[int, int, int, int], int] = 27
|
641 |
+
heads: int = 16
|
642 |
+
patch_size: int = 14
|
643 |
+
image_size: Union[Tuple[int, int], int] = 336
|
644 |
+
global_pool: str = "map"
|
645 |
+
mlp_ratio: float = 3.7362
|
646 |
+
class_token: bool = False
|
647 |
+
num_classes: int = 0
|
648 |
+
use_checkpoint: bool = False
|
649 |
+
|
650 |
+
|
651 |
+
SigLIP_MODEL_CONFIG = {
|
652 |
+
"siglip_so400m_patch14_384": {
|
653 |
+
"image_size": 384,
|
654 |
+
"patch_size": 14,
|
655 |
+
"width": 1152,
|
656 |
+
"layers": 27,
|
657 |
+
"heads": 16,
|
658 |
+
"mlp_ratio": 3.7362,
|
659 |
+
"global_pool": "map",
|
660 |
+
"use_checkpoint": False,
|
661 |
+
},
|
662 |
+
"siglip_so400m_patch16_384": {
|
663 |
+
"image_size": 384,
|
664 |
+
"patch_size": 16,
|
665 |
+
"width": 1152,
|
666 |
+
"layers": 27,
|
667 |
+
"heads": 16,
|
668 |
+
"mlp_ratio": 3.7362,
|
669 |
+
"global_pool": "map",
|
670 |
+
"use_checkpoint": False,
|
671 |
+
},
|
672 |
+
"siglip_so400m_patch14_224": {
|
673 |
+
"image_size": 224,
|
674 |
+
"patch_size": 14,
|
675 |
+
"width": 1152,
|
676 |
+
"layers": 27,
|
677 |
+
"heads": 16,
|
678 |
+
"mlp_ratio": 3.7362,
|
679 |
+
"global_pool": "map",
|
680 |
+
"use_checkpoint": False,
|
681 |
+
},
|
682 |
+
"siglip_large_patch16_384": {
|
683 |
+
"image_size": 384,
|
684 |
+
"patch_size": 16,
|
685 |
+
"width": 1024,
|
686 |
+
"layers": 24,
|
687 |
+
"heads": 16,
|
688 |
+
"mlp_ratio": 4,
|
689 |
+
"global_pool": "map",
|
690 |
+
"use_checkpoint": False,
|
691 |
+
},
|
692 |
+
}
|
693 |
+
|
694 |
+
def resize_evaclip_pos_embed(model: VisionTransformer, interpolation: str = 'bicubic'):
|
695 |
+
# interpolate position embedding
|
696 |
+
orig_size = 24
|
697 |
+
new_size = 128
|
698 |
+
pos_tokens = model.pos_embed
|
699 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, model.embed_dim).permute(0, 3, 1, 2)
|
700 |
+
pos_tokens = torch.nn.functional.interpolate(
|
701 |
+
pos_tokens, size=(new_size, new_size), mode=interpolation, align_corners=False)
|
702 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
703 |
+
model.pos_embed = nn.Parameter(pos_tokens, requires_grad=True)
|
704 |
+
return model
|
705 |
+
|
706 |
+
def create_siglip_vit(
|
707 |
+
model_name: str = "siglip_so400m_patch14_384",
|
708 |
+
image_size: int = 384,
|
709 |
+
select_layer: int = -1,
|
710 |
+
path: str = "",
|
711 |
+
gradient_checkpointing: bool = False,
|
712 |
+
**kwargs,
|
713 |
+
):
|
714 |
+
assert (
|
715 |
+
model_name in SigLIP_MODEL_CONFIG.keys()
|
716 |
+
), f"model name should be in {SigLIP_MODEL_CONFIG.keys()}"
|
717 |
+
|
718 |
+
vision_cfg = SigLIPVisionCfg(**SigLIP_MODEL_CONFIG[model_name])
|
719 |
+
|
720 |
+
if select_layer <= 0:
|
721 |
+
layers = min(vision_cfg.layers, vision_cfg.layers + select_layer + 1)
|
722 |
+
else:
|
723 |
+
layers = min(vision_cfg.layers, select_layer)
|
724 |
+
|
725 |
+
model = VisionTransformer(
|
726 |
+
img_size=2048,
|
727 |
+
patch_size=16,
|
728 |
+
embed_dim=vision_cfg.width,
|
729 |
+
depth=layers,
|
730 |
+
num_heads=vision_cfg.heads,
|
731 |
+
mlp_ratio=vision_cfg.mlp_ratio,
|
732 |
+
class_token=vision_cfg.class_token,
|
733 |
+
global_pool=vision_cfg.global_pool,
|
734 |
+
dynamic_img_pad=False,
|
735 |
+
strict_img_size=False,
|
736 |
+
ignore_head=kwargs.get("ignore_head", False),
|
737 |
+
weight_init=kwargs.get("weight_init", "skip"),
|
738 |
+
num_classes=0
|
739 |
+
)
|
740 |
+
|
741 |
+
if path is not None and os.path.exists(path):
|
742 |
+
ckpt = path
|
743 |
+
else:
|
744 |
+
raise ValueError(f"Model checkpoint not found at {path}")
|
745 |
+
# state_dict = torch.load(ckpt, map_location="cpu")
|
746 |
+
# print('loading vision backbone from', path)
|
747 |
+
|
748 |
+
# msg = model.load_state_dict(state_dict, strict=False)
|
749 |
+
# print(msg)
|
750 |
+
|
751 |
+
if gradient_checkpointing:
|
752 |
+
model.set_grad_checkpointing(True)
|
753 |
+
return model
|
754 |
+
|
755 |
+
import os
|
756 |
+
|
757 |
+
from transformers import CLIPImageProcessor
|
758 |
+
import torch.distributed as dist
|
759 |
+
|
760 |
+
class OryxViTWrapper(nn.Module):
|
761 |
+
def __init__(self, vision_tower, path, args, delay_load=False):
|
762 |
+
super().__init__()
|
763 |
+
|
764 |
+
self.is_loaded = False
|
765 |
+
|
766 |
+
self.vision_tower_name = vision_tower
|
767 |
+
self.args = args
|
768 |
+
self.path = path
|
769 |
+
|
770 |
+
self.select_layer = -1
|
771 |
+
if self.select_layer < -1: self.select_layer += 1
|
772 |
+
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
|
773 |
+
|
774 |
+
self.output_dim = 1152
|
775 |
+
self.vision_tower = create_siglip_vit(path=self.path, model_name='siglip_so400m_patch16_384',
|
776 |
+
gradient_checkpointing=False)
|
777 |
+
if not delay_load:
|
778 |
+
self.load_model()
|
779 |
+
elif getattr(args, "unfreeze_mm_vision_tower", False):
|
780 |
+
# TODO: better detector is needed.
|
781 |
+
print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.")
|
782 |
+
self.load_model()
|
783 |
+
|
784 |
+
def load_model(self, device_map=None):
|
785 |
+
if self.is_loaded:
|
786 |
+
print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
|
787 |
+
return
|
788 |
+
|
789 |
+
self.image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14")
|
790 |
+
self.image_processor.image_mean = [0.5, 0.5, 0.5]
|
791 |
+
self.image_processor.image_std = [0.5, 0.5, 0.5]
|
792 |
+
print("Loading vision model...")
|
793 |
+
|
794 |
+
# self.vision_tower = create_siglip_vit(path=self.path, model_name='siglip_so400m_patch16_384',
|
795 |
+
# gradient_checkpointing=False)
|
796 |
+
for p in self.vision_tower.parameters():
|
797 |
+
p.requires_grad = False
|
798 |
+
self.vision_tower.eval()
|
799 |
+
self.is_loaded = True
|
800 |
+
|
801 |
+
def train(self, mode = True):
|
802 |
+
self.training = mode
|
803 |
+
|
804 |
+
if self.is_loaded:
|
805 |
+
self.vision_tower.eval()
|
806 |
+
|
807 |
+
def forward_func(self, images, force_fix_size=False, cal_attn_pool=False):
|
808 |
+
if type(images) is list:
|
809 |
+
xs = [x.to(self.dtype) for x in images]
|
810 |
+
image_features, img_size, cls_token = self.vision_tower(xs, cal_attn_pool=cal_attn_pool)
|
811 |
+
image_features = [x.to(images[0].dtype) for x in image_features]
|
812 |
+
|
813 |
+
else:
|
814 |
+
image_forward_outs, img_size, cls_token = self.vision_tower(images.to(self.dtype), cal_attn_pool=cal_attn_pool)
|
815 |
+
image_features = image_forward_outs.to(images.dtype)
|
816 |
+
|
817 |
+
return image_features, img_size, cls_token
|
818 |
+
|
819 |
+
def forward(self, images, cal_attn_pool=False):
|
820 |
+
with torch.no_grad():
|
821 |
+
image_features, img_size, cls_token = self.forward_func(images, cal_attn_pool=cal_attn_pool)
|
822 |
+
return image_features, img_size
|
823 |
+
|
824 |
+
@property
|
825 |
+
def dummy_feature(self):
|
826 |
+
return torch.zeros(1, 1152, device=self.device, dtype=self.dtype)
|
827 |
+
|
828 |
+
@property
|
829 |
+
def dtype(self):
|
830 |
+
return self.vision_tower.pos_embed.dtype
|
831 |
+
|
832 |
+
@property
|
833 |
+
def device(self):
|
834 |
+
return self.vision_tower.pos_embed.device
|
835 |
+
|
836 |
+
@property
|
837 |
+
def hidden_size(self):
|
838 |
+
return self.output_dim
|
839 |
+
|
840 |
+
@property
|
841 |
+
def config(self):
|
842 |
+
return type('OryxConfigWrapper', (), {
|
843 |
+
'patch_size': 16,
|
844 |
+
})()
|
oryx/model/multimodal_projector/builder.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import re
|
4 |
+
|
5 |
+
import math
|
6 |
+
|
7 |
+
class IdentityMap(nn.Module):
|
8 |
+
def __init__(self):
|
9 |
+
super().__init__()
|
10 |
+
|
11 |
+
def forward(self, x, *args, **kwargs):
|
12 |
+
return x
|
13 |
+
|
14 |
+
@property
|
15 |
+
def config(self):
|
16 |
+
return {"mm_projector_type": 'identity'}
|
17 |
+
|
18 |
+
|
19 |
+
class SimpleResBlock(nn.Module):
|
20 |
+
def __init__(self, channels):
|
21 |
+
super().__init__()
|
22 |
+
self.pre_norm = nn.LayerNorm(channels)
|
23 |
+
|
24 |
+
self.proj = nn.Sequential(
|
25 |
+
nn.Linear(channels, channels),
|
26 |
+
nn.GELU(),
|
27 |
+
nn.Linear(channels, channels)
|
28 |
+
)
|
29 |
+
def forward(self, x):
|
30 |
+
x = self.pre_norm(x)
|
31 |
+
return x + self.proj(x)
|
32 |
+
|
33 |
+
class SimpleMlp(nn.Module):
|
34 |
+
def __init__(self, in_channels, out_channels, twoview=False):
|
35 |
+
super().__init__()
|
36 |
+
self.proj = nn.Sequential(
|
37 |
+
nn.Linear(in_channels, out_channels),
|
38 |
+
nn.GELU(),
|
39 |
+
nn.Linear(out_channels, out_channels)
|
40 |
+
)
|
41 |
+
|
42 |
+
embed_std = 1 / math.sqrt(out_channels)
|
43 |
+
self.image_newline = nn.Parameter(
|
44 |
+
torch.randn(out_channels) * embed_std
|
45 |
+
)
|
46 |
+
self.image_begin = nn.Parameter(
|
47 |
+
torch.randn(out_channels) * embed_std
|
48 |
+
)
|
49 |
+
self.image_end = nn.Parameter(
|
50 |
+
torch.randn(out_channels) * embed_std
|
51 |
+
)
|
52 |
+
|
53 |
+
if twoview:
|
54 |
+
self.image_sep = nn.Parameter(
|
55 |
+
torch.randn(out_channels) * embed_std
|
56 |
+
)
|
57 |
+
|
58 |
+
def forward(self, x, size=(16,16), x2=None, size2=(16, 16), modalities='image'):
|
59 |
+
|
60 |
+
if modalities in ['image', 'text']:
|
61 |
+
h, w = size
|
62 |
+
dtype = x.dtype
|
63 |
+
x = x.reshape(x.shape[0], h, w, -1)
|
64 |
+
x = self.proj(x) #b,h,w, c
|
65 |
+
b, h, w, c = x.shape
|
66 |
+
x = torch.cat([
|
67 |
+
x,
|
68 |
+
self.image_newline.reshape(1, 1, 1, c).expand(b, h, 1, c).to(dtype)
|
69 |
+
], dim=2)
|
70 |
+
x = x.reshape(b, -1, c)
|
71 |
+
|
72 |
+
if x2 is not None:
|
73 |
+
h2, w2 = size2
|
74 |
+
x2 = x2.reshape(x2.shape[0], h2, w2, -1)
|
75 |
+
x2 = self.proj(x2) #b,h,w, c
|
76 |
+
b2, h2, w2, c2 = x2.shape
|
77 |
+
x2 = torch.cat([
|
78 |
+
x2,
|
79 |
+
self.image_newline.reshape(1, 1, 1, c).expand(b, h2, 1, c).to(dtype)
|
80 |
+
], dim=2)
|
81 |
+
x2 = x2.reshape(b, -1, c)
|
82 |
+
sep = self.image_sep.reshape(1, 1, -1).expand(b, 1, c2).to(dtype)
|
83 |
+
x = torch.cat([x, sep, x2], dim=1)
|
84 |
+
|
85 |
+
assert b == 1
|
86 |
+
assert b2 == 1 # only support batch size 1
|
87 |
+
|
88 |
+
begin = self.image_begin.reshape(1, 1, -1).expand(b, 1, c).to(dtype)
|
89 |
+
end = self.image_end.reshape(1, 1, -1).expand(b, 1, c).to(dtype)
|
90 |
+
x = torch.cat([begin, x, end], dim=1)
|
91 |
+
return x
|
92 |
+
elif modalities in ['video', 'video_long']:
|
93 |
+
# x2 is the true feature, ignore x
|
94 |
+
h, w = size
|
95 |
+
dtype = x.dtype
|
96 |
+
x = x.reshape(x.shape[0], h, w, -1)
|
97 |
+
x = self.proj(x).mean() * 0.0
|
98 |
+
|
99 |
+
h2, w2 = size2
|
100 |
+
x2 = x2.reshape(x2.shape[0], h2, w2, -1)
|
101 |
+
x2 = self.proj(x2) + x #b, h, w, c
|
102 |
+
|
103 |
+
b2, h2, w2, c = x2.shape
|
104 |
+
x2 = torch.cat([
|
105 |
+
x2,
|
106 |
+
self.image_newline.reshape(1, 1, 1, c).expand(b2, h2, 1, c).to(dtype)
|
107 |
+
], dim=2)
|
108 |
+
|
109 |
+
x2 = x2.reshape(b2, -1, c)
|
110 |
+
|
111 |
+
sep = self.image_sep.reshape(1, 1, -1).expand(b2, 1, c).to(dtype)
|
112 |
+
x2 = torch.cat([x2, sep], dim=1)
|
113 |
+
|
114 |
+
x2 = x2.flatten(0, 1)
|
115 |
+
|
116 |
+
begin = self.image_begin.reshape(1, -1).expand(1, c).to(dtype)
|
117 |
+
end = self.image_end.reshape(1, -1).expand(1, c).to(dtype)
|
118 |
+
x2 = torch.cat([begin, x2, end], dim=0)
|
119 |
+
x2 = x2.unsqueeze(0)
|
120 |
+
return x2
|
121 |
+
|
122 |
+
def build_vision_projector(config, delay_load=False, **kwargs):
|
123 |
+
projector_type = getattr(config, 'mm_projector_type', 'linear')
|
124 |
+
|
125 |
+
if projector_type == 'linear':
|
126 |
+
return nn.Linear(config.mm_hidden_size, config.hidden_size)
|
127 |
+
|
128 |
+
elif projector_type == 'simple_mlp_twoview':
|
129 |
+
return SimpleMlp(config.mm_hidden_size, config.hidden_size, twoview=True)
|
130 |
+
|
131 |
+
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
|
132 |
+
if mlp_gelu_match:
|
133 |
+
mlp_depth = int(mlp_gelu_match.group(1))
|
134 |
+
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
|
135 |
+
for _ in range(1, mlp_depth):
|
136 |
+
modules.append(nn.GELU())
|
137 |
+
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
|
138 |
+
return nn.Sequential(*modules)
|
139 |
+
|
140 |
+
mlp_gelu_resnet_match = re.match(r'^mlp(\d+)x_res(\d+)x_gelu$', projector_type)
|
141 |
+
if mlp_gelu_resnet_match:
|
142 |
+
mlp_depth = int(mlp_gelu_resnet_match.group(1))
|
143 |
+
res_depth = int(mlp_gelu_resnet_match.group(2))
|
144 |
+
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
|
145 |
+
for _ in range(1, mlp_depth):
|
146 |
+
modules.append(nn.GELU())
|
147 |
+
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
|
148 |
+
for _ in range(res_depth):
|
149 |
+
modules.append(SimpleResBlock(config.hidden_size))
|
150 |
+
return nn.Sequential(*modules)
|
151 |
+
|
152 |
+
if projector_type == 'identity':
|
153 |
+
return IdentityMap()
|
154 |
+
|
155 |
+
raise ValueError(f'Unknown projector type: {projector_type}')
|
oryx/model/multimodal_resampler/builder.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from .masked_drop import MaskedDrop
|
4 |
+
from .spatial_pool import SpatialPool
|
5 |
+
from .qformer import Qformer
|
6 |
+
from .vlm_attention import VlmAttention
|
7 |
+
from .perceiver import DynamicCompressor
|
8 |
+
|
9 |
+
class IdentityMap(torch.nn.Module):
|
10 |
+
def __init__(self):
|
11 |
+
super().__init__()
|
12 |
+
|
13 |
+
def forward(self, x, *args, **kwargs):
|
14 |
+
return x
|
15 |
+
|
16 |
+
@property
|
17 |
+
def config(self):
|
18 |
+
return {"mm_resampler_type": None}
|
19 |
+
|
20 |
+
def build_vision_resampler(model_args, delay_load=False, **kwargs):
|
21 |
+
# import pdb;pdb.set_trace()
|
22 |
+
resampler_type = getattr(model_args, 'mm_resampler_type', None)
|
23 |
+
if resampler_type == 'masked_drop':
|
24 |
+
return MaskedDrop(model_args)
|
25 |
+
elif resampler_type == 'spatial_pool':
|
26 |
+
return SpatialPool(model_args, **kwargs)
|
27 |
+
elif resampler_type == 'qformer':
|
28 |
+
return Qformer(model_args, **kwargs)
|
29 |
+
elif resampler_type == 'vlm_attention':
|
30 |
+
return VlmAttention(model_args,**kwargs)
|
31 |
+
elif resampler_type == 'dynamic_compressor':
|
32 |
+
return DynamicCompressor(model_args, **kwargs)
|
33 |
+
elif resampler_type is None:
|
34 |
+
return IdentityMap()
|
35 |
+
else:
|
36 |
+
raise ValueError(f'Unknown resampler type: {resampler_type}')
|
oryx/model/multimodal_resampler/masked_drop.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
import random
|
5 |
+
|
6 |
+
|
7 |
+
class MaskedDrop(nn.Module):
|
8 |
+
def __init__(self, model_args):
|
9 |
+
super().__init__()
|
10 |
+
|
11 |
+
self.mode = model_args.mm_mask_drop_mode
|
12 |
+
self.skip_percentage = model_args.mm_mask_drop_skip_percentage
|
13 |
+
self.ratio = model_args.mm_mask_drop_ratio
|
14 |
+
self.ratio_upper = model_args.mm_mask_drop_ratio_upper
|
15 |
+
self.ratio_lower = model_args.mm_mask_drop_ratio_lower
|
16 |
+
|
17 |
+
def forward(self, image_features, *args, **kwargs):
|
18 |
+
|
19 |
+
if not self.training:
|
20 |
+
return image_features
|
21 |
+
|
22 |
+
if self.skip_percentage > random.random():
|
23 |
+
return image_features
|
24 |
+
|
25 |
+
masked_features = []
|
26 |
+
|
27 |
+
for image_feature in image_features:
|
28 |
+
num_tokens = image_feature.shape[0]
|
29 |
+
if self.mode == 'fixed':
|
30 |
+
num_keep = int(num_tokens * self.ratio)
|
31 |
+
masked_features.append(self.random_masking(image_feature.unsqueeze(0), num_keep)[0][0])
|
32 |
+
elif self.mode == 'range':
|
33 |
+
num_keep = int(num_tokens * random.uniform(self.ratio_lower, self.ratio_upper))
|
34 |
+
masked_features.append(self.random_masking(image_feature.unsqueeze(0), num_keep)[0])
|
35 |
+
elif self.mode == 'cls_only':
|
36 |
+
masked_features.append(image_feature[0:1])
|
37 |
+
else:
|
38 |
+
raise ValueError(f'Unexpected masked drop mode: {self.mode}')
|
39 |
+
|
40 |
+
if self.mode not in ['range'] and \
|
41 |
+
(type(image_features) is not list or self.mode in ['cls_only']):
|
42 |
+
masked_features = torch.stack(masked_features, dim=0)
|
43 |
+
|
44 |
+
return masked_features
|
45 |
+
|
46 |
+
@property
|
47 |
+
def config(self):
|
48 |
+
return {
|
49 |
+
'mm_resampler_type': 'masked_drop',
|
50 |
+
'mm_mask_drop_mode': self.mode,
|
51 |
+
'mm_mask_drop_skip_percentage': self.skip_percentage,
|
52 |
+
'mm_mask_drop_ratio': self.ratio,
|
53 |
+
'mm_mask_drop_ratio_upper': self.ratio_upper,
|
54 |
+
'mm_mask_drop_ratio_lower': self.ratio_lower,
|
55 |
+
}
|
56 |
+
|
57 |
+
def random_masking(self, x, len_keep):
|
58 |
+
"""
|
59 |
+
Perform per-sample random masking by per-sample shuffling.
|
60 |
+
Per-sample shuffling is done by argsort random noise.
|
61 |
+
x: [N, L, D], sequence
|
62 |
+
"""
|
63 |
+
N, L, D = x.shape # batch, length, dim
|
64 |
+
|
65 |
+
noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
|
66 |
+
|
67 |
+
# sort noise for each sample
|
68 |
+
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
|
69 |
+
ids_restore = torch.argsort(ids_shuffle, dim=1)
|
70 |
+
|
71 |
+
# keep the first subset
|
72 |
+
ids_keep = ids_shuffle[:, :len_keep]
|
73 |
+
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
|
74 |
+
|
75 |
+
# generate the binary mask: 0 is keep, 1 is remove
|
76 |
+
mask = torch.ones([N, L], device=x.device)
|
77 |
+
mask[:, :len_keep] = 0
|
78 |
+
# unshuffle to get the binary mask
|
79 |
+
mask = torch.gather(mask, dim=1, index=ids_restore)
|
80 |
+
|
81 |
+
return x_masked, mask, ids_restore
|
82 |
+
|
oryx/model/multimodal_resampler/perceiver.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import math
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
class DynamicCompressor(nn.Module):
|
7 |
+
def __init__(self, model_args, vision_tower):
|
8 |
+
super().__init__()
|
9 |
+
|
10 |
+
self.out_channels = vision_tower.hidden_size
|
11 |
+
self.mid_channel = 256
|
12 |
+
|
13 |
+
self.vlm_query_projector = nn.Linear(self.out_channels, self.mid_channel)
|
14 |
+
self.vlm_key_projector = nn.Linear(self.out_channels, self.mid_channel)
|
15 |
+
|
16 |
+
def downsample(self, x):
|
17 |
+
return F.avg_pool2d(x, 2, 2)
|
18 |
+
|
19 |
+
def downsample_4(self, x):
|
20 |
+
return F.avg_pool2d(x, 4, 4)
|
21 |
+
|
22 |
+
def forward(self, image_features, forward_type, image_size=None):
|
23 |
+
if image_size is None:
|
24 |
+
ori_W = int(math.sqrt(image_features.shape[1]))
|
25 |
+
ori_H = int(ori_W)
|
26 |
+
else:
|
27 |
+
ori_H, ori_W = image_size
|
28 |
+
T, N, C = image_features.shape
|
29 |
+
image_features = image_features.view(T, ori_H, ori_W, C).permute(0, 3, 1, 2) # T, C, H, W
|
30 |
+
|
31 |
+
if forward_type == 'video':
|
32 |
+
image_features_pool = self.downsample(image_features)
|
33 |
+
image_feature_attn = image_features.reshape(T, C, ori_H // 2, 2, ori_W // 2, 2).permute(0, 2, 4, 3, 5, 1).reshape(T, ori_H // 2 * ori_W // 2, 4, C)
|
34 |
+
new_image_size = (ori_H // 2, ori_W // 2)
|
35 |
+
elif forward_type == 'image' or forward_type == 'text':
|
36 |
+
image_features_pool = image_features
|
37 |
+
image_feature_attn = image_features.reshape(T, C, ori_H, 1, ori_W, 1).permute(0, 2, 4, 3, 5, 1).reshape(T, ori_H * ori_W, 1, C)
|
38 |
+
new_image_size = (ori_H, ori_W)
|
39 |
+
elif forward_type == 'video_long':
|
40 |
+
image_features_pool = self.downsample_4(image_features)
|
41 |
+
image_feature_attn = image_features.reshape(T, C, ori_H // 4, 4, ori_W // 4, 4).permute(0, 2, 4, 3, 5, 1).reshape(T, ori_H // 4 * ori_W // 4, 16, C)
|
42 |
+
new_image_size = (ori_H // 4, ori_W // 4)
|
43 |
+
else:
|
44 |
+
raise NotImplementedError
|
45 |
+
|
46 |
+
image_features_pool = image_features_pool.flatten(2).permute(0, 2, 1) # T, H*W, C
|
47 |
+
new_t, new_p, _ = image_features_pool.shape
|
48 |
+
|
49 |
+
image_query = self.vlm_query_projector(image_features_pool).reshape(new_t*new_p, self.mid_channel)
|
50 |
+
image_key = self.vlm_key_projector(image_feature_attn).reshape(new_t*new_p, -1, self.mid_channel)
|
51 |
+
|
52 |
+
image_value = image_feature_attn.reshape(new_t*new_p, -1, self.out_channels)
|
53 |
+
image_attn = image_query[:,None] @ (image_key.transpose(-1,-2) / (image_key.shape[-1]**0.5))
|
54 |
+
image_attn = image_attn.nan_to_num()
|
55 |
+
attn_feat = (image_attn.softmax(-1) @ image_value).mean(1).reshape(new_t, new_p, C)
|
56 |
+
|
57 |
+
image_features_pool = image_features_pool + attn_feat
|
58 |
+
|
59 |
+
return image_features_pool, new_image_size
|
60 |
+
|
61 |
+
@property
|
62 |
+
def config(self):
|
63 |
+
return {
|
64 |
+
'mm_resampler_type': 'dynamic_compressor',
|
65 |
+
'mm_out_channels': self.out_channels,
|
66 |
+
}
|
67 |
+
|
68 |
+
@property
|
69 |
+
def hidden_size(self):
|
70 |
+
return self.out_channels
|
oryx/model/multimodal_resampler/qformer.py
ADDED
@@ -0,0 +1,1287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
* Copyright (c) 2023, salesforce.com, inc.
|
3 |
+
* All rights reserved.
|
4 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
* By Junnan Li
|
7 |
+
* Based on huggingface code base
|
8 |
+
* https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
|
9 |
+
"""
|
10 |
+
|
11 |
+
import math
|
12 |
+
import os
|
13 |
+
import warnings
|
14 |
+
from dataclasses import dataclass
|
15 |
+
from typing import Optional, Tuple, Dict, Any
|
16 |
+
|
17 |
+
import torch
|
18 |
+
from torch import Tensor, device, dtype, nn
|
19 |
+
import torch.utils.checkpoint
|
20 |
+
from torch import nn
|
21 |
+
from torch.nn import CrossEntropyLoss
|
22 |
+
import torch.nn.functional as F
|
23 |
+
|
24 |
+
from transformers.activations import ACT2FN
|
25 |
+
from transformers.file_utils import (
|
26 |
+
ModelOutput,
|
27 |
+
)
|
28 |
+
from transformers.modeling_outputs import (
|
29 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
30 |
+
BaseModelOutputWithPoolingAndCrossAttentions,
|
31 |
+
CausalLMOutputWithCrossAttentions,
|
32 |
+
MaskedLMOutput,
|
33 |
+
MultipleChoiceModelOutput,
|
34 |
+
NextSentencePredictorOutput,
|
35 |
+
QuestionAnsweringModelOutput,
|
36 |
+
SequenceClassifierOutput,
|
37 |
+
TokenClassifierOutput,
|
38 |
+
)
|
39 |
+
from transformers.modeling_utils import (
|
40 |
+
PreTrainedModel,
|
41 |
+
apply_chunking_to_forward,
|
42 |
+
find_pruneable_heads_and_indices,
|
43 |
+
prune_linear_layer,
|
44 |
+
)
|
45 |
+
from transformers.utils import logging
|
46 |
+
from transformers.models.bert.configuration_bert import BertConfig
|
47 |
+
|
48 |
+
logger = logging.get_logger(__name__)
|
49 |
+
|
50 |
+
|
51 |
+
def disabled_train(self, mode=True):
|
52 |
+
"""Overwrite model.train with this function to make sure train/eval mode
|
53 |
+
does not change anymore."""
|
54 |
+
return self
|
55 |
+
|
56 |
+
|
57 |
+
class BertEmbeddings(nn.Module):
|
58 |
+
"""Construct the embeddings from word and position embeddings."""
|
59 |
+
|
60 |
+
def __init__(self, config):
|
61 |
+
super().__init__()
|
62 |
+
self.word_embeddings = nn.Embedding(
|
63 |
+
config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
|
64 |
+
)
|
65 |
+
self.position_embeddings = nn.Embedding(
|
66 |
+
config.max_position_embeddings, config.hidden_size
|
67 |
+
)
|
68 |
+
|
69 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
70 |
+
# any TensorFlow checkpoint file
|
71 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
72 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
73 |
+
|
74 |
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
75 |
+
self.register_buffer(
|
76 |
+
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))
|
77 |
+
)
|
78 |
+
self.position_embedding_type = getattr(
|
79 |
+
config, "position_embedding_type", "absolute"
|
80 |
+
)
|
81 |
+
|
82 |
+
self.config = config
|
83 |
+
|
84 |
+
def forward(
|
85 |
+
self,
|
86 |
+
input_ids=None,
|
87 |
+
position_ids=None,
|
88 |
+
query_embeds=None,
|
89 |
+
past_key_values_length=0,
|
90 |
+
):
|
91 |
+
if input_ids is not None:
|
92 |
+
seq_length = input_ids.size()[1]
|
93 |
+
else:
|
94 |
+
seq_length = 0
|
95 |
+
|
96 |
+
if position_ids is None:
|
97 |
+
position_ids = self.position_ids[
|
98 |
+
:, past_key_values_length : seq_length + past_key_values_length
|
99 |
+
].clone()
|
100 |
+
|
101 |
+
if input_ids is not None:
|
102 |
+
embeddings = self.word_embeddings(input_ids)
|
103 |
+
if self.position_embedding_type == "absolute":
|
104 |
+
position_embeddings = self.position_embeddings(position_ids)
|
105 |
+
embeddings = embeddings + position_embeddings
|
106 |
+
|
107 |
+
if query_embeds is not None:
|
108 |
+
embeddings = torch.cat((query_embeds, embeddings), dim=1)
|
109 |
+
else:
|
110 |
+
embeddings = query_embeds
|
111 |
+
|
112 |
+
embeddings = self.LayerNorm(embeddings)
|
113 |
+
embeddings = self.dropout(embeddings)
|
114 |
+
return embeddings
|
115 |
+
|
116 |
+
|
117 |
+
class BertSelfAttention(nn.Module):
|
118 |
+
def __init__(self, config, is_cross_attention):
|
119 |
+
super().__init__()
|
120 |
+
self.config = config
|
121 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
|
122 |
+
config, "embedding_size"
|
123 |
+
):
|
124 |
+
raise ValueError(
|
125 |
+
"The hidden size (%d) is not a multiple of the number of attention "
|
126 |
+
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
|
127 |
+
)
|
128 |
+
|
129 |
+
self.num_attention_heads = config.num_attention_heads
|
130 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
131 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
132 |
+
|
133 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
134 |
+
if is_cross_attention:
|
135 |
+
self.key = nn.Linear(config.encoder_width, self.all_head_size)
|
136 |
+
self.value = nn.Linear(config.encoder_width, self.all_head_size)
|
137 |
+
else:
|
138 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
139 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
140 |
+
|
141 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
142 |
+
self.position_embedding_type = getattr(
|
143 |
+
config, "position_embedding_type", "absolute"
|
144 |
+
)
|
145 |
+
if (
|
146 |
+
self.position_embedding_type == "relative_key"
|
147 |
+
or self.position_embedding_type == "relative_key_query"
|
148 |
+
):
|
149 |
+
self.max_position_embeddings = config.max_position_embeddings
|
150 |
+
self.distance_embedding = nn.Embedding(
|
151 |
+
2 * config.max_position_embeddings - 1, self.attention_head_size
|
152 |
+
)
|
153 |
+
self.save_attention = False
|
154 |
+
|
155 |
+
def save_attn_gradients(self, attn_gradients):
|
156 |
+
self.attn_gradients = attn_gradients
|
157 |
+
|
158 |
+
def get_attn_gradients(self):
|
159 |
+
return self.attn_gradients
|
160 |
+
|
161 |
+
def save_attention_map(self, attention_map):
|
162 |
+
self.attention_map = attention_map
|
163 |
+
|
164 |
+
def get_attention_map(self):
|
165 |
+
return self.attention_map
|
166 |
+
|
167 |
+
def transpose_for_scores(self, x):
|
168 |
+
new_x_shape = x.size()[:-1] + (
|
169 |
+
self.num_attention_heads,
|
170 |
+
self.attention_head_size,
|
171 |
+
)
|
172 |
+
x = x.view(*new_x_shape)
|
173 |
+
return x.permute(0, 2, 1, 3)
|
174 |
+
|
175 |
+
def forward(
|
176 |
+
self,
|
177 |
+
hidden_states,
|
178 |
+
attention_mask=None,
|
179 |
+
head_mask=None,
|
180 |
+
encoder_hidden_states=None,
|
181 |
+
encoder_attention_mask=None,
|
182 |
+
past_key_value=None,
|
183 |
+
output_attentions=False,
|
184 |
+
):
|
185 |
+
|
186 |
+
# If this is instantiated as a cross-attention module, the keys
|
187 |
+
# and values come from an encoder; the attention mask needs to be
|
188 |
+
# such that the encoder's padding tokens are not attended to.
|
189 |
+
is_cross_attention = encoder_hidden_states is not None
|
190 |
+
|
191 |
+
if is_cross_attention:
|
192 |
+
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
193 |
+
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
194 |
+
attention_mask = encoder_attention_mask
|
195 |
+
elif past_key_value is not None:
|
196 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
197 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
198 |
+
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
199 |
+
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
200 |
+
else:
|
201 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
202 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
203 |
+
|
204 |
+
mixed_query_layer = self.query(hidden_states)
|
205 |
+
|
206 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
207 |
+
|
208 |
+
past_key_value = (key_layer, value_layer)
|
209 |
+
|
210 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
211 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
212 |
+
|
213 |
+
if (
|
214 |
+
self.position_embedding_type == "relative_key"
|
215 |
+
or self.position_embedding_type == "relative_key_query"
|
216 |
+
):
|
217 |
+
seq_length = hidden_states.size()[1]
|
218 |
+
position_ids_l = torch.arange(
|
219 |
+
seq_length, dtype=torch.long, device=hidden_states.device
|
220 |
+
).view(-1, 1)
|
221 |
+
position_ids_r = torch.arange(
|
222 |
+
seq_length, dtype=torch.long, device=hidden_states.device
|
223 |
+
).view(1, -1)
|
224 |
+
distance = position_ids_l - position_ids_r
|
225 |
+
positional_embedding = self.distance_embedding(
|
226 |
+
distance + self.max_position_embeddings - 1
|
227 |
+
)
|
228 |
+
positional_embedding = positional_embedding.to(
|
229 |
+
dtype=query_layer.dtype
|
230 |
+
) # fp16 compatibility
|
231 |
+
|
232 |
+
if self.position_embedding_type == "relative_key":
|
233 |
+
relative_position_scores = torch.einsum(
|
234 |
+
"bhld,lrd->bhlr", query_layer, positional_embedding
|
235 |
+
)
|
236 |
+
attention_scores = attention_scores + relative_position_scores
|
237 |
+
elif self.position_embedding_type == "relative_key_query":
|
238 |
+
relative_position_scores_query = torch.einsum(
|
239 |
+
"bhld,lrd->bhlr", query_layer, positional_embedding
|
240 |
+
)
|
241 |
+
relative_position_scores_key = torch.einsum(
|
242 |
+
"bhrd,lrd->bhlr", key_layer, positional_embedding
|
243 |
+
)
|
244 |
+
attention_scores = (
|
245 |
+
attention_scores
|
246 |
+
+ relative_position_scores_query
|
247 |
+
+ relative_position_scores_key
|
248 |
+
)
|
249 |
+
|
250 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
251 |
+
if attention_mask is not None:
|
252 |
+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
253 |
+
attention_scores = attention_scores + attention_mask
|
254 |
+
|
255 |
+
# Normalize the attention scores to probabilities.
|
256 |
+
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
257 |
+
|
258 |
+
if is_cross_attention and self.save_attention:
|
259 |
+
self.save_attention_map(attention_probs)
|
260 |
+
attention_probs.register_hook(self.save_attn_gradients)
|
261 |
+
|
262 |
+
# This is actually dropping out entire tokens to attend to, which might
|
263 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
264 |
+
attention_probs_dropped = self.dropout(attention_probs)
|
265 |
+
|
266 |
+
# Mask heads if we want to
|
267 |
+
if head_mask is not None:
|
268 |
+
attention_probs_dropped = attention_probs_dropped * head_mask
|
269 |
+
|
270 |
+
context_layer = torch.matmul(attention_probs_dropped, value_layer)
|
271 |
+
|
272 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
273 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
274 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
275 |
+
|
276 |
+
outputs = (
|
277 |
+
(context_layer, attention_probs) if output_attentions else (context_layer,)
|
278 |
+
)
|
279 |
+
|
280 |
+
outputs = outputs + (past_key_value,)
|
281 |
+
return outputs
|
282 |
+
|
283 |
+
|
284 |
+
class BertSelfOutput(nn.Module):
|
285 |
+
def __init__(self, config):
|
286 |
+
super().__init__()
|
287 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
288 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
289 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
290 |
+
|
291 |
+
def forward(self, hidden_states, input_tensor):
|
292 |
+
hidden_states = self.dense(hidden_states)
|
293 |
+
hidden_states = self.dropout(hidden_states)
|
294 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
295 |
+
return hidden_states
|
296 |
+
|
297 |
+
|
298 |
+
class BertAttention(nn.Module):
|
299 |
+
def __init__(self, config, is_cross_attention=False):
|
300 |
+
super().__init__()
|
301 |
+
self.self = BertSelfAttention(config, is_cross_attention)
|
302 |
+
self.output = BertSelfOutput(config)
|
303 |
+
self.pruned_heads = set()
|
304 |
+
|
305 |
+
def prune_heads(self, heads):
|
306 |
+
if len(heads) == 0:
|
307 |
+
return
|
308 |
+
heads, index = find_pruneable_heads_and_indices(
|
309 |
+
heads,
|
310 |
+
self.self.num_attention_heads,
|
311 |
+
self.self.attention_head_size,
|
312 |
+
self.pruned_heads,
|
313 |
+
)
|
314 |
+
|
315 |
+
# Prune linear layers
|
316 |
+
self.self.query = prune_linear_layer(self.self.query, index)
|
317 |
+
self.self.key = prune_linear_layer(self.self.key, index)
|
318 |
+
self.self.value = prune_linear_layer(self.self.value, index)
|
319 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
320 |
+
|
321 |
+
# Update hyper params and store pruned heads
|
322 |
+
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
323 |
+
self.self.all_head_size = (
|
324 |
+
self.self.attention_head_size * self.self.num_attention_heads
|
325 |
+
)
|
326 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
327 |
+
|
328 |
+
def forward(
|
329 |
+
self,
|
330 |
+
hidden_states,
|
331 |
+
attention_mask=None,
|
332 |
+
head_mask=None,
|
333 |
+
encoder_hidden_states=None,
|
334 |
+
encoder_attention_mask=None,
|
335 |
+
past_key_value=None,
|
336 |
+
output_attentions=False,
|
337 |
+
):
|
338 |
+
self_outputs = self.self(
|
339 |
+
hidden_states,
|
340 |
+
attention_mask,
|
341 |
+
head_mask,
|
342 |
+
encoder_hidden_states,
|
343 |
+
encoder_attention_mask,
|
344 |
+
past_key_value,
|
345 |
+
output_attentions,
|
346 |
+
)
|
347 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
348 |
+
|
349 |
+
outputs = (attention_output,) + self_outputs[
|
350 |
+
1:
|
351 |
+
] # add attentions if we output them
|
352 |
+
return outputs
|
353 |
+
|
354 |
+
|
355 |
+
class BertIntermediate(nn.Module):
|
356 |
+
def __init__(self, config):
|
357 |
+
super().__init__()
|
358 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
359 |
+
if isinstance(config.hidden_act, str):
|
360 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
361 |
+
else:
|
362 |
+
self.intermediate_act_fn = config.hidden_act
|
363 |
+
|
364 |
+
def forward(self, hidden_states):
|
365 |
+
hidden_states = self.dense(hidden_states)
|
366 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
367 |
+
return hidden_states
|
368 |
+
|
369 |
+
|
370 |
+
class BertOutput(nn.Module):
|
371 |
+
def __init__(self, config):
|
372 |
+
super().__init__()
|
373 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
374 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
375 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
376 |
+
|
377 |
+
def forward(self, hidden_states, input_tensor):
|
378 |
+
hidden_states = self.dense(hidden_states)
|
379 |
+
hidden_states = self.dropout(hidden_states)
|
380 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
381 |
+
return hidden_states
|
382 |
+
|
383 |
+
|
384 |
+
class BertLayer(nn.Module):
|
385 |
+
def __init__(self, config, layer_num):
|
386 |
+
super().__init__()
|
387 |
+
self.config = config
|
388 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
389 |
+
self.seq_len_dim = 1
|
390 |
+
self.attention = BertAttention(config)
|
391 |
+
self.layer_num = layer_num
|
392 |
+
if (
|
393 |
+
self.config.add_cross_attention
|
394 |
+
and layer_num % self.config.cross_attention_freq == 0
|
395 |
+
):
|
396 |
+
self.crossattention = BertAttention(
|
397 |
+
config, is_cross_attention=self.config.add_cross_attention
|
398 |
+
)
|
399 |
+
self.has_cross_attention = True
|
400 |
+
else:
|
401 |
+
self.has_cross_attention = False
|
402 |
+
self.intermediate = BertIntermediate(config)
|
403 |
+
self.output = BertOutput(config)
|
404 |
+
|
405 |
+
self.intermediate_query = BertIntermediate(config)
|
406 |
+
self.output_query = BertOutput(config)
|
407 |
+
|
408 |
+
def forward(
|
409 |
+
self,
|
410 |
+
hidden_states,
|
411 |
+
attention_mask=None,
|
412 |
+
head_mask=None,
|
413 |
+
encoder_hidden_states=None,
|
414 |
+
encoder_attention_mask=None,
|
415 |
+
past_key_value=None,
|
416 |
+
output_attentions=False,
|
417 |
+
query_length=0,
|
418 |
+
):
|
419 |
+
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
420 |
+
self_attn_past_key_value = (
|
421 |
+
past_key_value[:2] if past_key_value is not None else None
|
422 |
+
)
|
423 |
+
self_attention_outputs = self.attention(
|
424 |
+
hidden_states,
|
425 |
+
attention_mask,
|
426 |
+
head_mask,
|
427 |
+
output_attentions=output_attentions,
|
428 |
+
past_key_value=self_attn_past_key_value,
|
429 |
+
)
|
430 |
+
attention_output = self_attention_outputs[0]
|
431 |
+
outputs = self_attention_outputs[1:-1]
|
432 |
+
|
433 |
+
present_key_value = self_attention_outputs[-1]
|
434 |
+
|
435 |
+
if query_length > 0:
|
436 |
+
query_attention_output = attention_output[:, :query_length, :]
|
437 |
+
|
438 |
+
if self.has_cross_attention:
|
439 |
+
assert (
|
440 |
+
encoder_hidden_states is not None
|
441 |
+
), "encoder_hidden_states must be given for cross-attention layers"
|
442 |
+
cross_attention_outputs = self.crossattention(
|
443 |
+
query_attention_output,
|
444 |
+
attention_mask,
|
445 |
+
head_mask,
|
446 |
+
encoder_hidden_states,
|
447 |
+
encoder_attention_mask,
|
448 |
+
output_attentions=output_attentions,
|
449 |
+
)
|
450 |
+
query_attention_output = cross_attention_outputs[0]
|
451 |
+
outputs = (
|
452 |
+
outputs + cross_attention_outputs[1:-1]
|
453 |
+
) # add cross attentions if we output attention weights
|
454 |
+
|
455 |
+
layer_output = apply_chunking_to_forward(
|
456 |
+
self.feed_forward_chunk_query,
|
457 |
+
self.chunk_size_feed_forward,
|
458 |
+
self.seq_len_dim,
|
459 |
+
query_attention_output,
|
460 |
+
)
|
461 |
+
if attention_output.shape[1] > query_length:
|
462 |
+
layer_output_text = apply_chunking_to_forward(
|
463 |
+
self.feed_forward_chunk,
|
464 |
+
self.chunk_size_feed_forward,
|
465 |
+
self.seq_len_dim,
|
466 |
+
attention_output[:, query_length:, :],
|
467 |
+
)
|
468 |
+
layer_output = torch.cat([layer_output, layer_output_text], dim=1)
|
469 |
+
else:
|
470 |
+
layer_output = apply_chunking_to_forward(
|
471 |
+
self.feed_forward_chunk,
|
472 |
+
self.chunk_size_feed_forward,
|
473 |
+
self.seq_len_dim,
|
474 |
+
attention_output,
|
475 |
+
)
|
476 |
+
outputs = (layer_output,) + outputs
|
477 |
+
|
478 |
+
outputs = outputs + (present_key_value,)
|
479 |
+
|
480 |
+
return outputs
|
481 |
+
|
482 |
+
def feed_forward_chunk(self, attention_output):
|
483 |
+
intermediate_output = self.intermediate(attention_output)
|
484 |
+
layer_output = self.output(intermediate_output, attention_output)
|
485 |
+
return layer_output
|
486 |
+
|
487 |
+
def feed_forward_chunk_query(self, attention_output):
|
488 |
+
intermediate_output = self.intermediate_query(attention_output)
|
489 |
+
layer_output = self.output_query(intermediate_output, attention_output)
|
490 |
+
return layer_output
|
491 |
+
|
492 |
+
|
493 |
+
class BertEncoder(nn.Module):
|
494 |
+
def __init__(self, config):
|
495 |
+
super().__init__()
|
496 |
+
self.config = config
|
497 |
+
self.layer = nn.ModuleList(
|
498 |
+
[BertLayer(config, i) for i in range(config.num_hidden_layers)]
|
499 |
+
)
|
500 |
+
|
501 |
+
def forward(
|
502 |
+
self,
|
503 |
+
hidden_states,
|
504 |
+
attention_mask=None,
|
505 |
+
head_mask=None,
|
506 |
+
encoder_hidden_states=None,
|
507 |
+
encoder_attention_mask=None,
|
508 |
+
past_key_values=None,
|
509 |
+
use_cache=None,
|
510 |
+
output_attentions=False,
|
511 |
+
output_hidden_states=False,
|
512 |
+
return_dict=True,
|
513 |
+
query_length=0,
|
514 |
+
):
|
515 |
+
all_hidden_states = () if output_hidden_states else None
|
516 |
+
all_self_attentions = () if output_attentions else None
|
517 |
+
all_cross_attentions = (
|
518 |
+
() if output_attentions and self.config.add_cross_attention else None
|
519 |
+
)
|
520 |
+
|
521 |
+
next_decoder_cache = () if use_cache else None
|
522 |
+
|
523 |
+
for i in range(self.config.num_hidden_layers):
|
524 |
+
layer_module = self.layer[i]
|
525 |
+
if output_hidden_states:
|
526 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
527 |
+
|
528 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
529 |
+
past_key_value = past_key_values[i] if past_key_values is not None else None
|
530 |
+
|
531 |
+
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
532 |
+
|
533 |
+
if use_cache:
|
534 |
+
logger.warn(
|
535 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
536 |
+
)
|
537 |
+
use_cache = False
|
538 |
+
|
539 |
+
def create_custom_forward(module):
|
540 |
+
def custom_forward(*inputs):
|
541 |
+
return module(
|
542 |
+
*inputs, past_key_value, output_attentions, query_length
|
543 |
+
)
|
544 |
+
|
545 |
+
return custom_forward
|
546 |
+
|
547 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
548 |
+
create_custom_forward(layer_module),
|
549 |
+
hidden_states,
|
550 |
+
attention_mask,
|
551 |
+
layer_head_mask,
|
552 |
+
encoder_hidden_states,
|
553 |
+
encoder_attention_mask,
|
554 |
+
)
|
555 |
+
else:
|
556 |
+
layer_outputs = layer_module(
|
557 |
+
hidden_states,
|
558 |
+
attention_mask,
|
559 |
+
layer_head_mask,
|
560 |
+
encoder_hidden_states,
|
561 |
+
encoder_attention_mask,
|
562 |
+
past_key_value,
|
563 |
+
output_attentions,
|
564 |
+
query_length,
|
565 |
+
)
|
566 |
+
|
567 |
+
hidden_states = layer_outputs[0]
|
568 |
+
if use_cache:
|
569 |
+
next_decoder_cache += (layer_outputs[-1],)
|
570 |
+
if output_attentions:
|
571 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
572 |
+
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
573 |
+
|
574 |
+
if output_hidden_states:
|
575 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
576 |
+
|
577 |
+
if not return_dict:
|
578 |
+
return tuple(
|
579 |
+
v
|
580 |
+
for v in [
|
581 |
+
hidden_states,
|
582 |
+
next_decoder_cache,
|
583 |
+
all_hidden_states,
|
584 |
+
all_self_attentions,
|
585 |
+
all_cross_attentions,
|
586 |
+
]
|
587 |
+
if v is not None
|
588 |
+
)
|
589 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
590 |
+
last_hidden_state=hidden_states,
|
591 |
+
past_key_values=next_decoder_cache,
|
592 |
+
hidden_states=all_hidden_states,
|
593 |
+
attentions=all_self_attentions,
|
594 |
+
cross_attentions=all_cross_attentions,
|
595 |
+
)
|
596 |
+
|
597 |
+
|
598 |
+
class BertPooler(nn.Module):
|
599 |
+
def __init__(self, config):
|
600 |
+
super().__init__()
|
601 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
602 |
+
self.activation = nn.Tanh()
|
603 |
+
|
604 |
+
def forward(self, hidden_states):
|
605 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
606 |
+
# to the first token.
|
607 |
+
first_token_tensor = hidden_states[:, 0]
|
608 |
+
pooled_output = self.dense(first_token_tensor)
|
609 |
+
pooled_output = self.activation(pooled_output)
|
610 |
+
return pooled_output
|
611 |
+
|
612 |
+
|
613 |
+
class BertPredictionHeadTransform(nn.Module):
|
614 |
+
def __init__(self, config):
|
615 |
+
super().__init__()
|
616 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
617 |
+
if isinstance(config.hidden_act, str):
|
618 |
+
self.transform_act_fn = ACT2FN[config.hidden_act]
|
619 |
+
else:
|
620 |
+
self.transform_act_fn = config.hidden_act
|
621 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
622 |
+
|
623 |
+
def forward(self, hidden_states):
|
624 |
+
hidden_states = self.dense(hidden_states)
|
625 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
626 |
+
hidden_states = self.LayerNorm(hidden_states)
|
627 |
+
return hidden_states
|
628 |
+
|
629 |
+
|
630 |
+
class BertLMPredictionHead(nn.Module):
|
631 |
+
def __init__(self, config):
|
632 |
+
super().__init__()
|
633 |
+
self.transform = BertPredictionHeadTransform(config)
|
634 |
+
|
635 |
+
# The output weights are the same as the input embeddings, but there is
|
636 |
+
# an output-only bias for each token.
|
637 |
+
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
638 |
+
|
639 |
+
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
640 |
+
|
641 |
+
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
642 |
+
self.decoder.bias = self.bias
|
643 |
+
|
644 |
+
def forward(self, hidden_states):
|
645 |
+
hidden_states = self.transform(hidden_states)
|
646 |
+
hidden_states = self.decoder(hidden_states)
|
647 |
+
return hidden_states
|
648 |
+
|
649 |
+
|
650 |
+
class BertOnlyMLMHead(nn.Module):
|
651 |
+
def __init__(self, config):
|
652 |
+
super().__init__()
|
653 |
+
self.predictions = BertLMPredictionHead(config)
|
654 |
+
|
655 |
+
def forward(self, sequence_output):
|
656 |
+
prediction_scores = self.predictions(sequence_output)
|
657 |
+
return prediction_scores
|
658 |
+
|
659 |
+
|
660 |
+
class BertPreTrainedModel(PreTrainedModel):
|
661 |
+
"""
|
662 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
663 |
+
models.
|
664 |
+
"""
|
665 |
+
|
666 |
+
config_class = BertConfig
|
667 |
+
base_model_prefix = "bert"
|
668 |
+
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
669 |
+
|
670 |
+
def _init_weights(self, module):
|
671 |
+
"""Initialize the weights"""
|
672 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
673 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
674 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
675 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
676 |
+
elif isinstance(module, nn.LayerNorm):
|
677 |
+
module.bias.data.zero_()
|
678 |
+
module.weight.data.fill_(1.0)
|
679 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
680 |
+
module.bias.data.zero_()
|
681 |
+
|
682 |
+
|
683 |
+
class BertModel(BertPreTrainedModel):
|
684 |
+
"""
|
685 |
+
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
686 |
+
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
|
687 |
+
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
|
688 |
+
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
689 |
+
argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
|
690 |
+
input to the forward pass.
|
691 |
+
"""
|
692 |
+
|
693 |
+
def __init__(self, config, add_pooling_layer=False):
|
694 |
+
super().__init__(config)
|
695 |
+
self.config = config
|
696 |
+
|
697 |
+
self.embeddings = BertEmbeddings(config)
|
698 |
+
|
699 |
+
self.encoder = BertEncoder(config)
|
700 |
+
|
701 |
+
self.pooler = BertPooler(config) if add_pooling_layer else None
|
702 |
+
|
703 |
+
self.init_weights()
|
704 |
+
|
705 |
+
def get_input_embeddings(self):
|
706 |
+
return self.embeddings.word_embeddings
|
707 |
+
|
708 |
+
def set_input_embeddings(self, value):
|
709 |
+
self.embeddings.word_embeddings = value
|
710 |
+
|
711 |
+
def _prune_heads(self, heads_to_prune):
|
712 |
+
"""
|
713 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
714 |
+
class PreTrainedModel
|
715 |
+
"""
|
716 |
+
for layer, heads in heads_to_prune.items():
|
717 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
718 |
+
|
719 |
+
def get_extended_attention_mask(
|
720 |
+
self,
|
721 |
+
attention_mask: Tensor,
|
722 |
+
input_shape: Tuple[int],
|
723 |
+
device: device,
|
724 |
+
is_decoder: bool,
|
725 |
+
has_query: bool = False,
|
726 |
+
) -> Tensor:
|
727 |
+
"""
|
728 |
+
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
|
729 |
+
|
730 |
+
Arguments:
|
731 |
+
attention_mask (:obj:`torch.Tensor`):
|
732 |
+
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
|
733 |
+
input_shape (:obj:`Tuple[int]`):
|
734 |
+
The shape of the input to the model.
|
735 |
+
device: (:obj:`torch.device`):
|
736 |
+
The device of the input to the model.
|
737 |
+
|
738 |
+
Returns:
|
739 |
+
:obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
|
740 |
+
"""
|
741 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
742 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
743 |
+
if attention_mask.dim() == 3:
|
744 |
+
extended_attention_mask = attention_mask[:, None, :, :]
|
745 |
+
elif attention_mask.dim() == 2:
|
746 |
+
# Provided a padding mask of dimensions [batch_size, seq_length]
|
747 |
+
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
748 |
+
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
749 |
+
if is_decoder:
|
750 |
+
batch_size, seq_length = input_shape
|
751 |
+
|
752 |
+
seq_ids = torch.arange(seq_length, device=device)
|
753 |
+
causal_mask = (
|
754 |
+
seq_ids[None, None, :].repeat(batch_size, seq_length, 1)
|
755 |
+
<= seq_ids[None, :, None]
|
756 |
+
)
|
757 |
+
|
758 |
+
# add a prefix ones mask to the causal mask
|
759 |
+
# causal and attention masks must have same type with pytorch version < 1.3
|
760 |
+
causal_mask = causal_mask.to(attention_mask.dtype)
|
761 |
+
|
762 |
+
if causal_mask.shape[1] < attention_mask.shape[1]:
|
763 |
+
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
|
764 |
+
if has_query: # UniLM style attention mask
|
765 |
+
causal_mask = torch.cat(
|
766 |
+
[
|
767 |
+
torch.zeros(
|
768 |
+
(batch_size, prefix_seq_len, seq_length),
|
769 |
+
device=device,
|
770 |
+
dtype=causal_mask.dtype,
|
771 |
+
),
|
772 |
+
causal_mask,
|
773 |
+
],
|
774 |
+
axis=1,
|
775 |
+
)
|
776 |
+
causal_mask = torch.cat(
|
777 |
+
[
|
778 |
+
torch.ones(
|
779 |
+
(batch_size, causal_mask.shape[1], prefix_seq_len),
|
780 |
+
device=device,
|
781 |
+
dtype=causal_mask.dtype,
|
782 |
+
),
|
783 |
+
causal_mask,
|
784 |
+
],
|
785 |
+
axis=-1,
|
786 |
+
)
|
787 |
+
extended_attention_mask = (
|
788 |
+
causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
789 |
+
)
|
790 |
+
else:
|
791 |
+
extended_attention_mask = attention_mask[:, None, None, :]
|
792 |
+
else:
|
793 |
+
raise ValueError(
|
794 |
+
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
|
795 |
+
input_shape, attention_mask.shape
|
796 |
+
)
|
797 |
+
)
|
798 |
+
|
799 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
800 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
801 |
+
# positions we want to attend and -10000.0 for masked positions.
|
802 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
803 |
+
# effectively the same as removing these entirely.
|
804 |
+
extended_attention_mask = extended_attention_mask.to(
|
805 |
+
dtype=self.dtype
|
806 |
+
) # fp16 compatibility
|
807 |
+
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
808 |
+
return extended_attention_mask
|
809 |
+
|
810 |
+
def forward(
|
811 |
+
self,
|
812 |
+
input_ids=None,
|
813 |
+
attention_mask=None,
|
814 |
+
position_ids=None,
|
815 |
+
head_mask=None,
|
816 |
+
query_embeds=None,
|
817 |
+
encoder_hidden_states=None,
|
818 |
+
encoder_attention_mask=None,
|
819 |
+
past_key_values=None,
|
820 |
+
use_cache=None,
|
821 |
+
output_attentions=None,
|
822 |
+
output_hidden_states=None,
|
823 |
+
return_dict=None,
|
824 |
+
is_decoder=False,
|
825 |
+
):
|
826 |
+
r"""
|
827 |
+
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
828 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
829 |
+
the model is configured as a decoder.
|
830 |
+
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
831 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
832 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
833 |
+
- 1 for tokens that are **not masked**,
|
834 |
+
- 0 for tokens that are **masked**.
|
835 |
+
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
836 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
837 |
+
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
838 |
+
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
839 |
+
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
840 |
+
use_cache (:obj:`bool`, `optional`):
|
841 |
+
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
842 |
+
decoding (see :obj:`past_key_values`).
|
843 |
+
"""
|
844 |
+
output_attentions = (
|
845 |
+
output_attentions
|
846 |
+
if output_attentions is not None
|
847 |
+
else self.config.output_attentions
|
848 |
+
)
|
849 |
+
output_hidden_states = (
|
850 |
+
output_hidden_states
|
851 |
+
if output_hidden_states is not None
|
852 |
+
else self.config.output_hidden_states
|
853 |
+
)
|
854 |
+
return_dict = (
|
855 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
856 |
+
)
|
857 |
+
|
858 |
+
# use_cache = use_cache if use_cache is not None else self.config.use_cache
|
859 |
+
|
860 |
+
if input_ids is None:
|
861 |
+
assert (
|
862 |
+
query_embeds is not None
|
863 |
+
), "You have to specify query_embeds when input_ids is None"
|
864 |
+
|
865 |
+
# past_key_values_length
|
866 |
+
past_key_values_length = (
|
867 |
+
past_key_values[0][0].shape[2] - self.config.query_length
|
868 |
+
if past_key_values is not None
|
869 |
+
else 0
|
870 |
+
)
|
871 |
+
|
872 |
+
query_length = query_embeds.shape[1] if query_embeds is not None else 0
|
873 |
+
|
874 |
+
embedding_output = self.embeddings(
|
875 |
+
input_ids=input_ids,
|
876 |
+
position_ids=position_ids,
|
877 |
+
query_embeds=query_embeds,
|
878 |
+
past_key_values_length=past_key_values_length,
|
879 |
+
)
|
880 |
+
|
881 |
+
input_shape = embedding_output.size()[:-1]
|
882 |
+
batch_size, seq_length = input_shape
|
883 |
+
device = embedding_output.device
|
884 |
+
|
885 |
+
if attention_mask is None:
|
886 |
+
attention_mask = torch.ones(
|
887 |
+
((batch_size, seq_length + past_key_values_length)), device=device
|
888 |
+
)
|
889 |
+
|
890 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
891 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
892 |
+
if is_decoder:
|
893 |
+
extended_attention_mask = self.get_extended_attention_mask(
|
894 |
+
attention_mask,
|
895 |
+
input_ids.shape,
|
896 |
+
device,
|
897 |
+
is_decoder,
|
898 |
+
has_query=(query_embeds is not None),
|
899 |
+
)
|
900 |
+
else:
|
901 |
+
extended_attention_mask = self.get_extended_attention_mask(
|
902 |
+
attention_mask, input_shape, device, is_decoder
|
903 |
+
)
|
904 |
+
|
905 |
+
# If a 2D or 3D attention mask is provided for the cross-attention
|
906 |
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
907 |
+
if encoder_hidden_states is not None:
|
908 |
+
if type(encoder_hidden_states) == list:
|
909 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[
|
910 |
+
0
|
911 |
+
].size()
|
912 |
+
else:
|
913 |
+
(
|
914 |
+
encoder_batch_size,
|
915 |
+
encoder_sequence_length,
|
916 |
+
_,
|
917 |
+
) = encoder_hidden_states.size()
|
918 |
+
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
919 |
+
|
920 |
+
if type(encoder_attention_mask) == list:
|
921 |
+
encoder_extended_attention_mask = [
|
922 |
+
self.invert_attention_mask(mask) for mask in encoder_attention_mask
|
923 |
+
]
|
924 |
+
elif encoder_attention_mask is None:
|
925 |
+
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
926 |
+
encoder_extended_attention_mask = self.invert_attention_mask(
|
927 |
+
encoder_attention_mask
|
928 |
+
)
|
929 |
+
else:
|
930 |
+
encoder_extended_attention_mask = self.invert_attention_mask(
|
931 |
+
encoder_attention_mask
|
932 |
+
)
|
933 |
+
else:
|
934 |
+
encoder_extended_attention_mask = None
|
935 |
+
|
936 |
+
# Prepare head mask if needed
|
937 |
+
# 1.0 in head_mask indicate we keep the head
|
938 |
+
# attention_probs has shape bsz x n_heads x N x N
|
939 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
940 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
941 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
942 |
+
|
943 |
+
encoder_outputs = self.encoder(
|
944 |
+
embedding_output,
|
945 |
+
attention_mask=extended_attention_mask,
|
946 |
+
head_mask=head_mask,
|
947 |
+
encoder_hidden_states=encoder_hidden_states,
|
948 |
+
encoder_attention_mask=encoder_extended_attention_mask,
|
949 |
+
past_key_values=past_key_values,
|
950 |
+
use_cache=use_cache,
|
951 |
+
output_attentions=output_attentions,
|
952 |
+
output_hidden_states=output_hidden_states,
|
953 |
+
return_dict=return_dict,
|
954 |
+
query_length=query_length,
|
955 |
+
)
|
956 |
+
sequence_output = encoder_outputs[0]
|
957 |
+
pooled_output = (
|
958 |
+
self.pooler(sequence_output) if self.pooler is not None else None
|
959 |
+
)
|
960 |
+
|
961 |
+
if not return_dict:
|
962 |
+
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
963 |
+
|
964 |
+
return BaseModelOutputWithPoolingAndCrossAttentions(
|
965 |
+
last_hidden_state=sequence_output,
|
966 |
+
pooler_output=pooled_output,
|
967 |
+
past_key_values=encoder_outputs.past_key_values,
|
968 |
+
hidden_states=encoder_outputs.hidden_states,
|
969 |
+
attentions=encoder_outputs.attentions,
|
970 |
+
cross_attentions=encoder_outputs.cross_attentions,
|
971 |
+
)
|
972 |
+
|
973 |
+
|
974 |
+
class BertLMHeadModel(BertPreTrainedModel):
|
975 |
+
|
976 |
+
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
977 |
+
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
978 |
+
|
979 |
+
def __init__(self, config):
|
980 |
+
super().__init__(config)
|
981 |
+
|
982 |
+
self.bert = BertModel(config, add_pooling_layer=False)
|
983 |
+
self.cls = BertOnlyMLMHead(config)
|
984 |
+
|
985 |
+
self.init_weights()
|
986 |
+
|
987 |
+
def get_output_embeddings(self):
|
988 |
+
return self.cls.predictions.decoder
|
989 |
+
|
990 |
+
def set_output_embeddings(self, new_embeddings):
|
991 |
+
self.cls.predictions.decoder = new_embeddings
|
992 |
+
|
993 |
+
def forward(
|
994 |
+
self,
|
995 |
+
input_ids=None,
|
996 |
+
attention_mask=None,
|
997 |
+
position_ids=None,
|
998 |
+
head_mask=None,
|
999 |
+
query_embeds=None,
|
1000 |
+
encoder_hidden_states=None,
|
1001 |
+
encoder_attention_mask=None,
|
1002 |
+
labels=None,
|
1003 |
+
past_key_values=None,
|
1004 |
+
use_cache=True,
|
1005 |
+
output_attentions=None,
|
1006 |
+
output_hidden_states=None,
|
1007 |
+
return_dict=None,
|
1008 |
+
return_logits=False,
|
1009 |
+
is_decoder=True,
|
1010 |
+
reduction="mean",
|
1011 |
+
):
|
1012 |
+
r"""
|
1013 |
+
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
1014 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
1015 |
+
the model is configured as a decoder.
|
1016 |
+
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
1017 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
1018 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
1019 |
+
- 1 for tokens that are **not masked**,
|
1020 |
+
- 0 for tokens that are **masked**.
|
1021 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
1022 |
+
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
|
1023 |
+
``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
|
1024 |
+
ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
|
1025 |
+
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
1026 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
1027 |
+
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
1028 |
+
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
1029 |
+
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
1030 |
+
use_cache (:obj:`bool`, `optional`):
|
1031 |
+
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
1032 |
+
decoding (see :obj:`past_key_values`).
|
1033 |
+
Returns:
|
1034 |
+
Example::
|
1035 |
+
>>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
|
1036 |
+
>>> import torch
|
1037 |
+
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
|
1038 |
+
>>> config = BertConfig.from_pretrained("bert-base-cased")
|
1039 |
+
>>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
|
1040 |
+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
1041 |
+
>>> outputs = model(**inputs)
|
1042 |
+
>>> prediction_logits = outputs.logits
|
1043 |
+
"""
|
1044 |
+
return_dict = (
|
1045 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
1046 |
+
)
|
1047 |
+
if labels is not None:
|
1048 |
+
use_cache = False
|
1049 |
+
if past_key_values is not None:
|
1050 |
+
query_embeds = None
|
1051 |
+
|
1052 |
+
outputs = self.bert(
|
1053 |
+
input_ids,
|
1054 |
+
attention_mask=attention_mask,
|
1055 |
+
position_ids=position_ids,
|
1056 |
+
head_mask=head_mask,
|
1057 |
+
query_embeds=query_embeds,
|
1058 |
+
encoder_hidden_states=encoder_hidden_states,
|
1059 |
+
encoder_attention_mask=encoder_attention_mask,
|
1060 |
+
past_key_values=past_key_values,
|
1061 |
+
use_cache=use_cache,
|
1062 |
+
output_attentions=output_attentions,
|
1063 |
+
output_hidden_states=output_hidden_states,
|
1064 |
+
return_dict=return_dict,
|
1065 |
+
is_decoder=is_decoder,
|
1066 |
+
)
|
1067 |
+
|
1068 |
+
sequence_output = outputs[0]
|
1069 |
+
if query_embeds is not None:
|
1070 |
+
sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
|
1071 |
+
|
1072 |
+
prediction_scores = self.cls(sequence_output)
|
1073 |
+
|
1074 |
+
if return_logits:
|
1075 |
+
return prediction_scores[:, :-1, :].contiguous()
|
1076 |
+
|
1077 |
+
lm_loss = None
|
1078 |
+
if labels is not None:
|
1079 |
+
# we are doing next-token prediction; shift prediction scores and input ids by one
|
1080 |
+
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
|
1081 |
+
labels = labels[:, 1:].contiguous()
|
1082 |
+
loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
|
1083 |
+
lm_loss = loss_fct(
|
1084 |
+
shifted_prediction_scores.view(-1, self.config.vocab_size),
|
1085 |
+
labels.view(-1),
|
1086 |
+
)
|
1087 |
+
if reduction == "none":
|
1088 |
+
lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
|
1089 |
+
|
1090 |
+
if not return_dict:
|
1091 |
+
output = (prediction_scores,) + outputs[2:]
|
1092 |
+
return ((lm_loss,) + output) if lm_loss is not None else output
|
1093 |
+
|
1094 |
+
return CausalLMOutputWithCrossAttentions(
|
1095 |
+
loss=lm_loss,
|
1096 |
+
logits=prediction_scores,
|
1097 |
+
past_key_values=outputs.past_key_values,
|
1098 |
+
hidden_states=outputs.hidden_states,
|
1099 |
+
attentions=outputs.attentions,
|
1100 |
+
cross_attentions=outputs.cross_attentions,
|
1101 |
+
)
|
1102 |
+
|
1103 |
+
def prepare_inputs_for_generation(
|
1104 |
+
self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs
|
1105 |
+
):
|
1106 |
+
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
1107 |
+
if attention_mask is None:
|
1108 |
+
attention_mask = input_ids.new_ones(input_ids.shape)
|
1109 |
+
query_mask = input_ids.new_ones(query_embeds.shape[:-1])
|
1110 |
+
attention_mask = torch.cat([query_mask, attention_mask], dim=-1)
|
1111 |
+
|
1112 |
+
# cut decoder_input_ids if past is used
|
1113 |
+
if past is not None:
|
1114 |
+
input_ids = input_ids[:, -1:]
|
1115 |
+
|
1116 |
+
return {
|
1117 |
+
"input_ids": input_ids,
|
1118 |
+
"query_embeds": query_embeds,
|
1119 |
+
"attention_mask": attention_mask,
|
1120 |
+
"past_key_values": past,
|
1121 |
+
"encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
|
1122 |
+
"encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
|
1123 |
+
"is_decoder": True,
|
1124 |
+
}
|
1125 |
+
|
1126 |
+
def _reorder_cache(self, past, beam_idx):
|
1127 |
+
reordered_past = ()
|
1128 |
+
for layer_past in past:
|
1129 |
+
reordered_past += (
|
1130 |
+
tuple(
|
1131 |
+
past_state.index_select(0, beam_idx) for past_state in layer_past
|
1132 |
+
),
|
1133 |
+
)
|
1134 |
+
return reordered_past
|
1135 |
+
|
1136 |
+
|
1137 |
+
class BertForMaskedLM(BertPreTrainedModel):
|
1138 |
+
|
1139 |
+
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
1140 |
+
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
1141 |
+
|
1142 |
+
def __init__(self, config):
|
1143 |
+
super().__init__(config)
|
1144 |
+
|
1145 |
+
self.bert = BertModel(config, add_pooling_layer=False)
|
1146 |
+
self.cls = BertOnlyMLMHead(config)
|
1147 |
+
|
1148 |
+
self.init_weights()
|
1149 |
+
|
1150 |
+
def get_output_embeddings(self):
|
1151 |
+
return self.cls.predictions.decoder
|
1152 |
+
|
1153 |
+
def set_output_embeddings(self, new_embeddings):
|
1154 |
+
self.cls.predictions.decoder = new_embeddings
|
1155 |
+
|
1156 |
+
def forward(
|
1157 |
+
self,
|
1158 |
+
input_ids=None,
|
1159 |
+
attention_mask=None,
|
1160 |
+
position_ids=None,
|
1161 |
+
head_mask=None,
|
1162 |
+
query_embeds=None,
|
1163 |
+
encoder_hidden_states=None,
|
1164 |
+
encoder_attention_mask=None,
|
1165 |
+
labels=None,
|
1166 |
+
output_attentions=None,
|
1167 |
+
output_hidden_states=None,
|
1168 |
+
return_dict=None,
|
1169 |
+
return_logits=False,
|
1170 |
+
is_decoder=False,
|
1171 |
+
):
|
1172 |
+
r"""
|
1173 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
1174 |
+
Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
|
1175 |
+
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
|
1176 |
+
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
|
1177 |
+
"""
|
1178 |
+
|
1179 |
+
return_dict = (
|
1180 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
1181 |
+
)
|
1182 |
+
|
1183 |
+
outputs = self.bert(
|
1184 |
+
input_ids,
|
1185 |
+
attention_mask=attention_mask,
|
1186 |
+
position_ids=position_ids,
|
1187 |
+
head_mask=head_mask,
|
1188 |
+
query_embeds=query_embeds,
|
1189 |
+
encoder_hidden_states=encoder_hidden_states,
|
1190 |
+
encoder_attention_mask=encoder_attention_mask,
|
1191 |
+
output_attentions=output_attentions,
|
1192 |
+
output_hidden_states=output_hidden_states,
|
1193 |
+
return_dict=return_dict,
|
1194 |
+
is_decoder=is_decoder,
|
1195 |
+
)
|
1196 |
+
|
1197 |
+
if query_embeds is not None:
|
1198 |
+
sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
|
1199 |
+
prediction_scores = self.cls(sequence_output)
|
1200 |
+
|
1201 |
+
if return_logits:
|
1202 |
+
return prediction_scores
|
1203 |
+
|
1204 |
+
masked_lm_loss = None
|
1205 |
+
if labels is not None:
|
1206 |
+
loss_fct = CrossEntropyLoss() # -100 index = padding token
|
1207 |
+
masked_lm_loss = loss_fct(
|
1208 |
+
prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
|
1209 |
+
)
|
1210 |
+
|
1211 |
+
if not return_dict:
|
1212 |
+
output = (prediction_scores,) + outputs[2:]
|
1213 |
+
return (
|
1214 |
+
((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
1215 |
+
)
|
1216 |
+
|
1217 |
+
return MaskedLMOutput(
|
1218 |
+
loss=masked_lm_loss,
|
1219 |
+
logits=prediction_scores,
|
1220 |
+
hidden_states=outputs.hidden_states,
|
1221 |
+
attentions=outputs.attentions,
|
1222 |
+
)
|
1223 |
+
|
1224 |
+
|
1225 |
+
class Qformer(nn.Module):
|
1226 |
+
def __init__(self, model_args, vision_tower):
|
1227 |
+
super().__init__()
|
1228 |
+
|
1229 |
+
self.depth = model_args.mm_qformer_depth
|
1230 |
+
self.num_latents = model_args.mm_qformer_latents
|
1231 |
+
self.pretrained = model_args.mm_qformer_pretrained
|
1232 |
+
|
1233 |
+
self.Qformer, self.query_tokens, self.ln_vision = self.build_Qformer(vision_tower.hidden_size, self.depth, self.num_latents)
|
1234 |
+
|
1235 |
+
if self.pretrained is not None:
|
1236 |
+
pretrained_dict = torch.load(self.pretrained, map_location='cpu')['model']
|
1237 |
+
pretrained_dict = {k: v for k, v in pretrained_dict.items() if not k.startswith('t5_proj')}
|
1238 |
+
# import pdb;pdb.set_trace()
|
1239 |
+
_ = self.load_state_dict(pretrained_dict,strict=False)
|
1240 |
+
print(_)
|
1241 |
+
|
1242 |
+
def build_Qformer(self, vision_width, cross_attention_freq, num_query_token):
|
1243 |
+
encoder_config = BertConfig.from_pretrained("bert-base-uncased")
|
1244 |
+
encoder_config.encoder_width = vision_width
|
1245 |
+
# insert cross-attention layer every other block
|
1246 |
+
encoder_config.add_cross_attention = True
|
1247 |
+
encoder_config.cross_attention_freq = cross_attention_freq
|
1248 |
+
encoder_config.query_length = num_query_token
|
1249 |
+
Qformer = BertLMHeadModel(config=encoder_config)
|
1250 |
+
query_tokens = nn.Parameter(
|
1251 |
+
torch.zeros(1, num_query_token, encoder_config.hidden_size)
|
1252 |
+
)
|
1253 |
+
query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
|
1254 |
+
Qformer.cls = None
|
1255 |
+
Qformer.bert.embeddings.word_embeddings = None
|
1256 |
+
Qformer.bert.embeddings.position_embeddings = None
|
1257 |
+
for layer in Qformer.bert.encoder.layer:
|
1258 |
+
layer.output = None
|
1259 |
+
layer.intermediate = None
|
1260 |
+
return Qformer, query_tokens, nn.LayerNorm(vision_width)
|
1261 |
+
|
1262 |
+
def forward(self, image_features, *args, **kwargs):
|
1263 |
+
x = self.ln_vision(image_features)
|
1264 |
+
image_atts = torch.ones(x.size()[:-1], dtype=torch.long).to(x.device)
|
1265 |
+
|
1266 |
+
query_tokens = self.query_tokens.expand(x.shape[0], -1, -1)
|
1267 |
+
query_output = self.Qformer.bert(
|
1268 |
+
query_embeds=query_tokens,
|
1269 |
+
encoder_hidden_states=x,
|
1270 |
+
encoder_attention_mask=image_atts,
|
1271 |
+
return_dict=True,
|
1272 |
+
)
|
1273 |
+
|
1274 |
+
return query_output.last_hidden_state
|
1275 |
+
|
1276 |
+
@property
|
1277 |
+
def hidden_size(self):
|
1278 |
+
return 768
|
1279 |
+
|
1280 |
+
@property
|
1281 |
+
def config(self):
|
1282 |
+
return {
|
1283 |
+
'mm_resampler_type': 'qformer',
|
1284 |
+
'mm_qformer_depth': self.depth,
|
1285 |
+
'mm_qformer_latents': self.num_latents,
|
1286 |
+
'mm_qformer_pretrained': self.pretrained,
|
1287 |
+
}
|
oryx/model/multimodal_resampler/spatial_pool.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import math
|
4 |
+
|
5 |
+
|
6 |
+
class SpatialPool(nn.Module):
|
7 |
+
def __init__(self, model_args, vision_tower):
|
8 |
+
super().__init__()
|
9 |
+
|
10 |
+
self.mode = model_args.mm_spatial_pool_mode
|
11 |
+
self.stride = model_args.mm_spatial_pool_stride
|
12 |
+
# import pdb; pdb.set_trace()
|
13 |
+
self.out_channels = getattr(model_args, 'mm_spatial_pool_out_channels', vision_tower.hidden_size)
|
14 |
+
|
15 |
+
if self.mode == 'average':
|
16 |
+
self.pool = nn.AvgPool2d(kernel_size=self.stride, stride=self.stride)
|
17 |
+
elif self.mode == 'max':
|
18 |
+
self.pool = nn.MaxPool2d(kernel_size=self.stride, stride=self.stride)
|
19 |
+
elif self.mode == 'conv':
|
20 |
+
self.pool = nn.Conv2d(in_channels=vision_tower.hidden_size, out_channels=self.out_channels, kernel_size=self.stride, stride=self.stride)
|
21 |
+
else:
|
22 |
+
raise ValueError(f'Unknown pooling mode: {self.pool}.')
|
23 |
+
|
24 |
+
def forward(self, image_features, images, *args, **kwargs):
|
25 |
+
ori_W = int(math.sqrt(image_features.shape[1] * images.shape[3] // images.shape[2]))
|
26 |
+
ori_H = int(ori_W * images.shape[2] // images.shape[3])
|
27 |
+
|
28 |
+
B, _, F = image_features.shape
|
29 |
+
|
30 |
+
image_features_spatial = image_features.view(B, ori_H, ori_H, F).permute(0, 3, 1, 2)
|
31 |
+
image_features_spatial_pool = self.pool(image_features_spatial)
|
32 |
+
|
33 |
+
return image_features_spatial_pool.flatten(2).transpose(1, 2).contiguous()
|
34 |
+
|
35 |
+
@property
|
36 |
+
def config(self):
|
37 |
+
return {
|
38 |
+
'mm_resampler_type': 'spatial_pool',
|
39 |
+
'mm_spatial_pool_stride': self.stride,
|
40 |
+
'mm_spatial_pool_mode': self.mode,
|
41 |
+
'mm_spatial_pool_out_channels': self.out_channels,
|
42 |
+
}
|
oryx/model/multimodal_resampler/vlm_attention.py
ADDED
@@ -0,0 +1,337 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import math
|
3 |
+
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
from transformers import BertTokenizer
|
8 |
+
from transformers.models.bert.modeling_bert import BertLMHeadModel as BertLMHeadModelRaw
|
9 |
+
|
10 |
+
from .qformer import BertConfig
|
11 |
+
from .qformer import BertLMHeadModel as BertLMHeadModelQF
|
12 |
+
|
13 |
+
class VlmAttention(nn.Module):
|
14 |
+
def __init__(self, model_args, vision_tower):
|
15 |
+
super().__init__()
|
16 |
+
|
17 |
+
pretrain_mm_mlp_adapter = getattr(model_args, "pretrain_mm_mlp_adapter", None)
|
18 |
+
pretrain_qformer = getattr(model_args, "mm_vlmattention_pretrained", None)
|
19 |
+
self.bert_type = getattr(model_args, "mm_vlmattention_bert_type", "qformer")
|
20 |
+
self.num_query = getattr(model_args, "mm_vlmattention_num_query", 32)
|
21 |
+
self.compress_type = getattr(model_args, "mm_vlmattention_compress_type", None)
|
22 |
+
self.mm_hidden_size = self.hidden_size = vision_tower.hidden_size
|
23 |
+
self.mm_vision_select_feature = model_args.mm_vision_select_feature
|
24 |
+
self.language_hidden_size = 4096
|
25 |
+
for_eval = True
|
26 |
+
|
27 |
+
if 'pretrain' in self.bert_type:
|
28 |
+
# for qformer that use evaclip for prtrain
|
29 |
+
att_feat_size = 1408
|
30 |
+
else:
|
31 |
+
att_feat_size = self.mm_hidden_size
|
32 |
+
self.vlm_att_tokenlizer, self.vlm_att_encoder, self.vlm_att_query = self.init_bert(att_feat_size, truncation_side="left")
|
33 |
+
self.vlm_att_projector = torch.nn.Linear(self.vlm_att_encoder.config.hidden_size, self.mm_hidden_size)
|
34 |
+
self.vlm_att_key_projector = torch.nn.Linear(self.mm_hidden_size, self.mm_hidden_size)
|
35 |
+
self.vlm_att_val_projector = torch.nn.Linear(self.mm_hidden_size, self.language_hidden_size)
|
36 |
+
|
37 |
+
if "raw" in self.bert_type:
|
38 |
+
self.vlm_att_bert_proj = torch.nn.Linear(att_feat_size, self.vlm_att_encoder.config.hidden_size)
|
39 |
+
elif "pretrain" in self.bert_type and self.mm_hidden_size!=att_feat_size:
|
40 |
+
self.vlm_att_bert_proj = torch.nn.Linear(self.mm_hidden_size, att_feat_size)
|
41 |
+
else:
|
42 |
+
self.vlm_att_bert_proj = None
|
43 |
+
|
44 |
+
def get_w(weights, keyword):
|
45 |
+
return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
|
46 |
+
|
47 |
+
if 'qformer_pretrain' in self.bert_type:
|
48 |
+
self.vlm_att_ln = torch.nn.LayerNorm(att_feat_size)
|
49 |
+
|
50 |
+
if pretrain_qformer is not None:
|
51 |
+
print("Loading pretrained qformer weights...")
|
52 |
+
qformer_weight = torch.load(pretrain_qformer, map_location='cpu')['model']
|
53 |
+
bert_weight = {_key: qformer_weight[_key] for _key in qformer_weight if 'bert' in _key}
|
54 |
+
self.vlm_att_encoder.load_state_dict(get_w(bert_weight, 'Qformer'))
|
55 |
+
self.vlm_att_ln.load_state_dict(get_w(qformer_weight, 'ln_vision'))
|
56 |
+
self.vlm_att_query.data = qformer_weight['query_tokens']
|
57 |
+
|
58 |
+
if 'freeze_all' in self.bert_type:
|
59 |
+
print("Freezing all qformer weights...")
|
60 |
+
self.vlm_att_encoder.requires_grad_(False)
|
61 |
+
self.vlm_att_ln.requires_grad_(False)
|
62 |
+
self.vlm_att_query.requires_grad_(False)
|
63 |
+
self.vlm_att_projector.requires_grad_(False)
|
64 |
+
self.vlm_att_key_projector.requires_grad_(False)
|
65 |
+
self.vlm_att_val_projector.requires_grad_(False)
|
66 |
+
elif 'freeze' in self.bert_type:
|
67 |
+
print("Freezing pretrained qformer weights...")
|
68 |
+
self.vlm_att_encoder.requires_grad_(False)
|
69 |
+
self.vlm_att_ln.requires_grad_(False)
|
70 |
+
self.vlm_att_query.requires_grad_(False)
|
71 |
+
|
72 |
+
|
73 |
+
if pretrain_mm_mlp_adapter is not None:
|
74 |
+
att_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
|
75 |
+
else:
|
76 |
+
trainable_module = ['vlm_att_encoder', 'vlm_att_projector', 'vlm_att_key_projector',
|
77 |
+
'vlm_att_val_projector', 'vlm_att_query', 'vlm_att_visual_proj',
|
78 |
+
'vlm_att_ln']
|
79 |
+
if hasattr(model_args, 'model_name_or_path'):
|
80 |
+
model_save_path = model_args.model_name_or_path
|
81 |
+
else:
|
82 |
+
model_save_path = model_args.model_path
|
83 |
+
model_idx_path = getattr(model_args, 'model_path', model_save_path)
|
84 |
+
weight_file = json.load(open(os.path.join(model_idx_path, 'pytorch_model.bin.index.json'), 'r'))['weight_map']
|
85 |
+
model_path = set([weight_file[_key] for _key in weight_file if any([_module in _key for _module in trainable_module])])
|
86 |
+
att_projector_weights = {}
|
87 |
+
for _model in model_path:
|
88 |
+
att_projector_weights.update(torch.load(os.path.join(model_idx_path, _model), map_location='cpu'))
|
89 |
+
if len(att_projector_weights) == 0:
|
90 |
+
return
|
91 |
+
|
92 |
+
bert_dict = get_w(att_projector_weights, 'vlm_att_encoder')
|
93 |
+
if "bert.embeddings.position_ids" not in bert_dict and "raw_bert" not in self.bert_type:
|
94 |
+
bert_dict["bert.embeddings.position_ids"] = self.vlm_att_encoder.bert.embeddings.position_ids
|
95 |
+
print('Loading pretrained weights...')
|
96 |
+
# import pdb;pdb.set_trace()
|
97 |
+
|
98 |
+
self.vlm_att_encoder.load_state_dict(bert_dict)
|
99 |
+
self.vlm_att_projector.load_state_dict(get_w(att_projector_weights, 'vlm_att_projector'))
|
100 |
+
self.vlm_att_key_projector.load_state_dict(get_w(att_projector_weights, 'vlm_att_key_projector'))
|
101 |
+
self.vlm_att_val_projector.load_state_dict(get_w(att_projector_weights, 'vlm_att_val_projector'))
|
102 |
+
|
103 |
+
if "qformer" in self.bert_type:
|
104 |
+
print('Loading vlm_att_query weights...')
|
105 |
+
self.vlm_att_query.data = att_projector_weights['model.vlm_att_query']
|
106 |
+
if "pretrain" in self.bert_type:
|
107 |
+
print('Loading vlm_att_ln weights...')
|
108 |
+
self.vlm_att_ln.load_state_dict(get_w(att_projector_weights, 'vlm_att_ln'))
|
109 |
+
|
110 |
+
if self.vlm_att_bert_proj is not None:
|
111 |
+
print('Loading vlm_att_bert_proj weights...')
|
112 |
+
self.vlm_att_bert_proj.load_state_dict(get_w(att_projector_weights, 'vlm_att_bert_proj'))
|
113 |
+
|
114 |
+
if for_eval:
|
115 |
+
weight_type = torch.float16
|
116 |
+
# import pdb;pdb.set_trace()
|
117 |
+
# device_type = self.mm_projector[0].weight.device
|
118 |
+
device_type = vision_tower.vision_tower.patch_embed.proj.weight.device
|
119 |
+
self.vlm_att_encoder = self.vlm_att_encoder.to(device=device_type, dtype=weight_type)
|
120 |
+
self.vlm_att_projector = self.vlm_att_projector.to(device=device_type, dtype=weight_type)
|
121 |
+
self.vlm_att_key_projector = self.vlm_att_key_projector.to(device=device_type, dtype=weight_type)
|
122 |
+
self.vlm_att_val_projector = self.vlm_att_val_projector.to(device=device_type, dtype=weight_type)
|
123 |
+
|
124 |
+
if "qformer" in self.bert_type:
|
125 |
+
self.vlm_att_query.data = self.vlm_att_query.data.to(device=device_type, dtype=weight_type)
|
126 |
+
if "pretrain" in self.bert_type:
|
127 |
+
self.vlm_att_ln = self.vlm_att_ln.to(device=device_type, dtype=weight_type)
|
128 |
+
|
129 |
+
if self.vlm_att_bert_proj is not None:
|
130 |
+
self.vlm_att_bert_proj = self.vlm_att_bert_proj.to(device=device_type, dtype=weight_type)
|
131 |
+
|
132 |
+
def forward(self, image_features, prompts=None, image_counts=None, long_video=False):
|
133 |
+
img_feat_lst = []
|
134 |
+
# import pdb;pdb.set_trace()
|
135 |
+
if image_counts is None:
|
136 |
+
assert len(image_features) == len(prompts), f"Size mismatch! image_features: {len(image_features)}, prompts: {len(prompts)}"
|
137 |
+
else:
|
138 |
+
assert len(prompts) == len(image_counts), f"Size mismatch! prompts: {len(prompts)}, image_counts: {len(image_counts)}"
|
139 |
+
image_atts = torch.ones(image_features.size()[:-1], dtype=torch.long).to(image_features.device)
|
140 |
+
|
141 |
+
total_count = 0
|
142 |
+
# calculate each image feat according to the prompt
|
143 |
+
# import pdb;pdb.set_trace()
|
144 |
+
for _idx in range(len(prompts)):
|
145 |
+
assert isinstance(prompts[_idx], list), f"Prompt should be a list, but got {type(prompts[_idx])}"
|
146 |
+
input_token = self.vlm_att_tokenlizer(
|
147 |
+
prompts[_idx],
|
148 |
+
padding='longest',
|
149 |
+
truncation=True,
|
150 |
+
max_length=256,
|
151 |
+
return_tensors="pt"
|
152 |
+
).to(image_features.device)
|
153 |
+
|
154 |
+
input_ids = input_token.input_ids
|
155 |
+
attention_masks = input_token.attention_mask
|
156 |
+
|
157 |
+
if image_counts is None:
|
158 |
+
img_feat_prompt = image_features[_idx, None].expand(len(prompts[_idx]), -1, -1)
|
159 |
+
img_att_prompt = image_atts[_idx, None].expand(len(prompts[_idx]), -1)
|
160 |
+
else:
|
161 |
+
# shape: [prompt_num*frame_num, image_shape, feat_dim]
|
162 |
+
img_feat_prompt = image_features[total_count:total_count+image_counts[_idx]]
|
163 |
+
img_feat_prompt = img_feat_prompt[None].expand(len(prompts[_idx]), -1, -1, -1).flatten(0,1)
|
164 |
+
img_att_prompt = image_atts[total_count:total_count+image_counts[_idx]]
|
165 |
+
img_att_prompt = img_att_prompt[None].expand(len(prompts[_idx]), -1, -1).flatten(0,1)
|
166 |
+
input_ids = input_ids[:,None].expand(-1, image_counts[_idx], -1).flatten(0,1)
|
167 |
+
attention_masks = attention_masks[:,None].expand(-1, image_counts[_idx], -1).flatten(0,1)
|
168 |
+
total_count += image_counts[_idx]
|
169 |
+
|
170 |
+
if "pretrain" in self.bert_type and self.vlm_att_bert_proj is not None:
|
171 |
+
bert_feat = self.vlm_att_bert_proj(img_feat_prompt)
|
172 |
+
else:
|
173 |
+
bert_feat = img_feat_prompt.clone()
|
174 |
+
|
175 |
+
# remove cls embedding
|
176 |
+
if self.mm_vision_select_feature == 'patch':
|
177 |
+
if img_feat_prompt.shape[1]%2 == 1:
|
178 |
+
img_feat_prompt = img_feat_prompt[:, 1:]
|
179 |
+
|
180 |
+
if "qformer" in self.bert_type:
|
181 |
+
query_tokens = self.vlm_att_query.expand(bert_feat.shape[0], -1, -1)
|
182 |
+
query_atts = torch.cat([torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(bert_feat.device),
|
183 |
+
attention_masks],dim=1)
|
184 |
+
|
185 |
+
if 'pretrain' in self.bert_type:
|
186 |
+
mm_img_in = self.vlm_att_ln(bert_feat)
|
187 |
+
else:
|
188 |
+
mm_img_in = bert_feat
|
189 |
+
|
190 |
+
if long_video:
|
191 |
+
outputs = []
|
192 |
+
block_size = 64
|
193 |
+
for L in range(0, len(input_ids), block_size):
|
194 |
+
R = L + block_size
|
195 |
+
mm_output = self.vlm_att_encoder.bert(
|
196 |
+
input_ids[L:R],
|
197 |
+
query_embeds=query_tokens[L:R],
|
198 |
+
attention_mask=query_atts[L:R],
|
199 |
+
encoder_hidden_states=mm_img_in[L:R],
|
200 |
+
encoder_attention_mask=img_att_prompt[L:R],
|
201 |
+
return_dict=True,
|
202 |
+
)
|
203 |
+
mm_output = mm_output.last_hidden_state[:,:query_tokens.shape[1]]
|
204 |
+
outputs.append(mm_output)
|
205 |
+
mm_output = torch.cat(outputs)
|
206 |
+
torch.cuda.empty_cache()
|
207 |
+
else:
|
208 |
+
mm_output = self.vlm_att_encoder.bert(
|
209 |
+
input_ids,
|
210 |
+
query_embeds=query_tokens,
|
211 |
+
attention_mask=query_atts,
|
212 |
+
encoder_hidden_states=mm_img_in,
|
213 |
+
encoder_attention_mask=img_att_prompt,
|
214 |
+
return_dict=True,
|
215 |
+
)
|
216 |
+
mm_output = mm_output.last_hidden_state[:,:query_tokens.shape[1]]
|
217 |
+
|
218 |
+
elif "raw" in self.bert_type:
|
219 |
+
if self.mm_vision_select_feature == 'patch' and bert_feat.shape[1]%2 == 1:
|
220 |
+
bert_feat = bert_feat[:, 1:]
|
221 |
+
img_att_prompt = img_att_prompt[:, 1:]
|
222 |
+
|
223 |
+
mm_output = self.vlm_att_encoder.bert(
|
224 |
+
input_ids,
|
225 |
+
attention_mask=attention_masks,
|
226 |
+
encoder_hidden_states=self.vlm_att_bert_proj(bert_feat),
|
227 |
+
encoder_attention_mask=img_att_prompt,
|
228 |
+
return_dict=True,
|
229 |
+
)
|
230 |
+
mm_output = mm_output.last_hidden_state
|
231 |
+
else:
|
232 |
+
raise ValueError(f'Unexpected bert type: {self.bert_type}')
|
233 |
+
|
234 |
+
text_q = self.vlm_att_projector(mm_output)
|
235 |
+
# shape: [prompt_num*frame_num, feat_dim]
|
236 |
+
# ctx_embed,vis_embed = self.token_generation(text_q, img_feat_prompt, long_video=long_video)
|
237 |
+
final_token = self.token_generation(text_q, img_feat_prompt, long_video=long_video)
|
238 |
+
|
239 |
+
if image_counts is not None:
|
240 |
+
# shape: [prompt_num, frame_num*image_shape, feat_dim]
|
241 |
+
final_token = final_token.reshape(len(prompts[_idx]), image_counts[_idx], *final_token.shape[-2:])
|
242 |
+
final_token = final_token.flatten(1,2)
|
243 |
+
img_feat_lst.append(final_token)
|
244 |
+
|
245 |
+
return img_feat_lst
|
246 |
+
|
247 |
+
def init_bert(self, vision_width, cross_attention_freq=2, truncation_side="right"):
|
248 |
+
# initialize BERT tokenizer
|
249 |
+
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side=truncation_side)
|
250 |
+
tokenizer.add_special_tokens({"bos_token": "[DEC]"})
|
251 |
+
# initialize BERT
|
252 |
+
encoder_config = BertConfig.from_pretrained("bert-base-uncased")
|
253 |
+
encoder_config.encoder_width = vision_width
|
254 |
+
# insert cross-attention layer every other block
|
255 |
+
encoder_config.add_cross_attention = True
|
256 |
+
encoder_config.cross_attention_freq = cross_attention_freq
|
257 |
+
query_tokens = None
|
258 |
+
|
259 |
+
if "qformer" in self.bert_type:
|
260 |
+
mm_model = BertLMHeadModelQF.from_pretrained(
|
261 |
+
"bert-base-uncased", config=encoder_config
|
262 |
+
)
|
263 |
+
query_tokens = nn.Parameter(
|
264 |
+
torch.zeros(1, self.num_query, encoder_config.hidden_size)
|
265 |
+
)
|
266 |
+
query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
|
267 |
+
elif "raw" in self.bert_type:
|
268 |
+
encoder_config.is_decoder = True
|
269 |
+
mm_model = BertLMHeadModelRaw.from_pretrained(
|
270 |
+
"bert-base-uncased", config=encoder_config
|
271 |
+
)
|
272 |
+
else:
|
273 |
+
raise NotImplementedError("BERT type not implemented...")
|
274 |
+
|
275 |
+
mm_model.resize_token_embeddings(len(tokenizer))
|
276 |
+
mm_model.cls = None
|
277 |
+
|
278 |
+
if "layer" in self.bert_type:
|
279 |
+
layer_num = int(self.bert_type.split(':')[-1])
|
280 |
+
mm_model.bert.encoder.layer = mm_model.bert.encoder.layer[:layer_num]
|
281 |
+
print(f"Only use {layer_num} layers in BERT...")
|
282 |
+
|
283 |
+
return tokenizer, mm_model, query_tokens
|
284 |
+
|
285 |
+
|
286 |
+
def token_generation(self, text_q, vis_embed, long_video=False):
|
287 |
+
ctx_embed = self.vlm_att_key_projector(vis_embed)
|
288 |
+
# Key part 1: calculate context-related embedding
|
289 |
+
ctx_embed = text_q @ ctx_embed.transpose(-1,-2)
|
290 |
+
ctx_embed = ctx_embed / (vis_embed.shape[-1] ** 0.5)
|
291 |
+
if not long_video:
|
292 |
+
ctx_embed = (ctx_embed.softmax(-1) @ vis_embed).mean(1)
|
293 |
+
else:
|
294 |
+
block_size = 64
|
295 |
+
outputs = []
|
296 |
+
ctx_score = ctx_embed.softmax(-1)
|
297 |
+
for L in range(0, len(ctx_score), block_size):
|
298 |
+
R = L + block_size
|
299 |
+
sub_embed = (ctx_score[L:R] @ vis_embed[L:R]).mean(1)
|
300 |
+
outputs.append(sub_embed)
|
301 |
+
ctx_embed = torch.cat(outputs)
|
302 |
+
torch.cuda.empty_cache()
|
303 |
+
ctx_embed = self.vlm_att_val_projector(ctx_embed[:,None])
|
304 |
+
|
305 |
+
# Key part 2: calculate visual embedding
|
306 |
+
if self.compress_type is not None:
|
307 |
+
if 'grid' in self.compress_type:
|
308 |
+
grid_size = int(self.compress_type.split('grid:')[-1])
|
309 |
+
cur_shape = int(vis_embed.shape[1]**0.5)
|
310 |
+
assert grid_size > 1, f'Grid size should be larger than 1, but got {grid_size}'
|
311 |
+
vis_embed = vis_embed.reshape(vis_embed.shape[0], cur_shape, cur_shape, -1)
|
312 |
+
grid_stride = cur_shape // grid_size
|
313 |
+
vis_embed = F.avg_pool2d(vis_embed.permute(0, 3, 1, 2),
|
314 |
+
padding=0,
|
315 |
+
kernel_size=grid_stride,
|
316 |
+
stride=grid_stride)
|
317 |
+
|
318 |
+
vis_embed = vis_embed.permute(0, 2, 3, 1).flatten(1,2)
|
319 |
+
elif 'mean' in self.compress_type:
|
320 |
+
# import pdb;pdb.set_trace()
|
321 |
+
vis_embed = vis_embed.mean(dim=1, keepdim=True)
|
322 |
+
|
323 |
+
# import pdb ; pdb.set_trace()
|
324 |
+
# concat token in shape (B, n+1, C)
|
325 |
+
vis_embed = self.mm_projector(vis_embed)
|
326 |
+
final_token = torch.cat([ctx_embed, vis_embed], dim=1)
|
327 |
+
return final_token
|
328 |
+
|
329 |
+
@property
|
330 |
+
def config(self):
|
331 |
+
return {
|
332 |
+
'mm_resampler_type': 'vlm_attention',
|
333 |
+
'mm_vlmattention_bert_type': self.bert_type,
|
334 |
+
'mm_vlmattention_num_query': self.num_query,
|
335 |
+
'mm_vlmattention_compress_type': self.compress_type,
|
336 |
+
}
|
337 |
+
|
oryx/model/oryx_arch.py
ADDED
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from .multimodal_encoder.builder import build_vision_tower
|
7 |
+
from .multimodal_resampler.builder import build_vision_resampler
|
8 |
+
from .multimodal_projector.builder import build_vision_projector
|
9 |
+
|
10 |
+
from oryx.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
11 |
+
|
12 |
+
import ast
|
13 |
+
import torch.distributed as dist
|
14 |
+
|
15 |
+
class OryxMetaModel:
|
16 |
+
|
17 |
+
def __init__(self, config):
|
18 |
+
super(OryxMetaModel, self).__init__(config)
|
19 |
+
|
20 |
+
if hasattr(config, "mm_vision_tower"):
|
21 |
+
self.vision_tower = build_vision_tower(config, delay_load=True)
|
22 |
+
self.vision_resampler = build_vision_resampler(config, vision_tower=self.vision_tower)
|
23 |
+
self.mm_projector = build_vision_projector(config, vision_cfg=self.vision_tower.config)
|
24 |
+
def get_vision_tower(self):
|
25 |
+
vision_tower = getattr(self, 'vision_tower', None)
|
26 |
+
if type(vision_tower) is list:
|
27 |
+
vision_tower = vision_tower[0]
|
28 |
+
return vision_tower
|
29 |
+
|
30 |
+
def initialize_vision_modules(self, model_args, fsdp=None):
|
31 |
+
vision_tower = model_args.vision_tower
|
32 |
+
mm_vision_select_layer = model_args.mm_vision_select_layer
|
33 |
+
mm_vision_select_feature = model_args.mm_vision_select_feature
|
34 |
+
pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
|
35 |
+
|
36 |
+
self.config.mm_vision_tower = vision_tower
|
37 |
+
|
38 |
+
if self.get_vision_tower() is None:
|
39 |
+
vision_tower = build_vision_tower(model_args)
|
40 |
+
vision_resampler = build_vision_resampler(model_args, vision_tower=vision_tower)
|
41 |
+
## Get the mm_spatial_pool_mode and mm_spatial_pool_stride
|
42 |
+
for k, v in vision_resampler.config.items():
|
43 |
+
setattr(self.config, k, v)
|
44 |
+
|
45 |
+
if fsdp is not None and len(fsdp) > 0:
|
46 |
+
self.vision_tower = [vision_tower]
|
47 |
+
self.vision_resampler = [vision_resampler]
|
48 |
+
else:
|
49 |
+
self.vision_tower = vision_tower
|
50 |
+
self.vision_resampler = vision_resampler
|
51 |
+
else:
|
52 |
+
if fsdp is not None and len(fsdp) > 0:
|
53 |
+
vision_resampler = self.vision_resampler[0]
|
54 |
+
vision_tower = self.vision_tower[0]
|
55 |
+
else:
|
56 |
+
vision_resampler = self.vision_resampler
|
57 |
+
vision_tower = self.vision_tower
|
58 |
+
vision_tower.load_model()
|
59 |
+
|
60 |
+
# In case it is frozen by LoRA
|
61 |
+
for p in self.vision_resampler.parameters():
|
62 |
+
p.requires_grad = True
|
63 |
+
|
64 |
+
self.config.use_mm_proj = True
|
65 |
+
self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
|
66 |
+
self.config.mm_hidden_size = getattr(vision_resampler, 'hidden_size', vision_tower.hidden_size)
|
67 |
+
|
68 |
+
self.config.mm_vision_select_layer = mm_vision_select_layer
|
69 |
+
self.config.mm_vision_select_feature = mm_vision_select_feature
|
70 |
+
|
71 |
+
if getattr(self, 'mm_projector', None) is None:
|
72 |
+
self.mm_projector = build_vision_projector(self.config, vision_cfg=vision_tower.config)
|
73 |
+
else:
|
74 |
+
for p in self.mm_projector.parameters():
|
75 |
+
p.requires_grad = True
|
76 |
+
|
77 |
+
if pretrain_mm_mlp_adapter is not None:
|
78 |
+
mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
|
79 |
+
def get_w(weights, keyword):
|
80 |
+
return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
|
81 |
+
|
82 |
+
self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
|
83 |
+
incompatible_keys = self.vision_resampler.load_state_dict(get_w(mm_projector_weights, 'vision_resampler'), strict=False)
|
84 |
+
print(incompatible_keys)
|
85 |
+
|
86 |
+
|
87 |
+
class OryxMetaForCausalLM(ABC):
|
88 |
+
|
89 |
+
@abstractmethod
|
90 |
+
def get_model(self):
|
91 |
+
pass
|
92 |
+
|
93 |
+
def get_vision_tower(self):
|
94 |
+
return self.get_model().get_vision_tower()
|
95 |
+
|
96 |
+
def prepare_inputs_labels_for_multimodal(
|
97 |
+
self, input_ids, position_ids, attention_mask, past_key_values, labels,
|
98 |
+
images, modalities, image_sizes=None, images_highres=None):
|
99 |
+
# print(modalities, len(images), len(images_highres), len(input_ids))
|
100 |
+
vision_tower = self.get_vision_tower()
|
101 |
+
if vision_tower is None or images is None or input_ids.shape[1] == 1:
|
102 |
+
return input_ids, position_ids, attention_mask, past_key_values, None, labels
|
103 |
+
|
104 |
+
if isinstance(modalities, str):
|
105 |
+
modalities = [modalities]
|
106 |
+
|
107 |
+
video_idx_in_batch = []
|
108 |
+
for modal in range(len(modalities)):
|
109 |
+
if 'video' in modalities[modal]:
|
110 |
+
video_idx_in_batch.append(modal)
|
111 |
+
|
112 |
+
# Fix training with deepspeed zero3
|
113 |
+
num_modality = len(modalities)
|
114 |
+
# try:
|
115 |
+
# world_size = dist.get_world_size()
|
116 |
+
# tensor_in = torch.zeros(1, dtype=torch.int64, device=images[0].device).fill_(num_modality)
|
117 |
+
# tensor_out = torch.zeros(world_size, dtype=torch.int64, device=images[0].device)
|
118 |
+
# dist.all_gather_into_tensor(tensor_out, tensor_in)
|
119 |
+
# max_num_modality = tensor_out.max().item()
|
120 |
+
# except:
|
121 |
+
max_num_modality = num_modality
|
122 |
+
|
123 |
+
aimg = images[-1]
|
124 |
+
lowres_img = []
|
125 |
+
for idx, img_feat in enumerate(images):
|
126 |
+
if idx in video_idx_in_batch:
|
127 |
+
img_feat = aimg.new(1, 3, 128, 128).fill_(0)
|
128 |
+
lowres_img.append(img_feat)
|
129 |
+
|
130 |
+
# Fix training with deepspeed zero3
|
131 |
+
if max_num_modality > num_modality:
|
132 |
+
for _ in range(max_num_modality - num_modality):
|
133 |
+
lowres_img.append(aimg.new(1, 3, 64, 64).fill_(0))
|
134 |
+
images_highres.append(aimg.new(1, 3, 64, 64).fill_(0))
|
135 |
+
modalities.append('image')
|
136 |
+
|
137 |
+
lowres_img_features, lowres_img_sizes = self.get_model().get_vision_tower()(lowres_img)
|
138 |
+
highres_img_features = []
|
139 |
+
highres_img_sizes = []
|
140 |
+
for idx, img_feat in enumerate(images_highres):
|
141 |
+
if img_feat.ndim == 5:
|
142 |
+
img_feat = img_feat.squeeze(1)
|
143 |
+
highres_img_feature, highres_img_size = self.get_model().get_vision_tower()(img_feat)
|
144 |
+
highres_img_features.append(highres_img_feature)
|
145 |
+
highres_img_sizes.append(highres_img_size)
|
146 |
+
image_features = []
|
147 |
+
for idx in range(len(modalities)):
|
148 |
+
img_feat_highres, img_size_highres = self.get_model().vision_resampler(highres_img_features[idx],
|
149 |
+
modalities[idx],
|
150 |
+
highres_img_sizes[idx])
|
151 |
+
img_feat_lowres, img_size_lowres = self.get_model().vision_resampler(lowres_img_features[idx],
|
152 |
+
modalities[idx],
|
153 |
+
lowres_img_sizes[idx])
|
154 |
+
img_feat = self.get_model().mm_projector(img_feat_lowres,
|
155 |
+
img_size_lowres,
|
156 |
+
img_feat_highres,
|
157 |
+
img_size_highres,
|
158 |
+
modalities[idx])
|
159 |
+
image_features.append(img_feat.flatten(0, 1))
|
160 |
+
|
161 |
+
if max_num_modality > num_modality:
|
162 |
+
image_features = image_features[:num_modality]
|
163 |
+
modalities = modalities[:num_modality]
|
164 |
+
|
165 |
+
if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
|
166 |
+
raise NotImplementedError
|
167 |
+
|
168 |
+
# Let's just add dummy tensors if they do not exist,
|
169 |
+
# it is a headache to deal with None all the time.
|
170 |
+
# But it is not ideal, and if you have a better idea,
|
171 |
+
# please open an issue / submit a PR, thanks.
|
172 |
+
_labels = labels
|
173 |
+
_position_ids = position_ids
|
174 |
+
_attention_mask = attention_mask
|
175 |
+
if attention_mask is None:
|
176 |
+
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
|
177 |
+
else:
|
178 |
+
attention_mask = attention_mask.bool()
|
179 |
+
if position_ids is None:
|
180 |
+
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
|
181 |
+
if labels is None:
|
182 |
+
labels = torch.full_like(input_ids, IGNORE_INDEX)
|
183 |
+
|
184 |
+
# remove the padding using attention_mask -- FIXME
|
185 |
+
_input_ids = input_ids
|
186 |
+
input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
|
187 |
+
labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
|
188 |
+
|
189 |
+
new_input_embeds = []
|
190 |
+
new_labels = []
|
191 |
+
cur_image_idx = 0
|
192 |
+
for batch_idx, cur_input_ids in enumerate(input_ids):
|
193 |
+
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
|
194 |
+
if num_images == 0:
|
195 |
+
cur_image_features = image_features[cur_image_idx]
|
196 |
+
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
|
197 |
+
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
|
198 |
+
new_input_embeds.append(cur_input_embeds)
|
199 |
+
new_labels.append(labels[batch_idx])
|
200 |
+
cur_image_idx += 1
|
201 |
+
continue
|
202 |
+
|
203 |
+
image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
|
204 |
+
cur_input_ids_noim = []
|
205 |
+
cur_labels = labels[batch_idx]
|
206 |
+
cur_labels_noim = []
|
207 |
+
for i in range(len(image_token_indices) - 1):
|
208 |
+
cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
|
209 |
+
cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
|
210 |
+
split_sizes = [x.shape[0] for x in cur_labels_noim]
|
211 |
+
cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
|
212 |
+
cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
|
213 |
+
cur_new_input_embeds = []
|
214 |
+
cur_new_labels = []
|
215 |
+
|
216 |
+
for i in range(num_images + 1):
|
217 |
+
cur_new_input_embeds.append(cur_input_embeds_no_im[i])
|
218 |
+
cur_new_labels.append(cur_labels_noim[i])
|
219 |
+
if i < num_images:
|
220 |
+
cur_image_features = image_features[cur_image_idx]
|
221 |
+
cur_image_idx += 1
|
222 |
+
cur_new_input_embeds.append(cur_image_features)
|
223 |
+
cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
|
224 |
+
|
225 |
+
cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
|
226 |
+
|
227 |
+
cur_new_input_embeds = torch.cat(cur_new_input_embeds)
|
228 |
+
cur_new_labels = torch.cat(cur_new_labels)
|
229 |
+
|
230 |
+
new_input_embeds.append(cur_new_input_embeds)
|
231 |
+
new_labels.append(cur_new_labels)
|
232 |
+
|
233 |
+
# Truncate sequences to max length as image embeddings can make the sequence longer
|
234 |
+
tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
|
235 |
+
modality_max_length = getattr(self.config, 'modality_max_length', None)
|
236 |
+
|
237 |
+
if modality_max_length is None or modality_max_length == "None":
|
238 |
+
if tokenizer_model_max_length is not None:
|
239 |
+
# if new_input_embeds[0] > tokenizer_model_max_length:
|
240 |
+
# print(f"Embeds length ({new_input_embeds.shape[0]}) larger than max length")
|
241 |
+
new_input_embeds =[x[:tokenizer_model_max_length] for x, modality in zip(new_input_embeds, modalities)]
|
242 |
+
new_labels = [x[:tokenizer_model_max_length] for x, modality in zip(new_labels, modalities)]
|
243 |
+
else:
|
244 |
+
modality_max_length = ast.literal_eval(modality_max_length)
|
245 |
+
modality_max_length_dict = {"image": modality_max_length[0], "text": modality_max_length[1], "video": modality_max_length[2]}
|
246 |
+
new_input_embeds =[x[: modality_max_length_dict[modality]] for x, modality in zip(new_input_embeds, modalities)]
|
247 |
+
new_labels = [x[: modality_max_length_dict[modality]] for x, modality in zip(new_labels, modalities)]
|
248 |
+
|
249 |
+
# Combine them
|
250 |
+
max_len = max(x.shape[0] for x in new_input_embeds)
|
251 |
+
batch_size = len(new_input_embeds)
|
252 |
+
|
253 |
+
new_input_embeds_padded = []
|
254 |
+
new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
|
255 |
+
attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
|
256 |
+
position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
|
257 |
+
|
258 |
+
for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
|
259 |
+
cur_len = cur_new_embed.shape[0]
|
260 |
+
if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
|
261 |
+
new_input_embeds_padded.append(torch.cat((
|
262 |
+
torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
|
263 |
+
cur_new_embed
|
264 |
+
), dim=0))
|
265 |
+
if cur_len > 0:
|
266 |
+
new_labels_padded[i, -cur_len:] = cur_new_labels
|
267 |
+
attention_mask[i, -cur_len:] = True
|
268 |
+
position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
|
269 |
+
else:
|
270 |
+
new_input_embeds_padded.append(torch.cat((
|
271 |
+
cur_new_embed,
|
272 |
+
torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
|
273 |
+
), dim=0))
|
274 |
+
if cur_len > 0:
|
275 |
+
new_labels_padded[i, :cur_len] = cur_new_labels
|
276 |
+
attention_mask[i, :cur_len] = True
|
277 |
+
position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
|
278 |
+
|
279 |
+
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
|
280 |
+
|
281 |
+
if _labels is None:
|
282 |
+
new_labels = None
|
283 |
+
else:
|
284 |
+
new_labels = new_labels_padded
|
285 |
+
|
286 |
+
if _attention_mask is None:
|
287 |
+
attention_mask = None
|
288 |
+
else:
|
289 |
+
attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
|
290 |
+
|
291 |
+
if _position_ids is None:
|
292 |
+
position_ids = None
|
293 |
+
|
294 |
+
return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
|
295 |
+
|
296 |
+
def initialize_vision_tokenizer(self, model_args, tokenizer):
|
297 |
+
if model_args.mm_use_im_patch_token:
|
298 |
+
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
299 |
+
self.resize_token_embeddings(len(tokenizer))
|
300 |
+
|
301 |
+
if model_args.mm_use_im_start_end:
|
302 |
+
num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
|
303 |
+
self.resize_token_embeddings(len(tokenizer))
|
304 |
+
|
305 |
+
if num_new_tokens > 0:
|
306 |
+
input_embeddings = self.get_input_embeddings().weight.data
|
307 |
+
output_embeddings = self.get_output_embeddings().weight.data
|
308 |
+
|
309 |
+
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
|
310 |
+
dim=0, keepdim=True)
|
311 |
+
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
|
312 |
+
dim=0, keepdim=True)
|
313 |
+
|
314 |
+
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
315 |
+
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
316 |
+
|
317 |
+
if model_args.tune_mm_mlp_adapter:
|
318 |
+
for p in self.get_input_embeddings().parameters():
|
319 |
+
p.requires_grad = True
|
320 |
+
for p in self.get_output_embeddings().parameters():
|
321 |
+
p.requires_grad = False
|
322 |
+
|
323 |
+
if model_args.pretrain_mm_mlp_adapter:
|
324 |
+
mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
|
325 |
+
embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
|
326 |
+
assert num_new_tokens == 2
|
327 |
+
if input_embeddings.shape == embed_tokens_weight.shape:
|
328 |
+
input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
|
329 |
+
elif embed_tokens_weight.shape[0] == num_new_tokens:
|
330 |
+
input_embeddings[-num_new_tokens:] = embed_tokens_weight
|
331 |
+
else:
|
332 |
+
raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
|
333 |
+
elif model_args.mm_use_im_patch_token:
|
334 |
+
if model_args.tune_mm_mlp_adapter:
|
335 |
+
for p in self.get_input_embeddings().parameters():
|
336 |
+
p.requires_grad = False
|
337 |
+
for p in self.get_output_embeddings().parameters():
|
338 |
+
p.requires_grad = False
|
oryx/train/llama_flash_attn_monkey_patch.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple
|
2 |
+
import warnings
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
import transformers
|
7 |
+
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
|
8 |
+
|
9 |
+
try:
|
10 |
+
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
|
11 |
+
except ImportError:
|
12 |
+
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
|
13 |
+
from flash_attn.bert_padding import unpad_input, pad_input
|
14 |
+
|
15 |
+
|
16 |
+
def forward(
|
17 |
+
self,
|
18 |
+
hidden_states: torch.Tensor,
|
19 |
+
attention_mask: Optional[torch.Tensor] = None,
|
20 |
+
position_ids: Optional[torch.Tensor] = None,
|
21 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
22 |
+
output_attentions: bool = False,
|
23 |
+
use_cache: bool = False,
|
24 |
+
padding_mask: Optional[torch.Tensor] = None,
|
25 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
26 |
+
if output_attentions:
|
27 |
+
warnings.warn(
|
28 |
+
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
|
29 |
+
)
|
30 |
+
|
31 |
+
bsz, q_len, _ = hidden_states.size()
|
32 |
+
|
33 |
+
query_states = (
|
34 |
+
self.q_proj(hidden_states)
|
35 |
+
.view(bsz, q_len, self.num_heads, self.head_dim)
|
36 |
+
.transpose(1, 2)
|
37 |
+
)
|
38 |
+
key_states = (
|
39 |
+
self.k_proj(hidden_states)
|
40 |
+
.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
|
41 |
+
.transpose(1, 2)
|
42 |
+
)
|
43 |
+
value_states = (
|
44 |
+
self.v_proj(hidden_states)
|
45 |
+
.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
|
46 |
+
.transpose(1, 2)
|
47 |
+
) # shape: (b, num_heads, s, head_dim)
|
48 |
+
|
49 |
+
kv_seq_len = key_states.shape[-2]
|
50 |
+
if past_key_value is not None:
|
51 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
52 |
+
|
53 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
54 |
+
query_states, key_states = apply_rotary_pos_emb(
|
55 |
+
query_states, key_states, cos, sin, position_ids
|
56 |
+
)
|
57 |
+
|
58 |
+
if past_key_value is not None:
|
59 |
+
# reuse k, v
|
60 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
61 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
62 |
+
|
63 |
+
past_key_value = (key_states, value_states) if use_cache else None
|
64 |
+
|
65 |
+
# repeat k/v heads if n_kv_heads < n_heads
|
66 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
67 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
68 |
+
|
69 |
+
# Transform the data into the format required by flash attention
|
70 |
+
qkv = torch.stack([query_states, key_states, value_states], dim=2)
|
71 |
+
qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim]
|
72 |
+
key_padding_mask = attention_mask
|
73 |
+
|
74 |
+
if key_padding_mask is None:
|
75 |
+
qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim)
|
76 |
+
cu_q_lens = torch.arange(
|
77 |
+
0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
|
78 |
+
)
|
79 |
+
max_s = q_len
|
80 |
+
output = flash_attn_unpadded_qkvpacked_func(
|
81 |
+
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
82 |
+
)
|
83 |
+
output = output.view(bsz, q_len, -1)
|
84 |
+
else:
|
85 |
+
qkv = qkv.reshape(bsz, q_len, -1)
|
86 |
+
qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask)
|
87 |
+
qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
|
88 |
+
output_unpad = flash_attn_unpadded_qkvpacked_func(
|
89 |
+
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
90 |
+
)
|
91 |
+
output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
|
92 |
+
output = pad_input(output_unpad, indices, bsz, q_len)
|
93 |
+
|
94 |
+
return self.o_proj(output), None, past_key_value
|
95 |
+
|
96 |
+
|
97 |
+
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
98 |
+
# requires the attention mask to be the same as the key_padding_mask
|
99 |
+
def _prepare_decoder_attention_mask(
|
100 |
+
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
|
101 |
+
):
|
102 |
+
# [bsz, seq_len]
|
103 |
+
return attention_mask
|
104 |
+
|
105 |
+
|
106 |
+
def replace_llama_attn_with_flash_attn():
|
107 |
+
cuda_major, cuda_minor = torch.cuda.get_device_capability()
|
108 |
+
if cuda_major < 8:
|
109 |
+
warnings.warn(
|
110 |
+
"Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
|
111 |
+
"ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
|
112 |
+
)
|
113 |
+
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
|
114 |
+
_prepare_decoder_attention_mask
|
115 |
+
)
|
116 |
+
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
|
oryx/train/oryx_trainer.py
ADDED
@@ -0,0 +1,435 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
import importlib.metadata
|
6 |
+
|
7 |
+
from torch.utils.data import Sampler
|
8 |
+
|
9 |
+
|
10 |
+
from transformers import Trainer
|
11 |
+
from transformers.trainer import (
|
12 |
+
is_sagemaker_mp_enabled,
|
13 |
+
get_parameter_names,
|
14 |
+
has_length,
|
15 |
+
ALL_LAYERNORM_LAYERS,
|
16 |
+
logger,
|
17 |
+
)
|
18 |
+
from transformers.trainer_pt_utils import get_length_grouped_indices as get_length_grouped_indices_hf
|
19 |
+
from typing import List, Optional
|
20 |
+
|
21 |
+
from transformers.trainer_pt_utils import (
|
22 |
+
get_dataloader_sampler,
|
23 |
+
get_model_param_count,
|
24 |
+
get_parameter_names,
|
25 |
+
)
|
26 |
+
|
27 |
+
from transformers.training_args import ParallelMode
|
28 |
+
from transformers.utils import (
|
29 |
+
is_peft_available,
|
30 |
+
is_accelerate_available,
|
31 |
+
is_sagemaker_mp_enabled,
|
32 |
+
is_torch_xla_available,
|
33 |
+
)
|
34 |
+
|
35 |
+
from transformers.trainer_utils import (
|
36 |
+
HPSearchBackend,
|
37 |
+
TrainOutput,
|
38 |
+
has_length,
|
39 |
+
speed_metrics,
|
40 |
+
)
|
41 |
+
|
42 |
+
from packaging import version
|
43 |
+
|
44 |
+
from peft import PeftModel
|
45 |
+
|
46 |
+
TIME_STAMP = os.environ.get('TIME_STAMP', 'default_value')
|
47 |
+
BYTENAS = os.environ.get('BYTENAS', 'vl-research')
|
48 |
+
|
49 |
+
def maybe_zero_3(param, ignore_status=False, name=None):
|
50 |
+
from deepspeed import zero
|
51 |
+
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
|
52 |
+
if hasattr(param, "ds_id"):
|
53 |
+
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
|
54 |
+
if not ignore_status:
|
55 |
+
print(name, 'no ignore status')
|
56 |
+
with zero.GatheredParameters([param]):
|
57 |
+
param = param.data.detach().cpu().clone()
|
58 |
+
else:
|
59 |
+
param = param.detach().cpu().clone()
|
60 |
+
return param
|
61 |
+
|
62 |
+
|
63 |
+
def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
|
64 |
+
to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
|
65 |
+
to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()}
|
66 |
+
return to_return
|
67 |
+
|
68 |
+
|
69 |
+
def split_to_even_chunks(indices, lengths, num_chunks):
|
70 |
+
"""
|
71 |
+
Split a list of indices into `chunks` chunks of roughly equal lengths.
|
72 |
+
"""
|
73 |
+
|
74 |
+
if len(indices) % num_chunks != 0:
|
75 |
+
return [indices[i::num_chunks] for i in range(num_chunks)]
|
76 |
+
|
77 |
+
num_indices_per_chunk = len(indices) // num_chunks
|
78 |
+
|
79 |
+
chunks = [[] for _ in range(num_chunks)]
|
80 |
+
chunks_lengths = [0 for _ in range(num_chunks)]
|
81 |
+
for index in indices:
|
82 |
+
shortest_chunk = chunks_lengths.index(min(chunks_lengths))
|
83 |
+
chunks[shortest_chunk].append(index)
|
84 |
+
chunks_lengths[shortest_chunk] += lengths[index]
|
85 |
+
if len(chunks[shortest_chunk]) == num_indices_per_chunk:
|
86 |
+
chunks_lengths[shortest_chunk] = float("inf")
|
87 |
+
|
88 |
+
return chunks
|
89 |
+
|
90 |
+
|
91 |
+
def get_variable_length_grouped_indices(lengths, batch_size, world_size, megabatch_mult = 8, generator=None):
|
92 |
+
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
|
93 |
+
indices = torch.randperm(len(lengths), generator=generator)
|
94 |
+
sorted_indices = sorted(range(len(lengths)), key=lambda i: lengths[i], reverse=True)
|
95 |
+
megabatch_size = world_size * batch_size * megabatch_mult
|
96 |
+
megabatches = [sorted_indices[i : i + megabatch_size] for i in range(0, len(lengths), megabatch_size)]
|
97 |
+
megabatches = [sorted(megabatch, key=lambda i: indices[i], reverse=True) for megabatch in megabatches]
|
98 |
+
shuffled_indices = [i for megabatch in megabatches for i in megabatch]
|
99 |
+
world_batch_size = world_size * batch_size
|
100 |
+
batches = [shuffled_indices[i : i + world_batch_size] for i in range(0, len(lengths), world_batch_size)]
|
101 |
+
batch_indices = torch.randperm(len(batches), generator=generator)
|
102 |
+
batches = [batches[i] for i in batch_indices]
|
103 |
+
|
104 |
+
return [i for batch in batches for i in batch]
|
105 |
+
|
106 |
+
|
107 |
+
def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None):
|
108 |
+
"""
|
109 |
+
Return a list of indices so that each slice of `batch_size` consecutive indices correspond to elements of similar
|
110 |
+
lengths. To do this, the indices are:
|
111 |
+
|
112 |
+
- randomly permuted
|
113 |
+
- grouped in mega-batches of size `mega_batch_mult * batch_size`
|
114 |
+
- reorder by length in each mega-batch
|
115 |
+
|
116 |
+
The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of
|
117 |
+
maximum length placed first, so that an OOM happens sooner rather than later.
|
118 |
+
"""
|
119 |
+
|
120 |
+
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
|
121 |
+
assert all(l != 0 for l in lengths), "Should not have zero length."
|
122 |
+
if all(l > 0 for l in lengths) or all(l < 0 for l in lengths):
|
123 |
+
# all samples are in the same modality
|
124 |
+
return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator)
|
125 |
+
mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0])
|
126 |
+
lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0])
|
127 |
+
|
128 |
+
mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)]
|
129 |
+
lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)]
|
130 |
+
megabatch_size = world_size * batch_size
|
131 |
+
mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)]
|
132 |
+
lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)]
|
133 |
+
|
134 |
+
last_mm = mm_megabatches[-1]
|
135 |
+
last_lang = lang_megabatches[-1]
|
136 |
+
additional_batch = last_mm + last_lang
|
137 |
+
megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
|
138 |
+
megabatch_indices = torch.randperm(len(megabatches), generator=generator)
|
139 |
+
megabatches = [megabatches[i] for i in megabatch_indices]
|
140 |
+
|
141 |
+
if len(additional_batch) > 0:
|
142 |
+
megabatches.append(sorted(additional_batch))
|
143 |
+
|
144 |
+
return [i for megabatch in megabatches for i in megabatch]
|
145 |
+
|
146 |
+
|
147 |
+
def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True):
|
148 |
+
"""
|
149 |
+
Return a list of indices so that each slice of `batch_size` consecutive indices correspond to elements of similar
|
150 |
+
lengths. To do this, the indices are:
|
151 |
+
|
152 |
+
- randomly permuted
|
153 |
+
- grouped in mega-batches of size `mega_batch_mult * batch_size`
|
154 |
+
- reorder by length in each mega-batch
|
155 |
+
|
156 |
+
The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of
|
157 |
+
maximum length placed first, so that an OOM happens sooner rather than later.
|
158 |
+
"""
|
159 |
+
|
160 |
+
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
|
161 |
+
indices = torch.randperm(len(lengths), generator=generator)
|
162 |
+
megabatch_size = world_size * batch_size
|
163 |
+
megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
|
164 |
+
megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
|
165 |
+
megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]
|
166 |
+
|
167 |
+
return [i for megabatch in megabatches for batch in megabatch for i in batch]
|
168 |
+
|
169 |
+
|
170 |
+
def get_length_grouped_indices_auto_single(lengths, batch_size, world_size, generator=None):
|
171 |
+
indices = get_length_grouped_indices_hf(lengths, batch_size * world_size, generator=generator)
|
172 |
+
|
173 |
+
megabatch_size = world_size * batch_size
|
174 |
+
megabatches = [indices[i : i + megabatch_size] for i in range(0, len(lengths), megabatch_size)]
|
175 |
+
megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
|
176 |
+
megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]
|
177 |
+
|
178 |
+
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
|
179 |
+
batch_indices = torch.randperm(len(megabatches), generator=generator)
|
180 |
+
megabatches = [megabatches[i] for i in batch_indices]
|
181 |
+
|
182 |
+
return [i for megabatch in megabatches for batch in megabatch for i in batch]
|
183 |
+
|
184 |
+
|
185 |
+
def get_modality_length_grouped_indices_auto(lengths, batch_size, world_size, generator=None):
|
186 |
+
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
|
187 |
+
assert all(l != 0 for l in lengths), "Should not have zero length."
|
188 |
+
if all(l > 0 for l in lengths) or all(l < 0 for l in lengths):
|
189 |
+
# all samples are in the same modality
|
190 |
+
return get_length_grouped_indices_auto_single(lengths, batch_size, world_size, generator=generator)
|
191 |
+
mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0])
|
192 |
+
lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0])
|
193 |
+
|
194 |
+
mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices_auto_single(mm_lengths, batch_size, world_size, generator=None)]
|
195 |
+
lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices_auto_single(lang_lengths, batch_size, world_size, generator=None)]
|
196 |
+
megabatch_size = world_size * batch_size
|
197 |
+
mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)]
|
198 |
+
lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)]
|
199 |
+
|
200 |
+
last_mm = mm_megabatches[-1]
|
201 |
+
last_lang = lang_megabatches[-1]
|
202 |
+
additional_batch = last_mm + last_lang
|
203 |
+
megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
|
204 |
+
megabatch_indices = torch.randperm(len(megabatches), generator=generator)
|
205 |
+
megabatches = [megabatches[i] for i in megabatch_indices]
|
206 |
+
|
207 |
+
if len(additional_batch) > 0:
|
208 |
+
megabatches.append(sorted(additional_batch))
|
209 |
+
|
210 |
+
return [i for megabatch in megabatches for i in megabatch]
|
211 |
+
|
212 |
+
|
213 |
+
class LengthGroupedSampler(Sampler):
|
214 |
+
r"""
|
215 |
+
Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
|
216 |
+
keeping a bit of randomness.
|
217 |
+
"""
|
218 |
+
|
219 |
+
def __init__(
|
220 |
+
self,
|
221 |
+
batch_size: int,
|
222 |
+
world_size: int,
|
223 |
+
lengths: Optional[List[int]] = None,
|
224 |
+
generator=None,
|
225 |
+
variable_length: bool = False,
|
226 |
+
group_by_modality: bool = False,
|
227 |
+
group_by_modality_auto: bool = False,
|
228 |
+
):
|
229 |
+
if lengths is None:
|
230 |
+
raise ValueError("Lengths must be provided.")
|
231 |
+
|
232 |
+
self.batch_size = batch_size
|
233 |
+
self.world_size = world_size
|
234 |
+
self.lengths = lengths
|
235 |
+
self.generator = generator
|
236 |
+
self.variable_length = variable_length
|
237 |
+
self.group_by_modality = group_by_modality
|
238 |
+
self.group_by_modality_auto = group_by_modality_auto
|
239 |
+
|
240 |
+
def __len__(self):
|
241 |
+
return len(self.lengths)
|
242 |
+
|
243 |
+
def __iter__(self):
|
244 |
+
if self.variable_length:
|
245 |
+
assert not self.group_by_modality, "Variable length grouping is not supported with modality grouping."
|
246 |
+
indices = get_variable_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
|
247 |
+
else:
|
248 |
+
if self.group_by_modality:
|
249 |
+
indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
|
250 |
+
elif self.group_by_modality_auto:
|
251 |
+
indices = get_modality_length_grouped_indices_auto(self.lengths, self.batch_size, self.world_size, generator=self.generator)
|
252 |
+
else:
|
253 |
+
indices = get_length_grouped_indices_auto_single(self.lengths, self.batch_size, self.world_size, generator=self.generator)
|
254 |
+
return iter(indices)
|
255 |
+
|
256 |
+
|
257 |
+
|
258 |
+
def _is_peft_model(model):
|
259 |
+
if is_peft_available():
|
260 |
+
classes_to_check = (PeftModel,) if is_peft_available() else ()
|
261 |
+
# Here we also check if the model is an instance of `PeftMixedModel` introduced in peft>=0.7.0: https://github.com/huggingface/transformers/pull/28321
|
262 |
+
if version.parse(importlib.metadata.version("peft")) >= version.parse("0.7.0"):
|
263 |
+
from peft import PeftMixedModel
|
264 |
+
|
265 |
+
classes_to_check = (*classes_to_check, PeftMixedModel)
|
266 |
+
return isinstance(model, classes_to_check)
|
267 |
+
return False
|
268 |
+
|
269 |
+
|
270 |
+
TRAINER_STATE_NAME = "trainer_state.json"
|
271 |
+
|
272 |
+
class OryxTrainer(Trainer):
|
273 |
+
|
274 |
+
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
275 |
+
if self.train_dataset is None or not has_length(self.train_dataset):
|
276 |
+
return None
|
277 |
+
|
278 |
+
if self.args.group_by_length:
|
279 |
+
lengths = self.train_dataset.lengths
|
280 |
+
return LengthGroupedSampler(
|
281 |
+
# self.args.train_batch_size * self.args.gradient_accumulation_steps, # TODO: seems that we should not have gradient_accumulation_steps
|
282 |
+
self.args.train_batch_size,
|
283 |
+
# world_size=self.args.world_size,
|
284 |
+
world_size=self.args.world_size * self.args.gradient_accumulation_steps, # TODO: seems that this may work?
|
285 |
+
lengths=lengths,
|
286 |
+
)
|
287 |
+
elif self.args.group_by_modality_length:
|
288 |
+
lengths = self.train_dataset.modality_lengths
|
289 |
+
return LengthGroupedSampler(
|
290 |
+
# self.args.train_batch_size * self.args.gradient_accumulation_steps, # TODO: seems that we should not have gradient_accumulation_steps
|
291 |
+
self.args.train_batch_size,
|
292 |
+
# world_size=self.args.world_size,
|
293 |
+
world_size=self.args.world_size * self.args.gradient_accumulation_steps, # TODO: seems that this may work?
|
294 |
+
lengths=lengths,
|
295 |
+
group_by_modality=True,
|
296 |
+
)
|
297 |
+
elif self.args.group_by_modality_length_auto:
|
298 |
+
lengths = self.train_dataset.modality_lengths
|
299 |
+
return LengthGroupedSampler(
|
300 |
+
# self.args.train_batch_size * self.args.gradient_accumulation_steps, # TODO: seems that we should not have gradient_accumulation_steps
|
301 |
+
self.args.train_batch_size,
|
302 |
+
# world_size=self.args.world_size,
|
303 |
+
world_size=self.args.world_size * self.args.gradient_accumulation_steps, # TODO: seems that this may work?
|
304 |
+
lengths=lengths,
|
305 |
+
group_by_modality_auto=True,
|
306 |
+
)
|
307 |
+
elif self.args.group_by_varlen:
|
308 |
+
lengths = self.train_dataset.lengths
|
309 |
+
return LengthGroupedSampler(
|
310 |
+
self.args.train_batch_size * self.args.gradient_accumulation_steps,
|
311 |
+
# self.args.train_batch_size, # TODO: seems that we should have gradient_accumulation_steps
|
312 |
+
# world_size=self.args.world_size,
|
313 |
+
world_size=self.args.world_size * self.args.gradient_accumulation_steps, # TODO: seems that this may work?
|
314 |
+
lengths=lengths,
|
315 |
+
variable_length=True,
|
316 |
+
)
|
317 |
+
else:
|
318 |
+
return super()._get_train_sampler()
|
319 |
+
|
320 |
+
def create_optimizer(self):
|
321 |
+
"""
|
322 |
+
Setup the optimizer.
|
323 |
+
|
324 |
+
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
|
325 |
+
Trainer's init through `optimizers`, or subclass and override this method in a subclass.
|
326 |
+
"""
|
327 |
+
if is_sagemaker_mp_enabled():
|
328 |
+
return super().create_optimizer()
|
329 |
+
|
330 |
+
opt_model = self.model
|
331 |
+
|
332 |
+
if self.optimizer is None:
|
333 |
+
decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
|
334 |
+
decay_parameters = [name for name in decay_parameters if "bias" not in name]
|
335 |
+
lr_mapper = {}
|
336 |
+
if self.args.mm_projector_lr is not None:
|
337 |
+
lr_mapper['mm_projector'] = self.args.mm_projector_lr
|
338 |
+
if self.args.mm_vision_tower_lr is not None:
|
339 |
+
lr_mapper['vision_tower'] = self.args.mm_vision_tower_lr
|
340 |
+
if len(lr_mapper) > 0:
|
341 |
+
special_lr_parameters = [name for name, _ in opt_model.named_parameters() if any(module_keyword in name for module_keyword in lr_mapper)]
|
342 |
+
optimizer_grouped_parameters = [
|
343 |
+
{
|
344 |
+
"params": [
|
345 |
+
p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in special_lr_parameters and p.requires_grad)
|
346 |
+
],
|
347 |
+
"weight_decay": self.args.weight_decay,
|
348 |
+
},
|
349 |
+
{
|
350 |
+
"params": [
|
351 |
+
p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in special_lr_parameters and p.requires_grad)
|
352 |
+
],
|
353 |
+
"weight_decay": 0.0,
|
354 |
+
},
|
355 |
+
]
|
356 |
+
for module_keyword, lr in lr_mapper.items():
|
357 |
+
module_parameters = [name for name, _ in opt_model.named_parameters() if module_keyword in name]
|
358 |
+
optimizer_grouped_parameters.extend([
|
359 |
+
{
|
360 |
+
"params": [
|
361 |
+
p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in module_parameters and p.requires_grad)
|
362 |
+
],
|
363 |
+
"weight_decay": self.args.weight_decay,
|
364 |
+
"lr": lr,
|
365 |
+
},
|
366 |
+
{
|
367 |
+
"params": [
|
368 |
+
p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in module_parameters and p.requires_grad)
|
369 |
+
],
|
370 |
+
"weight_decay": 0.0,
|
371 |
+
"lr": lr,
|
372 |
+
},
|
373 |
+
])
|
374 |
+
else:
|
375 |
+
optimizer_grouped_parameters = [
|
376 |
+
{
|
377 |
+
"params": [
|
378 |
+
p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
|
379 |
+
],
|
380 |
+
"weight_decay": self.args.weight_decay,
|
381 |
+
},
|
382 |
+
{
|
383 |
+
"params": [
|
384 |
+
p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
|
385 |
+
],
|
386 |
+
"weight_decay": 0.0,
|
387 |
+
},
|
388 |
+
]
|
389 |
+
|
390 |
+
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
|
391 |
+
|
392 |
+
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
393 |
+
if optimizer_cls.__name__ == "Adam8bit":
|
394 |
+
import bitsandbytes
|
395 |
+
|
396 |
+
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
|
397 |
+
|
398 |
+
skipped = 0
|
399 |
+
for module in opt_model.modules():
|
400 |
+
if isinstance(module, nn.Embedding):
|
401 |
+
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
|
402 |
+
logger.info(f"skipped {module}: {skipped/2**20}M params")
|
403 |
+
manager.register_module_override(module, "weight", {"optim_bits": 32})
|
404 |
+
logger.debug(f"bitsandbytes: will optimize {module} in fp32")
|
405 |
+
logger.info(f"skipped: {skipped/2**20}M params")
|
406 |
+
|
407 |
+
return self.optimizer
|
408 |
+
|
409 |
+
def _save_checkpoint(self, model, trial, metrics=None):
|
410 |
+
if getattr(self.args, 'tune_mm_mlp_adapter', False):
|
411 |
+
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
412 |
+
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
413 |
+
|
414 |
+
run_dir = self._get_output_dir(trial=trial)
|
415 |
+
output_dir = os.path.join(run_dir, checkpoint_folder)
|
416 |
+
|
417 |
+
# Only save Adapter
|
418 |
+
keys_to_match = ['mm_projector', 'vision_resampler']
|
419 |
+
if getattr(self.args, "use_im_start_end", False):
|
420 |
+
keys_to_match.extend(['embed_tokens', 'embed_in'])
|
421 |
+
|
422 |
+
weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match)
|
423 |
+
|
424 |
+
if self.args.local_rank == 0 or self.args.local_rank == -1:
|
425 |
+
self.model.config.save_pretrained(output_dir)
|
426 |
+
torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
|
427 |
+
else:
|
428 |
+
print("self.is_local_process_zero()",self.is_local_process_zero())
|
429 |
+
super(OryxTrainer, self)._save_checkpoint(model, trial, metrics)
|
430 |
+
|
431 |
+
def _save(self, output_dir: Optional[str] = None, state_dict=None):
|
432 |
+
if getattr(self.args, 'tune_mm_mlp_adapter', False):
|
433 |
+
pass
|
434 |
+
else:
|
435 |
+
super(OryxTrainer, self)._save(output_dir, state_dict)
|
oryx/train/train.py
ADDED
@@ -0,0 +1,1686 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
import copy
|
4 |
+
from dataclasses import dataclass, field
|
5 |
+
import json
|
6 |
+
import logging
|
7 |
+
import pathlib
|
8 |
+
from typing import Dict, Optional, Sequence, List
|
9 |
+
import ast
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import time
|
13 |
+
import random
|
14 |
+
import cv2
|
15 |
+
|
16 |
+
import transformers
|
17 |
+
import tokenizers
|
18 |
+
|
19 |
+
from oryx.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IMAGE_TOKEN_INDEX
|
20 |
+
from torch.utils.data import Dataset
|
21 |
+
from oryx.train.oryx_trainer import OryxTrainer
|
22 |
+
|
23 |
+
from oryx import conversation as conversation_lib
|
24 |
+
from oryx.model import *
|
25 |
+
from oryx.mm_utils import tokenizer_image_token, process_anyres_highres_image_genli, process_anyres_video_genli, process_anyres_video_genli_long
|
26 |
+
|
27 |
+
from PIL import Image
|
28 |
+
import io
|
29 |
+
import base64
|
30 |
+
|
31 |
+
from packaging import version
|
32 |
+
|
33 |
+
import numpy as np
|
34 |
+
|
35 |
+
from transformers import AutoConfig
|
36 |
+
|
37 |
+
import math
|
38 |
+
import copy
|
39 |
+
|
40 |
+
|
41 |
+
local_rank = None
|
42 |
+
|
43 |
+
|
44 |
+
def rank0_print(*args):
|
45 |
+
if local_rank == 0:
|
46 |
+
print(*args)
|
47 |
+
|
48 |
+
IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse('0.14')
|
49 |
+
|
50 |
+
@dataclass
|
51 |
+
class ModelArguments:
|
52 |
+
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
|
53 |
+
version: Optional[str] = field(default="v0")
|
54 |
+
freeze_backbone: bool = field(default=False)
|
55 |
+
tune_mm_mlp_adapter: bool = field(default=False)
|
56 |
+
tune_mm_vision_resampler: bool = field(default=False)
|
57 |
+
vision_tower: Optional[str] = field(default=None)
|
58 |
+
image_processor: Optional[str] = field(default=None)
|
59 |
+
unfreeze_mm_vision_tower: bool = field(default=False)
|
60 |
+
mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer
|
61 |
+
pretrain_mm_mlp_adapter: Optional[str] = field(default=None)
|
62 |
+
mm_projector_type: Optional[str] = field(default='linear')
|
63 |
+
mm_use_im_start_end: bool = field(default=False)
|
64 |
+
mm_use_im_patch_token: bool = field(default=True)
|
65 |
+
mm_vision_select_feature: Optional[str] = field(default="patch")
|
66 |
+
mm_resampler_type: Optional[str] = field(default=None)
|
67 |
+
mm_mask_drop_mode: str = field(default="fixed")
|
68 |
+
mm_mask_drop_skip_percentage: float = field(default=0.)
|
69 |
+
mm_mask_drop_ratio: float = field(default=0.25)
|
70 |
+
mm_mask_drop_ratio_upper: Optional[float] = field(default=None)
|
71 |
+
mm_mask_drop_ratio_lower: Optional[float] = field(default=None)
|
72 |
+
|
73 |
+
@dataclass
|
74 |
+
class DataArguments:
|
75 |
+
data_path: str = field(default=None,
|
76 |
+
metadata={"help": "Path to the training data."})
|
77 |
+
lazy_preprocess: bool = False
|
78 |
+
is_multimodal: bool = False
|
79 |
+
video_fps: Optional[int] = field(default=1)
|
80 |
+
frames_upbound: Optional[int] = field(default=0)
|
81 |
+
|
82 |
+
@dataclass
|
83 |
+
class TrainingArguments(transformers.TrainingArguments):
|
84 |
+
cache_dir: Optional[str] = field(default=None)
|
85 |
+
optim: str = field(default="adamw_torch")
|
86 |
+
remove_unused_columns: bool = field(default=False)
|
87 |
+
freeze_mm_mlp_adapter: bool = field(default=False)
|
88 |
+
freeze_mm_vision_resampler: bool = field(default=False)
|
89 |
+
mpt_attn_impl: Optional[str] = field(default="triton")
|
90 |
+
model_max_length: int = field(
|
91 |
+
default=512,
|
92 |
+
metadata={
|
93 |
+
"help":
|
94 |
+
"Maximum sequence length. Sequences will be right padded (and possibly truncated)."
|
95 |
+
},
|
96 |
+
)
|
97 |
+
double_quant: bool = field(
|
98 |
+
default=True,
|
99 |
+
metadata={"help": "Compress the quantization statistics through double quantization."}
|
100 |
+
)
|
101 |
+
quant_type: str = field(
|
102 |
+
default="nf4",
|
103 |
+
metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
|
104 |
+
)
|
105 |
+
bits: int = field(
|
106 |
+
default=16,
|
107 |
+
metadata={"help": "How many bits to use."}
|
108 |
+
)
|
109 |
+
lora_enable: bool = False
|
110 |
+
lora_r: int = 64
|
111 |
+
lora_alpha: int = 16
|
112 |
+
lora_dropout: float = 0.05
|
113 |
+
lora_weight_path: str = ""
|
114 |
+
lora_bias: str = "none"
|
115 |
+
mm_projector_lr: Optional[float] = None
|
116 |
+
mm_vision_tower_lr: Optional[float] = None
|
117 |
+
group_by_varlen: bool = field(default=False)
|
118 |
+
group_by_modality_length: bool = field(default=False)
|
119 |
+
group_by_modality_length_auto: bool = field(default=False)
|
120 |
+
do_resize: bool = field(default=False)
|
121 |
+
do_center_crop: bool = field(default=False)
|
122 |
+
|
123 |
+
|
124 |
+
def maybe_zero_3(param, ignore_status=False, name=None):
|
125 |
+
from deepspeed import zero
|
126 |
+
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
|
127 |
+
if hasattr(param, "ds_id"):
|
128 |
+
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
|
129 |
+
if not ignore_status:
|
130 |
+
logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
|
131 |
+
with zero.GatheredParameters([param]):
|
132 |
+
param = param.data.detach().cpu().clone()
|
133 |
+
else:
|
134 |
+
param = param.detach().cpu().clone()
|
135 |
+
return param
|
136 |
+
|
137 |
+
|
138 |
+
# Borrowed from peft.utils.get_peft_model_state_dict
|
139 |
+
def get_peft_state_maybe_zero_3(named_params, bias):
|
140 |
+
if bias == "none":
|
141 |
+
to_return = {k: t for k, t in named_params if "lora_" in k}
|
142 |
+
elif bias == "all":
|
143 |
+
to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
|
144 |
+
elif bias == "lora_only":
|
145 |
+
to_return = {}
|
146 |
+
maybe_lora_bias = {}
|
147 |
+
lora_bias_names = set()
|
148 |
+
for k, t in named_params:
|
149 |
+
if "lora_" in k:
|
150 |
+
to_return[k] = t
|
151 |
+
bias_name = k.split("lora_")[0] + "bias"
|
152 |
+
lora_bias_names.add(bias_name)
|
153 |
+
elif "bias" in k:
|
154 |
+
maybe_lora_bias[k] = t
|
155 |
+
for k, t in maybe_lora_bias:
|
156 |
+
if bias_name in lora_bias_names:
|
157 |
+
to_return[bias_name] = t
|
158 |
+
else:
|
159 |
+
raise NotImplementedError
|
160 |
+
to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
|
161 |
+
return to_return
|
162 |
+
|
163 |
+
|
164 |
+
def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
|
165 |
+
to_return = {k: t for k, t in named_params if "lora_" not in k}
|
166 |
+
if require_grad_only:
|
167 |
+
to_return = {k: t for k, t in to_return.items() if t.requires_grad}
|
168 |
+
to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
|
169 |
+
return to_return
|
170 |
+
|
171 |
+
|
172 |
+
def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
|
173 |
+
to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
|
174 |
+
to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
|
175 |
+
return to_return
|
176 |
+
|
177 |
+
|
178 |
+
def find_all_linear_names(model):
|
179 |
+
cls = torch.nn.Linear
|
180 |
+
lora_module_names = set()
|
181 |
+
multimodal_keywords = ['mm_projector', 'vision_tower', 'vision_resampler']
|
182 |
+
for name, module in model.named_modules():
|
183 |
+
if any(mm_keyword in name for mm_keyword in multimodal_keywords):
|
184 |
+
continue
|
185 |
+
if isinstance(module, cls):
|
186 |
+
names = name.split('.')
|
187 |
+
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
|
188 |
+
|
189 |
+
|
190 |
+
if 'lm_head' in lora_module_names: # needed for 16-bit
|
191 |
+
lora_module_names.remove('lm_head')
|
192 |
+
return list(lora_module_names)
|
193 |
+
|
194 |
+
|
195 |
+
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer,
|
196 |
+
output_dir: str):
|
197 |
+
"""Collects the state dict and dump to disk."""
|
198 |
+
|
199 |
+
if getattr(trainer.args, "tune_mm_mlp_adapter", False):
|
200 |
+
# Only save Adapter
|
201 |
+
keys_to_match = ['mm_projector', 'vision_resampler']
|
202 |
+
if getattr(trainer.args, "use_im_start_end", False):
|
203 |
+
keys_to_match.extend(['embed_tokens', 'embed_in'])
|
204 |
+
|
205 |
+
weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match)
|
206 |
+
trainer.model.config.save_pretrained(output_dir)
|
207 |
+
|
208 |
+
current_folder = output_dir.split('/')[-1]
|
209 |
+
parent_folder = os.path.dirname(output_dir)
|
210 |
+
if trainer.args.local_rank == 0 or trainer.args.local_rank == -1:
|
211 |
+
if current_folder.startswith('checkpoint-'):
|
212 |
+
mm_projector_folder = os.path.join(parent_folder, "mm_projector")
|
213 |
+
os.makedirs(mm_projector_folder, exist_ok=True)
|
214 |
+
torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin'))
|
215 |
+
else:
|
216 |
+
torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
|
217 |
+
return
|
218 |
+
|
219 |
+
if trainer.deepspeed:
|
220 |
+
torch.cuda.synchronize()
|
221 |
+
trainer.save_model(output_dir)
|
222 |
+
return
|
223 |
+
|
224 |
+
state_dict = trainer.model.state_dict()
|
225 |
+
if trainer.args.should_save:
|
226 |
+
cpu_state_dict = {
|
227 |
+
key: value.cpu()
|
228 |
+
for key, value in state_dict.items()
|
229 |
+
}
|
230 |
+
del state_dict
|
231 |
+
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
|
232 |
+
|
233 |
+
|
234 |
+
def smart_tokenizer_and_embedding_resize(
|
235 |
+
special_tokens_dict: Dict,
|
236 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
237 |
+
model: transformers.PreTrainedModel,
|
238 |
+
):
|
239 |
+
"""Resize tokenizer and embedding.
|
240 |
+
|
241 |
+
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
|
242 |
+
"""
|
243 |
+
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
|
244 |
+
model.resize_token_embeddings(len(tokenizer))
|
245 |
+
|
246 |
+
if num_new_tokens > 0:
|
247 |
+
input_embeddings = model.get_input_embeddings().weight.data
|
248 |
+
output_embeddings = model.get_output_embeddings().weight.data
|
249 |
+
|
250 |
+
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
|
251 |
+
dim=0, keepdim=True)
|
252 |
+
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
|
253 |
+
dim=0, keepdim=True)
|
254 |
+
|
255 |
+
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
256 |
+
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
257 |
+
|
258 |
+
|
259 |
+
def _tokenize_fn(strings: Sequence[str],
|
260 |
+
tokenizer: transformers.PreTrainedTokenizer) -> Dict:
|
261 |
+
"""Tokenize a list of strings."""
|
262 |
+
tokenized_list = [
|
263 |
+
tokenizer(
|
264 |
+
text,
|
265 |
+
return_tensors="pt",
|
266 |
+
padding="longest",
|
267 |
+
max_length=tokenizer.model_max_length,
|
268 |
+
truncation=True,
|
269 |
+
) for text in strings
|
270 |
+
]
|
271 |
+
input_ids = labels = [
|
272 |
+
tokenized.input_ids[0] for tokenized in tokenized_list
|
273 |
+
]
|
274 |
+
input_ids_lens = labels_lens = [
|
275 |
+
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
|
276 |
+
for tokenized in tokenized_list
|
277 |
+
]
|
278 |
+
return dict(
|
279 |
+
input_ids=input_ids,
|
280 |
+
labels=labels,
|
281 |
+
input_ids_lens=input_ids_lens,
|
282 |
+
labels_lens=labels_lens,
|
283 |
+
)
|
284 |
+
|
285 |
+
|
286 |
+
def _mask_targets(target, tokenized_lens, speakers):
|
287 |
+
# cur_idx = 0
|
288 |
+
cur_idx = tokenized_lens[0]
|
289 |
+
tokenized_lens = tokenized_lens[1:]
|
290 |
+
target[:cur_idx] = IGNORE_INDEX
|
291 |
+
for tokenized_len, speaker in zip(tokenized_lens, speakers):
|
292 |
+
if speaker == "human":
|
293 |
+
target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX
|
294 |
+
cur_idx += tokenized_len
|
295 |
+
|
296 |
+
|
297 |
+
def _add_speaker_and_signal(header, source, get_conversation=True):
|
298 |
+
"""Add speaker and start/end signal on each round."""
|
299 |
+
BEGIN_SIGNAL = "### "
|
300 |
+
END_SIGNAL = "\n"
|
301 |
+
conversation = header
|
302 |
+
for sentence in source:
|
303 |
+
from_str = sentence["from"]
|
304 |
+
if from_str.lower() == "human":
|
305 |
+
from_str = conversation_lib.default_conversation.roles[0]
|
306 |
+
elif from_str.lower() == "gpt":
|
307 |
+
from_str = conversation_lib.default_conversation.roles[1]
|
308 |
+
else:
|
309 |
+
from_str = 'unknown'
|
310 |
+
sentence["value"] = (BEGIN_SIGNAL + from_str + ": " +
|
311 |
+
sentence["value"] + END_SIGNAL)
|
312 |
+
if get_conversation:
|
313 |
+
conversation += sentence["value"]
|
314 |
+
conversation += BEGIN_SIGNAL
|
315 |
+
return conversation
|
316 |
+
|
317 |
+
|
318 |
+
def preprocess_multimodal(
|
319 |
+
sources: Sequence[str],
|
320 |
+
data_args: DataArguments,
|
321 |
+
) -> Dict:
|
322 |
+
is_multimodal = data_args.is_multimodal
|
323 |
+
if not is_multimodal:
|
324 |
+
return sources
|
325 |
+
|
326 |
+
for source in sources:
|
327 |
+
for sentence in source:
|
328 |
+
if DEFAULT_IMAGE_TOKEN in sentence['value'] and not sentence['value'].startswith(DEFAULT_IMAGE_TOKEN):
|
329 |
+
sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip()
|
330 |
+
sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value']
|
331 |
+
sentence['value'] = sentence['value'].strip()
|
332 |
+
if "mmtag" in conversation_lib.default_conversation.version:
|
333 |
+
sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '<Image>' + DEFAULT_IMAGE_TOKEN + '</Image>')
|
334 |
+
replace_token = DEFAULT_IMAGE_TOKEN
|
335 |
+
if data_args.mm_use_im_start_end:
|
336 |
+
replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
|
337 |
+
sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)
|
338 |
+
|
339 |
+
return sources
|
340 |
+
|
341 |
+
def preprocess_multimodal_movie(
|
342 |
+
sources: Sequence[str],
|
343 |
+
data_args: DataArguments,
|
344 |
+
video_inputs: str
|
345 |
+
) -> Dict:
|
346 |
+
is_multimodal = data_args.is_multimodal
|
347 |
+
if not is_multimodal:
|
348 |
+
return sources
|
349 |
+
|
350 |
+
for source in sources:
|
351 |
+
for sentence in source:
|
352 |
+
if DEFAULT_IMAGE_TOKEN in sentence['value']:
|
353 |
+
prompt = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip()
|
354 |
+
replace_token = video_inputs
|
355 |
+
if data_args.mm_use_im_start_end:
|
356 |
+
replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
|
357 |
+
sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)
|
358 |
+
|
359 |
+
return sources, prompt
|
360 |
+
|
361 |
+
|
362 |
+
def preprocess_qwen(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False, max_len=2048, system_message: str = "You are a helpful assistant.") -> Dict:
|
363 |
+
roles = {"human": "<|im_start|>user", "gpt": "<|im_start|>assistant"}
|
364 |
+
|
365 |
+
# im_start, im_end = tokenizer.additional_special_tokens_ids
|
366 |
+
|
367 |
+
im_start = tokenizer("<|im_start|>").input_ids[0]
|
368 |
+
im_end = tokenizer("<|im_end|>").input_ids[0]
|
369 |
+
nl_tokens = tokenizer("\n").input_ids
|
370 |
+
_system = tokenizer("system").input_ids + nl_tokens
|
371 |
+
|
372 |
+
# Apply prompt templates
|
373 |
+
input_ids, targets = [], []
|
374 |
+
for i, source in enumerate(sources):
|
375 |
+
if roles[source[0]["from"]] != roles["human"]:
|
376 |
+
source = source[1:]
|
377 |
+
|
378 |
+
input_id, target = [], []
|
379 |
+
system = [im_start] + _system + tokenizer(system_message).input_ids + [im_end] + nl_tokens
|
380 |
+
input_id += system
|
381 |
+
target += [im_start] + [IGNORE_INDEX] * (len(system) - 3) + [im_end] + nl_tokens
|
382 |
+
assert len(input_id) == len(target)
|
383 |
+
for j, sentence in enumerate(source):
|
384 |
+
role = roles[sentence["from"]]
|
385 |
+
if has_image and "<image>" in sentence["value"]:
|
386 |
+
# assert sentence["value"].startswith("<image>"), print(sentence["value"])
|
387 |
+
if sentence["value"].startswith("<image>"):
|
388 |
+
_input_id = tokenizer(role).input_ids + nl_tokens + [IMAGE_TOKEN_INDEX] + nl_tokens + tokenizer(sentence["value"][len("<image>") :]).input_ids + [im_end] + nl_tokens
|
389 |
+
else:
|
390 |
+
_input_id = []
|
391 |
+
split_value = sentence["value"].split('<image>\n')
|
392 |
+
_input_id += tokenizer(role).input_ids + nl_tokens
|
393 |
+
for idx, cur_value in enumerate(split_value):
|
394 |
+
if idx == len(split_value) - 1:
|
395 |
+
_input_id = _input_id + tokenizer(cur_value).input_ids + [im_end] + nl_tokens
|
396 |
+
else:
|
397 |
+
_input_id = _input_id + tokenizer(cur_value).input_ids + [IMAGE_TOKEN_INDEX] + nl_tokens
|
398 |
+
# # add end of text token
|
399 |
+
# if PACK_SEQ > 0:
|
400 |
+
# if j > 0:
|
401 |
+
# _input_id = _end_of_text + _input_id
|
402 |
+
else:
|
403 |
+
_input_id = tokenizer(role).input_ids + nl_tokens + tokenizer(sentence["value"]).input_ids + [im_end] + nl_tokens
|
404 |
+
# # add end of text token for pure text data
|
405 |
+
# if PACK_SEQ > 0:
|
406 |
+
# if sentence['from'] == 'human' and j > 0:
|
407 |
+
# _input_id = _end_of_text + _input_id
|
408 |
+
input_id += _input_id
|
409 |
+
if role == "<|im_start|>user":
|
410 |
+
_target = [im_start] + [IGNORE_INDEX] * (len(_input_id) - 3) + [im_end] + nl_tokens
|
411 |
+
elif role == "<|im_start|>assistant":
|
412 |
+
_target = [im_start] + [IGNORE_INDEX] * len(tokenizer(role).input_ids) + _input_id[len(tokenizer(role).input_ids) + 1 : -2] + [im_end] + nl_tokens
|
413 |
+
else:
|
414 |
+
raise NotImplementedError
|
415 |
+
target += _target
|
416 |
+
assert len(input_id) == len(target)
|
417 |
+
# input_id += [tokenizer.pad_token_id] * (max_len - len(input_id))
|
418 |
+
# target += [IGNORE_INDEX] * (max_len - len(target))
|
419 |
+
input_ids.append(input_id)
|
420 |
+
targets.append(target)
|
421 |
+
input_ids = torch.tensor(input_ids, dtype=torch.long)
|
422 |
+
targets = torch.tensor(targets, dtype=torch.long)
|
423 |
+
|
424 |
+
return dict(
|
425 |
+
input_ids=input_ids, # tensor(bs x seq_len)
|
426 |
+
labels=targets, # tensor(bs x seq_len)
|
427 |
+
# attention_mask=input_ids.ne(tokenizer.pad_token_id), # tensor(bs x seq_len)
|
428 |
+
)
|
429 |
+
|
430 |
+
def preprocess_llama_2(
|
431 |
+
sources,
|
432 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
433 |
+
has_image: bool = False
|
434 |
+
) -> Dict:
|
435 |
+
conv = conversation_lib.default_conversation.copy()
|
436 |
+
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
437 |
+
|
438 |
+
# Apply prompt templates
|
439 |
+
conversations = []
|
440 |
+
for i, source in enumerate(sources):
|
441 |
+
if roles[source[0]["from"]] != conv.roles[0]:
|
442 |
+
# Skip the first one if it is not from human
|
443 |
+
source = source[1:]
|
444 |
+
|
445 |
+
conv.messages = []
|
446 |
+
for j, sentence in enumerate(source):
|
447 |
+
role = roles[sentence["from"]]
|
448 |
+
assert role == conv.roles[j % 2], f"{i}"
|
449 |
+
conv.append_message(role, sentence["value"])
|
450 |
+
conversations.append(conv.get_prompt())
|
451 |
+
|
452 |
+
# Tokenize conversations
|
453 |
+
|
454 |
+
if has_image:
|
455 |
+
input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
|
456 |
+
else:
|
457 |
+
input_ids = tokenizer(
|
458 |
+
conversations,
|
459 |
+
return_tensors="pt",
|
460 |
+
padding="longest",
|
461 |
+
max_length=tokenizer.model_max_length,
|
462 |
+
truncation=True,
|
463 |
+
).input_ids
|
464 |
+
|
465 |
+
targets = input_ids.clone()
|
466 |
+
|
467 |
+
assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2
|
468 |
+
|
469 |
+
# Mask targets
|
470 |
+
sep = "[/INST] "
|
471 |
+
for conversation, target in zip(conversations, targets):
|
472 |
+
total_len = int(target.ne(tokenizer.pad_token_id).sum())
|
473 |
+
|
474 |
+
rounds = conversation.split(conv.sep2)
|
475 |
+
cur_len = 1
|
476 |
+
target[:cur_len] = IGNORE_INDEX
|
477 |
+
for i, rou in enumerate(rounds):
|
478 |
+
if rou == "":
|
479 |
+
break
|
480 |
+
|
481 |
+
parts = rou.split(sep)
|
482 |
+
if len(parts) != 2:
|
483 |
+
break
|
484 |
+
parts[0] += sep
|
485 |
+
|
486 |
+
if has_image:
|
487 |
+
round_len = len(tokenizer_image_token(rou, tokenizer))
|
488 |
+
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
|
489 |
+
else:
|
490 |
+
round_len = len(tokenizer(rou).input_ids)
|
491 |
+
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
|
492 |
+
|
493 |
+
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
|
494 |
+
|
495 |
+
cur_len += round_len
|
496 |
+
target[cur_len:] = IGNORE_INDEX
|
497 |
+
|
498 |
+
if cur_len < tokenizer.model_max_length:
|
499 |
+
if cur_len != total_len:
|
500 |
+
target[:] = IGNORE_INDEX
|
501 |
+
print(
|
502 |
+
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
|
503 |
+
f" (ignored)"
|
504 |
+
)
|
505 |
+
|
506 |
+
return dict(
|
507 |
+
input_ids=input_ids,
|
508 |
+
labels=targets,
|
509 |
+
)
|
510 |
+
|
511 |
+
def preprocess_llama_3(
|
512 |
+
sources,
|
513 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
514 |
+
has_image: bool = False
|
515 |
+
) -> Dict:
|
516 |
+
conv = copy.deepcopy(conversation_lib.conv_llava_llama_3)
|
517 |
+
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
518 |
+
|
519 |
+
# Apply prompt templates
|
520 |
+
conversations = []
|
521 |
+
for i, source in enumerate(sources):
|
522 |
+
if roles[source[0]["from"]] != conv.roles[0]:
|
523 |
+
# Skip the first one if it is not from human
|
524 |
+
source = source[1:]
|
525 |
+
|
526 |
+
conv.messages = []
|
527 |
+
for j, sentence in enumerate(source):
|
528 |
+
role = roles[sentence["from"]]
|
529 |
+
assert role == conv.roles[j % 2], f"{i}"
|
530 |
+
conv.append_message(role, sentence["value"])
|
531 |
+
conversations.append(conv.get_prompt())
|
532 |
+
|
533 |
+
# Tokenize conversations
|
534 |
+
|
535 |
+
if has_image:
|
536 |
+
input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
|
537 |
+
else:
|
538 |
+
input_ids = tokenizer(
|
539 |
+
conversations,
|
540 |
+
return_tensors="pt",
|
541 |
+
padding="longest",
|
542 |
+
max_length=tokenizer.model_max_length,
|
543 |
+
truncation=True,
|
544 |
+
).input_ids
|
545 |
+
|
546 |
+
targets = input_ids.clone()
|
547 |
+
|
548 |
+
offset = 0 if input_ids[0][0] != tokenizer.bos_token_id else 1
|
549 |
+
assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_3
|
550 |
+
# Mask targets
|
551 |
+
# sep = conv.sep + conv.roles[1] + ":"
|
552 |
+
sep = '<|start_header_id|>assistant<|end_header_id|>\n\n'
|
553 |
+
sep2 = '<|start_header_id|>user<|end_header_id|>\n\n'
|
554 |
+
# Llama3 tokenizer has the token for whitespace
|
555 |
+
# Typically, the token after whitespace will be naturally encoded as one token with whitespace
|
556 |
+
# some special cases like ": 3" will be encoded as :, whitespace, 3; 3 tokens. Only in this case, the loss on whitespace will be calculated
|
557 |
+
|
558 |
+
for conversation, target in zip(conversations, targets):
|
559 |
+
total_len = int(target.ne(tokenizer.pad_token_id).sum())
|
560 |
+
|
561 |
+
rounds = conversation.split(sep2)
|
562 |
+
cur_len = 1
|
563 |
+
target[:cur_len] = IGNORE_INDEX
|
564 |
+
|
565 |
+
# process system prompt
|
566 |
+
try:
|
567 |
+
rounds[1] = rounds[0] + sep2 + rounds[1]
|
568 |
+
del rounds[0]
|
569 |
+
except:
|
570 |
+
print('no user found')
|
571 |
+
raise ValueError
|
572 |
+
|
573 |
+
# add user
|
574 |
+
for i, rou in enumerate(rounds):
|
575 |
+
if i != 0:
|
576 |
+
rounds[i] = sep2 + rou
|
577 |
+
|
578 |
+
for i, rou in enumerate(rounds):
|
579 |
+
if rou == "":
|
580 |
+
break
|
581 |
+
|
582 |
+
parts = rou.split(sep)
|
583 |
+
if len(parts) != 2:
|
584 |
+
break
|
585 |
+
# parts[0] += sep
|
586 |
+
|
587 |
+
# supervise assistant: from pp's report
|
588 |
+
parts[1] = sep + parts[1]
|
589 |
+
# parts[0] = parts[0] + sep
|
590 |
+
|
591 |
+
if has_image:
|
592 |
+
round_len = len(tokenizer_image_token(rou, tokenizer)) - offset
|
593 |
+
instruction_len = len(tokenizer_image_token(parts[0], tokenizer))
|
594 |
+
else:
|
595 |
+
round_len = len(tokenizer(rou).input_ids) - offset
|
596 |
+
instruction_len = len(tokenizer(parts[0]).input_ids)
|
597 |
+
|
598 |
+
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
|
599 |
+
|
600 |
+
cur_len += round_len + (1 - offset) #starting from index 0, then cur_len will not cover eos token
|
601 |
+
|
602 |
+
if cur_len < tokenizer.model_max_length:
|
603 |
+
if cur_len != total_len:
|
604 |
+
target[:] = IGNORE_INDEX
|
605 |
+
print(
|
606 |
+
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
|
607 |
+
f" (ignored)"
|
608 |
+
)
|
609 |
+
|
610 |
+
if input_ids[0][0] != tokenizer.bos_token_id:
|
611 |
+
input_ids = [torch.cat([torch.LongTensor([tokenizer.bos_token_id]), i]) for i in input_ids]
|
612 |
+
targets = [torch.cat([torch.LongTensor([IGNORE_INDEX]), i]) for i in targets]
|
613 |
+
|
614 |
+
return dict(
|
615 |
+
input_ids=input_ids,
|
616 |
+
labels=targets,
|
617 |
+
)
|
618 |
+
|
619 |
+
|
620 |
+
def preprocess_v1(
|
621 |
+
sources,
|
622 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
623 |
+
has_image: bool = False
|
624 |
+
) -> Dict:
|
625 |
+
conv = conversation_lib.default_conversation.copy()
|
626 |
+
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
627 |
+
|
628 |
+
# Apply prompt templates
|
629 |
+
conversations = []
|
630 |
+
for i, source in enumerate(sources):
|
631 |
+
if roles[source[0]["from"]] != conv.roles[0]:
|
632 |
+
# Skip the first one if it is not from human
|
633 |
+
source = source[1:]
|
634 |
+
|
635 |
+
conv.messages = []
|
636 |
+
for j, sentence in enumerate(source):
|
637 |
+
role = roles[sentence["from"]]
|
638 |
+
assert role == conv.roles[j % 2], f"{i}"
|
639 |
+
conv.append_message(role, sentence["value"])
|
640 |
+
conversations.append(conv.get_prompt())
|
641 |
+
|
642 |
+
# Tokenize conversations
|
643 |
+
|
644 |
+
if has_image:
|
645 |
+
input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
|
646 |
+
else:
|
647 |
+
input_ids = tokenizer(
|
648 |
+
conversations,
|
649 |
+
return_tensors="pt",
|
650 |
+
padding="longest",
|
651 |
+
max_length=tokenizer.model_max_length,
|
652 |
+
truncation=True,
|
653 |
+
).input_ids
|
654 |
+
|
655 |
+
targets = input_ids.clone()
|
656 |
+
|
657 |
+
if conv.sep_style == conversation_lib.SeparatorStyle.TWO:
|
658 |
+
|
659 |
+
# Mask targets
|
660 |
+
sep = conv.sep + conv.roles[1] + ": "
|
661 |
+
for conversation, target in zip(conversations, targets):
|
662 |
+
total_len = int(target.ne(tokenizer.pad_token_id).sum())
|
663 |
+
|
664 |
+
rounds = conversation.split(conv.sep2)
|
665 |
+
cur_len = 1
|
666 |
+
target[:cur_len] = IGNORE_INDEX
|
667 |
+
for i, rou in enumerate(rounds):
|
668 |
+
if rou == "":
|
669 |
+
break
|
670 |
+
|
671 |
+
parts = rou.split(sep)
|
672 |
+
if len(parts) != 2:
|
673 |
+
break
|
674 |
+
parts[0] += sep
|
675 |
+
|
676 |
+
if has_image:
|
677 |
+
round_len = len(tokenizer_image_token(rou, tokenizer))
|
678 |
+
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
|
679 |
+
else:
|
680 |
+
round_len = len(tokenizer(rou).input_ids)
|
681 |
+
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
|
682 |
+
|
683 |
+
if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14:
|
684 |
+
round_len -= 1
|
685 |
+
instruction_len -= 1
|
686 |
+
|
687 |
+
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
|
688 |
+
|
689 |
+
cur_len += round_len
|
690 |
+
target[cur_len:] = IGNORE_INDEX
|
691 |
+
|
692 |
+
if cur_len < tokenizer.model_max_length:
|
693 |
+
if cur_len != total_len:
|
694 |
+
target[:] = IGNORE_INDEX
|
695 |
+
print(
|
696 |
+
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
|
697 |
+
f" (ignored)"
|
698 |
+
)
|
699 |
+
|
700 |
+
elif conv.sep_style == conversation_lib.SeparatorStyle.QWEN2:
|
701 |
+
# Mask targets
|
702 |
+
sep = '<|im_start|>assistant\n'
|
703 |
+
for conversation, target in zip(conversations, targets):
|
704 |
+
total_len = int(target.ne(tokenizer.pad_token_id).sum())
|
705 |
+
|
706 |
+
raw_rounds = conversation.split('<|im_end|>\n')
|
707 |
+
cur_len = 0
|
708 |
+
rounds = []
|
709 |
+
now_str = ''
|
710 |
+
for rou in raw_rounds:
|
711 |
+
if len(rou) > 0:
|
712 |
+
rou = rou + '<|im_end|>\n'
|
713 |
+
if rou.startswith('<|endoftext|>'):
|
714 |
+
rounds[-1] = rounds[-1] + '<|endoftext|>'
|
715 |
+
rou = rou.replace('<|endoftext|>', '')
|
716 |
+
if len(rou.strip()) == 0:
|
717 |
+
continue
|
718 |
+
if '<|im_start|>assistant\n' in rou:
|
719 |
+
now_str += rou
|
720 |
+
rounds.append(now_str)
|
721 |
+
now_str = ''
|
722 |
+
else:
|
723 |
+
now_str += rou
|
724 |
+
|
725 |
+
for i, rou in enumerate(rounds):
|
726 |
+
if rou == "":
|
727 |
+
break
|
728 |
+
|
729 |
+
parts = rou.split(sep)
|
730 |
+
if len(parts) != 2:
|
731 |
+
break
|
732 |
+
parts[0] += sep
|
733 |
+
|
734 |
+
if has_image:
|
735 |
+
round_len = len(tokenizer_image_token(rou, tokenizer))
|
736 |
+
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
|
737 |
+
else:
|
738 |
+
round_len = len(tokenizer(rou).input_ids)
|
739 |
+
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
|
740 |
+
|
741 |
+
try:
|
742 |
+
is_legacy = tokenizer.legacy
|
743 |
+
except:
|
744 |
+
is_legacy = True
|
745 |
+
|
746 |
+
if i != 0 and not is_legacy and IS_TOKENIZER_GREATER_THAN_0_14:
|
747 |
+
round_len -= 1
|
748 |
+
instruction_len -= 1
|
749 |
+
|
750 |
+
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
|
751 |
+
|
752 |
+
cur_len += round_len
|
753 |
+
target[cur_len:] = IGNORE_INDEX
|
754 |
+
|
755 |
+
if cur_len < tokenizer.model_max_length:
|
756 |
+
if cur_len != total_len:
|
757 |
+
target[:] = IGNORE_INDEX
|
758 |
+
print(
|
759 |
+
f"WARNING: tokenization mismatch for QWEN2: {cur_len} vs. {total_len}."
|
760 |
+
f" (ignored)"
|
761 |
+
)
|
762 |
+
|
763 |
+
return dict(
|
764 |
+
input_ids=input_ids,
|
765 |
+
labels=targets,
|
766 |
+
)
|
767 |
+
|
768 |
+
def preprocess_imgsp_v1(
|
769 |
+
sources,
|
770 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
771 |
+
has_image: bool = False,
|
772 |
+
img_token: str = '<image>',
|
773 |
+
refine_prompt: bool = False,
|
774 |
+
) -> Dict:
|
775 |
+
conv = conversation_lib.default_conversation.copy()
|
776 |
+
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
777 |
+
|
778 |
+
# Apply prompt templates
|
779 |
+
conversations = []
|
780 |
+
guided_prompt = []
|
781 |
+
for i, source in enumerate(sources):
|
782 |
+
if roles[source[0]["from"]] != conv.roles[0]:
|
783 |
+
# Skip the first one if it is not from human
|
784 |
+
source = source[1:]
|
785 |
+
|
786 |
+
conv.messages = []
|
787 |
+
img_in_text = False
|
788 |
+
for j, sentence in enumerate(source):
|
789 |
+
role = roles[sentence["from"]]
|
790 |
+
assert role == conv.roles[j % 2], f"{i}"
|
791 |
+
|
792 |
+
# add guided prompt
|
793 |
+
if role==conv.roles[0]:
|
794 |
+
guided_sent = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, '').replace('\n', '')
|
795 |
+
if refine_prompt:
|
796 |
+
# only keep the useful part of the prompt
|
797 |
+
if '\n' in guided_sent:
|
798 |
+
for _sent in guided_sent.split('\n'):
|
799 |
+
if '?' in _sent:
|
800 |
+
guided_sent = _sent
|
801 |
+
break
|
802 |
+
guided_prompt.append(guided_sent)
|
803 |
+
# check if image token in text
|
804 |
+
if img_token in sentence["value"]:
|
805 |
+
img_in_text = True
|
806 |
+
# add image token to all sentence if multimoal input
|
807 |
+
if role==conv.roles[0] and img_in_text and img_token not in sentence["value"]:
|
808 |
+
# randomly add image token to the beginning or end of the sentence
|
809 |
+
if random.randint(0,1)==0:
|
810 |
+
img_conv = img_token + '\n' + sentence["value"]
|
811 |
+
else:
|
812 |
+
img_conv = sentence["value"] + '\n' + img_token
|
813 |
+
|
814 |
+
conv.append_message(role, img_conv)
|
815 |
+
else:
|
816 |
+
conv.append_message(role, sentence["value"])
|
817 |
+
conversations.append(conv.get_prompt())
|
818 |
+
|
819 |
+
# Tokenize conversations
|
820 |
+
if has_image:
|
821 |
+
input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
|
822 |
+
else:
|
823 |
+
input_ids = tokenizer(
|
824 |
+
conversations,
|
825 |
+
return_tensors="pt",
|
826 |
+
padding="longest",
|
827 |
+
max_length=tokenizer.model_max_length,
|
828 |
+
truncation=True,
|
829 |
+
).input_ids
|
830 |
+
|
831 |
+
targets = input_ids.clone()
|
832 |
+
|
833 |
+
assert conv.sep_style == conversation_lib.SeparatorStyle.TWO
|
834 |
+
|
835 |
+
# Mask targets
|
836 |
+
sep = conv.sep + conv.roles[1] + ": "
|
837 |
+
for conversation, target in zip(conversations, targets):
|
838 |
+
total_len = int(target.ne(tokenizer.pad_token_id).sum())
|
839 |
+
|
840 |
+
rounds = conversation.split(conv.sep2)
|
841 |
+
cur_len = 1
|
842 |
+
target[:cur_len] = IGNORE_INDEX
|
843 |
+
for i, rou in enumerate(rounds):
|
844 |
+
if rou == "":
|
845 |
+
break
|
846 |
+
|
847 |
+
parts = rou.split(sep)
|
848 |
+
if len(parts) != 2:
|
849 |
+
break
|
850 |
+
parts[0] += sep
|
851 |
+
|
852 |
+
if has_image:
|
853 |
+
round_len = len(tokenizer_image_token(rou, tokenizer))
|
854 |
+
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
|
855 |
+
else:
|
856 |
+
round_len = len(tokenizer(rou).input_ids)
|
857 |
+
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
|
858 |
+
|
859 |
+
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
|
860 |
+
|
861 |
+
cur_len += round_len
|
862 |
+
target[cur_len:] = IGNORE_INDEX
|
863 |
+
|
864 |
+
if cur_len < tokenizer.model_max_length:
|
865 |
+
if cur_len != total_len:
|
866 |
+
target[:] = IGNORE_INDEX
|
867 |
+
print(
|
868 |
+
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
|
869 |
+
f" (ignored)"
|
870 |
+
)
|
871 |
+
return dict(
|
872 |
+
input_ids=input_ids,
|
873 |
+
labels=targets,
|
874 |
+
prompt=guided_prompt,
|
875 |
+
)
|
876 |
+
|
877 |
+
|
878 |
+
def preprocess_mpt(
|
879 |
+
sources,
|
880 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
881 |
+
has_image: bool = False
|
882 |
+
) -> Dict:
|
883 |
+
conv = conversation_lib.default_conversation.copy()
|
884 |
+
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
885 |
+
|
886 |
+
# Apply prompt templates
|
887 |
+
conversations = []
|
888 |
+
for i, source in enumerate(sources):
|
889 |
+
if roles[source[0]["from"]] != conv.roles[0]:
|
890 |
+
# Skip the first one if it is not from human
|
891 |
+
source = source[1:]
|
892 |
+
|
893 |
+
conv.messages = []
|
894 |
+
for j, sentence in enumerate(source):
|
895 |
+
role = roles[sentence["from"]]
|
896 |
+
assert role == conv.roles[j % 2], f"{i}"
|
897 |
+
conv.append_message(role, sentence["value"])
|
898 |
+
conversations.append(conv.get_prompt())
|
899 |
+
|
900 |
+
# Tokenize conversations
|
901 |
+
|
902 |
+
if has_image:
|
903 |
+
input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
|
904 |
+
else:
|
905 |
+
input_ids = tokenizer(
|
906 |
+
conversations,
|
907 |
+
return_tensors="pt",
|
908 |
+
padding="longest",
|
909 |
+
max_length=tokenizer.model_max_length,
|
910 |
+
truncation=True,
|
911 |
+
).input_ids
|
912 |
+
|
913 |
+
targets = input_ids.clone()
|
914 |
+
assert conv.sep_style == conversation_lib.SeparatorStyle.MPT
|
915 |
+
|
916 |
+
# Mask targets
|
917 |
+
sep = conv.sep + conv.roles[1]
|
918 |
+
for conversation, target in zip(conversations, targets):
|
919 |
+
total_len = int(target.ne(tokenizer.pad_token_id).sum())
|
920 |
+
|
921 |
+
rounds = conversation.split(conv.sep)
|
922 |
+
re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt
|
923 |
+
for conv_idx in range(3, len(rounds), 2):
|
924 |
+
re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2])) # user + gpt
|
925 |
+
cur_len = 1
|
926 |
+
target[:cur_len] = IGNORE_INDEX
|
927 |
+
for i, rou in enumerate(re_rounds):
|
928 |
+
if rou == "":
|
929 |
+
break
|
930 |
+
|
931 |
+
parts = rou.split(sep)
|
932 |
+
if len(parts) != 2:
|
933 |
+
break
|
934 |
+
parts[0] += sep
|
935 |
+
|
936 |
+
if has_image:
|
937 |
+
round_len = len(tokenizer_image_token(rou, tokenizer))
|
938 |
+
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1
|
939 |
+
else:
|
940 |
+
round_len = len(tokenizer(rou).input_ids)
|
941 |
+
instruction_len = len(tokenizer(parts[0]).input_ids) - 1
|
942 |
+
|
943 |
+
if i != 0 and getattr(tokenizer, 'legacy', False) and IS_TOKENIZER_GREATER_THAN_0_14:
|
944 |
+
round_len += 1
|
945 |
+
instruction_len += 1
|
946 |
+
|
947 |
+
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
|
948 |
+
|
949 |
+
cur_len += round_len
|
950 |
+
target[cur_len:] = IGNORE_INDEX
|
951 |
+
|
952 |
+
if cur_len < tokenizer.model_max_length:
|
953 |
+
if cur_len != total_len:
|
954 |
+
target[:] = IGNORE_INDEX
|
955 |
+
print(
|
956 |
+
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
|
957 |
+
f"(#turns={len(re_rounds)} ignored)"
|
958 |
+
)
|
959 |
+
|
960 |
+
return dict(
|
961 |
+
input_ids=input_ids,
|
962 |
+
labels=targets,
|
963 |
+
)
|
964 |
+
|
965 |
+
|
966 |
+
def preprocess_plain(
|
967 |
+
sources: Sequence[str],
|
968 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
969 |
+
) -> Dict:
|
970 |
+
# add end signal and concatenate together
|
971 |
+
conversations = []
|
972 |
+
for source in sources:
|
973 |
+
assert len(source) == 2
|
974 |
+
assert DEFAULT_IMAGE_TOKEN in source[0]['value']
|
975 |
+
source[0]['value'] = DEFAULT_IMAGE_TOKEN
|
976 |
+
conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep
|
977 |
+
conversations.append(conversation)
|
978 |
+
# tokenize conversations
|
979 |
+
input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]
|
980 |
+
targets = copy.deepcopy(input_ids)
|
981 |
+
for target, source in zip(targets, sources):
|
982 |
+
tokenized_len = len(tokenizer_image_token(source[0]['value'], tokenizer))
|
983 |
+
target[:tokenized_len] = IGNORE_INDEX
|
984 |
+
|
985 |
+
return dict(input_ids=input_ids, labels=targets)
|
986 |
+
|
987 |
+
|
988 |
+
def preprocess_plain_guided(
|
989 |
+
sources: Sequence[str],
|
990 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
991 |
+
prompt: str = None,
|
992 |
+
) -> Dict:
|
993 |
+
# add end signal and concatenate together
|
994 |
+
guided_prompt = []
|
995 |
+
conversations = []
|
996 |
+
for source in sources:
|
997 |
+
assert len(source) == 2
|
998 |
+
assert DEFAULT_IMAGE_TOKEN in source[0]['value']
|
999 |
+
guided_prompt.append(source[0]['value'].replace(DEFAULT_IMAGE_TOKEN, '').replace('\n', ''))
|
1000 |
+
source[0]['value'] = DEFAULT_IMAGE_TOKEN
|
1001 |
+
conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep
|
1002 |
+
conversations.append(conversation)
|
1003 |
+
# tokenize conversations
|
1004 |
+
input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]
|
1005 |
+
targets = copy.deepcopy(input_ids)
|
1006 |
+
for target, source in zip(targets, sources):
|
1007 |
+
tokenized_len = len(tokenizer_image_token(source[0]['value'], tokenizer))
|
1008 |
+
target[:tokenized_len] = IGNORE_INDEX
|
1009 |
+
|
1010 |
+
|
1011 |
+
def preprocess(
|
1012 |
+
sources: Sequence[str],
|
1013 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
1014 |
+
has_image: bool = False,
|
1015 |
+
) -> Dict:
|
1016 |
+
"""
|
1017 |
+
Given a list of sources, each is a conversation list. This transform:
|
1018 |
+
1. Add signal '### ' at the beginning each sentence, with end signal '\n';
|
1019 |
+
2. Concatenate conversations together;
|
1020 |
+
3. Tokenize the concatenated conversation;
|
1021 |
+
4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
|
1022 |
+
"""
|
1023 |
+
if conversation_lib.default_conversation.version.startswith("plain_guided"):
|
1024 |
+
return preprocess_plain_guided(sources, tokenizer)
|
1025 |
+
elif conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN:
|
1026 |
+
return preprocess_plain(sources, tokenizer)
|
1027 |
+
if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2:
|
1028 |
+
return preprocess_llama_2(sources, tokenizer, has_image=has_image)
|
1029 |
+
if conversation_lib.default_conversation.version.startswith("v1"):
|
1030 |
+
return preprocess_v1(sources, tokenizer, has_image=has_image)
|
1031 |
+
if conversation_lib.default_conversation.version.startswith("llama_v3"): # for llama 3 tokenizer
|
1032 |
+
return preprocess_llama_3(sources, tokenizer, has_image=has_image)
|
1033 |
+
if conversation_lib.default_conversation.version == "qwen":
|
1034 |
+
return preprocess_qwen(sources, tokenizer, has_image=has_image)
|
1035 |
+
elif conversation_lib.default_conversation.version.startswith("imgsp"):
|
1036 |
+
return preprocess_imgsp_v1(sources, tokenizer, has_image=has_image)
|
1037 |
+
if conversation_lib.default_conversation.version == "mpt":
|
1038 |
+
return preprocess_mpt(sources, tokenizer, has_image=has_image)
|
1039 |
+
# add end signal and concatenate together
|
1040 |
+
conversations = []
|
1041 |
+
for source in sources:
|
1042 |
+
header = f"{conversation_lib.default_conversation.system}\n\n"
|
1043 |
+
conversation = _add_speaker_and_signal(header, source)
|
1044 |
+
conversations.append(conversation)
|
1045 |
+
# tokenize conversations
|
1046 |
+
def get_tokenize_len(prompts):
|
1047 |
+
return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts]
|
1048 |
+
|
1049 |
+
if has_image:
|
1050 |
+
input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]
|
1051 |
+
else:
|
1052 |
+
conversations_tokenized = _tokenize_fn(conversations, tokenizer)
|
1053 |
+
input_ids = conversations_tokenized["input_ids"]
|
1054 |
+
|
1055 |
+
targets = copy.deepcopy(input_ids)
|
1056 |
+
for target, source in zip(targets, sources):
|
1057 |
+
if has_image:
|
1058 |
+
tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source])
|
1059 |
+
else:
|
1060 |
+
tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"]
|
1061 |
+
speakers = [sentence["from"] for sentence in source]
|
1062 |
+
_mask_targets(target, tokenized_lens, speakers)
|
1063 |
+
|
1064 |
+
return dict(input_ids=input_ids, labels=targets)
|
1065 |
+
|
1066 |
+
|
1067 |
+
def read_image_patch(patch_info):
|
1068 |
+
if 'img_path' in patch_info.keys():
|
1069 |
+
image = Image.open(patch_info['img_path']).convert('RGB')
|
1070 |
+
else:
|
1071 |
+
image_file_name = patch_info['patch']
|
1072 |
+
start_bytes = int(patch_info['start_num'])
|
1073 |
+
file_size = int(patch_info['size'])
|
1074 |
+
|
1075 |
+
with open(image_file_name, 'rb') as f:
|
1076 |
+
f.seek(start_bytes)
|
1077 |
+
if 'image_encoding' in patch_info.keys() and patch_info['image_encoding'] == 'base64':
|
1078 |
+
image = Image.open(io.BytesIO(base64.b64decode(f.read(file_size).decode()))).convert("RGB")
|
1079 |
+
else:
|
1080 |
+
image = Image.open(io.BytesIO(f.read(file_size))).convert("RGB")
|
1081 |
+
return image
|
1082 |
+
|
1083 |
+
|
1084 |
+
def read_video_patch(patch_info):
|
1085 |
+
if 'img_path' in patch_info.keys():
|
1086 |
+
image = Image.open(patch_info['img_path']).convert('RGB')
|
1087 |
+
else:
|
1088 |
+
image_file_name = patch_info['patch']
|
1089 |
+
start_bytes = int(patch_info['start_num'])
|
1090 |
+
file_size = patch_info['size'] # list of int
|
1091 |
+
total_file_size = 0
|
1092 |
+
images_all = []
|
1093 |
+
with open(image_file_name, 'rb') as f:
|
1094 |
+
for idx in range(len(file_size)):
|
1095 |
+
f.seek(start_bytes + total_file_size)
|
1096 |
+
if 'image_encoding' in patch_info.keys() and patch_info['image_encoding'] == 'base64':
|
1097 |
+
image = Image.open(io.BytesIO(base64.b64decode(f.read(int(file_size[idx])).decode()))).convert("RGB")
|
1098 |
+
else:
|
1099 |
+
if 'sharegpt4o' in image_file_name or 'ShareGPT4Video/new_patch' in image_file_name or 'cinepile' in image_file_name or 'nextqa' in image_file_name or 'perceptiontest' in image_file_name:
|
1100 |
+
byte_str = io.BytesIO(f.read(int(file_size[idx])))
|
1101 |
+
array = np.frombuffer(byte_str.getvalue(), dtype=np.uint8)
|
1102 |
+
image = cv2.imdecode(array, cv2.IMREAD_COLOR)
|
1103 |
+
image = Image.fromarray(image)
|
1104 |
+
else:
|
1105 |
+
image = Image.open(io.BytesIO(f.read(int(file_size[idx])))).convert("RGB")
|
1106 |
+
images_all.append(image)
|
1107 |
+
total_file_size += int(file_size[idx])
|
1108 |
+
return images_all
|
1109 |
+
|
1110 |
+
class LazySupervisedDataset(Dataset):
|
1111 |
+
"""Dataset for supervised fine-tuning."""
|
1112 |
+
|
1113 |
+
def __init__(self, data_path: str,
|
1114 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
1115 |
+
data_args: DataArguments):
|
1116 |
+
super(LazySupervisedDataset, self).__init__()
|
1117 |
+
list_data_dict = json.load(open(data_path, "r"))
|
1118 |
+
|
1119 |
+
rank0_print("Formatting inputs...Skip in lazy mode")
|
1120 |
+
self.tokenizer = tokenizer
|
1121 |
+
self.list_data_dict = list_data_dict
|
1122 |
+
self.data_args = data_args
|
1123 |
+
|
1124 |
+
# if PRETRAIN:
|
1125 |
+
self.mapping_dict = json.load(open('/apdcephfs_jn/share_302244400/peterrao/nj3/data/llava/videodata/MovieNet/movienet_mapping.json', "r"))
|
1126 |
+
print('loadding mapping dict')
|
1127 |
+
|
1128 |
+
def __len__(self):
|
1129 |
+
return len(self.list_data_dict)
|
1130 |
+
|
1131 |
+
@property
|
1132 |
+
def lengths(self):
|
1133 |
+
length_list = []
|
1134 |
+
for sample in self.list_data_dict:
|
1135 |
+
img_tokens = 128 if 'image' in sample else 0
|
1136 |
+
length_list.append(sum(len(conv['value'].split()) for conv in sample['conversations']) + img_tokens)
|
1137 |
+
return length_list
|
1138 |
+
|
1139 |
+
@property
|
1140 |
+
def modality_lengths(self):
|
1141 |
+
length_list = []
|
1142 |
+
for sample in self.list_data_dict:
|
1143 |
+
try:
|
1144 |
+
cur_len = sum(len(conv['value'].split()) for conv in sample['conversations'])
|
1145 |
+
except:
|
1146 |
+
cur_len = 1
|
1147 |
+
cur_len = cur_len if ('image' in sample) or ('video' in sample) or ('video_long' in sample) else -cur_len
|
1148 |
+
length_list.append(cur_len)
|
1149 |
+
return length_list
|
1150 |
+
|
1151 |
+
def process_image(self, image_file):
|
1152 |
+
if type(image_file) is str:
|
1153 |
+
image = Image.open(image_file).convert('RGB')
|
1154 |
+
elif type(image_file) is dict:
|
1155 |
+
image = read_image_patch(image_file)
|
1156 |
+
else:
|
1157 |
+
raise ValueError(f"Unknown image file type: {type(image_file)}, {image_file}")
|
1158 |
+
image_size = image.size
|
1159 |
+
image, image_padded = process_anyres_highres_image_genli(image, self.data_args.image_processor)
|
1160 |
+
|
1161 |
+
return (image, image_padded), image_size, "image"
|
1162 |
+
|
1163 |
+
def process_video(self, video_file):
|
1164 |
+
video = read_video_patch(video_file)
|
1165 |
+
video_processed = []
|
1166 |
+
|
1167 |
+
cur_frames_upbound = self.data_args.frames_upbound
|
1168 |
+
|
1169 |
+
if cur_frames_upbound > 0:
|
1170 |
+
if len(video) > cur_frames_upbound:
|
1171 |
+
uniform_sampled_frames = np.linspace(0, len(video) - 1, cur_frames_upbound, dtype=int)
|
1172 |
+
frame_idx = uniform_sampled_frames.tolist()
|
1173 |
+
else:
|
1174 |
+
frame_idx = None
|
1175 |
+
|
1176 |
+
for idx, frame in enumerate(video):
|
1177 |
+
frame = process_anyres_video_genli(frame, self.data_args.image_processor)
|
1178 |
+
if frame_idx is not None and idx in frame_idx:
|
1179 |
+
video_processed.append(frame.unsqueeze(0))
|
1180 |
+
elif frame_idx is None:
|
1181 |
+
video_processed.append(frame.unsqueeze(0))
|
1182 |
+
|
1183 |
+
if frame_idx is None:
|
1184 |
+
frame_idx = np.arange(0, len(video_processed), dtype=int).tolist()
|
1185 |
+
|
1186 |
+
video_processed = torch.cat(video_processed, dim=0)
|
1187 |
+
|
1188 |
+
video_processed = (video_processed, video_processed)
|
1189 |
+
return (video_processed, (384, 384), "video"), frame_idx
|
1190 |
+
|
1191 |
+
def process_video_pretrain(self, video_file, target_idx):
|
1192 |
+
video = read_video_patch(video_file)
|
1193 |
+
|
1194 |
+
cur_frames_upbound = random.randint(self.data_args.frames_upbound * 3, self.data_args.frames_upbound * 4)
|
1195 |
+
video_processed = []
|
1196 |
+
if cur_frames_upbound > 0:
|
1197 |
+
if len(video) > cur_frames_upbound:
|
1198 |
+
uniform_sampled_frames = np.linspace(0, len(video) - 1, cur_frames_upbound, dtype=int)
|
1199 |
+
frame_idx = uniform_sampled_frames.tolist()
|
1200 |
+
|
1201 |
+
# process longer case
|
1202 |
+
target_idx_new = []
|
1203 |
+
target_frame = []
|
1204 |
+
if len(target_idx) == 1:
|
1205 |
+
target_idx_new.append(np.random.randint(0, len(uniform_sampled_frames)))
|
1206 |
+
target_frame.append(video[target_idx[0]])
|
1207 |
+
elif len(target_idx) == 2:
|
1208 |
+
num1 = np.random.randint(0, len(uniform_sampled_frames) // 2)
|
1209 |
+
num2 = np.random.randint(num1 + 1, len(uniform_sampled_frames))
|
1210 |
+
target_idx_new.append(num1)
|
1211 |
+
target_idx_new.append(num2)
|
1212 |
+
target_frame.append(video[target_idx[0]])
|
1213 |
+
target_frame.append(video[target_idx[1]])
|
1214 |
+
|
1215 |
+
else:
|
1216 |
+
frame_idx = None
|
1217 |
+
target_idx_new = target_idx
|
1218 |
+
target_frame = None
|
1219 |
+
|
1220 |
+
for idx, frame in enumerate(video):
|
1221 |
+
frame = process_anyres_video_genli_long(frame, self.data_args.image_processor)
|
1222 |
+
|
1223 |
+
if frame_idx is not None and idx in frame_idx:
|
1224 |
+
video_processed.append(frame.unsqueeze(0))
|
1225 |
+
elif frame_idx is None:
|
1226 |
+
video_processed.append(frame.unsqueeze(0))
|
1227 |
+
|
1228 |
+
# process longer case
|
1229 |
+
if target_frame is not None:
|
1230 |
+
for idx in target_idx_new:
|
1231 |
+
frame = target_frame.pop(0)
|
1232 |
+
frame = process_anyres_video_genli_long(frame, self.data_args.image_processor)
|
1233 |
+
video_processed[idx] = frame.unsqueeze(0)
|
1234 |
+
|
1235 |
+
if frame_idx is None:
|
1236 |
+
frame_idx = np.arange(0, len(video_processed), dtype=int).tolist()
|
1237 |
+
|
1238 |
+
video_processed = torch.cat(video_processed, dim=0)
|
1239 |
+
|
1240 |
+
video_processed = (video_processed, video_processed)
|
1241 |
+
|
1242 |
+
return (video_processed, (384, 384), "video_long"), target_idx_new
|
1243 |
+
|
1244 |
+
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
1245 |
+
# TODO: define number of retries somewhere else
|
1246 |
+
num_base_retries = 3
|
1247 |
+
num_final_retries = 300
|
1248 |
+
# try the current sample first
|
1249 |
+
for attempt_idx in range(num_base_retries):
|
1250 |
+
try:
|
1251 |
+
sample = self._get_item(i)
|
1252 |
+
return sample
|
1253 |
+
except Exception as e:
|
1254 |
+
# sleep 1s in case it is a cloud disk issue
|
1255 |
+
print(f'[try #{attempt_idx}] Failed to fetch sample {i}. Exception:', e)
|
1256 |
+
time.sleep(1)
|
1257 |
+
|
1258 |
+
# try other samples, in case it is file corruption issue
|
1259 |
+
for attempt_idx in range(num_base_retries):
|
1260 |
+
try:
|
1261 |
+
sample_idx = random.choice(range(len(self)))
|
1262 |
+
sample = self._get_item(sample_idx)
|
1263 |
+
return sample
|
1264 |
+
except Exception as e:
|
1265 |
+
# no need to sleep
|
1266 |
+
print(f'[try other #{attempt_idx}] Failed to fetch sample {sample_idx}. Exception:', e)
|
1267 |
+
pass
|
1268 |
+
|
1269 |
+
# still fail, most likely to be path issue or cloud disk issue, retry the same sample for longer
|
1270 |
+
for attempt_idx in range(num_final_retries):
|
1271 |
+
try:
|
1272 |
+
sample = self._get_item(i)
|
1273 |
+
return sample
|
1274 |
+
except Exception as e:
|
1275 |
+
# sleep 1s in case it is a cloud disk issue
|
1276 |
+
print(f'[final try #{attempt_idx}] Failed to fetch sample {i}. Exception:', e)
|
1277 |
+
time.sleep(1)
|
1278 |
+
|
1279 |
+
# Finally raise exception on failing.
|
1280 |
+
assert False, "Failed to fetch sample."
|
1281 |
+
|
1282 |
+
def _get_item(self, i) -> Dict[str, torch.Tensor]:
|
1283 |
+
sources = self.list_data_dict[i]
|
1284 |
+
if isinstance(i, int):
|
1285 |
+
sources = [sources]
|
1286 |
+
assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
|
1287 |
+
|
1288 |
+
if 'image' in sources[0]:
|
1289 |
+
image_file = self.list_data_dict[i]['image']
|
1290 |
+
if type(image_file) is list:
|
1291 |
+
image = [self.process_image(f) for f in image_file]
|
1292 |
+
else:
|
1293 |
+
image = [self.process_image(image_file)]
|
1294 |
+
num_frames = 0
|
1295 |
+
sources = preprocess_multimodal(
|
1296 |
+
copy.deepcopy([e["conversations"] for e in sources]),
|
1297 |
+
self.data_args
|
1298 |
+
)
|
1299 |
+
elif 'video' in sources[0]:
|
1300 |
+
video_file = self.list_data_dict[i]['video']
|
1301 |
+
video, _ = self.process_video(video_file)
|
1302 |
+
video = [video]
|
1303 |
+
num_frames = len(video[0][0])
|
1304 |
+
sources = preprocess_multimodal(
|
1305 |
+
copy.deepcopy([e["conversations"] for e in sources]),
|
1306 |
+
self.data_args)
|
1307 |
+
|
1308 |
+
elif 'video_long' in sources[0]:
|
1309 |
+
video_file = self.mapping_dict[self.list_data_dict[i]['video_long']]['video']
|
1310 |
+
video, target_idx = self.process_video_pretrain(video_file, self.list_data_dict[i]['idx'])
|
1311 |
+
video = [video]
|
1312 |
+
num_frames = len(video[0][0][0])
|
1313 |
+
question = sources[0]['question']
|
1314 |
+
answer = sources[0]['answer']
|
1315 |
+
if sources[0]['type'] == 'diff':
|
1316 |
+
question = question.replace('<idx1>', str(target_idx[0]))
|
1317 |
+
question = question.replace('<idx2>', str(target_idx[1]))
|
1318 |
+
elif sources[0]['type'] == 'caption':
|
1319 |
+
question = question.replace('<idx>', str(target_idx[0]))
|
1320 |
+
else:
|
1321 |
+
raise NotImplementedError
|
1322 |
+
|
1323 |
+
sources[0]['conversations'] = [{'from': 'human', 'value': f'<image>\nThis is a extremely long video with a total of {num_frames} frames sampled from the video. Please carefully read every given frame in this video, identifying the detailed contents in every frame. '+ question},
|
1324 |
+
{'from': 'gpt', 'value': answer}]
|
1325 |
+
sources = preprocess_multimodal(
|
1326 |
+
copy.deepcopy([e["conversations"] for e in sources]),
|
1327 |
+
self.data_args)
|
1328 |
+
else:
|
1329 |
+
sources = copy.deepcopy([e["conversations"] for e in sources])
|
1330 |
+
|
1331 |
+
has_image = ('image' in self.list_data_dict[i]) or ('video' in self.list_data_dict[i]) or ('video_long' in self.list_data_dict[i])
|
1332 |
+
data_dict = preprocess(
|
1333 |
+
sources,
|
1334 |
+
self.tokenizer,
|
1335 |
+
has_image=has_image)
|
1336 |
+
|
1337 |
+
if isinstance(i, int):
|
1338 |
+
data_dict = dict(input_ids=data_dict["input_ids"][0],
|
1339 |
+
labels=data_dict["labels"][0])
|
1340 |
+
|
1341 |
+
# image exist in the data
|
1342 |
+
if 'image' in self.list_data_dict[i]:
|
1343 |
+
data_dict['image'] = image
|
1344 |
+
elif 'video' in self.list_data_dict[i]:
|
1345 |
+
data_dict['image'] = video
|
1346 |
+
elif 'video_long' in self.list_data_dict[i]:
|
1347 |
+
data_dict['image'] = video
|
1348 |
+
elif self.data_args.is_multimodal:
|
1349 |
+
# image does not exist in the data, but the model is multimodal
|
1350 |
+
crop_size = self.data_args.image_processor.crop_size
|
1351 |
+
data_dict['image'] = [
|
1352 |
+
(
|
1353 |
+
(torch.zeros(1, 3, crop_size['height'], crop_size['width']), torch.zeros(1, 3, crop_size['height'], crop_size['width'])),
|
1354 |
+
(crop_size['width'], crop_size['height']),
|
1355 |
+
"text"
|
1356 |
+
),
|
1357 |
+
]
|
1358 |
+
return data_dict
|
1359 |
+
|
1360 |
+
|
1361 |
+
@dataclass
|
1362 |
+
class DataCollatorForSupervisedDataset(object):
|
1363 |
+
"""Collate examples for supervised fine-tuning."""
|
1364 |
+
|
1365 |
+
tokenizer: transformers.PreTrainedTokenizer
|
1366 |
+
|
1367 |
+
def pad_sequence(self, input_ids, batch_first, padding_value):
|
1368 |
+
if self.tokenizer.padding_side == "left":
|
1369 |
+
input_ids = [torch.flip(_input_ids, [0]) for _input_ids in input_ids]
|
1370 |
+
input_ids = torch.nn.utils.rnn.pad_sequence(
|
1371 |
+
input_ids,
|
1372 |
+
batch_first=batch_first,
|
1373 |
+
padding_value=padding_value)
|
1374 |
+
if self.tokenizer.padding_side == "left":
|
1375 |
+
input_ids = torch.flip(input_ids, [1])
|
1376 |
+
return input_ids
|
1377 |
+
|
1378 |
+
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
1379 |
+
# input_ids, labels = tuple([instance[key] for instance in instances]
|
1380 |
+
# for key in ("input_ids", "labels"))
|
1381 |
+
input_ids, labels = tuple([instance[key] for instance in instances]
|
1382 |
+
for key in ("input_ids", "labels"))
|
1383 |
+
input_ids = [_input_ids[:self.tokenizer.model_max_length] for _input_ids in input_ids]
|
1384 |
+
labels = [_labels[:self.tokenizer.model_max_length] for _labels in labels]
|
1385 |
+
if self.tokenizer.pad_token_id is None:
|
1386 |
+
if "qwen" in self.tokenizer.name_or_path.lower():
|
1387 |
+
print("Setting pad token to bos token for qwen model.")
|
1388 |
+
self.tokenizer.pad_token_id = 151643
|
1389 |
+
else:
|
1390 |
+
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id # FIXME: this could only be triggered for llama3 model.
|
1391 |
+
input_ids = self.pad_sequence(
|
1392 |
+
input_ids,
|
1393 |
+
batch_first=True,
|
1394 |
+
padding_value=self.tokenizer.pad_token_id)
|
1395 |
+
labels = self.pad_sequence(labels,
|
1396 |
+
batch_first=True,
|
1397 |
+
padding_value=IGNORE_INDEX)
|
1398 |
+
|
1399 |
+
batch = dict(
|
1400 |
+
input_ids=input_ids,
|
1401 |
+
labels=labels,
|
1402 |
+
attention_mask=input_ids.ne(self.tokenizer.pad_token_id)
|
1403 |
+
)
|
1404 |
+
|
1405 |
+
if 'image' in instances[0]:
|
1406 |
+
images = [instance['image'] for instance in instances]
|
1407 |
+
batch['image_sizes'] = [im[1] for im_list in images for im in im_list]
|
1408 |
+
batch['modalities'] = [im[2] for im_list in images for im in im_list]
|
1409 |
+
images_lowres = [im[0][0] for im_list in images for im in im_list]
|
1410 |
+
images_highres = [im[0][1] for im_list in images for im in im_list]
|
1411 |
+
batch['images_highres'] = images_highres
|
1412 |
+
if all(x is not None and x.shape == images_lowres[0].shape for x in images_lowres):
|
1413 |
+
batch['images'] = torch.stack(images_lowres)
|
1414 |
+
else:
|
1415 |
+
batch['images'] = images_lowres
|
1416 |
+
return batch
|
1417 |
+
|
1418 |
+
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
|
1419 |
+
data_args) -> Dict:
|
1420 |
+
"""Make dataset and collator for supervised fine-tuning."""
|
1421 |
+
train_dataset = LazySupervisedDataset(tokenizer=tokenizer,
|
1422 |
+
data_path=data_args.data_path,
|
1423 |
+
data_args=data_args)
|
1424 |
+
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
|
1425 |
+
return dict(train_dataset=train_dataset,
|
1426 |
+
eval_dataset=None,
|
1427 |
+
data_collator=data_collator)
|
1428 |
+
|
1429 |
+
|
1430 |
+
def train():
|
1431 |
+
global local_rank
|
1432 |
+
|
1433 |
+
parser = transformers.HfArgumentParser(
|
1434 |
+
(ModelArguments, DataArguments, TrainingArguments))
|
1435 |
+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
1436 |
+
local_rank = training_args.local_rank
|
1437 |
+
compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
|
1438 |
+
|
1439 |
+
bnb_model_from_pretrained_args = {}
|
1440 |
+
if training_args.bits in [4, 8]:
|
1441 |
+
from transformers import BitsAndBytesConfig
|
1442 |
+
bnb_model_from_pretrained_args.update(dict(
|
1443 |
+
device_map={"": training_args.device},
|
1444 |
+
load_in_4bit=training_args.bits == 4,
|
1445 |
+
load_in_8bit=training_args.bits == 8,
|
1446 |
+
quantization_config=BitsAndBytesConfig(
|
1447 |
+
load_in_4bit=training_args.bits == 4,
|
1448 |
+
load_in_8bit=training_args.bits == 8,
|
1449 |
+
llm_int8_threshold=6.0,
|
1450 |
+
llm_int8_has_fp16_weight=False,
|
1451 |
+
bnb_4bit_compute_dtype=compute_dtype,
|
1452 |
+
bnb_4bit_use_double_quant=training_args.double_quant,
|
1453 |
+
bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'}
|
1454 |
+
)
|
1455 |
+
))
|
1456 |
+
|
1457 |
+
if model_args.vision_tower is not None:
|
1458 |
+
print(model_args.vision_tower)
|
1459 |
+
if 'qwen' in model_args.model_name_or_path.lower():
|
1460 |
+
|
1461 |
+
if not model_args.pretrain_mm_mlp_adapter:
|
1462 |
+
cfg_pretrained = AutoConfig.from_pretrained(model_args.model_name_or_path)
|
1463 |
+
overwrite_config = {}
|
1464 |
+
overwrite_config["mm_resampler_type"] = model_args.mm_resampler_type
|
1465 |
+
|
1466 |
+
print(f"Overwriting config with {overwrite_config}")
|
1467 |
+
for k, v in overwrite_config.items():
|
1468 |
+
setattr(cfg_pretrained, k, v)
|
1469 |
+
|
1470 |
+
model = OryxQwenForCausalLM.from_pretrained(
|
1471 |
+
model_args.model_name_or_path,
|
1472 |
+
config=cfg_pretrained,
|
1473 |
+
cache_dir=training_args.cache_dir,
|
1474 |
+
attn_implementation="flash_attention_2",
|
1475 |
+
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
|
1476 |
+
**bnb_model_from_pretrained_args
|
1477 |
+
)
|
1478 |
+
else:
|
1479 |
+
model = OryxQwenForCausalLM.from_pretrained(
|
1480 |
+
model_args.model_name_or_path,
|
1481 |
+
cache_dir=training_args.cache_dir,
|
1482 |
+
attn_implementation="flash_attention_2",
|
1483 |
+
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
|
1484 |
+
**bnb_model_from_pretrained_args
|
1485 |
+
)
|
1486 |
+
|
1487 |
+
else:
|
1488 |
+
# finetune from a image trained model
|
1489 |
+
# if not model_args.pretrain_mm_mlp_adapter:
|
1490 |
+
cfg_pretrained = AutoConfig.from_pretrained(model_args.model_name_or_path)
|
1491 |
+
overwrite_config = {}
|
1492 |
+
overwrite_config["mm_resampler_type"] = model_args.mm_resampler_type
|
1493 |
+
|
1494 |
+
print(f"Overwriting config with {overwrite_config}")
|
1495 |
+
for k, v in overwrite_config.items():
|
1496 |
+
setattr(cfg_pretrained, k, v)
|
1497 |
+
|
1498 |
+
model = OryxLlamaForCausalLM.from_pretrained(
|
1499 |
+
model_args.model_name_or_path,
|
1500 |
+
config=cfg_pretrained,
|
1501 |
+
cache_dir=training_args.cache_dir,
|
1502 |
+
attn_implementation="flash_attention_2",
|
1503 |
+
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
|
1504 |
+
**bnb_model_from_pretrained_args
|
1505 |
+
)
|
1506 |
+
|
1507 |
+
else:
|
1508 |
+
model = transformers.LlamaForCausalLM.from_pretrained(
|
1509 |
+
model_args.model_name_or_path,
|
1510 |
+
cache_dir=training_args.cache_dir,
|
1511 |
+
attn_implementation="flash_attention_2",
|
1512 |
+
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
|
1513 |
+
**bnb_model_from_pretrained_args
|
1514 |
+
)
|
1515 |
+
model.config.use_cache = False
|
1516 |
+
|
1517 |
+
if model_args.freeze_backbone:
|
1518 |
+
model.model.requires_grad_(False)
|
1519 |
+
|
1520 |
+
if training_args.bits in [4, 8]:
|
1521 |
+
from peft import prepare_model_for_kbit_training
|
1522 |
+
model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
|
1523 |
+
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)
|
1524 |
+
|
1525 |
+
if training_args.gradient_checkpointing:
|
1526 |
+
if hasattr(model, "enable_input_require_grads"):
|
1527 |
+
model.enable_input_require_grads()
|
1528 |
+
else:
|
1529 |
+
def make_inputs_require_grad(module, input, output):
|
1530 |
+
output.requires_grad_(True)
|
1531 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
1532 |
+
|
1533 |
+
if training_args.lora_enable:
|
1534 |
+
from peft import LoraConfig, get_peft_model
|
1535 |
+
lora_config = LoraConfig(
|
1536 |
+
r=training_args.lora_r,
|
1537 |
+
lora_alpha=training_args.lora_alpha,
|
1538 |
+
target_modules=find_all_linear_names(model),
|
1539 |
+
lora_dropout=training_args.lora_dropout,
|
1540 |
+
bias=training_args.lora_bias,
|
1541 |
+
task_type="CAUSAL_LM",
|
1542 |
+
)
|
1543 |
+
if training_args.bits == 16:
|
1544 |
+
if training_args.bf16:
|
1545 |
+
model.to(torch.bfloat16)
|
1546 |
+
if training_args.fp16:
|
1547 |
+
model.to(torch.float16)
|
1548 |
+
rank0_print("Adding LoRA adapters...")
|
1549 |
+
model = get_peft_model(model, lora_config)
|
1550 |
+
|
1551 |
+
if "qwen" in model_args.model_name_or_path.lower():
|
1552 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
1553 |
+
model_args.model_name_or_path,
|
1554 |
+
cache_dir=training_args.cache_dir,
|
1555 |
+
model_max_length=training_args.model_max_length,
|
1556 |
+
padding_side="right")
|
1557 |
+
else:
|
1558 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
1559 |
+
model_args.model_name_or_path,
|
1560 |
+
cache_dir=training_args.cache_dir,
|
1561 |
+
model_max_length=training_args.model_max_length,
|
1562 |
+
padding_side="right",
|
1563 |
+
use_fast=False,
|
1564 |
+
)
|
1565 |
+
|
1566 |
+
if model_args.version == "v0":
|
1567 |
+
if tokenizer.pad_token is None:
|
1568 |
+
smart_tokenizer_and_embedding_resize(
|
1569 |
+
special_tokens_dict=dict(pad_token="[PAD]"),
|
1570 |
+
tokenizer=tokenizer,
|
1571 |
+
model=model,
|
1572 |
+
)
|
1573 |
+
elif model_args.version == "v0.5":
|
1574 |
+
tokenizer.pad_token = tokenizer.unk_token
|
1575 |
+
elif model_args.version == "llava_llama_3":
|
1576 |
+
tokenizer.pad_token = "<|reserved_special_token_0|>" # only for llama3
|
1577 |
+
conversation_lib.default_conversation = conversation_lib.conv_templates["llava_llama_3"]
|
1578 |
+
else:
|
1579 |
+
if 'llama-3' in model_args.model_name_or_path.lower():
|
1580 |
+
tokenizer.pad_token = "<|reserved_special_token_0|>"
|
1581 |
+
else:
|
1582 |
+
tokenizer.pad_token = tokenizer.unk_token
|
1583 |
+
if model_args.version in conversation_lib.conv_templates:
|
1584 |
+
conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version]
|
1585 |
+
else:
|
1586 |
+
conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1"]
|
1587 |
+
|
1588 |
+
if model_args.vision_tower is not None:
|
1589 |
+
model.get_model().initialize_vision_modules(
|
1590 |
+
model_args=model_args,
|
1591 |
+
fsdp=training_args.fsdp
|
1592 |
+
)
|
1593 |
+
|
1594 |
+
vision_tower = model.get_vision_tower()
|
1595 |
+
vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device)
|
1596 |
+
|
1597 |
+
vision_tower.image_processor.do_resize = training_args.do_resize
|
1598 |
+
vision_tower.image_processor.do_center_crop = training_args.do_center_crop
|
1599 |
+
|
1600 |
+
data_args.image_processor = vision_tower.image_processor
|
1601 |
+
data_args.is_multimodal = True
|
1602 |
+
|
1603 |
+
model.config.tokenizer_padding_side = tokenizer.padding_side
|
1604 |
+
model.config.tokenizer_model_max_length = tokenizer.model_max_length
|
1605 |
+
|
1606 |
+
model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter
|
1607 |
+
model.config.tune_mm_vision_resampler = training_args.tune_mm_vision_resampler = model_args.tune_mm_vision_resampler
|
1608 |
+
if model_args.tune_mm_mlp_adapter or model_args.tune_mm_vision_resampler:
|
1609 |
+
model.requires_grad_(False)
|
1610 |
+
if model_args.tune_mm_mlp_adapter:
|
1611 |
+
for p in model.get_model().mm_projector.parameters():
|
1612 |
+
p.requires_grad = True
|
1613 |
+
if model_args.tune_mm_vision_resampler:
|
1614 |
+
for p in model.get_model().vision_resampler.parameters():
|
1615 |
+
p.requires_grad = True
|
1616 |
+
|
1617 |
+
model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter
|
1618 |
+
if training_args.freeze_mm_mlp_adapter:
|
1619 |
+
for p in model.get_model().mm_projector.parameters():
|
1620 |
+
p.requires_grad = False
|
1621 |
+
|
1622 |
+
model.config.freeze_mm_vision_resampler = training_args.freeze_mm_vision_resampler
|
1623 |
+
if training_args.freeze_mm_vision_resampler:
|
1624 |
+
for p in model.get_model().vision_resampler.parameters():
|
1625 |
+
p.requires_grad = False
|
1626 |
+
|
1627 |
+
model.config.unfreeze_mm_vision_tower = model_args.unfreeze_mm_vision_tower
|
1628 |
+
if model_args.unfreeze_mm_vision_tower:
|
1629 |
+
vision_tower.requires_grad_(True)
|
1630 |
+
|
1631 |
+
if training_args.bits in [4, 8]:
|
1632 |
+
model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device)
|
1633 |
+
|
1634 |
+
model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end
|
1635 |
+
model.config.mm_projector_lr = training_args.mm_projector_lr
|
1636 |
+
model.config.mm_vision_tower_lr = training_args.mm_vision_tower_lr
|
1637 |
+
training_args.use_im_start_end = model_args.mm_use_im_start_end
|
1638 |
+
model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token
|
1639 |
+
model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer)
|
1640 |
+
|
1641 |
+
if training_args.bits in [4, 8]:
|
1642 |
+
from peft.tuners.lora import LoraLayer
|
1643 |
+
for name, module in model.named_modules():
|
1644 |
+
if isinstance(module, LoraLayer):
|
1645 |
+
if training_args.bf16:
|
1646 |
+
module = module.to(torch.bfloat16)
|
1647 |
+
if 'norm' in name:
|
1648 |
+
module = module.to(torch.float32)
|
1649 |
+
if 'lm_head' in name or 'embed_tokens' in name:
|
1650 |
+
if hasattr(module, 'weight'):
|
1651 |
+
if training_args.bf16 and module.weight.dtype == torch.float32:
|
1652 |
+
module = module.to(torch.bfloat16)
|
1653 |
+
|
1654 |
+
data_module = make_supervised_data_module(tokenizer=tokenizer,
|
1655 |
+
data_args=data_args)
|
1656 |
+
trainer = OryxTrainer(model=model,
|
1657 |
+
tokenizer=tokenizer,
|
1658 |
+
args=training_args,
|
1659 |
+
**data_module)
|
1660 |
+
|
1661 |
+
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
|
1662 |
+
trainer.train(resume_from_checkpoint=True)
|
1663 |
+
else:
|
1664 |
+
trainer.train()
|
1665 |
+
trainer.save_state()
|
1666 |
+
|
1667 |
+
model.config.use_cache = True
|
1668 |
+
|
1669 |
+
if training_args.lora_enable:
|
1670 |
+
state_dict = get_peft_state_maybe_zero_3(
|
1671 |
+
model.named_parameters(), training_args.lora_bias
|
1672 |
+
)
|
1673 |
+
non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(
|
1674 |
+
model.named_parameters()
|
1675 |
+
)
|
1676 |
+
if training_args.local_rank == 0 or training_args.local_rank == -1:
|
1677 |
+
model.config.save_pretrained(training_args.output_dir)
|
1678 |
+
model.save_pretrained(training_args.output_dir, state_dict=state_dict)
|
1679 |
+
torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin'))
|
1680 |
+
else:
|
1681 |
+
safe_save_model_for_hf_trainer(trainer=trainer,
|
1682 |
+
output_dir=training_args.output_dir)
|
1683 |
+
|
1684 |
+
|
1685 |
+
if __name__ == "__main__":
|
1686 |
+
train()
|
oryx/train/train_mem.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from oryx.train.train import train
|
3 |
+
|
4 |
+
if __name__ == "__main__":
|
5 |
+
train()
|
oryx/utils.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import logging
|
3 |
+
import logging.handlers
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
|
7 |
+
import requests
|
8 |
+
|
9 |
+
from oryx.constants import LOGDIR
|
10 |
+
|
11 |
+
server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
|
12 |
+
moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
|
13 |
+
|
14 |
+
handler = None
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
def rank0_print(*args):
|
19 |
+
if dist.is_initialized():
|
20 |
+
if dist.get_rank() == 0:
|
21 |
+
print(f"Rank {dist.get_rank()}: ", *args)
|
22 |
+
else:
|
23 |
+
print(*args)
|
24 |
+
|
25 |
+
def build_logger(logger_name, logger_filename):
|
26 |
+
global handler
|
27 |
+
|
28 |
+
formatter = logging.Formatter(
|
29 |
+
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
30 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
31 |
+
)
|
32 |
+
|
33 |
+
# Set the format of root handlers
|
34 |
+
if not logging.getLogger().handlers:
|
35 |
+
logging.basicConfig(level=logging.INFO)
|
36 |
+
logging.getLogger().handlers[0].setFormatter(formatter)
|
37 |
+
|
38 |
+
# Redirect stdout and stderr to loggers
|
39 |
+
stdout_logger = logging.getLogger("stdout")
|
40 |
+
stdout_logger.setLevel(logging.INFO)
|
41 |
+
sl = StreamToLogger(stdout_logger, logging.INFO)
|
42 |
+
sys.stdout = sl
|
43 |
+
|
44 |
+
stderr_logger = logging.getLogger("stderr")
|
45 |
+
stderr_logger.setLevel(logging.ERROR)
|
46 |
+
sl = StreamToLogger(stderr_logger, logging.ERROR)
|
47 |
+
sys.stderr = sl
|
48 |
+
|
49 |
+
# Get logger
|
50 |
+
logger = logging.getLogger(logger_name)
|
51 |
+
logger.setLevel(logging.INFO)
|
52 |
+
|
53 |
+
# Add a file handler for all loggers
|
54 |
+
if handler is None:
|
55 |
+
os.makedirs(LOGDIR, exist_ok=True)
|
56 |
+
filename = os.path.join(LOGDIR, logger_filename)
|
57 |
+
handler = logging.handlers.TimedRotatingFileHandler(
|
58 |
+
filename, when='D', utc=True)
|
59 |
+
handler.setFormatter(formatter)
|
60 |
+
|
61 |
+
for name, item in logging.root.manager.loggerDict.items():
|
62 |
+
if isinstance(item, logging.Logger):
|
63 |
+
item.addHandler(handler)
|
64 |
+
|
65 |
+
return logger
|
66 |
+
|
67 |
+
|
68 |
+
class StreamToLogger(object):
|
69 |
+
"""
|
70 |
+
Fake file-like stream object that redirects writes to a logger instance.
|
71 |
+
"""
|
72 |
+
def __init__(self, logger, log_level=logging.INFO):
|
73 |
+
self.terminal = sys.stdout
|
74 |
+
self.logger = logger
|
75 |
+
self.log_level = log_level
|
76 |
+
self.linebuf = ''
|
77 |
+
|
78 |
+
def __getattr__(self, attr):
|
79 |
+
return getattr(self.terminal, attr)
|
80 |
+
|
81 |
+
def write(self, buf):
|
82 |
+
temp_linebuf = self.linebuf + buf
|
83 |
+
self.linebuf = ''
|
84 |
+
for line in temp_linebuf.splitlines(True):
|
85 |
+
# From the io.TextIOWrapper docs:
|
86 |
+
# On output, if newline is None, any '\n' characters written
|
87 |
+
# are translated to the system default line separator.
|
88 |
+
# By default sys.stdout.write() expects '\n' newlines and then
|
89 |
+
# translates them so this is still cross platform.
|
90 |
+
if line[-1] == '\n':
|
91 |
+
self.logger.log(self.log_level, line.rstrip())
|
92 |
+
else:
|
93 |
+
self.linebuf += line
|
94 |
+
|
95 |
+
def flush(self):
|
96 |
+
if self.linebuf != '':
|
97 |
+
self.logger.log(self.log_level, self.linebuf.rstrip())
|
98 |
+
self.linebuf = ''
|
99 |
+
|
100 |
+
|
101 |
+
def disable_torch_init():
|
102 |
+
"""
|
103 |
+
Disable the redundant torch default initialization to accelerate model creation.
|
104 |
+
"""
|
105 |
+
import torch
|
106 |
+
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
107 |
+
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
108 |
+
|
109 |
+
|
110 |
+
def violates_moderation(text):
|
111 |
+
"""
|
112 |
+
Check whether the text violates OpenAI moderation API.
|
113 |
+
"""
|
114 |
+
url = "https://api.openai.com/v1/moderations"
|
115 |
+
headers = {"Content-Type": "application/json",
|
116 |
+
"Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
|
117 |
+
text = text.replace("\n", "")
|
118 |
+
data = "{" + '"input": ' + f'"{text}"' + "}"
|
119 |
+
data = data.encode("utf-8")
|
120 |
+
try:
|
121 |
+
ret = requests.post(url, headers=headers, data=data, timeout=5)
|
122 |
+
flagged = ret.json()["results"][0]["flagged"]
|
123 |
+
except requests.exceptions.RequestException as e:
|
124 |
+
flagged = False
|
125 |
+
except KeyError as e:
|
126 |
+
flagged = False
|
127 |
+
|
128 |
+
return flagged
|
129 |
+
|
130 |
+
|
131 |
+
def pretty_print_semaphore(semaphore):
|
132 |
+
if semaphore is None:
|
133 |
+
return "None"
|
134 |
+
return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
|