boxinw@nvidia.com
commited on
Commit
·
b925209
1
Parent(s):
07a7f16
Add benchmark evaluation scripts
Browse files- README.md +16 -0
- eval/conversation.py +492 -0
- eval/eval_dataset.py +850 -0
- eval/full_eval.yaml +188 -0
- eval/mmmu_utils.py +663 -0
- eval/requirements.txt +3 -0
- eval/vqa_utils.py +317 -0
- run_eval.py +702 -0
README.md
CHANGED
@@ -318,6 +318,22 @@ response = model.chat(tokenizer, pixel_values, question, generation_config)
|
|
318 |
print(f'User: {question}\nAssistant: {response}')
|
319 |
```
|
320 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
321 |
## Software Integration
|
322 |
**Runtime Engine(s)**
|
323 |
* PyTorch <br>
|
|
|
318 |
print(f'User: {question}\nAssistant: {response}')
|
319 |
```
|
320 |
|
321 |
+
### Benchmark Evaluation
|
322 |
+
|
323 |
+
To test our NVLM-1.0 model on the benchmark datasets, you can use the following code:
|
324 |
+
|
325 |
+
```bash
|
326 |
+
python run_eval.py --config-path eval/full_eval.yaml \
|
327 |
+
--result-save-path path/to/eval_results/ \
|
328 |
+
--zero-shot-eval-tasks chartqa coco_caption flickr30k_caption vqav2 mmmu textvqa mathvista mmbench chartqa docvqa realworldqa ocrbench ai2diagram ai2diagram_nomask mmmu_pro docvqa_test
|
329 |
+
```
|
330 |
+
|
331 |
+
Specifically,
|
332 |
+
- `--config-path eval/full_eval.yaml` file contains the evaluation configurations, including the evaluation prompt, the evaluation dataset paths, and generation hyper-parameters.
|
333 |
+
- `--result-save-path path/to/eval_results/` specifies the path to save the evaluation results.
|
334 |
+
- `--zero-shot-eval-tasks` specifies the tasks to evaluate on.
|
335 |
+
|
336 |
+
|
337 |
## Software Integration
|
338 |
**Runtime Engine(s)**
|
339 |
* PyTorch <br>
|
eval/conversation.py
ADDED
@@ -0,0 +1,492 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# From https://github.com/haotian-liu/LLaVA/blob/main/llava/conversation.py
|
2 |
+
|
3 |
+
import dataclasses
|
4 |
+
from enum import auto, Enum
|
5 |
+
from typing import List, Tuple
|
6 |
+
|
7 |
+
|
8 |
+
class SeparatorStyle(Enum):
|
9 |
+
"""Different separator style."""
|
10 |
+
SINGLE = auto()
|
11 |
+
TWO = auto()
|
12 |
+
MPT = auto()
|
13 |
+
PLAIN = auto()
|
14 |
+
LLAMA_2 = auto()
|
15 |
+
|
16 |
+
|
17 |
+
@dataclasses.dataclass
|
18 |
+
class Conversation:
|
19 |
+
"""A class that keeps all conversation history."""
|
20 |
+
system: str
|
21 |
+
roles: List[str]
|
22 |
+
messages: List[List[str]]
|
23 |
+
offset: int
|
24 |
+
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
25 |
+
sep: str = "###"
|
26 |
+
sep2: str = None
|
27 |
+
real_sep2: str = None
|
28 |
+
version: str = "Unknown"
|
29 |
+
|
30 |
+
skip_next: bool = False
|
31 |
+
|
32 |
+
def get_prompt(self):
|
33 |
+
messages = self.messages
|
34 |
+
if len(messages) > 0 and type(messages[0][1]) is tuple:
|
35 |
+
messages = self.messages.copy()
|
36 |
+
init_role, init_msg = messages[0].copy()
|
37 |
+
init_msg = init_msg[0].replace("<image>", "").strip()
|
38 |
+
if 'mmtag' in self.version:
|
39 |
+
messages[0] = (init_role, init_msg)
|
40 |
+
messages.insert(0, (self.roles[0], "<Image><image></Image>"))
|
41 |
+
messages.insert(1, (self.roles[1], "Received."))
|
42 |
+
else:
|
43 |
+
messages[0] = (init_role, "<image>\n" + init_msg)
|
44 |
+
|
45 |
+
if self.sep_style == SeparatorStyle.SINGLE:
|
46 |
+
ret = self.system + self.sep
|
47 |
+
for role, message in messages:
|
48 |
+
if message:
|
49 |
+
if type(message) is tuple:
|
50 |
+
message, _, _ = message
|
51 |
+
ret += role + ": " + message + self.sep
|
52 |
+
else:
|
53 |
+
ret += role + ":"
|
54 |
+
elif self.sep_style == SeparatorStyle.TWO:
|
55 |
+
seps = [self.sep, self.sep2]
|
56 |
+
ret = self.system + seps[0]
|
57 |
+
for i, (role, message) in enumerate(messages):
|
58 |
+
if message:
|
59 |
+
if type(message) is tuple:
|
60 |
+
message, _, _ = message
|
61 |
+
ret += role + ": " + message + seps[i % 2]
|
62 |
+
else:
|
63 |
+
ret += role + ":"
|
64 |
+
elif self.sep_style == SeparatorStyle.MPT:
|
65 |
+
ret = self.system + self.sep
|
66 |
+
for role, message in messages:
|
67 |
+
if message:
|
68 |
+
if type(message) is tuple:
|
69 |
+
message, _, _ = message
|
70 |
+
ret += role + message + self.sep
|
71 |
+
else:
|
72 |
+
ret += role
|
73 |
+
elif self.sep_style == SeparatorStyle.LLAMA_2:
|
74 |
+
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
|
75 |
+
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
|
76 |
+
ret = ""
|
77 |
+
|
78 |
+
for i, (role, message) in enumerate(messages):
|
79 |
+
if i == 0:
|
80 |
+
assert message, "first message should not be none"
|
81 |
+
assert role == self.roles[0], "first message should come from user"
|
82 |
+
if message:
|
83 |
+
if type(message) is tuple:
|
84 |
+
message, _, _ = message
|
85 |
+
if i == 0: message = wrap_sys(self.system) + message
|
86 |
+
if i % 2 == 0:
|
87 |
+
message = wrap_inst(message)
|
88 |
+
ret += self.sep + message
|
89 |
+
else:
|
90 |
+
ret += " " + message + " " + self.sep2
|
91 |
+
else:
|
92 |
+
ret += ""
|
93 |
+
ret = ret.lstrip(self.sep)
|
94 |
+
elif self.sep_style == SeparatorStyle.PLAIN:
|
95 |
+
seps = [self.sep, self.sep2]
|
96 |
+
ret = self.system
|
97 |
+
for i, (role, message) in enumerate(messages):
|
98 |
+
if message:
|
99 |
+
if type(message) is tuple:
|
100 |
+
message, _, _ = message
|
101 |
+
ret += message + seps[i % 2]
|
102 |
+
else:
|
103 |
+
ret += ""
|
104 |
+
else:
|
105 |
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
106 |
+
|
107 |
+
return ret
|
108 |
+
|
109 |
+
def append_message(self, role, message):
|
110 |
+
self.messages.append([role, message])
|
111 |
+
|
112 |
+
def get_images(self, return_pil=False):
|
113 |
+
images = []
|
114 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
115 |
+
if i % 2 == 0:
|
116 |
+
if type(msg) is tuple:
|
117 |
+
import base64
|
118 |
+
from io import BytesIO
|
119 |
+
from PIL import Image
|
120 |
+
msg, image, image_process_mode = msg
|
121 |
+
if image_process_mode == "Pad":
|
122 |
+
def expand2square(pil_img, background_color=(122, 116, 104)):
|
123 |
+
width, height = pil_img.size
|
124 |
+
if width == height:
|
125 |
+
return pil_img
|
126 |
+
elif width > height:
|
127 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
128 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
129 |
+
return result
|
130 |
+
else:
|
131 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
132 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
133 |
+
return result
|
134 |
+
image = expand2square(image)
|
135 |
+
elif image_process_mode in ["Default", "Crop"]:
|
136 |
+
pass
|
137 |
+
elif image_process_mode == "Resize":
|
138 |
+
image = image.resize((336, 336))
|
139 |
+
else:
|
140 |
+
raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
|
141 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
142 |
+
aspect_ratio = max_hw / min_hw
|
143 |
+
max_len, min_len = 800, 400
|
144 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
145 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
146 |
+
W, H = image.size
|
147 |
+
if longest_edge != max(image.size):
|
148 |
+
if H > W:
|
149 |
+
H, W = longest_edge, shortest_edge
|
150 |
+
else:
|
151 |
+
H, W = shortest_edge, longest_edge
|
152 |
+
image = image.resize((W, H))
|
153 |
+
if return_pil:
|
154 |
+
images.append(image)
|
155 |
+
else:
|
156 |
+
buffered = BytesIO()
|
157 |
+
image.save(buffered, format="PNG")
|
158 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
159 |
+
images.append(img_b64_str)
|
160 |
+
return images
|
161 |
+
|
162 |
+
def to_gradio_chatbot(self):
|
163 |
+
ret = []
|
164 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
165 |
+
if i % 2 == 0:
|
166 |
+
if type(msg) is tuple:
|
167 |
+
import base64
|
168 |
+
from io import BytesIO
|
169 |
+
msg, image, image_process_mode = msg
|
170 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
171 |
+
aspect_ratio = max_hw / min_hw
|
172 |
+
max_len, min_len = 800, 400
|
173 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
174 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
175 |
+
W, H = image.size
|
176 |
+
if H > W:
|
177 |
+
H, W = longest_edge, shortest_edge
|
178 |
+
else:
|
179 |
+
H, W = shortest_edge, longest_edge
|
180 |
+
image = image.resize((W, H))
|
181 |
+
buffered = BytesIO()
|
182 |
+
image.save(buffered, format="JPEG")
|
183 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
184 |
+
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
|
185 |
+
msg = img_str + msg.replace('<image>', '').strip()
|
186 |
+
ret.append([msg, None])
|
187 |
+
else:
|
188 |
+
ret.append([msg, None])
|
189 |
+
else:
|
190 |
+
ret[-1][-1] = msg
|
191 |
+
return ret
|
192 |
+
|
193 |
+
def copy(self):
|
194 |
+
return Conversation(
|
195 |
+
system=self.system,
|
196 |
+
roles=self.roles,
|
197 |
+
messages=[[x, y] for x, y in self.messages],
|
198 |
+
offset=self.offset,
|
199 |
+
sep_style=self.sep_style,
|
200 |
+
sep=self.sep,
|
201 |
+
sep2=self.sep2,
|
202 |
+
real_sep2=self.real_sep2,
|
203 |
+
version=self.version)
|
204 |
+
|
205 |
+
def dict(self):
|
206 |
+
if len(self.get_images()) > 0:
|
207 |
+
return {
|
208 |
+
"system": self.system,
|
209 |
+
"roles": self.roles,
|
210 |
+
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
|
211 |
+
"offset": self.offset,
|
212 |
+
"sep": self.sep,
|
213 |
+
"sep2": self.sep2,
|
214 |
+
"real_sep2": self.real_sep2
|
215 |
+
}
|
216 |
+
return {
|
217 |
+
"system": self.system,
|
218 |
+
"roles": self.roles,
|
219 |
+
"messages": self.messages,
|
220 |
+
"offset": self.offset,
|
221 |
+
"sep": self.sep,
|
222 |
+
"sep2": self.sep2,
|
223 |
+
"real_sep2": self.real_sep2
|
224 |
+
}
|
225 |
+
|
226 |
+
|
227 |
+
conv_vicuna_v0 = Conversation(
|
228 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
229 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
230 |
+
roles=("Human", "Assistant"),
|
231 |
+
messages=(
|
232 |
+
("Human", "What are the key differences between renewable and non-renewable energy sources?"),
|
233 |
+
("Assistant",
|
234 |
+
"Renewable energy sources are those that can be replenished naturally in a relatively "
|
235 |
+
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
|
236 |
+
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
|
237 |
+
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
|
238 |
+
"renewable and non-renewable energy sources:\n"
|
239 |
+
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
|
240 |
+
"energy sources are finite and will eventually run out.\n"
|
241 |
+
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
|
242 |
+
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
|
243 |
+
"and other negative effects.\n"
|
244 |
+
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
|
245 |
+
"have lower operational costs than non-renewable sources.\n"
|
246 |
+
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
|
247 |
+
"locations than non-renewable sources.\n"
|
248 |
+
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
|
249 |
+
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
|
250 |
+
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
|
251 |
+
"non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
|
252 |
+
),
|
253 |
+
offset=2,
|
254 |
+
sep_style=SeparatorStyle.SINGLE,
|
255 |
+
sep="###",
|
256 |
+
)
|
257 |
+
|
258 |
+
### Used for llava-instruction-tuning stage
|
259 |
+
conv_vicuna_v1 = Conversation(
|
260 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
261 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
262 |
+
roles=("USER", "ASSISTANT"),
|
263 |
+
version="v1",
|
264 |
+
messages=(),
|
265 |
+
offset=0,
|
266 |
+
sep_style=SeparatorStyle.TWO,
|
267 |
+
sep=" ",
|
268 |
+
sep2="</s>",
|
269 |
+
)
|
270 |
+
|
271 |
+
conv_llama_2 = Conversation(
|
272 |
+
system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
|
273 |
+
|
274 |
+
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
|
275 |
+
roles=("USER", "ASSISTANT"),
|
276 |
+
version="llama_v2",
|
277 |
+
messages=(),
|
278 |
+
offset=0,
|
279 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
280 |
+
sep="<s>",
|
281 |
+
sep2="</s>",
|
282 |
+
)
|
283 |
+
|
284 |
+
conv_llava_llama_2 = Conversation(
|
285 |
+
system="You are a helpful language and vision assistant. "
|
286 |
+
"You are able to understand the visual content that the user provides, "
|
287 |
+
"and assist the user with a variety of tasks using natural language.",
|
288 |
+
roles=("USER", "ASSISTANT"),
|
289 |
+
version="llama_v2",
|
290 |
+
messages=(),
|
291 |
+
offset=0,
|
292 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
293 |
+
sep="<s>",
|
294 |
+
sep2="</s>",
|
295 |
+
)
|
296 |
+
|
297 |
+
conv_mpt = Conversation(
|
298 |
+
system="""<|im_start|>system
|
299 |
+
A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
|
300 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
301 |
+
version="mpt",
|
302 |
+
messages=(),
|
303 |
+
offset=0,
|
304 |
+
sep_style=SeparatorStyle.MPT,
|
305 |
+
sep="<|im_end|>",
|
306 |
+
)
|
307 |
+
|
308 |
+
|
309 |
+
|
310 |
+
### Used for llava-pretraining
|
311 |
+
conv_llava_plain = Conversation(
|
312 |
+
system="",
|
313 |
+
roles=("", ""),
|
314 |
+
messages=(
|
315 |
+
),
|
316 |
+
offset=0,
|
317 |
+
sep_style=SeparatorStyle.PLAIN,
|
318 |
+
sep="\n",
|
319 |
+
)
|
320 |
+
|
321 |
+
conv_llava_v0 = Conversation(
|
322 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
323 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
324 |
+
roles=("Human", "Assistant"),
|
325 |
+
messages=(
|
326 |
+
),
|
327 |
+
offset=0,
|
328 |
+
sep_style=SeparatorStyle.SINGLE,
|
329 |
+
sep="###",
|
330 |
+
)
|
331 |
+
|
332 |
+
conv_llava_v0_mmtag = Conversation(
|
333 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
334 |
+
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
335 |
+
"The visual content will be provided with the following format: <Image>visual content</Image>.",
|
336 |
+
roles=("Human", "Assistant"),
|
337 |
+
messages=(
|
338 |
+
),
|
339 |
+
offset=0,
|
340 |
+
sep_style=SeparatorStyle.SINGLE,
|
341 |
+
sep="###",
|
342 |
+
version="v0_mmtag",
|
343 |
+
)
|
344 |
+
|
345 |
+
conv_llava_v1 = Conversation(
|
346 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
347 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
348 |
+
roles=("USER", "ASSISTANT"),
|
349 |
+
version="v1",
|
350 |
+
messages=(),
|
351 |
+
offset=0,
|
352 |
+
sep_style=SeparatorStyle.TWO,
|
353 |
+
sep=" ",
|
354 |
+
sep2="</s>",
|
355 |
+
)
|
356 |
+
|
357 |
+
conv_llava_v1_mmtag = Conversation(
|
358 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
359 |
+
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
360 |
+
"The visual content will be provided with the following format: <Image>visual content</Image>.",
|
361 |
+
roles=("USER", "ASSISTANT"),
|
362 |
+
messages=(),
|
363 |
+
offset=0,
|
364 |
+
sep_style=SeparatorStyle.TWO,
|
365 |
+
sep=" ",
|
366 |
+
sep2="</s>",
|
367 |
+
version="v1_mmtag",
|
368 |
+
)
|
369 |
+
|
370 |
+
|
371 |
+
nvllm_8b_pretrain = Conversation(
|
372 |
+
system="",
|
373 |
+
roles=(),
|
374 |
+
version="nvllm_8b",
|
375 |
+
messages=(),
|
376 |
+
offset=0,
|
377 |
+
sep_style=SeparatorStyle.SINGLE,
|
378 |
+
sep="\n",
|
379 |
+
)
|
380 |
+
|
381 |
+
nvllm_8b_sft = Conversation(
|
382 |
+
system="System: This is a chat between a user and an artificial intelligence assistant. "
|
383 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
384 |
+
roles=("User", "Assistant"),
|
385 |
+
version="nvllm_8b",
|
386 |
+
messages=(),
|
387 |
+
offset=0,
|
388 |
+
sep_style=SeparatorStyle.TWO,
|
389 |
+
sep="\n\n",
|
390 |
+
sep2="\n\n\n",
|
391 |
+
real_sep2="\n\n"
|
392 |
+
)
|
393 |
+
|
394 |
+
chatqa_sft = Conversation(
|
395 |
+
system="System: This is a chat between a user and an artificial intelligence assistant. "
|
396 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
397 |
+
roles=("User", "Assistant"),
|
398 |
+
version="chatqa",
|
399 |
+
messages=(),
|
400 |
+
offset=0,
|
401 |
+
sep_style=SeparatorStyle.TWO,
|
402 |
+
sep="\n\n",
|
403 |
+
sep2="\n\n",
|
404 |
+
real_sep2="\n\n"
|
405 |
+
)
|
406 |
+
|
407 |
+
nvllm_8b_sft_noinstruction = Conversation(
|
408 |
+
system="",
|
409 |
+
roles=("User", "Assistant"),
|
410 |
+
version="nvllm_8b",
|
411 |
+
messages=(),
|
412 |
+
offset=0,
|
413 |
+
sep_style=SeparatorStyle.TWO,
|
414 |
+
sep="\n\n",
|
415 |
+
sep2="\n\n\n",
|
416 |
+
real_sep2="\n\n"
|
417 |
+
)
|
418 |
+
|
419 |
+
# conv_yi = Conversation(
|
420 |
+
# system="""<|im_start|>system
|
421 |
+
# A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
|
422 |
+
# roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
423 |
+
# version="mpt",
|
424 |
+
# messages=(),
|
425 |
+
# offset=0,
|
426 |
+
# sep_style=SeparatorStyle.MPT,
|
427 |
+
# sep="<|im_end|>",
|
428 |
+
# )
|
429 |
+
|
430 |
+
conv_chatml = Conversation(
|
431 |
+
system="""<|im_start|>system
|
432 |
+
Answer the questions.""",
|
433 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
434 |
+
version="mpt",
|
435 |
+
messages=(),
|
436 |
+
offset=0,
|
437 |
+
sep_style=SeparatorStyle.MPT,
|
438 |
+
sep="<|im_end|>",
|
439 |
+
)
|
440 |
+
|
441 |
+
llama3_instruct = Conversation(
|
442 |
+
system="<|start_header_id|>system<|end_header_id|>\n\nAnswer the questions.",
|
443 |
+
roles=("<|start_header_id|>user<|end_header_id|>\n\n", "<|start_header_id|>assistant<|end_header_id|>\n\n"),
|
444 |
+
version="mpt",
|
445 |
+
messages=(),
|
446 |
+
offset=0,
|
447 |
+
sep_style=SeparatorStyle.MPT,
|
448 |
+
sep="<|eot_id|>",
|
449 |
+
)
|
450 |
+
|
451 |
+
llama3_1_instruct = Conversation(
|
452 |
+
system="<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nAnswer the questions.",
|
453 |
+
roles=("<|start_header_id|>user<|end_header_id|>\n\n", "<|start_header_id|>assistant<|end_header_id|>\n\n"),
|
454 |
+
version="mpt",
|
455 |
+
messages=(),
|
456 |
+
offset=0,
|
457 |
+
sep_style=SeparatorStyle.MPT,
|
458 |
+
sep="<|eot_id|>",
|
459 |
+
)
|
460 |
+
|
461 |
+
# default_conversation = conv_vicuna_v0
|
462 |
+
default_conversation = nvllm_8b_sft
|
463 |
+
|
464 |
+
# original_llava_pretraining = conv_llava_plain
|
465 |
+
# original_llava_sft = conv_vicuna_v1
|
466 |
+
|
467 |
+
conv_templates = {
|
468 |
+
"default": conv_vicuna_v0,
|
469 |
+
"v0": conv_vicuna_v0,
|
470 |
+
"v1": conv_vicuna_v1,
|
471 |
+
"vicuna_v1": conv_vicuna_v1,
|
472 |
+
"llama_2": conv_llama_2,
|
473 |
+
|
474 |
+
"plain": conv_llava_plain,
|
475 |
+
"v0_plain": conv_llava_plain,
|
476 |
+
"llava_v0": conv_llava_v0,
|
477 |
+
"v0_mmtag": conv_llava_v0_mmtag,
|
478 |
+
"llava_v1": conv_llava_v1,
|
479 |
+
"v1_mmtag": conv_llava_v1_mmtag,
|
480 |
+
"llava_llama_2": conv_llava_llama_2,
|
481 |
+
|
482 |
+
"mpt": conv_mpt,
|
483 |
+
}
|
484 |
+
|
485 |
+
|
486 |
+
if __name__ == "__main__":
|
487 |
+
|
488 |
+
print(default_conversation)
|
489 |
+
print(default_conversation.roles[0])
|
490 |
+
|
491 |
+
|
492 |
+
# print(default_conversation.get_prompt())
|
eval/eval_dataset.py
ADDED
@@ -0,0 +1,850 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import time
|
5 |
+
import yaml
|
6 |
+
import spacy
|
7 |
+
import ast
|
8 |
+
from PIL import Image
|
9 |
+
from glob import glob
|
10 |
+
from tqdm import tqdm
|
11 |
+
from collections import defaultdict
|
12 |
+
import pandas as pd
|
13 |
+
from io import BytesIO
|
14 |
+
import base64
|
15 |
+
from anls import anls_score
|
16 |
+
import torch
|
17 |
+
from torch.utils.data import Dataset, DataLoader, DistributedSampler
|
18 |
+
import torchvision.transforms as T
|
19 |
+
from eval import conversation as conversation_lib
|
20 |
+
from eval.mmmu_utils import CAT_SHORT2LONG, DOMAIN_CAT2SUB_CAT, parse_multi_choice_response, parse_open_response, \
|
21 |
+
process_single_sample, construct_prompt, mmmu_main_eval, process_single_sample_pro, construct_prompt_pro
|
22 |
+
from eval.mmmu_utils import evaluate as evaluate_mmmu
|
23 |
+
from torchvision.transforms.functional import InterpolationMode
|
24 |
+
from datasets import load_dataset, concatenate_datasets
|
25 |
+
|
26 |
+
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
27 |
+
IMAGENET_STD = (0.229, 0.224, 0.225)
|
28 |
+
|
29 |
+
|
30 |
+
def build_transform(input_size):
|
31 |
+
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
|
32 |
+
transform = T.Compose([
|
33 |
+
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
34 |
+
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
|
35 |
+
T.ToTensor(),
|
36 |
+
T.Normalize(mean=MEAN, std=STD)
|
37 |
+
])
|
38 |
+
return transform
|
39 |
+
|
40 |
+
|
41 |
+
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
|
42 |
+
best_ratio_diff = float('inf')
|
43 |
+
best_ratio = (1, 1)
|
44 |
+
area = width * height
|
45 |
+
for ratio in target_ratios:
|
46 |
+
target_aspect_ratio = ratio[0] / ratio[1]
|
47 |
+
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
48 |
+
if ratio_diff < best_ratio_diff:
|
49 |
+
best_ratio_diff = ratio_diff
|
50 |
+
best_ratio = ratio
|
51 |
+
elif ratio_diff == best_ratio_diff:
|
52 |
+
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
|
53 |
+
best_ratio = ratio
|
54 |
+
return best_ratio
|
55 |
+
|
56 |
+
|
57 |
+
def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False):
|
58 |
+
orig_width, orig_height = image.size
|
59 |
+
aspect_ratio = orig_width / orig_height
|
60 |
+
|
61 |
+
# calculate the existing image aspect ratio
|
62 |
+
target_ratios = set(
|
63 |
+
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
|
64 |
+
i * j <= max_num and i * j >= min_num)
|
65 |
+
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
66 |
+
|
67 |
+
# find the closest aspect ratio to the target
|
68 |
+
target_aspect_ratio = find_closest_aspect_ratio(
|
69 |
+
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
|
70 |
+
|
71 |
+
# calculate the target width and height
|
72 |
+
target_width = image_size * target_aspect_ratio[0]
|
73 |
+
target_height = image_size * target_aspect_ratio[1]
|
74 |
+
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
75 |
+
|
76 |
+
# resize the image
|
77 |
+
resized_img = image.resize((target_width, target_height))
|
78 |
+
processed_images = []
|
79 |
+
for i in range(blocks):
|
80 |
+
box = (
|
81 |
+
(i % (target_width // image_size)) * image_size,
|
82 |
+
(i // (target_width // image_size)) * image_size,
|
83 |
+
((i % (target_width // image_size)) + 1) * image_size,
|
84 |
+
((i // (target_width // image_size)) + 1) * image_size
|
85 |
+
)
|
86 |
+
# split the image
|
87 |
+
split_img = resized_img.crop(box)
|
88 |
+
processed_images.append(split_img)
|
89 |
+
assert len(processed_images) == blocks
|
90 |
+
if use_thumbnail and len(processed_images) != 1:
|
91 |
+
thumbnail_img = image.resize((image_size, image_size))
|
92 |
+
processed_images.append(thumbnail_img)
|
93 |
+
return processed_images
|
94 |
+
|
95 |
+
|
96 |
+
def load_image(image, input_size=448, max_num=6, decoded=False):
|
97 |
+
if not decoded:
|
98 |
+
image = Image.open(image).convert('RGB')
|
99 |
+
transform = build_transform(input_size=input_size)
|
100 |
+
images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
|
101 |
+
pixel_values = [transform(image) for image in images]
|
102 |
+
pixel_values = torch.stack(pixel_values)
|
103 |
+
return pixel_values
|
104 |
+
|
105 |
+
|
106 |
+
def levenshtein_distance(s1, s2):
|
107 |
+
if len(s1) > len(s2):
|
108 |
+
s1, s2 = s2, s1
|
109 |
+
|
110 |
+
distances = range(len(s1) + 1)
|
111 |
+
for i2, c2 in enumerate(s2):
|
112 |
+
distances_ = [i2 + 1]
|
113 |
+
for i1, c1 in enumerate(s1):
|
114 |
+
if c1 == c2:
|
115 |
+
distances_.append(distances[i1])
|
116 |
+
else:
|
117 |
+
distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
|
118 |
+
distances = distances_
|
119 |
+
return distances[-1]
|
120 |
+
|
121 |
+
|
122 |
+
def get_anls_score(pred, gold_labels, threshold, llava_eval=False):
|
123 |
+
values = []
|
124 |
+
for answer in gold_labels:
|
125 |
+
# preprocess both the answers - gt and prediction
|
126 |
+
gt_answer = ' '.join(answer.strip().lower().split())
|
127 |
+
det_answer = ' '.join(pred.strip().lower().split())
|
128 |
+
|
129 |
+
dist = levenshtein_distance(gt_answer, det_answer)
|
130 |
+
length = max(len(answer.upper()), len(pred.upper()))
|
131 |
+
values.append(0.0 if length == 0 else float(dist) / float(length))
|
132 |
+
|
133 |
+
question_result = 1 - min(values)
|
134 |
+
|
135 |
+
if llava_eval:
|
136 |
+
question_result = 1.0 if question_result >= threshold else 0.0
|
137 |
+
else:
|
138 |
+
if (question_result < threshold):
|
139 |
+
question_result = 0
|
140 |
+
|
141 |
+
return question_result
|
142 |
+
|
143 |
+
|
144 |
+
def isNumber(n: str):
|
145 |
+
try:
|
146 |
+
float(n)
|
147 |
+
return True
|
148 |
+
except ValueError:
|
149 |
+
return False
|
150 |
+
|
151 |
+
|
152 |
+
class COCOEvalDataset(Dataset):
|
153 |
+
def __init__(self, args, img_dir, subset=None):
|
154 |
+
self.args = args
|
155 |
+
self.img_files = sorted(glob(os.path.join(img_dir, "*")))
|
156 |
+
|
157 |
+
if subset:
|
158 |
+
self.img_files = self.img_files[:subset]
|
159 |
+
|
160 |
+
self.image_ids = [int(img_file.split("_")[-1].split(".")[0]) for img_file in self.img_files]
|
161 |
+
|
162 |
+
def __len__(self):
|
163 |
+
return len(self.img_files)
|
164 |
+
|
165 |
+
def __getitem__(self, idx):
|
166 |
+
img_path = self.img_files[idx]
|
167 |
+
img = load_image(img_path, max_num=6).to(torch.bfloat16)
|
168 |
+
|
169 |
+
return self.image_ids[idx], img
|
170 |
+
|
171 |
+
|
172 |
+
class Flickr30KEvalDataset(Dataset):
|
173 |
+
def __init__(self, args, img_dir, subset=None):
|
174 |
+
self.args = args
|
175 |
+
self.img_dir = img_dir
|
176 |
+
self.test_samples = json.load(open(os.path.join(img_dir, "flickr30k_test.json"), encoding='utf-8'))
|
177 |
+
|
178 |
+
if subset:
|
179 |
+
self.test_samples = self.test_samples[:subset]
|
180 |
+
|
181 |
+
def __len__(self):
|
182 |
+
return len(self.test_samples)
|
183 |
+
|
184 |
+
def __getitem__(self, idx):
|
185 |
+
img_path = os.path.join(self.img_dir, self.test_samples[idx]["image"])
|
186 |
+
img = load_image(img_path, max_num=6).to(torch.bfloat16)
|
187 |
+
|
188 |
+
image_id = int(self.test_samples[idx]["image"].split("/")[-1].replace(".jpg", ""))
|
189 |
+
|
190 |
+
return image_id, img
|
191 |
+
|
192 |
+
|
193 |
+
class VQAv2EvalDataset(Dataset):
|
194 |
+
def __init__(self, args, img_dir, gt_path, subset=None):
|
195 |
+
self.args = args
|
196 |
+
self.img_dir = img_dir
|
197 |
+
self.gt = json.load(open(gt_path, encoding='utf-8'))
|
198 |
+
|
199 |
+
if subset:
|
200 |
+
self.gt = self.gt[:subset]
|
201 |
+
|
202 |
+
def __len__(self):
|
203 |
+
return len(self.gt)
|
204 |
+
|
205 |
+
def __getitem__(self, idx):
|
206 |
+
img_path = os.path.join(self.img_dir, self.gt[idx]["image"])
|
207 |
+
img = load_image(img_path, max_num=6).to(torch.bfloat16)
|
208 |
+
|
209 |
+
question_id = self.gt[idx]["question_id"]
|
210 |
+
question = self.gt[idx]["question"]
|
211 |
+
answer = self.gt[idx]["answer"]
|
212 |
+
|
213 |
+
return img, question_id, question, answer
|
214 |
+
|
215 |
+
|
216 |
+
class TextVQAEvalDataset(Dataset):
|
217 |
+
def __init__(self, args, img_dir, gt_path, subset=None):
|
218 |
+
self.args = args
|
219 |
+
self.img_dir = img_dir
|
220 |
+
self.gt = json.load(open(gt_path, encoding='utf-8'))['data']
|
221 |
+
|
222 |
+
if subset:
|
223 |
+
self.gt = self.gt[:subset]
|
224 |
+
|
225 |
+
def __len__(self):
|
226 |
+
return len(self.gt)
|
227 |
+
|
228 |
+
def __getitem__(self, idx):
|
229 |
+
img_path = os.path.join(self.img_dir, self.gt[idx]["image_id"] + '.jpg')
|
230 |
+
if not os.path.exists(img_path):
|
231 |
+
img_path = img_path.replace('.jpg', '.png')
|
232 |
+
img = load_image(img_path, max_num=6).to(torch.bfloat16)
|
233 |
+
|
234 |
+
question_id = self.gt[idx]["question_id"]
|
235 |
+
question = self.gt[idx]["question"]
|
236 |
+
answer = self.gt[idx]["answers"]
|
237 |
+
|
238 |
+
return img, question_id, question, answer
|
239 |
+
|
240 |
+
|
241 |
+
class GQAEvalDataset(Dataset):
|
242 |
+
def __init__(self, args, img_dir, gt_path, subset=None):
|
243 |
+
self.args = args
|
244 |
+
self.img_dir = img_dir
|
245 |
+
self.gt = json.load(open(gt_path, encoding='utf-8'))
|
246 |
+
self.gt = [{
|
247 |
+
"question_id": int(k),
|
248 |
+
"image": v['imageId'] + ".jpg",
|
249 |
+
"question": v['question'],
|
250 |
+
"answer": v['answer']
|
251 |
+
} for k, v in self.gt.items()]
|
252 |
+
|
253 |
+
if subset:
|
254 |
+
self.gt = self.gt[:subset]
|
255 |
+
|
256 |
+
def __len__(self):
|
257 |
+
return len(self.gt)
|
258 |
+
|
259 |
+
def __getitem__(self, idx):
|
260 |
+
img_path = os.path.join(self.img_dir, self.gt[idx]["image"])
|
261 |
+
img = load_image(img_path, max_num=6).to(torch.bfloat16)
|
262 |
+
|
263 |
+
question_id = self.gt[idx]["question_id"]
|
264 |
+
question = self.gt[idx]["question"]
|
265 |
+
answer = self.gt[idx]["answer"]
|
266 |
+
|
267 |
+
return img, question_id, question, [answer]
|
268 |
+
|
269 |
+
|
270 |
+
class ChartQAEvalDataset(Dataset):
|
271 |
+
def __init__(self, args, img_dir, gt_path, subset=None):
|
272 |
+
self.args = args
|
273 |
+
self.img_dir = img_dir
|
274 |
+
self.gt = json.load(open(gt_path, encoding='utf-8'))
|
275 |
+
for i in range(len(self.gt)):
|
276 |
+
self.gt[i]['question_id'] = i
|
277 |
+
|
278 |
+
if subset:
|
279 |
+
self.gt = self.gt[:subset]
|
280 |
+
|
281 |
+
def __len__(self):
|
282 |
+
return len(self.gt)
|
283 |
+
|
284 |
+
def __getitem__(self, idx):
|
285 |
+
img_path = os.path.join(self.img_dir, self.gt[idx]["imgname"])
|
286 |
+
img = load_image(img_path, max_num=6).to(torch.bfloat16)
|
287 |
+
|
288 |
+
question_id = self.gt[idx]["question_id"]
|
289 |
+
question = self.gt[idx]["query"]
|
290 |
+
answer = self.gt[idx]["label"]
|
291 |
+
|
292 |
+
return img, question_id, question, [answer]
|
293 |
+
|
294 |
+
|
295 |
+
class OKVQAEvalDataset(Dataset):
|
296 |
+
def __init__(self, args, img_dir, gt_path, question_path, subset=None):
|
297 |
+
self.args = args
|
298 |
+
self.img_dir = img_dir
|
299 |
+
self.gt = json.load(open(gt_path, encoding='utf-8'))['annotations']
|
300 |
+
self.questions = json.load(open(question_path, 'r'))['questions']
|
301 |
+
|
302 |
+
if subset:
|
303 |
+
self.gt = self.gt[:subset]
|
304 |
+
|
305 |
+
qid2q = {q['question_id']: q['question'] for q in self.questions}
|
306 |
+
|
307 |
+
for ann in self.gt:
|
308 |
+
ann['answers'] = [ans['answer'] for ans in ann['answers']]
|
309 |
+
ann['question'] = qid2q[ann['question_id']]
|
310 |
+
|
311 |
+
def __len__(self):
|
312 |
+
return len(self.gt)
|
313 |
+
|
314 |
+
def __getitem__(self, idx):
|
315 |
+
img_id = str(self.gt[idx]["image_id"])
|
316 |
+
img_id = '0' * (12 - len(img_id)) + img_id
|
317 |
+
img_file_name = f"COCO_val2014_{img_id}.jpg"
|
318 |
+
img_path = os.path.join(self.img_dir, img_file_name)
|
319 |
+
img = load_image(img_path, max_num=6).to(torch.bfloat16)
|
320 |
+
|
321 |
+
question_id = self.gt[idx]["question_id"]
|
322 |
+
question = self.gt[idx]["question"]
|
323 |
+
answer = self.gt[idx]["answers"]
|
324 |
+
|
325 |
+
return img, question_id, question, answer
|
326 |
+
|
327 |
+
|
328 |
+
class DocVQAEvalDataset(Dataset):
|
329 |
+
def __init__(self, args, img_dir, gt_path, split='val', subset=None):
|
330 |
+
self.args = args
|
331 |
+
self.img_dir = img_dir
|
332 |
+
self.gt = json.load(open(gt_path, encoding='utf-8'))['data']
|
333 |
+
|
334 |
+
if subset:
|
335 |
+
self.gt = self.gt[:subset]
|
336 |
+
|
337 |
+
self.split = split
|
338 |
+
|
339 |
+
def __len__(self):
|
340 |
+
return len(self.gt)
|
341 |
+
|
342 |
+
def __getitem__(self, idx):
|
343 |
+
img_path = os.path.join(self.img_dir, self.gt[idx]['image'].split('/')[-1])
|
344 |
+
img = load_image(img_path, max_num=6).to(torch.bfloat16)
|
345 |
+
|
346 |
+
question_id = self.gt[idx]["questionId"]
|
347 |
+
question = self.gt[idx]["question"]
|
348 |
+
|
349 |
+
if self.split == 'val':
|
350 |
+
answer = self.gt[idx]["answers"]
|
351 |
+
else:
|
352 |
+
answer = ['']
|
353 |
+
|
354 |
+
return img, question_id, question, answer
|
355 |
+
|
356 |
+
|
357 |
+
class OCRBenchEvalDataset(Dataset):
|
358 |
+
def __init__(self, args, img_dir, gt_path, subset=None):
|
359 |
+
self.args = args
|
360 |
+
self.img_dir = img_dir
|
361 |
+
self.gt = json.load(open(gt_path, encoding='utf-8'))
|
362 |
+
|
363 |
+
if subset:
|
364 |
+
self.gt = self.gt[:subset]
|
365 |
+
|
366 |
+
def __len__(self):
|
367 |
+
return len(self.gt)
|
368 |
+
|
369 |
+
def __getitem__(self, idx):
|
370 |
+
img_path = os.path.join(self.img_dir, self.gt[idx]['image_path'])
|
371 |
+
img = load_image(img_path, max_num=6).to(torch.bfloat16)
|
372 |
+
|
373 |
+
dataset_name = self.gt[idx]["dataset_name"]
|
374 |
+
question_id = f"{idx}"
|
375 |
+
question = self.gt[idx]["question"]
|
376 |
+
answer = self.gt[idx]["answers"]
|
377 |
+
data_type = self.gt[idx]["type"]
|
378 |
+
|
379 |
+
return img, question_id, question, answer, dataset_name, data_type
|
380 |
+
|
381 |
+
|
382 |
+
class AI2DiagramEvalDataset(Dataset):
|
383 |
+
def __init__(self, args, img_dir, gt_path, subset=None):
|
384 |
+
self.args = args
|
385 |
+
self.img_dir = img_dir
|
386 |
+
|
387 |
+
with open(gt_path, 'r') as json_file:
|
388 |
+
json_list = list(json_file)
|
389 |
+
self.gt = [json.loads(json_str) for json_str in json_list]
|
390 |
+
|
391 |
+
if subset:
|
392 |
+
self.gt = self.gt[:subset]
|
393 |
+
|
394 |
+
def __len__(self):
|
395 |
+
return len(self.gt)
|
396 |
+
|
397 |
+
def __getitem__(self, idx):
|
398 |
+
img_path = os.path.join(self.img_dir, self.gt[idx]['image'])
|
399 |
+
img = load_image(img_path, max_num=6).to(torch.bfloat16)
|
400 |
+
|
401 |
+
question_id = self.gt[idx]["question_id"]
|
402 |
+
question = self.gt[idx]["question"]
|
403 |
+
answer = self.gt[idx]["answer"]
|
404 |
+
|
405 |
+
return img, question_id, question, answer
|
406 |
+
|
407 |
+
|
408 |
+
class AI2DiagramNoMaskEvalDataset(Dataset):
|
409 |
+
def __init__(self, args, img_dir, gt_path, subset=None):
|
410 |
+
self.args = args
|
411 |
+
self.img_dir = img_dir
|
412 |
+
|
413 |
+
with open(gt_path, 'r') as json_file:
|
414 |
+
json_list = list(json_file)
|
415 |
+
self.gt = [json.loads(json_str) for json_str in json_list]
|
416 |
+
|
417 |
+
if subset:
|
418 |
+
self.gt = self.gt[:subset]
|
419 |
+
|
420 |
+
def __len__(self):
|
421 |
+
return len(self.gt)
|
422 |
+
|
423 |
+
def __getitem__(self, idx):
|
424 |
+
img_file_name = self.gt[idx]['image'].replace("AI2D_TEST", "AI2D_TEST_NO_MASK_IMAGES")
|
425 |
+
img_path = os.path.join(self.img_dir, img_file_name)
|
426 |
+
img = load_image(img_path, max_num=6).to(torch.bfloat16)
|
427 |
+
|
428 |
+
question_id = self.gt[idx]["question_id"]
|
429 |
+
question = self.gt[idx]["question"]
|
430 |
+
answer = self.gt[idx]["answer"]
|
431 |
+
|
432 |
+
return img, question_id, question, answer
|
433 |
+
|
434 |
+
|
435 |
+
class RealworldQAEvalDataset(Dataset):
|
436 |
+
def __init__(self, args, img_dir, gt_path, subset=None):
|
437 |
+
self.args = args
|
438 |
+
self.img_dir = img_dir
|
439 |
+
self.gt = json.load(open(gt_path, encoding='utf-8'))
|
440 |
+
|
441 |
+
if subset:
|
442 |
+
self.gt = self.gt[:subset]
|
443 |
+
|
444 |
+
def __len__(self):
|
445 |
+
return len(self.gt)
|
446 |
+
|
447 |
+
def __getitem__(self, idx):
|
448 |
+
img_path = os.path.join(self.img_dir, self.gt[idx]['image'])
|
449 |
+
img = load_image(img_path, max_num=6).to(torch.bfloat16)
|
450 |
+
|
451 |
+
question_id = int(self.gt[idx]['image'].replace(".webp", ""))
|
452 |
+
question = self.gt[idx]["question"]
|
453 |
+
|
454 |
+
if self.gt[idx]['question_type'] == "multi-choice":
|
455 |
+
choices = self.gt[idx]["choices"]
|
456 |
+
start_chr = 'A'
|
457 |
+
choices_str = ''
|
458 |
+
index2ans = {}
|
459 |
+
all_choices = []
|
460 |
+
for choice in choices:
|
461 |
+
all_choices.append(start_chr)
|
462 |
+
index2ans[start_chr] = choice
|
463 |
+
choices_str += f"{start_chr}. {choice}\n"
|
464 |
+
start_chr = chr(ord(start_chr) + 1)
|
465 |
+
|
466 |
+
question = question + '\n' + choices_str
|
467 |
+
question = question + "Answer with the option's letter from the given choices directly."
|
468 |
+
answer = chr(ord('A') + self.gt[idx]['correct_choice_index'])
|
469 |
+
else:
|
470 |
+
question = question + "\nAnswer the question using a single word or phrase."
|
471 |
+
answer = self.gt[idx]['answer']
|
472 |
+
|
473 |
+
return img, question_id, question, [answer]
|
474 |
+
|
475 |
+
|
476 |
+
class MathVistaEvalDataset(Dataset):
|
477 |
+
def __init__(self, args, task_cfg, gt_path=None):
|
478 |
+
self.args = args
|
479 |
+
self.task_cfg = task_cfg
|
480 |
+
self.dataset = load_dataset("AI4Math/MathVista")['testmini']
|
481 |
+
|
482 |
+
def __len__(self):
|
483 |
+
return len(self.dataset)
|
484 |
+
|
485 |
+
def __getitem__(self, idx):
|
486 |
+
img = self.dataset[idx]['decoded_image']
|
487 |
+
img = load_image(img.convert("RGB"), max_num=6, decoded=True).to(torch.bfloat16)
|
488 |
+
|
489 |
+
question_id = self.dataset[idx]["pid"]
|
490 |
+
question = self.dataset[idx]["question"]
|
491 |
+
question_type = self.dataset[idx]["question_type"] # free_form or multi_choice
|
492 |
+
query = self.dataset[idx]["query"]
|
493 |
+
choices = self.dataset[idx]["choices"]
|
494 |
+
answer = self.dataset[idx]["answer"]
|
495 |
+
|
496 |
+
if question_type == 'multi_choice':
|
497 |
+
start_chr = 'A'
|
498 |
+
choices_str = ''
|
499 |
+
index2ans = {}
|
500 |
+
all_choices = []
|
501 |
+
for choice in choices:
|
502 |
+
all_choices.append(start_chr)
|
503 |
+
index2ans[start_chr] = choice
|
504 |
+
choices_str += f"{start_chr}. {choice}\n"
|
505 |
+
start_chr = chr(ord(start_chr) + 1)
|
506 |
+
|
507 |
+
question = question + '\n' + choices_str
|
508 |
+
question = question + "Answer with the option's letter from the given choices directly."
|
509 |
+
answer = chr(ord('A') + choices.index(answer))
|
510 |
+
else:
|
511 |
+
question = query.replace("Hint: ", "")
|
512 |
+
index2ans = {}
|
513 |
+
all_choices = []
|
514 |
+
|
515 |
+
return img, question_id, question_type, question, answer, str(index2ans), str(all_choices)
|
516 |
+
|
517 |
+
|
518 |
+
def construct_prompt_for_fewshot(sample):
|
519 |
+
config = {
|
520 |
+
"task_instructions": "",
|
521 |
+
"multi_choice_example_format": "{}\n{}Answer with the option's letter from the given choices directly.",
|
522 |
+
"short_ans_example_format": "{}\nAnswer the question using a single word or phrase."
|
523 |
+
}
|
524 |
+
|
525 |
+
question = sample['question'].strip()
|
526 |
+
|
527 |
+
|
528 |
+
options = eval(sample['options'])
|
529 |
+
example = ""
|
530 |
+
if sample['question_type'] == 'multiple-choice':
|
531 |
+
start_chr = 'A'
|
532 |
+
prediction_range = []
|
533 |
+
index2ans = {}
|
534 |
+
for option in options:
|
535 |
+
prediction_range.append(start_chr)
|
536 |
+
example += f"({start_chr}) {option}\n"
|
537 |
+
index2ans[start_chr] = option
|
538 |
+
start_chr = chr(ord(start_chr) + 1)
|
539 |
+
empty_prompt_sample_structure = config['multi_choice_example_format']
|
540 |
+
empty_prompt = empty_prompt_sample_structure.format(question, example)
|
541 |
+
res_dict = {'type': 'multichoice'}
|
542 |
+
res_dict['index2ans'] = index2ans
|
543 |
+
res_dict['correct_choice'] = sample['answer']
|
544 |
+
res_dict['all_choices'] = prediction_range
|
545 |
+
res_dict['empty_prompt'] = empty_prompt
|
546 |
+
if config['task_instructions']:
|
547 |
+
res_dict['final_input_prompt'] = config['task_instructions'].strip() + '\n\n' + empty_prompt
|
548 |
+
else:
|
549 |
+
res_dict['final_input_prompt'] = empty_prompt
|
550 |
+
|
551 |
+
res_dict['gt_content'] = options[ord(sample['answer'].upper()) - ord('A')]
|
552 |
+
else:
|
553 |
+
empty_prompt_sample_structure = config['short_ans_example_format']
|
554 |
+
empty_prompt = empty_prompt_sample_structure.format(question)
|
555 |
+
res_dict = {'type': 'open'}
|
556 |
+
res_dict['empty_prompt'] = empty_prompt
|
557 |
+
if config['task_instructions']:
|
558 |
+
res_dict['final_input_prompt'] = config['task_instructions'].strip() + '\n\n' + empty_prompt
|
559 |
+
else:
|
560 |
+
res_dict['final_input_prompt'] = empty_prompt
|
561 |
+
res_dict['gt_content'] = sample['answer']
|
562 |
+
|
563 |
+
res_dict.update(sample)
|
564 |
+
return res_dict
|
565 |
+
|
566 |
+
|
567 |
+
def process_image_tag(q):
|
568 |
+
q = q.strip()
|
569 |
+
|
570 |
+
# heuristic way of removing <image 1>
|
571 |
+
if q == '<image 1>':
|
572 |
+
q = 'Answer the question in the image.'
|
573 |
+
elif ':<image 1>' in q:
|
574 |
+
q = q.replace(':<image 1>', ' in the image. ')
|
575 |
+
q = q.strip()
|
576 |
+
elif ': <image 1>' in q:
|
577 |
+
q = q.replace(': <image 1>', ' in the image. ')
|
578 |
+
q = q.strip()
|
579 |
+
elif '.<image 1>' in q or '. <image 1>' in q:
|
580 |
+
q_list = q.split('<image 1>')
|
581 |
+
q_list = [part.strip() for part in q_list if part.strip() != '']
|
582 |
+
q = ' '.join(q_list)
|
583 |
+
elif q.startswith('<image 1> '):
|
584 |
+
if q[10].isupper():
|
585 |
+
q = q.replace('<image 1>', '')
|
586 |
+
else:
|
587 |
+
q = q.replace('<image 1>', 'The image')
|
588 |
+
q = q.strip()
|
589 |
+
elif q.startswith('<image 1>'):
|
590 |
+
q = q.replace('<image 1>', '')
|
591 |
+
elif q.endswith('<image 1>?'):
|
592 |
+
q = q.replace('<image 1>', 'the image')
|
593 |
+
elif q.endswith('?<image 1>') or q.endswith('? <image 1>') or q.endswith('\n<image 1>'):
|
594 |
+
q = q.replace('<image 1>', '')
|
595 |
+
q = q.strip()
|
596 |
+
elif ' <image 1> ' in q:
|
597 |
+
q = q.replace('<image 1>', 'the image')
|
598 |
+
elif ' <image 1>' in q:
|
599 |
+
q = q.replace('<image 1>', 'the image')
|
600 |
+
elif '()<image 1>' in q:
|
601 |
+
q = q.replace('()<image 1>', '')
|
602 |
+
elif '(<image 1>)' in q:
|
603 |
+
q = q.replace('(<image 1>)', '')
|
604 |
+
elif '<image 1>.' in q:
|
605 |
+
q = q.replace("<image 1>.", ". ")
|
606 |
+
else:
|
607 |
+
q = q.replace("<image 1>", ". ")
|
608 |
+
q = q.strip()
|
609 |
+
|
610 |
+
# remove <image 2> to <image 8>
|
611 |
+
for i in range(2, 8):
|
612 |
+
q = q.replace(f"<image {i}>", "")
|
613 |
+
|
614 |
+
return q
|
615 |
+
|
616 |
+
|
617 |
+
class MMMUProEvalDataset(Dataset):
|
618 |
+
def __init__(self, args, task_cfg, subset=None):
|
619 |
+
self.args = args
|
620 |
+
self.task_cfg = task_cfg
|
621 |
+
sub_dataset_list = []
|
622 |
+
# load_dataset will throw error if split is 'dev'
|
623 |
+
# 'dev' is part of the 'validation' and we need to manually split them
|
624 |
+
|
625 |
+
MMMU_path = "MMMU/MMMU_Pro"
|
626 |
+
|
627 |
+
_split = "test"
|
628 |
+
|
629 |
+
self.dataset = load_dataset(MMMU_path, "standard", split=_split)
|
630 |
+
if subset:
|
631 |
+
self.dataset = self.dataset[:subset]
|
632 |
+
|
633 |
+
def __len__(self):
|
634 |
+
return len(self.dataset)
|
635 |
+
|
636 |
+
def __getitem__(self, idx):
|
637 |
+
# ===== single-image =====
|
638 |
+
sample = self.dataset[idx]
|
639 |
+
sample = process_single_sample_pro(sample)
|
640 |
+
sample = construct_prompt_pro(sample, self.task_cfg)
|
641 |
+
img = load_image(sample['image'].convert("RGB"), max_num=6, decoded=True).to(torch.bfloat16)
|
642 |
+
|
643 |
+
# img = img.reshape(-1, 3, self.args.img_h, self.args.img_w)
|
644 |
+
|
645 |
+
question_id = sample['id']
|
646 |
+
question = sample['final_input_prompt']
|
647 |
+
answer = sample['answer']
|
648 |
+
|
649 |
+
question = process_image_tag(question)
|
650 |
+
question = self.task_cfg['default_image_token'] + '\n' + question
|
651 |
+
|
652 |
+
if sample['question_type'] == 'multiple-choice':
|
653 |
+
index2ans = sample['index2ans']
|
654 |
+
all_choices = sample['all_choices']
|
655 |
+
else:
|
656 |
+
index2ans = {}
|
657 |
+
all_choices = []
|
658 |
+
|
659 |
+
return img, question_id, sample['subfield'], sample['question_type'], question, answer, str(index2ans), str \
|
660 |
+
(all_choices)
|
661 |
+
|
662 |
+
|
663 |
+
class MMMUEvalDataset(Dataset):
|
664 |
+
def __init__(self, args, task_cfg, subset=None, start_idx=None):
|
665 |
+
self.args = args
|
666 |
+
self.task_cfg = task_cfg
|
667 |
+
sub_dataset_list = []
|
668 |
+
# load_dataset will throw error if split is 'dev'
|
669 |
+
# 'dev' is part of the 'validation' and we need to manually split them
|
670 |
+
|
671 |
+
MMMU_path = "MMMU/MMMU"
|
672 |
+
|
673 |
+
_split = "test" if task_cfg["split"] == "test" else "validation"
|
674 |
+
for subject in CAT_SHORT2LONG.values():
|
675 |
+
sub_dataset = load_dataset(
|
676 |
+
MMMU_path, subject,
|
677 |
+
split=_split,
|
678 |
+
)
|
679 |
+
sub_dataset_list.append(sub_dataset)
|
680 |
+
|
681 |
+
dataset = concatenate_datasets(sub_dataset_list)
|
682 |
+
|
683 |
+
if task_cfg["split"] != "test":
|
684 |
+
dataset = [s for s in dataset if s['id'].startswith(task_cfg["split"])]
|
685 |
+
|
686 |
+
# dataset = [s for s in dataset if s['image_2'] is not None][1:]
|
687 |
+
|
688 |
+
self.dataset = dataset
|
689 |
+
|
690 |
+
if subset:
|
691 |
+
self.dataset = [dataset[i] for i in range(start_idx, min(start_idx + subset, len(dataset)))]
|
692 |
+
print(f"Evaluating a subset of dataset: {len(self.dataset)} from {start_idx} to {start_idx + subset}")
|
693 |
+
|
694 |
+
def __len__(self):
|
695 |
+
return len(self.dataset)
|
696 |
+
|
697 |
+
def __getitem__(self, idx):
|
698 |
+
# ===== single-image =====
|
699 |
+
sample = self.dataset[idx]
|
700 |
+
sample = process_single_sample(sample)
|
701 |
+
sample = construct_prompt(sample, self.task_cfg)
|
702 |
+
|
703 |
+
img = load_image(sample['image'].convert("RGB"), max_num=6, decoded=True).to(torch.bfloat16)
|
704 |
+
|
705 |
+
question_id = sample['id']
|
706 |
+
question = sample['final_input_prompt']
|
707 |
+
answer = sample['answer']
|
708 |
+
|
709 |
+
question = process_image_tag(question)
|
710 |
+
question = self.task_cfg['default_image_token'] + '\n' + question
|
711 |
+
|
712 |
+
|
713 |
+
if sample['question_type'] == 'multiple-choice':
|
714 |
+
index2ans = sample['index2ans']
|
715 |
+
all_choices = sample['all_choices']
|
716 |
+
else:
|
717 |
+
index2ans = {}
|
718 |
+
all_choices = []
|
719 |
+
|
720 |
+
return img, question_id, sample['subfield'], sample['question_type'], question, answer, str(index2ans), str \
|
721 |
+
(all_choices)
|
722 |
+
|
723 |
+
|
724 |
+
|
725 |
+
class VizWizEvalDataset(Dataset):
|
726 |
+
def __init__(self, args, img_dir, question_path, subset=None):
|
727 |
+
self.args = args
|
728 |
+
self.img_dir = img_dir
|
729 |
+
self.questions = json.load(open(question_path, encoding='utf-8'))
|
730 |
+
|
731 |
+
def __len__(self):
|
732 |
+
return len(self.questions)
|
733 |
+
|
734 |
+
def __getitem__(self, idx):
|
735 |
+
img_path = os.path.join(self.img_dir, self.questions[idx]["image"])
|
736 |
+
img = load_image(img_path, max_num=6).to(torch.bfloat16)
|
737 |
+
question = self.questions[idx]["question"]
|
738 |
+
question_id = self.questions[idx]["image"]
|
739 |
+
|
740 |
+
return img, question_id, question
|
741 |
+
|
742 |
+
|
743 |
+
class MMBenchEvalDataset(Dataset):
|
744 |
+
def __init__(self, args, gt_path, subset=None):
|
745 |
+
self.args = args
|
746 |
+
df = pd.read_csv(gt_path, sep='\t')
|
747 |
+
self.dataset = []
|
748 |
+
for i, row in df.iterrows():
|
749 |
+
choices = []
|
750 |
+
for choice in ['A', 'B', 'C', 'D']:
|
751 |
+
if str(row[choice]) != 'nan':
|
752 |
+
choices.append(row[choice])
|
753 |
+
|
754 |
+
this_sample = {
|
755 |
+
'index': row['index'],
|
756 |
+
'question': row['question'],
|
757 |
+
'hint': row['hint'],
|
758 |
+
'category': row['category'],
|
759 |
+
'image': Image.open(BytesIO(base64.b64decode(row['image']))),
|
760 |
+
'choices': choices
|
761 |
+
}
|
762 |
+
|
763 |
+
# Only dev set gives the ground truth answer
|
764 |
+
if 'answer' in row.keys():
|
765 |
+
this_sample['answer'] = row['answer']
|
766 |
+
else:
|
767 |
+
this_sample['answer'] = ''
|
768 |
+
|
769 |
+
self.dataset.append(this_sample)
|
770 |
+
|
771 |
+
def __len__(self):
|
772 |
+
return len(self.dataset)
|
773 |
+
|
774 |
+
def __getitem__(self, idx):
|
775 |
+
img = load_image(self.dataset[idx]["image"].convert("RGB"), max_num=6, decoded=True).to(torch.bfloat16)
|
776 |
+
|
777 |
+
question = self.dataset[idx]["question"]
|
778 |
+
hint = self.dataset[idx]["hint"]
|
779 |
+
question_id = self.dataset[idx]["index"]
|
780 |
+
choices = self.dataset[idx]["choices"]
|
781 |
+
answer = self.dataset[idx]["answer"]
|
782 |
+
|
783 |
+
start_chr = 'A'
|
784 |
+
choices_str = ''
|
785 |
+
index2ans = {}
|
786 |
+
all_choices = []
|
787 |
+
for choice in choices:
|
788 |
+
all_choices.append(start_chr)
|
789 |
+
index2ans[start_chr] = choice
|
790 |
+
choices_str += f"{start_chr}. {choice}\n"
|
791 |
+
start_chr = chr(ord(start_chr) + 1)
|
792 |
+
|
793 |
+
question = question + '\n' + choices_str
|
794 |
+
|
795 |
+
return img, question_id, question, answer, str(index2ans), str(all_choices), self.dataset[idx]["question"]
|
796 |
+
|
797 |
+
|
798 |
+
def get_task_dataloader(task_name, task_cfg, args):
|
799 |
+
if "subset" in task_cfg.keys():
|
800 |
+
subset = task_cfg["subset"]
|
801 |
+
else:
|
802 |
+
subset = None
|
803 |
+
|
804 |
+
if task_name == "coco_caption":
|
805 |
+
dataset = COCOEvalDataset(args, task_cfg["image_dir"], subset)
|
806 |
+
elif task_name == "flickr30k_caption":
|
807 |
+
dataset = Flickr30KEvalDataset(args, task_cfg["image_dir"], subset)
|
808 |
+
elif task_name == "vqav2":
|
809 |
+
dataset = VQAv2EvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset)
|
810 |
+
elif task_name == "textvqa":
|
811 |
+
dataset = TextVQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset)
|
812 |
+
elif task_name == "gqa":
|
813 |
+
dataset = GQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset)
|
814 |
+
elif task_name == "chartqa":
|
815 |
+
dataset = ChartQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset)
|
816 |
+
elif task_name == "okvqa":
|
817 |
+
dataset = OKVQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], task_cfg["question_path"], subset)
|
818 |
+
elif task_name == "vizwiz":
|
819 |
+
dataset = VizWizEvalDataset(args, task_cfg["image_dir"], task_cfg["question_path"], subset)
|
820 |
+
elif task_name == "docvqa":
|
821 |
+
dataset = DocVQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], split='val', subset=subset)
|
822 |
+
elif task_name == "docvqa_test":
|
823 |
+
dataset = DocVQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], split='test', subset=subset)
|
824 |
+
elif task_name == "realworldqa":
|
825 |
+
dataset = RealworldQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset)
|
826 |
+
elif task_name == "mmmu":
|
827 |
+
dataset = MMMUEvalDataset(args, task_cfg, subset=args.subset, start_idx=args.start_idx)
|
828 |
+
elif task_name == "mmmu_pro":
|
829 |
+
dataset = MMMUProEvalDataset(args, task_cfg)
|
830 |
+
elif task_name == "mathvista":
|
831 |
+
dataset = MathVistaEvalDataset(args, task_cfg)
|
832 |
+
elif task_name == "mmbench":
|
833 |
+
dataset = MMBenchEvalDataset(args, task_cfg["gt_path"])
|
834 |
+
elif task_name == 'ocrbench':
|
835 |
+
dataset = OCRBenchEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset)
|
836 |
+
elif task_name == 'ai2diagram':
|
837 |
+
dataset = AI2DiagramEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset)
|
838 |
+
elif task_name == 'ai2diagram_nomask':
|
839 |
+
dataset = AI2DiagramNoMaskEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset)
|
840 |
+
else:
|
841 |
+
raise NotImplementedError(f"Task {task_name} is not supported yet.")
|
842 |
+
|
843 |
+
dataloader = DataLoader(
|
844 |
+
dataset,
|
845 |
+
batch_size=1,
|
846 |
+
shuffle=False,
|
847 |
+
pin_memory=True,
|
848 |
+
)
|
849 |
+
|
850 |
+
return dataloader
|
eval/full_eval.yaml
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
datasets:
|
2 |
+
coco_caption:
|
3 |
+
image_dir: "path/to/image"
|
4 |
+
gt_path: "path/to/ground_truth"
|
5 |
+
prompt: "<|im_start|>system\nFollow the user's instruction and answer questions.<|im_end|><|im_start|>user\n<image>\nGive a brief description of this image in one sentence.<|im_end|><|im_start|>assistant\n"
|
6 |
+
beam_search: True
|
7 |
+
beam_size: 1
|
8 |
+
output_max_len: 30
|
9 |
+
top_k: 3
|
10 |
+
temperature: 1.0
|
11 |
+
|
12 |
+
flickr30k_caption:
|
13 |
+
image_dir: "path/to/image"
|
14 |
+
gt_path: "path/to/ground_truth"
|
15 |
+
prompt: "<|im_start|>system\nFollow the user's instruction and answer questions.<|im_end|><|im_start|>user\n<image>\nGive a brief description of this image in one sentence.<|im_end|><|im_start|>assistant\n"
|
16 |
+
beam_search: True
|
17 |
+
beam_size: 1
|
18 |
+
output_max_len: 30
|
19 |
+
top_k: 3
|
20 |
+
temperature: 1.0
|
21 |
+
|
22 |
+
vqav2:
|
23 |
+
image_dir: "path/to/image"
|
24 |
+
gt_path: "path/to/ground_truth"
|
25 |
+
prompt: "<|im_start|>system\nFollow the user's instruction and answer questions.<|im_end|><|im_start|>user\n<image>\n{}\nAnswer the question using a single word or phrase.<|im_end|><|im_start|>assistant\n"
|
26 |
+
beam_search: True
|
27 |
+
beam_size: 1
|
28 |
+
top_k: 1
|
29 |
+
top_p: 0.0
|
30 |
+
output_max_len: 8
|
31 |
+
temperature: 1.0
|
32 |
+
|
33 |
+
mmmu:
|
34 |
+
split: "validation"
|
35 |
+
beam_search: True
|
36 |
+
beam_size: 1
|
37 |
+
top_k: 1
|
38 |
+
top_p: 0.0
|
39 |
+
output_max_len: 1024
|
40 |
+
temperature: 1.0
|
41 |
+
apply_lemmatizer: False
|
42 |
+
task_instructions: ""
|
43 |
+
multi_choice_example_format: "{}\n{}\nAnswer with the option's letter from the given choices directly."
|
44 |
+
short_ans_example_format: "{}\nAnswer the question using a single word or phrase."
|
45 |
+
use_chat_format: True
|
46 |
+
conv_format: "yi_nous_sft"
|
47 |
+
default_image_token: "<image>"
|
48 |
+
prompt_offset: 4
|
49 |
+
answer_dict: "path/to/answer_dict_val.json"
|
50 |
+
|
51 |
+
textvqa:
|
52 |
+
split: "val"
|
53 |
+
image_dir: "path/to/image"
|
54 |
+
gt_path: "path/to/ground_truth"
|
55 |
+
prompt: "<|im_start|>system\nFollow the user's instruction and answer questions.<|im_end|><|im_start|>user\n<image>\n{}\nAnswer the question using a single word, phrase, or number.<|im_end|><|im_start|>assistant\n"
|
56 |
+
beam_search: True
|
57 |
+
beam_size: 1
|
58 |
+
top_k: 1
|
59 |
+
top_p: 0.0
|
60 |
+
output_max_len: 10
|
61 |
+
temperature: 1.0
|
62 |
+
|
63 |
+
mathvista:
|
64 |
+
split: "testmini"
|
65 |
+
prompt: "<|im_start|>system\nYou are math expert. Use your math knowledge to calculate the answer.<|im_end|><|im_start|>user\n<image>\n{}\nAnswer the question using a single word, phrase, or number.<|im_end|><|im_start|>assistant\n"
|
66 |
+
beam_search: True
|
67 |
+
beam_size: 1
|
68 |
+
top_k: 1
|
69 |
+
top_p: 0.0
|
70 |
+
output_max_len: 1024
|
71 |
+
temperature: 1.0
|
72 |
+
|
73 |
+
mmbench:
|
74 |
+
split: "dev"
|
75 |
+
gt_path: "path/to/ground_truth"
|
76 |
+
prompt: "<|im_start|>system\nFollow the user's instruction and answer questions.<|im_end|><|im_start|>user\n<image>\n{}Answer with the option's letter from the given choices directly.<|im_end|><|im_start|>assistant\n"
|
77 |
+
beam_search: True
|
78 |
+
beam_size: 1
|
79 |
+
top_k: 1
|
80 |
+
top_p: 0.0
|
81 |
+
output_max_len: 10
|
82 |
+
temperature: 1.0
|
83 |
+
submission: False
|
84 |
+
|
85 |
+
chartqa:
|
86 |
+
split: "test"
|
87 |
+
image_dir: "path/to/image"
|
88 |
+
gt_path: "path/to/ground_truth"
|
89 |
+
prompt: "<|im_start|>system\nFollow the user's instruction and answer questions.<|im_end|><|im_start|>user\n<image>\n{}<|im_end|><|im_start|>assistant\n"
|
90 |
+
|
91 |
+
beam_search: True
|
92 |
+
beam_size: 1
|
93 |
+
top_k: 1
|
94 |
+
top_p: 0.0
|
95 |
+
output_max_len: 20
|
96 |
+
temperature: 1.0
|
97 |
+
|
98 |
+
docvqa:
|
99 |
+
split: "val"
|
100 |
+
image_dir: "path/to/image"
|
101 |
+
gt_path: "path/to/ground_truth"
|
102 |
+
prompt: "<|im_start|>system\nFollow the user's instruction and answer questions.<|im_end|><|im_start|>user\n<image>\n{}<|im_end|><|im_start|>assistant\n"
|
103 |
+
beam_search: True
|
104 |
+
beam_size: 1
|
105 |
+
top_k: 1
|
106 |
+
top_p: 0.0
|
107 |
+
output_max_len: 20
|
108 |
+
temperature: 1.0
|
109 |
+
|
110 |
+
realworldqa:
|
111 |
+
split: "test"
|
112 |
+
image_dir: "path/to/image"
|
113 |
+
gt_path: "path/to/ground_truth"
|
114 |
+
prompt: "<|im_start|>system\nFollow the user's instruction and answer questions.<|im_end|><|im_start|>user\n<image>\n{}<|im_end|><|im_start|>assistant\n"
|
115 |
+
beam_search: True
|
116 |
+
beam_size: 1
|
117 |
+
top_k: 1
|
118 |
+
top_p: 0.0
|
119 |
+
output_max_len: 20
|
120 |
+
temperature: 1.0
|
121 |
+
submission: False
|
122 |
+
|
123 |
+
ocrbench:
|
124 |
+
split: "test"
|
125 |
+
image_dir: "path/to/image"
|
126 |
+
gt_path: "path/to/ground_truth"
|
127 |
+
prompt: "<|im_start|>system\nFollow the user's instruction and answer questions.<|im_end|><|im_start|>user\n<image>\n{}<|im_end|><|im_start|>assistant\n"
|
128 |
+
beam_search: True
|
129 |
+
beam_size: 1
|
130 |
+
top_k: 1
|
131 |
+
top_p: 0.0
|
132 |
+
output_max_len: 70
|
133 |
+
temperature: 1.0
|
134 |
+
submission: False
|
135 |
+
|
136 |
+
ai2diagram:
|
137 |
+
split: "test"
|
138 |
+
image_dir: "path/to/image"
|
139 |
+
gt_path: "path/to/ground_truth"
|
140 |
+
prompt: "<|im_start|>system\nFollow the user's instruction and answer questions.<|im_end|><|im_start|>user\n<image>\n{}\nAnswer the question using a single word, phrase, or number.<|im_end|><|im_start|>assistant\n"
|
141 |
+
beam_search: True
|
142 |
+
beam_size: 1
|
143 |
+
top_k: 1
|
144 |
+
top_p: 0.0
|
145 |
+
output_max_len: 20
|
146 |
+
temperature: 1.0
|
147 |
+
|
148 |
+
ai2diagram_nomask:
|
149 |
+
split: "test"
|
150 |
+
image_dir: "path/to/image"
|
151 |
+
gt_path: "path/to/ground_truth"
|
152 |
+
prompt: "<|im_start|>system\nFollow the user's instruction and answer questions.<|im_end|><|im_start|>user\n<image>\n{}\nAnswer the question using a single word, phrase, or number.<|im_end|><|im_start|>assistant\n"
|
153 |
+
beam_search: True
|
154 |
+
beam_size: 1
|
155 |
+
top_k: 1
|
156 |
+
top_p: 0.0
|
157 |
+
output_max_len: 20
|
158 |
+
temperature: 1.0
|
159 |
+
|
160 |
+
mmmu_pro:
|
161 |
+
split: "validation"
|
162 |
+
beam_search: True
|
163 |
+
beam_size: 1
|
164 |
+
top_k: 1
|
165 |
+
top_p: 0.0
|
166 |
+
output_max_len: 10
|
167 |
+
temperature: 1.0
|
168 |
+
apply_lemmatizer: False
|
169 |
+
task_instructions: ""
|
170 |
+
multi_choice_example_format: "{}\n{}\nAnswer with the option's letter from the given choices directly."
|
171 |
+
short_ans_example_format: "{}\nAnswer the question using a single word or phrase."
|
172 |
+
use_chat_format: True
|
173 |
+
conv_format: "yi_nous_sft"
|
174 |
+
default_image_token: "<image>"
|
175 |
+
prompt_offset: 4
|
176 |
+
answer_dict: "path/to/answer_dict.json"
|
177 |
+
|
178 |
+
docvqa_test:
|
179 |
+
split: "test"
|
180 |
+
image_dir: "path/to/image"
|
181 |
+
gt_path: "path/to/ground_truth"
|
182 |
+
prompt: "<|im_start|>system\nFollow the user's instruction and answer questions.<|im_end|>\n<|im_start|>user\n<image>\n{}\nAnswer this question using the text in the image directly.<|im_end|>\n<|im_start|>assistant\n"
|
183 |
+
beam_search: True
|
184 |
+
beam_size: 1
|
185 |
+
top_k: 1
|
186 |
+
top_p: 0.0
|
187 |
+
output_max_len: 20
|
188 |
+
temperature: 1.0
|
eval/mmmu_utils.py
ADDED
@@ -0,0 +1,663 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/MMMU-Benchmark/MMMU/blob/main/mmmu/utils/data_utils.py
|
2 |
+
|
3 |
+
"""Utils for data load, save, and process (e.g., prompt construction)"""
|
4 |
+
|
5 |
+
import os
|
6 |
+
import json
|
7 |
+
import yaml
|
8 |
+
import re
|
9 |
+
|
10 |
+
DOMAIN_CAT2SUB_CAT = {
|
11 |
+
'Art and Design': ['Art', 'Art_Theory', 'Design', 'Music'],
|
12 |
+
'Business': ['Accounting', 'Economics', 'Finance', 'Manage', 'Marketing'],
|
13 |
+
'Science': ['Biology', 'Chemistry', 'Geography', 'Math', 'Physics', ],
|
14 |
+
'Health and Medicine': ['Basic_Medical_Science', 'Clinical_Medicine', 'Diagnostics_and_Laboratory_Medicine',
|
15 |
+
'Pharmacy', 'Public_Health'],
|
16 |
+
'Humanities and Social Science': ['History', 'Literature', 'Sociology', 'Psychology'],
|
17 |
+
'Tech and Engineering': ['Agriculture', 'Architecture_and_Engineering', 'Computer_Science', 'Electronics',
|
18 |
+
'Energy_and_Power', 'Materials', 'Mechanical_Engineering'],
|
19 |
+
}
|
20 |
+
|
21 |
+
CAT_SHORT2LONG = {
|
22 |
+
'acc': 'Accounting',
|
23 |
+
'agri': 'Agriculture',
|
24 |
+
'arch': 'Architecture_and_Engineering',
|
25 |
+
'art': 'Art',
|
26 |
+
'art_theory': 'Art_Theory',
|
27 |
+
'bas_med': 'Basic_Medical_Science',
|
28 |
+
'bio': 'Biology',
|
29 |
+
'chem': 'Chemistry',
|
30 |
+
'cli_med': 'Clinical_Medicine',
|
31 |
+
'cs': 'Computer_Science',
|
32 |
+
'design': 'Design',
|
33 |
+
'diag_med': 'Diagnostics_and_Laboratory_Medicine',
|
34 |
+
'econ': 'Economics',
|
35 |
+
'elec': 'Electronics',
|
36 |
+
'ep': 'Energy_and_Power',
|
37 |
+
'fin': 'Finance',
|
38 |
+
'geo': 'Geography',
|
39 |
+
'his': 'History',
|
40 |
+
'liter': 'Literature',
|
41 |
+
'manage': 'Manage',
|
42 |
+
'mark': 'Marketing',
|
43 |
+
'mate': 'Materials',
|
44 |
+
'math': 'Math',
|
45 |
+
'mech': 'Mechanical_Engineering',
|
46 |
+
'music': 'Music',
|
47 |
+
'phar': 'Pharmacy',
|
48 |
+
'phys': 'Physics',
|
49 |
+
'psy': 'Psychology',
|
50 |
+
'pub_health': 'Public_Health',
|
51 |
+
'socio': 'Sociology'
|
52 |
+
}
|
53 |
+
|
54 |
+
|
55 |
+
# DATA SAVING
|
56 |
+
def save_json(filename, ds):
|
57 |
+
with open(filename, 'w') as f:
|
58 |
+
json.dump(ds, f, indent=4)
|
59 |
+
|
60 |
+
|
61 |
+
def get_multi_choice_info(options):
|
62 |
+
"""
|
63 |
+
Given the list of options for multiple choice question
|
64 |
+
Return the index2ans and all_choices
|
65 |
+
"""
|
66 |
+
|
67 |
+
start_chr = 'A'
|
68 |
+
all_choices = []
|
69 |
+
index2ans = {}
|
70 |
+
for i, option in enumerate(options):
|
71 |
+
index2ans[chr(ord(start_chr) + i)] = option
|
72 |
+
all_choices.append(chr(ord(start_chr) + i))
|
73 |
+
|
74 |
+
return index2ans, all_choices
|
75 |
+
|
76 |
+
|
77 |
+
def load_yaml(file_path):
|
78 |
+
with open(file_path, 'r') as stream:
|
79 |
+
try:
|
80 |
+
yaml_dict = yaml.safe_load(stream)
|
81 |
+
except yaml.YAMLError as exc:
|
82 |
+
print(exc)
|
83 |
+
|
84 |
+
return yaml_dict
|
85 |
+
|
86 |
+
|
87 |
+
def parse_img_path(text):
|
88 |
+
matches = re.findall("<img='(.*?)'>", text)
|
89 |
+
return matches
|
90 |
+
|
91 |
+
|
92 |
+
def process_single_sample(data):
|
93 |
+
question = data['question']
|
94 |
+
o_imgs_paths = []
|
95 |
+
for option in data['options']:
|
96 |
+
current_o_imgs_paths = parse_img_path(option)
|
97 |
+
for img_path in current_o_imgs_paths:
|
98 |
+
o_imgs_paths.append(img_path)
|
99 |
+
|
100 |
+
if len(o_imgs_paths) > 1: # multiple images in options, used for random selection
|
101 |
+
return {'id': data['id'], 'question': question, 'options': data['options'], 'answer': data['answer'],
|
102 |
+
'image': None, 'question_type': data['question_type'], 'subfield': data['subfield']}
|
103 |
+
else:
|
104 |
+
return {'id': data['id'], 'question': question, 'options': data['options'], 'answer': data['answer'],
|
105 |
+
'image': data['image_1'], 'question_type': data['question_type'], 'subfield': data['subfield']}
|
106 |
+
|
107 |
+
|
108 |
+
def process_single_sample_pro(data):
|
109 |
+
question = data['question']
|
110 |
+
o_imgs_paths = []
|
111 |
+
for option in data['options']:
|
112 |
+
current_o_imgs_paths = parse_img_path(option)
|
113 |
+
for img_path in current_o_imgs_paths:
|
114 |
+
o_imgs_paths.append(img_path)
|
115 |
+
|
116 |
+
if len(o_imgs_paths) > 1: # multiple images in options, used for random selection
|
117 |
+
return {'id': data['id'], 'question': question, 'options': data['options'], 'answer': data['answer'],
|
118 |
+
'image': None, 'question_type': 'multiple-choice', 'subfield': data['subject']}
|
119 |
+
else:
|
120 |
+
return {'id': data['id'], 'question': question, 'options': data['options'], 'answer': data['answer'],
|
121 |
+
'image': data['image_1'], 'question_type': 'multiple-choice', 'subfield': data['subject']}
|
122 |
+
|
123 |
+
|
124 |
+
# DATA SAVING
|
125 |
+
def save_json(filename, ds):
|
126 |
+
with open(filename, 'w') as f:
|
127 |
+
json.dump(ds, f, indent=4)
|
128 |
+
|
129 |
+
|
130 |
+
def save_jsonl(filename, data):
|
131 |
+
"""
|
132 |
+
Save a dictionary of data to a JSON Lines file with the filename as key and caption as value.
|
133 |
+
|
134 |
+
Args:
|
135 |
+
filename (str): The path to the file where the data should be saved.
|
136 |
+
data (dict): The dictionary containing the data to save where key is the image path and value is the caption.
|
137 |
+
"""
|
138 |
+
with open(filename, 'w', encoding='utf-8') as f:
|
139 |
+
for img_path, caption in data.items():
|
140 |
+
# Extract the base filename without the extension
|
141 |
+
base_filename = os.path.basename(img_path)
|
142 |
+
# Create a JSON object with the filename as the key and caption as the value
|
143 |
+
json_record = json.dumps({base_filename: caption}, ensure_ascii=False)
|
144 |
+
# Write the JSON object to the file, one per line
|
145 |
+
f.write(json_record + '\n')
|
146 |
+
|
147 |
+
|
148 |
+
def save_args(args, path_dir):
|
149 |
+
argsDict = args.__dict__
|
150 |
+
with open(path_dir + 'setting.txt', 'w') as f:
|
151 |
+
f.writelines('------------------ start ------------------' + '\n')
|
152 |
+
for eachArg, value in argsDict.items():
|
153 |
+
f.writelines(eachArg + ' : ' + str(value) + '\n')
|
154 |
+
f.writelines('------------------- end -------------------')
|
155 |
+
|
156 |
+
|
157 |
+
# DATA PROCESSING
|
158 |
+
def construct_prompt(sample, config):
|
159 |
+
question = sample['question'].strip()
|
160 |
+
|
161 |
+
# for i in range(8):
|
162 |
+
# question = question.replace(f" <image {i}> ", " ")
|
163 |
+
# question = question.replace(f" <image {i}>", " ")
|
164 |
+
# question = question.replace(f"<image {i}> ", " ")
|
165 |
+
# question = question.replace(f"<image {i}>", " ")
|
166 |
+
# question = question.strip()
|
167 |
+
|
168 |
+
options = eval(sample['options'])
|
169 |
+
example = ""
|
170 |
+
if sample['question_type'] == 'multiple-choice':
|
171 |
+
start_chr = 'A'
|
172 |
+
prediction_range = []
|
173 |
+
index2ans = {}
|
174 |
+
for option in options:
|
175 |
+
prediction_range.append(start_chr)
|
176 |
+
example += f"({start_chr}) {option}\n"
|
177 |
+
# example += f"{start_chr}. {option}\n"
|
178 |
+
index2ans[start_chr] = option
|
179 |
+
start_chr = chr(ord(start_chr) + 1)
|
180 |
+
# example = example.rstrip()
|
181 |
+
empty_prompt_sample_structure = config['multi_choice_example_format']
|
182 |
+
empty_prompt = empty_prompt_sample_structure.format(question, example)
|
183 |
+
res_dict = {'type': 'multichoice'}
|
184 |
+
res_dict['index2ans'] = index2ans
|
185 |
+
res_dict['correct_choice'] = sample['answer']
|
186 |
+
res_dict['all_choices'] = prediction_range
|
187 |
+
res_dict['empty_prompt'] = empty_prompt
|
188 |
+
if config['task_instructions']:
|
189 |
+
res_dict['final_input_prompt'] = config['task_instructions'].strip() + '\n\n' + empty_prompt
|
190 |
+
else:
|
191 |
+
res_dict['final_input_prompt'] = empty_prompt
|
192 |
+
|
193 |
+
res_dict['gt_content'] = options[ord(sample['answer'].upper()) - ord('A')]
|
194 |
+
else:
|
195 |
+
empty_prompt_sample_structure = config['short_ans_example_format']
|
196 |
+
empty_prompt = empty_prompt_sample_structure.format(question)
|
197 |
+
res_dict = {'type': 'open'}
|
198 |
+
res_dict['empty_prompt'] = empty_prompt
|
199 |
+
if config['task_instructions']:
|
200 |
+
res_dict['final_input_prompt'] = config['task_instructions'].strip() + '\n\n' + empty_prompt
|
201 |
+
else:
|
202 |
+
res_dict['final_input_prompt'] = empty_prompt
|
203 |
+
res_dict['gt_content'] = sample['answer']
|
204 |
+
|
205 |
+
res_dict.update(sample)
|
206 |
+
return res_dict
|
207 |
+
|
208 |
+
|
209 |
+
def construct_prompt_pro(sample, config):
|
210 |
+
question = sample['question'].strip()
|
211 |
+
|
212 |
+
# for i in range(8):
|
213 |
+
# question = question.replace(f" <image {i}> ", " ")
|
214 |
+
# question = question.replace(f" <image {i}>", " ")
|
215 |
+
# question = question.replace(f"<image {i}> ", " ")
|
216 |
+
# question = question.replace(f"<image {i}>", " ")
|
217 |
+
# question = question.strip()
|
218 |
+
|
219 |
+
options = eval(sample['options'])
|
220 |
+
|
221 |
+
if len(options) == 1:
|
222 |
+
print("This is wrongly formated. We correct to options[0].")
|
223 |
+
options = options[0]
|
224 |
+
|
225 |
+
example = ""
|
226 |
+
if sample['question_type'] == 'multiple-choice':
|
227 |
+
start_chr = 'A'
|
228 |
+
prediction_range = []
|
229 |
+
index2ans = {}
|
230 |
+
for option in options:
|
231 |
+
prediction_range.append(start_chr)
|
232 |
+
example += f"({start_chr}) {option}\n"
|
233 |
+
# example += f"{start_chr}. {option}\n"
|
234 |
+
index2ans[start_chr] = option
|
235 |
+
start_chr = chr(ord(start_chr) + 1)
|
236 |
+
# example = example.rstrip()
|
237 |
+
empty_prompt_sample_structure = config['multi_choice_example_format']
|
238 |
+
empty_prompt = empty_prompt_sample_structure.format(question, example)
|
239 |
+
res_dict = {'type': 'multichoice'}
|
240 |
+
res_dict['index2ans'] = index2ans
|
241 |
+
res_dict['correct_choice'] = sample['answer']
|
242 |
+
res_dict['all_choices'] = prediction_range
|
243 |
+
res_dict['empty_prompt'] = empty_prompt
|
244 |
+
if config['task_instructions']:
|
245 |
+
res_dict['final_input_prompt'] = config['task_instructions'].strip() + '\n\n' + empty_prompt
|
246 |
+
else:
|
247 |
+
res_dict['final_input_prompt'] = empty_prompt
|
248 |
+
|
249 |
+
res_dict['gt_content'] = options[ord(sample['answer'].upper()) - ord('A')]
|
250 |
+
else:
|
251 |
+
empty_prompt_sample_structure = config['short_ans_example_format']
|
252 |
+
empty_prompt = empty_prompt_sample_structure.format(question)
|
253 |
+
res_dict = {'type': 'open'}
|
254 |
+
res_dict['empty_prompt'] = empty_prompt
|
255 |
+
if config['task_instructions']:
|
256 |
+
res_dict['final_input_prompt'] = config['task_instructions'].strip() + '\n\n' + empty_prompt
|
257 |
+
else:
|
258 |
+
res_dict['final_input_prompt'] = empty_prompt
|
259 |
+
res_dict['gt_content'] = sample['answer']
|
260 |
+
|
261 |
+
res_dict.update(sample)
|
262 |
+
return res_dict
|
263 |
+
|
264 |
+
"""Response Parsing and Evaluation for various models"""
|
265 |
+
from typing import Dict
|
266 |
+
|
267 |
+
import re
|
268 |
+
import random
|
269 |
+
|
270 |
+
import numpy as np
|
271 |
+
|
272 |
+
|
273 |
+
# ----------- Process Multi-choice -------------
|
274 |
+
def parse_multi_choice_response(response, all_choices, index2ans):
|
275 |
+
"""
|
276 |
+
Parse the prediction from the generated response.
|
277 |
+
Return the predicted index e.g., A, B, C, D.
|
278 |
+
"""
|
279 |
+
for char in [',', '.', '!', '?', ';', ':', "'"]:
|
280 |
+
response = response.strip(char)
|
281 |
+
response = " " + response + " " # add space to avoid partial match
|
282 |
+
|
283 |
+
index_ans = True
|
284 |
+
ans_with_brack = False
|
285 |
+
candidates = []
|
286 |
+
for choice in all_choices: # e.g., (A) (B) (C) (D) A) B) C) D)
|
287 |
+
if f'({choice})' in response or f'{choice})' in response:
|
288 |
+
candidates.append(choice)
|
289 |
+
ans_with_brack = True
|
290 |
+
|
291 |
+
if len(candidates) == 0:
|
292 |
+
for choice in all_choices: # e.g., A B C D
|
293 |
+
if f' {choice} ' in response:
|
294 |
+
candidates.append(choice)
|
295 |
+
|
296 |
+
# if all above doesn't get candidates, check if the content is larger than 5 tokens and try to parse the example
|
297 |
+
if len(candidates) == 0 and len(response.split()) > 5:
|
298 |
+
for index, ans in index2ans.items():
|
299 |
+
if ans.lower() in response.lower():
|
300 |
+
candidates.append(index)
|
301 |
+
index_ans = False # it's content ans.
|
302 |
+
|
303 |
+
if len(candidates) == 0: # still not get answer, randomly choose one.
|
304 |
+
pred_index = all_choices[0]
|
305 |
+
elif len(candidates) > 1:
|
306 |
+
start_indexes = []
|
307 |
+
if index_ans:
|
308 |
+
if ans_with_brack:
|
309 |
+
for can in candidates:
|
310 |
+
index = response.rfind(f'({can})')
|
311 |
+
start_indexes.append(index) # -1 will be ignored anyway
|
312 |
+
# start_indexes = [generated_response.index(f'({can})') for can in candidates]
|
313 |
+
else:
|
314 |
+
for can in candidates:
|
315 |
+
index = response.rfind(f" {can} ")
|
316 |
+
start_indexes.append(index)
|
317 |
+
else:
|
318 |
+
for can in candidates:
|
319 |
+
index = response.lower().rfind(index2ans[can].lower())
|
320 |
+
start_indexes.append(index)
|
321 |
+
# get the last one
|
322 |
+
pred_index = candidates[np.argmax(start_indexes)]
|
323 |
+
else: # if only one candidate, use it.
|
324 |
+
pred_index = candidates[0]
|
325 |
+
|
326 |
+
return pred_index
|
327 |
+
|
328 |
+
|
329 |
+
# ----------- Process Open -------------
|
330 |
+
def check_is_number(string):
|
331 |
+
"""
|
332 |
+
Check if the given string a number.
|
333 |
+
"""
|
334 |
+
try:
|
335 |
+
float(string.replace(',', ''))
|
336 |
+
return True
|
337 |
+
except ValueError:
|
338 |
+
# check if there's comma inside
|
339 |
+
return False
|
340 |
+
|
341 |
+
|
342 |
+
def normalize_str(string):
|
343 |
+
"""
|
344 |
+
Normalize the str to lower case and make them float numbers if possible.
|
345 |
+
"""
|
346 |
+
# check if characters in the string
|
347 |
+
|
348 |
+
# if number, numerize it.
|
349 |
+
string = string.strip()
|
350 |
+
|
351 |
+
is_number = check_is_number(string)
|
352 |
+
|
353 |
+
if is_number:
|
354 |
+
string = string.replace(',', '')
|
355 |
+
string = float(string)
|
356 |
+
# leave 2 decimal
|
357 |
+
string = round(string, 2)
|
358 |
+
return [string]
|
359 |
+
else: # it's likely to be a string
|
360 |
+
# lower it
|
361 |
+
string = string.lower()
|
362 |
+
if len(string) == 1:
|
363 |
+
return [" " + string, string + " "] # avoid trivial matches
|
364 |
+
return [string]
|
365 |
+
|
366 |
+
|
367 |
+
def extract_numbers(string):
|
368 |
+
"""
|
369 |
+
Exact all forms of numbers from a string with regex.
|
370 |
+
"""
|
371 |
+
# Pattern for numbers with commas
|
372 |
+
pattern_commas = r'-?\b\d{1,3}(?:,\d{3})+\b'
|
373 |
+
# Pattern for scientific notation
|
374 |
+
pattern_scientific = r'-?\d+(?:\.\d+)?[eE][+-]?\d+'
|
375 |
+
# Pattern for simple numbers without commas
|
376 |
+
pattern_simple = r'-?(?:\d+\.\d+|\.\d+|\d+\b)(?![eE][+-]?\d+)(?![,\d])'
|
377 |
+
|
378 |
+
# Extract numbers with commas
|
379 |
+
numbers_with_commas = re.findall(pattern_commas, string)
|
380 |
+
# Extract numbers in scientific notation
|
381 |
+
numbers_scientific = re.findall(pattern_scientific, string)
|
382 |
+
# Extract simple numbers without commas
|
383 |
+
numbers_simple = re.findall(pattern_simple, string)
|
384 |
+
|
385 |
+
# Combine all extracted numbers
|
386 |
+
all_numbers = numbers_with_commas + numbers_scientific + numbers_simple
|
387 |
+
return all_numbers
|
388 |
+
|
389 |
+
|
390 |
+
def parse_open_response(response):
|
391 |
+
"""
|
392 |
+
Parse the prediction from the generated response.
|
393 |
+
Return a list of predicted strings or numbers.
|
394 |
+
"""
|
395 |
+
|
396 |
+
# content = content.strip("\n").strip(".").strip(" ")
|
397 |
+
def get_key_subresponses(response):
|
398 |
+
key_responses = []
|
399 |
+
response = response.strip().strip(".").lower()
|
400 |
+
sub_responses = re.split(r'\.\s(?=[A-Z])|\n', response)
|
401 |
+
indicators_of_keys = ['could be ', 'so ', 'is ',
|
402 |
+
'thus ', 'therefore ', 'final ', 'answer ', 'result ']
|
403 |
+
key_responses = []
|
404 |
+
for index, resp in enumerate(sub_responses):
|
405 |
+
# if last one, accept it's an equation (the entire response can be just one sentence with equation)
|
406 |
+
if index == len(sub_responses) - 1:
|
407 |
+
indicators_of_keys.extend(['='])
|
408 |
+
shortest_key_response = None # the shortest response that may contain the answer (tail part of the response)
|
409 |
+
for indicator in indicators_of_keys:
|
410 |
+
if indicator in resp:
|
411 |
+
if not shortest_key_response:
|
412 |
+
shortest_key_response = resp.split(indicator)[-1].strip()
|
413 |
+
else:
|
414 |
+
if len(resp.split(indicator)[-1].strip()) < len(shortest_key_response):
|
415 |
+
shortest_key_response = resp.split(indicator)[-1].strip()
|
416 |
+
# key_responses.append(resp.split(indicator)[1].strip())
|
417 |
+
|
418 |
+
if shortest_key_response:
|
419 |
+
# and it's not trivial
|
420 |
+
if shortest_key_response.strip() not in [":", ",", ".", "!", "?", ";", ":", "'"]:
|
421 |
+
key_responses.append(shortest_key_response)
|
422 |
+
if len(key_responses) == 0: # did not found any
|
423 |
+
return [response]
|
424 |
+
return key_responses
|
425 |
+
|
426 |
+
# pdb.set_trace()
|
427 |
+
key_responses = get_key_subresponses(response)
|
428 |
+
|
429 |
+
pred_list = key_responses.copy() # keep the original string response
|
430 |
+
for resp in key_responses:
|
431 |
+
pred_list.extend(extract_numbers(resp))
|
432 |
+
|
433 |
+
tmp_pred_list = []
|
434 |
+
for i in range(len(pred_list)):
|
435 |
+
tmp_pred_list.extend(normalize_str(pred_list[i]))
|
436 |
+
pred_list = tmp_pred_list
|
437 |
+
|
438 |
+
# remove duplicates
|
439 |
+
pred_list = list(set(pred_list))
|
440 |
+
|
441 |
+
return pred_list
|
442 |
+
|
443 |
+
|
444 |
+
# ----------- Evaluation -------------
|
445 |
+
|
446 |
+
def eval_multi_choice(gold_i, pred_i):
|
447 |
+
"""
|
448 |
+
Evaluate a multiple choice instance.
|
449 |
+
"""
|
450 |
+
correct = False
|
451 |
+
# only they are exactly the same, we consider it as correct
|
452 |
+
if isinstance(gold_i, list):
|
453 |
+
for answer in gold_i:
|
454 |
+
if answer == pred_i:
|
455 |
+
correct = True
|
456 |
+
break
|
457 |
+
else: # gold_i is a string
|
458 |
+
if gold_i == pred_i:
|
459 |
+
correct = True
|
460 |
+
return correct
|
461 |
+
|
462 |
+
|
463 |
+
def eval_open(gold_i, pred_i):
|
464 |
+
"""
|
465 |
+
Evaluate an open question instance
|
466 |
+
"""
|
467 |
+
correct = False
|
468 |
+
if isinstance(gold_i, list):
|
469 |
+
# use float to avoid trivial matches
|
470 |
+
norm_answers = []
|
471 |
+
for answer in gold_i:
|
472 |
+
norm_answers.extend(normalize_str(answer))
|
473 |
+
else:
|
474 |
+
norm_answers = normalize_str(gold_i)
|
475 |
+
for pred in pred_i: # pred is already normalized in parse response phase
|
476 |
+
if isinstance(pred, str): # if it's a string, then find if ans in the pred_i
|
477 |
+
for norm_ans in norm_answers:
|
478 |
+
# only see if the string answer in the string pred
|
479 |
+
if isinstance(norm_ans, str) and norm_ans in pred:
|
480 |
+
if not correct:
|
481 |
+
correct = True
|
482 |
+
break
|
483 |
+
else: # it's a float number
|
484 |
+
if pred in norm_answers:
|
485 |
+
if not correct:
|
486 |
+
correct = True
|
487 |
+
break
|
488 |
+
return correct
|
489 |
+
|
490 |
+
|
491 |
+
# ----------- Batch Evaluation -------------
|
492 |
+
def evaluate(samples):
|
493 |
+
"""
|
494 |
+
Batch evaluation for multiple choice and open questions.
|
495 |
+
"""
|
496 |
+
pred_correct = 0
|
497 |
+
judge_dict = dict()
|
498 |
+
for sample in samples:
|
499 |
+
gold_i = sample['answer']
|
500 |
+
pred_i = sample['parsed_pred']
|
501 |
+
if sample['question_type'] == 'multiple-choice':
|
502 |
+
correct = eval_multi_choice(gold_i, pred_i)
|
503 |
+
else: # open question
|
504 |
+
correct = eval_open(gold_i, pred_i)
|
505 |
+
|
506 |
+
if correct:
|
507 |
+
judge_dict[sample['id']] = 'Correct'
|
508 |
+
pred_correct += 1
|
509 |
+
else:
|
510 |
+
judge_dict[sample['id']] = 'Wrong'
|
511 |
+
|
512 |
+
if len(samples) == 0:
|
513 |
+
return {'acc': 0}
|
514 |
+
return judge_dict, {'acc': pred_correct / len(samples)}
|
515 |
+
|
516 |
+
|
517 |
+
# ----------- Calculate Accuracy -------------
|
518 |
+
def calculate_ins_level_acc(results: Dict):
|
519 |
+
"""Calculate the instruction level accuracy for given Subject results"""
|
520 |
+
acc = 0
|
521 |
+
ins_num = 0
|
522 |
+
for cat_results in results.values():
|
523 |
+
acc += cat_results['acc'] * cat_results['num_example']
|
524 |
+
ins_num += cat_results['num_example']
|
525 |
+
if ins_num == 0:
|
526 |
+
return 0
|
527 |
+
return acc / ins_num
|
528 |
+
|
529 |
+
|
530 |
+
def mmmu_main_eval(output_dict, task_cfg):
|
531 |
+
answer_dict = json.load(open(task_cfg["answer_dict"]))
|
532 |
+
|
533 |
+
# group by category
|
534 |
+
output_dict_w_cat = {}
|
535 |
+
for data_id, parsed_pred in output_dict.items():
|
536 |
+
category = "_".join(data_id.split("_")[1:-1])
|
537 |
+
if category not in output_dict_w_cat:
|
538 |
+
output_dict_w_cat.update({category: {}})
|
539 |
+
output_dict_w_cat[category].update({data_id: parsed_pred})
|
540 |
+
|
541 |
+
# group by category
|
542 |
+
answer_dict_w_cat = {}
|
543 |
+
for data_id, parsed_pred in answer_dict.items():
|
544 |
+
category = "_".join(data_id.split("_")[1:-1])
|
545 |
+
if category not in answer_dict_w_cat:
|
546 |
+
answer_dict_w_cat.update({category: {}})
|
547 |
+
answer_dict_w_cat[category].update({data_id: parsed_pred})
|
548 |
+
|
549 |
+
evaluation_result = {}
|
550 |
+
|
551 |
+
for category in CAT_SHORT2LONG.values():
|
552 |
+
# print("Evaluating: {}".format(category))
|
553 |
+
# get cat_outputs and cat_answers
|
554 |
+
try:
|
555 |
+
cat_outputs = output_dict_w_cat[category]
|
556 |
+
cat_answers = answer_dict_w_cat[category]
|
557 |
+
except KeyError:
|
558 |
+
print("Skipping {} for not found".format(category))
|
559 |
+
continue
|
560 |
+
|
561 |
+
exampels_to_eval = []
|
562 |
+
for data_id, parsed_pred in cat_outputs.items():
|
563 |
+
question_type = cat_answers[data_id]['question_type']
|
564 |
+
if question_type != 'multiple-choice':
|
565 |
+
parsed_pred = parse_open_response(parsed_pred) # mainly for type consistency (make it number, etc.)
|
566 |
+
else:
|
567 |
+
parsed_pred = parsed_pred
|
568 |
+
|
569 |
+
exampels_to_eval.append({
|
570 |
+
"id": data_id,
|
571 |
+
"question_type": question_type,
|
572 |
+
"answer": cat_answers[data_id]['ground_truth'],
|
573 |
+
"parsed_pred": parsed_pred
|
574 |
+
})
|
575 |
+
|
576 |
+
judge_dict, metric_dict = evaluate(exampels_to_eval)
|
577 |
+
metric_dict.update({"num_example": len(exampels_to_eval)})
|
578 |
+
|
579 |
+
evaluation_result[category] = metric_dict
|
580 |
+
|
581 |
+
printable_results = {}
|
582 |
+
# pdb.set_trace()
|
583 |
+
# add domain Subject
|
584 |
+
for domain, in_domain_cats in DOMAIN_CAT2SUB_CAT.items():
|
585 |
+
in_domain_cat_results = {}
|
586 |
+
for cat_name in in_domain_cats: # use the order in DOMAIN_CAT2SUB_CAT
|
587 |
+
if cat_name in evaluation_result.keys():
|
588 |
+
in_domain_cat_results[cat_name] = evaluation_result[cat_name]
|
589 |
+
else:
|
590 |
+
pass
|
591 |
+
in_domain_ins_acc = calculate_ins_level_acc(in_domain_cat_results)
|
592 |
+
in_domain_data_num = sum([cat_results['num_example'] for cat_results in in_domain_cat_results.values()])
|
593 |
+
printable_results['Overall-' + domain] = {"num": int(in_domain_data_num),
|
594 |
+
"acc": round(in_domain_ins_acc, 4)
|
595 |
+
}
|
596 |
+
# add sub category
|
597 |
+
for cat_name, cat_results in in_domain_cat_results.items():
|
598 |
+
printable_results[cat_name] = {"num": int(cat_results['num_example']),
|
599 |
+
"acc": round(cat_results['acc'], 4)
|
600 |
+
}
|
601 |
+
|
602 |
+
# table.append(["-----------------------------", "-----", "----"])
|
603 |
+
all_ins_acc = calculate_ins_level_acc(evaluation_result)
|
604 |
+
printable_results['Overall'] = {
|
605 |
+
"num": sum([cat_results['num_example'] for cat_results in evaluation_result.values()]),
|
606 |
+
"acc": round(all_ins_acc, 4)
|
607 |
+
}
|
608 |
+
|
609 |
+
# print(printable_results)
|
610 |
+
return printable_results
|
611 |
+
|
612 |
+
|
613 |
+
if __name__ == '__main__':
|
614 |
+
# tasks = yaml.safe_load(open("eval_config/eval_mmmu_yi_oci.yaml"))['datasets']
|
615 |
+
tasks = yaml.safe_load(open("eval_config/eval_mmmu_yi.yaml"))['datasets']
|
616 |
+
print(tasks)
|
617 |
+
|
618 |
+
# with open("/lustre/fs4/portfolios/adlr/users/boxinw/llava-megatron-gen/checkpoints/test/eval_mmmu_iter500_merged.4node.json") as f:
|
619 |
+
with open("/lustre/fsw/portfolios/llmservice/users/boxinw/eval_mmmu_iter6000_merged.0.53.json") as f:
|
620 |
+
merged_results = json.load(f)
|
621 |
+
|
622 |
+
eval_samples = []
|
623 |
+
eval_output_dict = {}
|
624 |
+
for res in merged_results:
|
625 |
+
pred_ans = res["answer"].upper()
|
626 |
+
gt_ans = res['gt_answer']
|
627 |
+
if res['question_type'] == 'multiple-choice':
|
628 |
+
parsed_pred = parse_multi_choice_response(pred_ans, res['all_choices'], res['index2ans'])
|
629 |
+
if pred_ans != parsed_pred:
|
630 |
+
print(f"MC: Original: {pred_ans}, Parsed: {parsed_pred}")
|
631 |
+
eval_samples.append(
|
632 |
+
{
|
633 |
+
'id': res['question_id'],
|
634 |
+
'question_type': res['question_type'],
|
635 |
+
'answer': res['gt_answer'], # the content in option, not answer index.
|
636 |
+
'response': pred_ans,
|
637 |
+
'parsed_pred': parsed_pred,
|
638 |
+
'index2ans': res['index2ans'],
|
639 |
+
}
|
640 |
+
)
|
641 |
+
eval_output_dict[res['question_id']] = parsed_pred
|
642 |
+
else:
|
643 |
+
parsed_pred = parse_open_response(pred_ans)
|
644 |
+
if pred_ans != parsed_pred:
|
645 |
+
print(f"Open: Original: {pred_ans}, Parsed: {parsed_pred}")
|
646 |
+
eval_samples.append(
|
647 |
+
{
|
648 |
+
'id': res['question_id'],
|
649 |
+
'question_type': res['question_type'],
|
650 |
+
'answer': res['gt_answer'],
|
651 |
+
'response': pred_ans,
|
652 |
+
'parsed_pred': parsed_pred,
|
653 |
+
}
|
654 |
+
)
|
655 |
+
eval_output_dict[res['question_id']] = pred_ans
|
656 |
+
|
657 |
+
json.dump(eval_output_dict, open("validation_mmmu_iter6000_merged.0.53.sorted.json", "w"), indent=4, sort_keys=True)
|
658 |
+
|
659 |
+
|
660 |
+
x = mmmu_main_eval(eval_output_dict,
|
661 |
+
task_cfg=tasks['mmmu'])
|
662 |
+
|
663 |
+
print(x)
|
eval/requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
anls
|
2 |
+
datasets
|
3 |
+
pycocoevalcap
|
eval/vqa_utils.py
ADDED
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
|
3 |
+
__author__ = "aagrawal"
|
4 |
+
|
5 |
+
# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
|
6 |
+
# (https://github.com/tylin/coco-caption/blob/master/pycocoevalcap/eval.py).
|
7 |
+
import sys
|
8 |
+
import re
|
9 |
+
|
10 |
+
|
11 |
+
class VQAEval:
|
12 |
+
def __init__(self, vqa=None, vqaRes=None, n=2):
|
13 |
+
self.n = n
|
14 |
+
self.accuracy = {}
|
15 |
+
self.evalQA = {}
|
16 |
+
self.evalQuesType = {}
|
17 |
+
self.evalAnsType = {}
|
18 |
+
self.vqa = vqa
|
19 |
+
self.vqaRes = vqaRes
|
20 |
+
if vqa is not None:
|
21 |
+
self.params = {"question_id": vqa.getQuesIds()}
|
22 |
+
self.contractions = {
|
23 |
+
"aint": "ain't",
|
24 |
+
"arent": "aren't",
|
25 |
+
"cant": "can't",
|
26 |
+
"couldve": "could've",
|
27 |
+
"couldnt": "couldn't",
|
28 |
+
"couldn'tve": "couldn't've",
|
29 |
+
"couldnt've": "couldn't've",
|
30 |
+
"didnt": "didn't",
|
31 |
+
"doesnt": "doesn't",
|
32 |
+
"dont": "don't",
|
33 |
+
"hadnt": "hadn't",
|
34 |
+
"hadnt've": "hadn't've",
|
35 |
+
"hadn'tve": "hadn't've",
|
36 |
+
"hasnt": "hasn't",
|
37 |
+
"havent": "haven't",
|
38 |
+
"hed": "he'd",
|
39 |
+
"hed've": "he'd've",
|
40 |
+
"he'dve": "he'd've",
|
41 |
+
"hes": "he's",
|
42 |
+
"howd": "how'd",
|
43 |
+
"howll": "how'll",
|
44 |
+
"hows": "how's",
|
45 |
+
"Id've": "I'd've",
|
46 |
+
"I'dve": "I'd've",
|
47 |
+
"Im": "I'm",
|
48 |
+
"Ive": "I've",
|
49 |
+
"isnt": "isn't",
|
50 |
+
"itd": "it'd",
|
51 |
+
"itd've": "it'd've",
|
52 |
+
"it'dve": "it'd've",
|
53 |
+
"itll": "it'll",
|
54 |
+
"let's": "let's",
|
55 |
+
"maam": "ma'am",
|
56 |
+
"mightnt": "mightn't",
|
57 |
+
"mightnt've": "mightn't've",
|
58 |
+
"mightn'tve": "mightn't've",
|
59 |
+
"mightve": "might've",
|
60 |
+
"mustnt": "mustn't",
|
61 |
+
"mustve": "must've",
|
62 |
+
"neednt": "needn't",
|
63 |
+
"notve": "not've",
|
64 |
+
"oclock": "o'clock",
|
65 |
+
"oughtnt": "oughtn't",
|
66 |
+
"ow's'at": "'ow's'at",
|
67 |
+
"'ows'at": "'ow's'at",
|
68 |
+
"'ow'sat": "'ow's'at",
|
69 |
+
"shant": "shan't",
|
70 |
+
"shed've": "she'd've",
|
71 |
+
"she'dve": "she'd've",
|
72 |
+
"she's": "she's",
|
73 |
+
"shouldve": "should've",
|
74 |
+
"shouldnt": "shouldn't",
|
75 |
+
"shouldnt've": "shouldn't've",
|
76 |
+
"shouldn'tve": "shouldn't've",
|
77 |
+
"somebody'd": "somebodyd",
|
78 |
+
"somebodyd've": "somebody'd've",
|
79 |
+
"somebody'dve": "somebody'd've",
|
80 |
+
"somebodyll": "somebody'll",
|
81 |
+
"somebodys": "somebody's",
|
82 |
+
"someoned": "someone'd",
|
83 |
+
"someoned've": "someone'd've",
|
84 |
+
"someone'dve": "someone'd've",
|
85 |
+
"someonell": "someone'll",
|
86 |
+
"someones": "someone's",
|
87 |
+
"somethingd": "something'd",
|
88 |
+
"somethingd've": "something'd've",
|
89 |
+
"something'dve": "something'd've",
|
90 |
+
"somethingll": "something'll",
|
91 |
+
"thats": "that's",
|
92 |
+
"thered": "there'd",
|
93 |
+
"thered've": "there'd've",
|
94 |
+
"there'dve": "there'd've",
|
95 |
+
"therere": "there're",
|
96 |
+
"theres": "there's",
|
97 |
+
"theyd": "they'd",
|
98 |
+
"theyd've": "they'd've",
|
99 |
+
"they'dve": "they'd've",
|
100 |
+
"theyll": "they'll",
|
101 |
+
"theyre": "they're",
|
102 |
+
"theyve": "they've",
|
103 |
+
"twas": "'twas",
|
104 |
+
"wasnt": "wasn't",
|
105 |
+
"wed've": "we'd've",
|
106 |
+
"we'dve": "we'd've",
|
107 |
+
"weve": "we've",
|
108 |
+
"werent": "weren't",
|
109 |
+
"whatll": "what'll",
|
110 |
+
"whatre": "what're",
|
111 |
+
"whats": "what's",
|
112 |
+
"whatve": "what've",
|
113 |
+
"whens": "when's",
|
114 |
+
"whered": "where'd",
|
115 |
+
"wheres": "where's",
|
116 |
+
"whereve": "where've",
|
117 |
+
"whod": "who'd",
|
118 |
+
"whod've": "who'd've",
|
119 |
+
"who'dve": "who'd've",
|
120 |
+
"wholl": "who'll",
|
121 |
+
"whos": "who's",
|
122 |
+
"whove": "who've",
|
123 |
+
"whyll": "why'll",
|
124 |
+
"whyre": "why're",
|
125 |
+
"whys": "why's",
|
126 |
+
"wont": "won't",
|
127 |
+
"wouldve": "would've",
|
128 |
+
"wouldnt": "wouldn't",
|
129 |
+
"wouldnt've": "wouldn't've",
|
130 |
+
"wouldn'tve": "wouldn't've",
|
131 |
+
"yall": "y'all",
|
132 |
+
"yall'll": "y'all'll",
|
133 |
+
"y'allll": "y'all'll",
|
134 |
+
"yall'd've": "y'all'd've",
|
135 |
+
"y'alld've": "y'all'd've",
|
136 |
+
"y'all'dve": "y'all'd've",
|
137 |
+
"youd": "you'd",
|
138 |
+
"youd've": "you'd've",
|
139 |
+
"you'dve": "you'd've",
|
140 |
+
"youll": "you'll",
|
141 |
+
"youre": "you're",
|
142 |
+
"youve": "you've",
|
143 |
+
}
|
144 |
+
self.manualMap = {
|
145 |
+
"none": "0",
|
146 |
+
"zero": "0",
|
147 |
+
"one": "1",
|
148 |
+
"two": "2",
|
149 |
+
"three": "3",
|
150 |
+
"four": "4",
|
151 |
+
"five": "5",
|
152 |
+
"six": "6",
|
153 |
+
"seven": "7",
|
154 |
+
"eight": "8",
|
155 |
+
"nine": "9",
|
156 |
+
"ten": "10",
|
157 |
+
}
|
158 |
+
self.articles = ["a", "an", "the"]
|
159 |
+
|
160 |
+
self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)")
|
161 |
+
self.commaStrip = re.compile("(\d)(,)(\d)")
|
162 |
+
self.punct = [
|
163 |
+
";",
|
164 |
+
r"/",
|
165 |
+
"[",
|
166 |
+
"]",
|
167 |
+
'"',
|
168 |
+
"{",
|
169 |
+
"}",
|
170 |
+
"(",
|
171 |
+
")",
|
172 |
+
"=",
|
173 |
+
"+",
|
174 |
+
"\\",
|
175 |
+
"_",
|
176 |
+
"-",
|
177 |
+
">",
|
178 |
+
"<",
|
179 |
+
"@",
|
180 |
+
"`",
|
181 |
+
",",
|
182 |
+
"?",
|
183 |
+
"!",
|
184 |
+
]
|
185 |
+
|
186 |
+
def evaluate(self, quesIds=None):
|
187 |
+
if quesIds == None:
|
188 |
+
quesIds = [quesId for quesId in self.params["question_id"]]
|
189 |
+
gts = {}
|
190 |
+
res = {}
|
191 |
+
for quesId in quesIds:
|
192 |
+
gts[quesId] = self.vqa.qa[quesId]
|
193 |
+
res[quesId] = self.vqaRes.qa[quesId]
|
194 |
+
|
195 |
+
# =================================================
|
196 |
+
# Compute accuracy
|
197 |
+
# =================================================
|
198 |
+
accQA = []
|
199 |
+
accQuesType = {}
|
200 |
+
accAnsType = {}
|
201 |
+
print("computing accuracy")
|
202 |
+
step = 0
|
203 |
+
for quesId in quesIds:
|
204 |
+
resAns = res[quesId]["answer"]
|
205 |
+
resAns = resAns.replace("\n", " ")
|
206 |
+
resAns = resAns.replace("\t", " ")
|
207 |
+
resAns = resAns.strip()
|
208 |
+
resAns = self.processPunctuation(resAns)
|
209 |
+
resAns = self.processDigitArticle(resAns)
|
210 |
+
gtAcc = []
|
211 |
+
gtAnswers = [ans["answer"] for ans in gts[quesId]["answers"]]
|
212 |
+
if len(set(gtAnswers)) > 1:
|
213 |
+
for ansDic in gts[quesId]["answers"]:
|
214 |
+
ansDic["answer"] = self.processPunctuation(ansDic["answer"])
|
215 |
+
for gtAnsDatum in gts[quesId]["answers"]:
|
216 |
+
otherGTAns = [
|
217 |
+
item for item in gts[quesId]["answers"] if item != gtAnsDatum
|
218 |
+
]
|
219 |
+
matchingAns = [item for item in otherGTAns if item["answer"] == resAns]
|
220 |
+
acc = min(1, float(len(matchingAns)) / 3)
|
221 |
+
gtAcc.append(acc)
|
222 |
+
quesType = gts[quesId]["question_type"]
|
223 |
+
ansType = gts[quesId]["answer_type"]
|
224 |
+
avgGTAcc = float(sum(gtAcc)) / len(gtAcc)
|
225 |
+
accQA.append(avgGTAcc)
|
226 |
+
if quesType not in accQuesType:
|
227 |
+
accQuesType[quesType] = []
|
228 |
+
accQuesType[quesType].append(avgGTAcc)
|
229 |
+
if ansType not in accAnsType:
|
230 |
+
accAnsType[ansType] = []
|
231 |
+
accAnsType[ansType].append(avgGTAcc)
|
232 |
+
self.setEvalQA(quesId, avgGTAcc)
|
233 |
+
self.setEvalQuesType(quesId, quesType, avgGTAcc)
|
234 |
+
self.setEvalAnsType(quesId, ansType, avgGTAcc)
|
235 |
+
if step % 100 == 0:
|
236 |
+
self.updateProgress(step / float(len(quesIds)))
|
237 |
+
step = step + 1
|
238 |
+
|
239 |
+
self.setAccuracy(accQA, accQuesType, accAnsType)
|
240 |
+
print("Done computing accuracy")
|
241 |
+
|
242 |
+
def processPunctuation(self, inText):
|
243 |
+
outText = inText
|
244 |
+
for p in self.punct:
|
245 |
+
if (p + " " in inText or " " + p in inText) or (
|
246 |
+
re.search(self.commaStrip, inText) != None
|
247 |
+
):
|
248 |
+
outText = outText.replace(p, "")
|
249 |
+
else:
|
250 |
+
outText = outText.replace(p, " ")
|
251 |
+
outText = self.periodStrip.sub("", outText, re.UNICODE)
|
252 |
+
return outText
|
253 |
+
|
254 |
+
def processDigitArticle(self, inText):
|
255 |
+
outText = []
|
256 |
+
tempText = inText.lower().split()
|
257 |
+
for word in tempText:
|
258 |
+
word = self.manualMap.setdefault(word, word)
|
259 |
+
if word not in self.articles:
|
260 |
+
outText.append(word)
|
261 |
+
else:
|
262 |
+
pass
|
263 |
+
for wordId, word in enumerate(outText):
|
264 |
+
if word in self.contractions:
|
265 |
+
outText[wordId] = self.contractions[word]
|
266 |
+
outText = " ".join(outText)
|
267 |
+
return outText
|
268 |
+
|
269 |
+
def setAccuracy(self, accQA, accQuesType, accAnsType):
|
270 |
+
self.accuracy["overall"] = round(100 * float(sum(accQA)) / len(accQA), self.n)
|
271 |
+
self.accuracy["perQuestionType"] = {
|
272 |
+
quesType: round(
|
273 |
+
100 * float(sum(accQuesType[quesType])) / len(accQuesType[quesType]),
|
274 |
+
self.n,
|
275 |
+
)
|
276 |
+
for quesType in accQuesType
|
277 |
+
}
|
278 |
+
self.accuracy["perAnswerType"] = {
|
279 |
+
ansType: round(
|
280 |
+
100 * float(sum(accAnsType[ansType])) / len(accAnsType[ansType]), self.n
|
281 |
+
)
|
282 |
+
for ansType in accAnsType
|
283 |
+
}
|
284 |
+
|
285 |
+
def setEvalQA(self, quesId, acc):
|
286 |
+
self.evalQA[quesId] = round(100 * acc, self.n)
|
287 |
+
|
288 |
+
def setEvalQuesType(self, quesId, quesType, acc):
|
289 |
+
if quesType not in self.evalQuesType:
|
290 |
+
self.evalQuesType[quesType] = {}
|
291 |
+
self.evalQuesType[quesType][quesId] = round(100 * acc, self.n)
|
292 |
+
|
293 |
+
def setEvalAnsType(self, quesId, ansType, acc):
|
294 |
+
if ansType not in self.evalAnsType:
|
295 |
+
self.evalAnsType[ansType] = {}
|
296 |
+
self.evalAnsType[ansType][quesId] = round(100 * acc, self.n)
|
297 |
+
|
298 |
+
def updateProgress(self, progress):
|
299 |
+
barLength = 20
|
300 |
+
status = ""
|
301 |
+
if isinstance(progress, int):
|
302 |
+
progress = float(progress)
|
303 |
+
if not isinstance(progress, float):
|
304 |
+
progress = 0
|
305 |
+
status = "error: progress var must be float\r\n"
|
306 |
+
if progress < 0:
|
307 |
+
progress = 0
|
308 |
+
status = "Halt...\r\n"
|
309 |
+
if progress >= 1:
|
310 |
+
progress = 1
|
311 |
+
status = "Done...\r\n"
|
312 |
+
block = int(round(barLength * progress))
|
313 |
+
text = "\rFinshed Percent: [{0}] {1}% {2}".format(
|
314 |
+
"#" * block + "-" * (barLength - block), int(progress * 100), status
|
315 |
+
)
|
316 |
+
sys.stdout.write(text)
|
317 |
+
sys.stdout.flush()
|
run_eval.py
ADDED
@@ -0,0 +1,702 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ast
|
2 |
+
import json
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from transformers import AutoTokenizer, AutoModel
|
6 |
+
import math
|
7 |
+
import time
|
8 |
+
import argparse
|
9 |
+
import yaml
|
10 |
+
import os
|
11 |
+
|
12 |
+
from eval.eval_dataset import get_task_dataloader, isNumber, get_anls_score
|
13 |
+
from eval.mmmu_utils import CAT_SHORT2LONG, DOMAIN_CAT2SUB_CAT, parse_multi_choice_response, parse_open_response, \
|
14 |
+
process_single_sample, construct_prompt, mmmu_main_eval, process_single_sample_pro, construct_prompt_pro
|
15 |
+
from eval.mmmu_utils import evaluate as evaluate_mmmu
|
16 |
+
from pycocotools.coco import COCO
|
17 |
+
from pycocoevalcap.eval import COCOEvalCap
|
18 |
+
from anls import anls_score
|
19 |
+
import pandas as pd
|
20 |
+
from eval.vqa_utils import VQAEval
|
21 |
+
|
22 |
+
|
23 |
+
def split_model():
|
24 |
+
device_map = {}
|
25 |
+
world_size = torch.cuda.device_count()
|
26 |
+
num_layers = 80
|
27 |
+
# Since the first GPU will be used for ViT, treat it as half a GPU.
|
28 |
+
num_layers_per_gpu = math.ceil(num_layers / (world_size - 0.5))
|
29 |
+
num_layers_per_gpu = [num_layers_per_gpu] * world_size
|
30 |
+
num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * 0.5)
|
31 |
+
layer_cnt = 0
|
32 |
+
for i, num_layer in enumerate(num_layers_per_gpu):
|
33 |
+
for j in range(num_layer):
|
34 |
+
device_map[f'language_model.model.layers.{layer_cnt}'] = i
|
35 |
+
layer_cnt += 1
|
36 |
+
device_map['vision_model'] = 0
|
37 |
+
device_map['mlp1'] = 0
|
38 |
+
device_map['language_model.model.tok_embeddings'] = 0
|
39 |
+
device_map['language_model.model.embed_tokens'] = 0
|
40 |
+
device_map['language_model.output'] = 0
|
41 |
+
device_map['language_model.model.norm'] = 0
|
42 |
+
device_map['language_model.lm_head'] = 0
|
43 |
+
device_map[f'language_model.model.layers.{num_layers - 1}'] = 0
|
44 |
+
|
45 |
+
return device_map
|
46 |
+
|
47 |
+
|
48 |
+
def generate_task_results(
|
49 |
+
task_name,
|
50 |
+
task_cfg,
|
51 |
+
args,
|
52 |
+
model,
|
53 |
+
tokenizer,
|
54 |
+
generation_config,
|
55 |
+
dataloader,
|
56 |
+
results_save_dir,
|
57 |
+
):
|
58 |
+
results = []
|
59 |
+
|
60 |
+
if "prompt" in task_cfg:
|
61 |
+
prompt = task_cfg["prompt"]
|
62 |
+
|
63 |
+
if task_name == "mmmu" or task_name == 'mmmu_pro':
|
64 |
+
for i, (img, question_id, subfield, question_type, question, answer, index2ans, all_choices) in enumerate(
|
65 |
+
dataloader):
|
66 |
+
if img.dim() == 5:
|
67 |
+
img = img.squeeze(0)
|
68 |
+
question_id = question_id[0]
|
69 |
+
question_type = question_type[0]
|
70 |
+
question = question[0]
|
71 |
+
answer = answer[0]
|
72 |
+
subfield = subfield[0]
|
73 |
+
index2ans = ast.literal_eval(index2ans[0])
|
74 |
+
all_choices = ast.literal_eval(all_choices[0])
|
75 |
+
|
76 |
+
this_prompt = question
|
77 |
+
|
78 |
+
if task_name == 'mmmu':
|
79 |
+
categories = list(CAT_SHORT2LONG.values())
|
80 |
+
new_sys_msg = "You are expert in {} ({}). Read the image and use your knowledge in {} ({}) to answer the question."
|
81 |
+
for c in categories:
|
82 |
+
if c in question_id:
|
83 |
+
cat = c.lower().replace('_', ' ')
|
84 |
+
new_sys_msg = new_sys_msg.format(cat, subfield, cat, subfield)
|
85 |
+
break
|
86 |
+
else:
|
87 |
+
new_sys_msg = f"You are expert in {subfield}. Read the image and use your knowledge in {subfield} to answer the question."
|
88 |
+
model.system_message = new_sys_msg
|
89 |
+
|
90 |
+
generated_answer = model.chat(tokenizer, img, question, generation_config)
|
91 |
+
print("generated_answer", generated_answer)
|
92 |
+
|
93 |
+
results.append({
|
94 |
+
"question_id": question_id,
|
95 |
+
"question_type": question_type,
|
96 |
+
"answer": generated_answer,
|
97 |
+
"gt_answer": answer,
|
98 |
+
"index2ans": index2ans,
|
99 |
+
"all_choices": all_choices
|
100 |
+
})
|
101 |
+
print("idx:", i, results[-1])
|
102 |
+
|
103 |
+
|
104 |
+
elif task_name == "coco_caption" or task_name == "flickr30k_caption":
|
105 |
+
|
106 |
+
for i, (image_id, img,) in enumerate(dataloader):
|
107 |
+
|
108 |
+
if img.dim() == 5:
|
109 |
+
img = img.squeeze(0)
|
110 |
+
|
111 |
+
generated_answer = model.chat_without_chat_prompt(tokenizer, img, prompt, generation_config)
|
112 |
+
print("generated_answer", generated_answer)
|
113 |
+
|
114 |
+
results.append({"image_id": image_id.item(), "caption": generated_answer})
|
115 |
+
print("idx:", i, results[-1])
|
116 |
+
|
117 |
+
elif task_name in ["vqav2", "gqa", "okvqa", "textvqa", "chartqa", "docvqa", "realworldqa",
|
118 |
+
"ai2diagram", "ai2diagram_nomask", "docvqa_test"]:
|
119 |
+
for i, (img, question_id, question, answer,) in enumerate(dataloader):
|
120 |
+
if img.dim() == 5:
|
121 |
+
img = img.squeeze(0)
|
122 |
+
|
123 |
+
# Only works for batch size = 1
|
124 |
+
# need to implement collate function when we use bs > 1 in the future
|
125 |
+
question = question[0]
|
126 |
+
if task_name == 'ai2diagram' or task_name == 'ai2diagram_nomask':
|
127 |
+
question_id = question_id[0]
|
128 |
+
else:
|
129 |
+
question_id = question_id[0].item()
|
130 |
+
|
131 |
+
if type(answer) == list:
|
132 |
+
answer = [ans[0] for ans in answer]
|
133 |
+
else:
|
134 |
+
answer = [answer[0]]
|
135 |
+
|
136 |
+
# Need to change if using batch size > 1 in the future
|
137 |
+
this_prompt = prompt.format(question)
|
138 |
+
|
139 |
+
generated_answer = model.chat_without_chat_prompt(tokenizer, img, this_prompt, generation_config)
|
140 |
+
print("generated_answer", generated_answer)
|
141 |
+
|
142 |
+
results.append({"question_id": question_id, "answer": generated_answer, "gt_answer": answer})
|
143 |
+
print("idx:", i, results[-1])
|
144 |
+
|
145 |
+
|
146 |
+
elif task_name == "ocrbench":
|
147 |
+
for i, (img, question_id, question, answer, dataset_name, data_type,) in enumerate(
|
148 |
+
dataloader):
|
149 |
+
if img.dim() == 5:
|
150 |
+
img = img.squeeze(0)
|
151 |
+
|
152 |
+
question_id = question_id[0]
|
153 |
+
question = question[0]
|
154 |
+
answer = answer[0]
|
155 |
+
dataset_name = dataset_name[0]
|
156 |
+
data_type = data_type[0]
|
157 |
+
|
158 |
+
this_prompt = prompt.format(question)
|
159 |
+
|
160 |
+
generated_answer = model.chat_without_chat_prompt(tokenizer, img, this_prompt, generation_config)
|
161 |
+
print("generated_answer", generated_answer)
|
162 |
+
|
163 |
+
results.append({"question_id": question_id, "answer": generated_answer, "gt_answer": answer,
|
164 |
+
"dataset_name": dataset_name, "type": data_type})
|
165 |
+
print("idx:", i, results[-1])
|
166 |
+
|
167 |
+
|
168 |
+
elif task_name == "mathvista":
|
169 |
+
for i, (
|
170 |
+
img, question_id, question_type, question, answer, index2ans, all_choices,) in enumerate(
|
171 |
+
dataloader):
|
172 |
+
if img.dim() == 5:
|
173 |
+
img = img.squeeze(0)
|
174 |
+
|
175 |
+
question_id = question_id[0]
|
176 |
+
question = question[0]
|
177 |
+
question_type = question_type[0]
|
178 |
+
answer = answer[0]
|
179 |
+
|
180 |
+
index2ans = ast.literal_eval(index2ans[0])
|
181 |
+
all_choices = ast.literal_eval(all_choices[0])
|
182 |
+
|
183 |
+
this_prompt = prompt.format(question)
|
184 |
+
|
185 |
+
generated_answer = model.chat_without_chat_prompt(tokenizer, img, this_prompt, generation_config)
|
186 |
+
print("generated_answer", generated_answer)
|
187 |
+
|
188 |
+
results.append({
|
189 |
+
"question_id": question_id,
|
190 |
+
"question_type": question_type,
|
191 |
+
"answer": generated_answer,
|
192 |
+
"gt_answer": answer,
|
193 |
+
"index2ans": index2ans,
|
194 |
+
"all_choices": all_choices
|
195 |
+
})
|
196 |
+
print("idx:", i, results[-1])
|
197 |
+
|
198 |
+
|
199 |
+
elif task_name == "mmbench":
|
200 |
+
for i, (
|
201 |
+
img, question_id, question, answer, index2ans, all_choices, original_q,) in enumerate(
|
202 |
+
dataloader):
|
203 |
+
if img.dim() == 5:
|
204 |
+
img = img.squeeze(0)
|
205 |
+
|
206 |
+
question_id = question_id[0]
|
207 |
+
question = question[0]
|
208 |
+
answer = answer[0]
|
209 |
+
original_q = original_q[0]
|
210 |
+
|
211 |
+
index2ans = ast.literal_eval(index2ans[0])
|
212 |
+
all_choices = ast.literal_eval(all_choices[0])
|
213 |
+
|
214 |
+
this_prompt = prompt.format(question)
|
215 |
+
|
216 |
+
generated_answer = model.chat_without_chat_prompt(tokenizer, img, this_prompt, generation_config)
|
217 |
+
print("generated_answer", generated_answer)
|
218 |
+
|
219 |
+
results.append({
|
220 |
+
"question_id": question_id.item(),
|
221 |
+
"question": original_q,
|
222 |
+
"answer": generated_answer,
|
223 |
+
"gt_answer": answer,
|
224 |
+
"index2ans": index2ans,
|
225 |
+
"all_choices": all_choices
|
226 |
+
})
|
227 |
+
print("idx:", i, results[-1])
|
228 |
+
|
229 |
+
|
230 |
+
elif task_name == "vizwiz":
|
231 |
+
for i, (img, question_id, question,) in enumerate(dataloader):
|
232 |
+
if img.dim() == 5:
|
233 |
+
img = img.squeeze(0)
|
234 |
+
|
235 |
+
question = question[0]
|
236 |
+
question_id = question_id[0]
|
237 |
+
|
238 |
+
this_prompt = prompt.format(question)
|
239 |
+
|
240 |
+
generated_answer = model.chat_without_chat_prompt(tokenizer, img, this_prompt, generation_config)
|
241 |
+
print("generated_answer", generated_answer)
|
242 |
+
|
243 |
+
results.append({"image": question_id, "answer": generated_answer})
|
244 |
+
print("idx:", i, results[-1])
|
245 |
+
|
246 |
+
else:
|
247 |
+
raise NotImplementedError(f"Task {task_name} is not supported yet.")
|
248 |
+
|
249 |
+
os.makedirs(results_save_dir, exist_ok=True)
|
250 |
+
if args.subset is not None:
|
251 |
+
results_save_path = os.path.join(results_save_dir,
|
252 |
+
f"eval_{task_name}_subset_{args.subset}_start_{args.start_idx}.json")
|
253 |
+
else:
|
254 |
+
results_save_path = os.path.join(results_save_dir,
|
255 |
+
f"eval_{task_name}.json")
|
256 |
+
print("Saving to ", results_save_path)
|
257 |
+
json.dump(results, open(results_save_path, 'w'))
|
258 |
+
|
259 |
+
|
260 |
+
def calc_task_metrics(args, task_name, task_cfg, results_save_dir):
|
261 |
+
if args.subset is not None:
|
262 |
+
merged_results_path = os.path.join(results_save_dir,
|
263 |
+
f"eval_{task_name}_subset_{args.subset}_start_{args.start_idx}.json")
|
264 |
+
else:
|
265 |
+
merged_results_path = os.path.join(results_save_dir,
|
266 |
+
f"eval_{task_name}.json")
|
267 |
+
merged_results = json.load(open(merged_results_path, "r"))
|
268 |
+
|
269 |
+
if task_name == "coco_caption" or task_name == "flickr30k_caption":
|
270 |
+
# Calculate scores
|
271 |
+
coco = COCO(task_cfg["gt_path"])
|
272 |
+
coco_result = coco.loadRes(merged_results_path)
|
273 |
+
coco_eval = COCOEvalCap(coco, coco_result)
|
274 |
+
coco_eval.params["image_id"] = coco_result.getImgIds()
|
275 |
+
coco_eval.evaluate()
|
276 |
+
|
277 |
+
# Print and save scores
|
278 |
+
print(f"====== {task_name} scores ======")
|
279 |
+
with open(os.path.join(results_save_dir, "scores.txt"), "a") as f:
|
280 |
+
f.write(f"{task_name} scores:\n")
|
281 |
+
for k, v in coco_eval.eval.items():
|
282 |
+
msg = f"{k} = {v * 100:.3f}"
|
283 |
+
print(msg)
|
284 |
+
f.write(msg + "\n")
|
285 |
+
f.write("\n")
|
286 |
+
return coco_eval.eval["CIDEr"]
|
287 |
+
elif task_name == "vqav2" or task_name == "okvqa":
|
288 |
+
vqa_tool = VQAEval()
|
289 |
+
all_acc = []
|
290 |
+
for res in merged_results:
|
291 |
+
pred_ans = res["answer"]
|
292 |
+
pred_ans = vqa_tool.processPunctuation(pred_ans)
|
293 |
+
pred_ans = vqa_tool.processDigitArticle(pred_ans)
|
294 |
+
|
295 |
+
gt_ans = res["gt_answer"]
|
296 |
+
gt_ans = [vqa_tool.processPunctuation(ans) for ans in gt_ans]
|
297 |
+
gt_ans = [vqa_tool.processDigitArticle(ans) for ans in gt_ans]
|
298 |
+
|
299 |
+
num_match = sum([pred_ans == ans for ans in gt_ans])
|
300 |
+
acc = min(1.0, num_match / 3.0)
|
301 |
+
all_acc.append(acc)
|
302 |
+
acc_avg = sum(all_acc) / len(all_acc) * 100
|
303 |
+
print(f"===== {task_name} Accuracy {acc_avg:.2f}% =====")
|
304 |
+
with open(os.path.join(results_save_dir, "scores.txt"), "a") as f:
|
305 |
+
f.write(f"{task_name} Accuracy = {acc_avg:.2f}%\n\n")
|
306 |
+
return acc_avg
|
307 |
+
|
308 |
+
elif task_name == "textvqa":
|
309 |
+
vqa_tool = VQAEval()
|
310 |
+
all_acc = []
|
311 |
+
for res in merged_results:
|
312 |
+
pred_ans = res["answer"]
|
313 |
+
pred_ans = vqa_tool.processPunctuation(pred_ans)
|
314 |
+
pred_ans = vqa_tool.processDigitArticle(pred_ans)
|
315 |
+
|
316 |
+
gt_ans = res["gt_answer"]
|
317 |
+
gt_ans = [vqa_tool.processPunctuation(ans) for ans in gt_ans]
|
318 |
+
gt_ans = [vqa_tool.processDigitArticle(ans) for ans in gt_ans]
|
319 |
+
|
320 |
+
num_match = sum([pred_ans == ans for ans in gt_ans])
|
321 |
+
acc = min(1.0, num_match / 3.0)
|
322 |
+
all_acc.append(acc)
|
323 |
+
acc_avg = sum(all_acc) / len(all_acc) * 100
|
324 |
+
print(
|
325 |
+
f"===== {task_name} Accuracy {acc_avg:.2f}% (need to submit to EvalAI for the accurate accuracy) =====")
|
326 |
+
with open(os.path.abspath(os.path.join(results_save_dir, "scores.txt")), "a") as f:
|
327 |
+
f.write(
|
328 |
+
f"{task_name} Accuracy = {acc_avg:.2f}% (need to submit to EvalAI for the accurate accuracy)\n\n")
|
329 |
+
return acc_avg
|
330 |
+
|
331 |
+
elif task_name in ["gqa", "realworldqa", "ai2diagram", "ai2diagram_nomask"]:
|
332 |
+
vqa_tool = VQAEval()
|
333 |
+
acc = 0
|
334 |
+
for res in merged_results:
|
335 |
+
pred_ans = res["answer"]
|
336 |
+
pred_ans = vqa_tool.processPunctuation(pred_ans)
|
337 |
+
pred_ans = vqa_tool.processDigitArticle(pred_ans)
|
338 |
+
|
339 |
+
gt_ans = res["gt_answer"][0]
|
340 |
+
gt_ans = vqa_tool.processPunctuation(gt_ans)
|
341 |
+
gt_ans = vqa_tool.processDigitArticle(gt_ans)
|
342 |
+
|
343 |
+
if pred_ans == gt_ans:
|
344 |
+
acc += 1
|
345 |
+
acc = acc / len(merged_results) * 100
|
346 |
+
print(f"===== {task_name} Accuracy {acc:.2f}% =====")
|
347 |
+
with open(os.path.abspath(os.path.join(results_save_dir, "scores.txt")), "a") as f:
|
348 |
+
f.write(f"{task_name} Accuracy = {acc:.2f}%\n\n")
|
349 |
+
return acc
|
350 |
+
|
351 |
+
elif task_name == "docvqa" or task_name == "docvqa_test":
|
352 |
+
vqa_tool = VQAEval()
|
353 |
+
anls = 0
|
354 |
+
for res in merged_results:
|
355 |
+
pred_ans = res["answer"]
|
356 |
+
gt_ans = res["gt_answer"]
|
357 |
+
anls += get_anls_score(pred=pred_ans, gold_labels=gt_ans, threshold=0.5)
|
358 |
+
anls = anls / len(merged_results) * 100
|
359 |
+
print(f"===== {task_name} ANLS {anls:.2f}% =====")
|
360 |
+
with open(os.path.abspath(os.path.join(results_save_dir, "scores.txt")), "a") as f:
|
361 |
+
f.write(f" {task_name} ANLS = {anls:.2f}%\n\n")
|
362 |
+
return anls
|
363 |
+
|
364 |
+
elif task_name == "chartqa":
|
365 |
+
vqa_tool = VQAEval()
|
366 |
+
acc = 0
|
367 |
+
for res in merged_results:
|
368 |
+
pred_ans = res["answer"]
|
369 |
+
pred_ans = vqa_tool.processPunctuation(pred_ans)
|
370 |
+
pred_ans = vqa_tool.processDigitArticle(pred_ans)
|
371 |
+
|
372 |
+
gt_ans = res["gt_answer"][0]
|
373 |
+
gt_ans = vqa_tool.processPunctuation(gt_ans)
|
374 |
+
gt_ans = vqa_tool.processDigitArticle(gt_ans)
|
375 |
+
|
376 |
+
# ChartQA uses relaxed accuracy:
|
377 |
+
# "We consider an answer to be correct if it is within 5% of the gold answer.
|
378 |
+
# For non-numeric answers, we still need an exact match to consider an answer to be correct."
|
379 |
+
if isNumber(pred_ans) and isNumber(gt_ans):
|
380 |
+
pred_ans = float(pred_ans)
|
381 |
+
gt_ans = float(gt_ans)
|
382 |
+
if pred_ans >= (gt_ans * 0.95) and pred_ans <= (gt_ans * 1.05):
|
383 |
+
acc += 1
|
384 |
+
elif pred_ans == gt_ans:
|
385 |
+
acc += 1
|
386 |
+
acc = acc / len(merged_results) * 100
|
387 |
+
print(f"===== {task_name} Accuracy {acc:.2f}% =====")
|
388 |
+
with open(os.path.abspath(os.path.join(results_save_dir, "scores.txt")), "a") as f:
|
389 |
+
f.write(f"{task_name} Accuracy = {acc:.2f}%\n\n")
|
390 |
+
return acc
|
391 |
+
|
392 |
+
elif task_name == 'ocrbench':
|
393 |
+
|
394 |
+
from collections import defaultdict
|
395 |
+
OCRBench_score = {"Regular Text Recognition": 0, "Irregular Text Recognition": 0,
|
396 |
+
"Artistic Text Recognition": 0, "Handwriting Recognition": 0,
|
397 |
+
"Digit String Recognition": 0, "Non-Semantic Text Recognition": 0,
|
398 |
+
"Scene Text-centric VQA": 0, "Doc-oriented VQA": 0, "Doc-oriented VQA": 0,
|
399 |
+
"Key Information Extraction": 0, "Handwritten Mathematical Expression Recognition": 0}
|
400 |
+
|
401 |
+
for res in merged_results:
|
402 |
+
predict = res["answer"]
|
403 |
+
answers = res["gt_answer"]
|
404 |
+
|
405 |
+
dataset_name = res["dataset_name"]
|
406 |
+
ocr_type = res["type"]
|
407 |
+
|
408 |
+
# data[i]['result'] = 0
|
409 |
+
if dataset_name == "HME100k":
|
410 |
+
if type(answers) == list:
|
411 |
+
for j in range(len(answers)):
|
412 |
+
answer = answers[j].strip().replace("\n", " ").replace(" ", "")
|
413 |
+
predict = predict.strip().replace("\n", " ").replace(" ", "")
|
414 |
+
if answer in predict:
|
415 |
+
OCRBench_score[ocr_type] += 1
|
416 |
+
else:
|
417 |
+
answers = answers.strip().replace("\n", " ").replace(" ", "")
|
418 |
+
predict = predict.strip().replace("\n", " ").replace(" ", "")
|
419 |
+
if answers in predict:
|
420 |
+
OCRBench_score[ocr_type] += 1
|
421 |
+
else:
|
422 |
+
if type(answers) == list:
|
423 |
+
for j in range(len(answers)):
|
424 |
+
answer = answers[j].lower().strip().replace("\n", " ")
|
425 |
+
predict = predict.lower().strip().replace("\n", " ")
|
426 |
+
if answer in predict:
|
427 |
+
OCRBench_score[ocr_type] += 1
|
428 |
+
else:
|
429 |
+
answers = answers.lower().strip().replace("\n", " ")
|
430 |
+
predict = predict.lower().strip().replace("\n", " ")
|
431 |
+
if answers in predict:
|
432 |
+
OCRBench_score[ocr_type] += 1
|
433 |
+
|
434 |
+
recognition_score = OCRBench_score['Regular Text Recognition'] + OCRBench_score['Irregular Text Recognition'] + \
|
435 |
+
OCRBench_score['Artistic Text Recognition'] + OCRBench_score['Handwriting Recognition'] + \
|
436 |
+
OCRBench_score['Digit String Recognition'] + OCRBench_score['Non-Semantic Text Recognition']
|
437 |
+
Final_score = recognition_score + OCRBench_score['Scene Text-centric VQA'] + OCRBench_score[
|
438 |
+
'Doc-oriented VQA'] + OCRBench_score['Key Information Extraction'] + OCRBench_score[
|
439 |
+
'Handwritten Mathematical Expression Recognition']
|
440 |
+
result_log = f"\n###########################OCRBench##############################\n\
|
441 |
+
Text Recognition(Total 300):{recognition_score}\n\
|
442 |
+
------------------Details of Recognition Score-------------------\n\
|
443 |
+
Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}\n\
|
444 |
+
Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}\n\
|
445 |
+
Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}\n\
|
446 |
+
Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}\n\
|
447 |
+
Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}\n\
|
448 |
+
Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}\n\
|
449 |
+
----------------------------------------------------------------\n\
|
450 |
+
Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}\n\
|
451 |
+
----------------------------------------------------------------\n\
|
452 |
+
Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}\n\
|
453 |
+
----------------------------------------------------------------\n\
|
454 |
+
Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}\n\
|
455 |
+
----------------------------------------------------------------\n\
|
456 |
+
Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}\n\
|
457 |
+
----------------------Final Score-------------------------------\n\
|
458 |
+
Final Score(Total 1000): {Final_score}\n"
|
459 |
+
|
460 |
+
print(f"===== {task_name} Final_score(Total 1000): {Final_score:.2f}% ===== {result_log}")
|
461 |
+
with open(os.path.abspath(os.path.join(results_save_dir, "scores.txt")), "a") as f:
|
462 |
+
f.write(f" {task_name} Accuracy = {Final_score:.2f}% {result_log}\n\n")
|
463 |
+
return Final_score
|
464 |
+
|
465 |
+
|
466 |
+
elif task_name == "mathvista":
|
467 |
+
import re
|
468 |
+
def extra_processing(text):
|
469 |
+
|
470 |
+
## max decimal point capped to 2 decimal point
|
471 |
+
regex = re.compile(r'^\d+\.\d+$')
|
472 |
+
decimal = regex.findall(text)
|
473 |
+
|
474 |
+
if len(decimal) > 0:
|
475 |
+
non_decimal = len(decimal[0].split(".")[0])
|
476 |
+
|
477 |
+
# if decimal values are all 0, trim them
|
478 |
+
|
479 |
+
decimal_digits = [int(d) for d in decimal[0].split(".")[1]]
|
480 |
+
if sum(decimal_digits) == 0:
|
481 |
+
text = decimal[0][:non_decimal]
|
482 |
+
else:
|
483 |
+
text = decimal[0][:non_decimal + 3]
|
484 |
+
|
485 |
+
## remove %
|
486 |
+
text = text.replace("%", "")
|
487 |
+
try:
|
488 |
+
if text[-1] == ".":
|
489 |
+
text = text[:-1]
|
490 |
+
except:
|
491 |
+
print(text)
|
492 |
+
|
493 |
+
return text
|
494 |
+
|
495 |
+
def extract_answer(text):
|
496 |
+
|
497 |
+
alphabet = re.findall(r'[a-zA-Z]+', text)
|
498 |
+
if len(alphabet) > 0 and "e+" not in text:
|
499 |
+
|
500 |
+
template_1 = re.findall(r'answer is -*\d+\.*\d*', text)
|
501 |
+
if len(template_1) > 0:
|
502 |
+
text = template_1[0]
|
503 |
+
|
504 |
+
numbers = re.findall(r'-*\d+\.*\d*', text)
|
505 |
+
text = numbers[0] if len(numbers) > 0 else text
|
506 |
+
|
507 |
+
return text
|
508 |
+
|
509 |
+
vqa_tool = VQAEval()
|
510 |
+
acc = 0
|
511 |
+
for res in merged_results:
|
512 |
+
pred_ans = res["answer"]
|
513 |
+
if res['question_type'] == 'multi_choice':
|
514 |
+
pred_ans = parse_multi_choice_response(pred_ans, res['all_choices'], res['index2ans'])
|
515 |
+
else:
|
516 |
+
pred_ans = vqa_tool.processPunctuation(pred_ans)
|
517 |
+
pred_ans = vqa_tool.processDigitArticle(pred_ans)
|
518 |
+
|
519 |
+
pred_ans = extra_processing(pred_ans)
|
520 |
+
pred_ans = extract_answer(pred_ans)
|
521 |
+
|
522 |
+
gt_ans = res["gt_answer"]
|
523 |
+
if res['question_type'] != 'multi_choice':
|
524 |
+
gt_ans = vqa_tool.processPunctuation(gt_ans)
|
525 |
+
gt_ans = vqa_tool.processDigitArticle(gt_ans)
|
526 |
+
gt_ans = extra_processing(gt_ans)
|
527 |
+
|
528 |
+
if pred_ans == gt_ans:
|
529 |
+
acc += 1
|
530 |
+
acc = acc / len(merged_results) * 100
|
531 |
+
print(f"===== {task_name} Accuracy {acc:.2f}% =====")
|
532 |
+
with open(os.path.abspath(os.path.join(results_save_dir, "scores.txt")), "a") as f:
|
533 |
+
f.write(f" {task_name} Accuracy = {acc:.2f}%\n\n")
|
534 |
+
return acc
|
535 |
+
elif task_name == "mmbench":
|
536 |
+
if task_cfg['split'] == 'dev':
|
537 |
+
acc = 0
|
538 |
+
for res in merged_results:
|
539 |
+
gt_ans = res["gt_answer"]
|
540 |
+
pred_ans = res["answer"]
|
541 |
+
pred_ans = parse_multi_choice_response(pred_ans, res['all_choices'], res['index2ans'])
|
542 |
+
|
543 |
+
if pred_ans == gt_ans:
|
544 |
+
acc += 1
|
545 |
+
acc = acc / len(merged_results) * 100
|
546 |
+
print(f"===== {task_name} Accuracy {acc:.2f}% =====")
|
547 |
+
with open(os.path.abspath(os.path.join(results_save_dir, "scores.txt")), "a") as f:
|
548 |
+
f.write(f"{task_name} Accuracy = {acc:.2f}%\n\n")
|
549 |
+
|
550 |
+
return acc
|
551 |
+
|
552 |
+
# Generate submission.xlsx for test set online evaluation, "https://mmbench.opencompass.org.cn/mmbench-submission"
|
553 |
+
if task_cfg['submission']:
|
554 |
+
res_df = pd.DataFrame(
|
555 |
+
{
|
556 |
+
"index": [r['question_id'] for r in merged_results],
|
557 |
+
"question": [r['question'] for r in merged_results],
|
558 |
+
"A": [r['index2ans']['A'] if 'A' in r['index2ans'] else None for r in merged_results],
|
559 |
+
"B": [r['index2ans']['B'] if 'B' in r['index2ans'] else None for r in merged_results],
|
560 |
+
"C": [r['index2ans']['C'] if 'C' in r['index2ans'] else None for r in merged_results],
|
561 |
+
"D": [r['index2ans']['D'] if 'D' in r['index2ans'] else None for r in merged_results],
|
562 |
+
"prediction": [parse_multi_choice_response(r['answer'], r['all_choices'], r['index2ans']) for r in
|
563 |
+
merged_results],
|
564 |
+
},
|
565 |
+
columns=['index', 'question', 'A', 'B', 'C', 'D', 'prediction']
|
566 |
+
)
|
567 |
+
res_df.to_excel(os.path.join(results_save_dir, "submission.xlsx"))
|
568 |
+
|
569 |
+
|
570 |
+
elif task_name == "vizwiz":
|
571 |
+
print(
|
572 |
+
f"VizWiz result file is saved at: {merged_results_path}\n"
|
573 |
+
f"Upload manually or use this CLI `python evaluation/upload_vizwiz.py --result {merged_results_path} --token <your_eval_user_token>`."
|
574 |
+
)
|
575 |
+
|
576 |
+
elif task_name == "mmmu" or task_name == 'mmmu_pro':
|
577 |
+
def extract_answer(text):
|
578 |
+
import re
|
579 |
+
# Regular expression to find content inside \answer{xxx}
|
580 |
+
match = re.search(r'\\answer\{(.*?)\}', text)
|
581 |
+
if match:
|
582 |
+
return match.group(1) # Return the content inside the braces
|
583 |
+
return text # Return the original string if no match is found
|
584 |
+
|
585 |
+
eval_samples = []
|
586 |
+
eval_output_dict = {}
|
587 |
+
for res in merged_results:
|
588 |
+
pred_ans = res["answer"]
|
589 |
+
gt_ans = res['gt_answer']
|
590 |
+
if res['question_type'] == 'multiple-choice':
|
591 |
+
parsed_pred = parse_multi_choice_response(pred_ans, res['all_choices'], res['index2ans'])
|
592 |
+
eval_samples.append(
|
593 |
+
{
|
594 |
+
'id': res['question_id'],
|
595 |
+
'question_type': res['question_type'],
|
596 |
+
'answer': res['gt_answer'], # the content in option, not answer index.
|
597 |
+
'response': pred_ans,
|
598 |
+
'parsed_pred': parsed_pred,
|
599 |
+
'index2ans': res['index2ans'],
|
600 |
+
}
|
601 |
+
)
|
602 |
+
eval_output_dict[res['question_id']] = parsed_pred
|
603 |
+
else:
|
604 |
+
pred_ans = extract_answer(pred_ans) # for sft v9prov3, we observe answers are within \answer{xxx}
|
605 |
+
parsed_pred = parse_open_response(pred_ans)
|
606 |
+
eval_samples.append(
|
607 |
+
{
|
608 |
+
'id': res['question_id'],
|
609 |
+
'question_type': res['question_type'],
|
610 |
+
'answer': res['gt_answer'],
|
611 |
+
'response': pred_ans,
|
612 |
+
'parsed_pred': parsed_pred,
|
613 |
+
}
|
614 |
+
)
|
615 |
+
eval_output_dict[res['question_id']] = pred_ans
|
616 |
+
if args.subset is not None:
|
617 |
+
eval_output_dict_path = os.path.join(results_save_dir,
|
618 |
+
f"eval_output_dict_{task_name}_subset_{args.subset}_start_{args.start_idx}.json")
|
619 |
+
else:
|
620 |
+
eval_output_dict_path = os.path.join(results_save_dir, f"eval_output_dict_{task_name}.json")
|
621 |
+
json.dump(eval_output_dict, open(eval_output_dict_path, "w"), indent=4, sort_keys=True)
|
622 |
+
|
623 |
+
mmmu_results = mmmu_main_eval(eval_output_dict, task_cfg)
|
624 |
+
with open(os.path.join(results_save_dir, "scores.txt"), "a") as f:
|
625 |
+
f.write(f"{task_name} {task_cfg['split']}:\n")
|
626 |
+
for cat, cat_val in mmmu_results.items():
|
627 |
+
if 'Overall' in cat:
|
628 |
+
cat = cat.replace("Overall-", "")
|
629 |
+
print(f'{cat}: {cat_val["acc"] * 100:.2f}')
|
630 |
+
f.write(f'{cat}: {cat_val["acc"] * 100:.2f}\n')
|
631 |
+
return mmmu_results['Overall']['acc']
|
632 |
+
|
633 |
+
else:
|
634 |
+
raise NotImplementedError(f"Task {task_name} is not supported yet.")
|
635 |
+
|
636 |
+
|
637 |
+
def evaluate(args, model, tokenizer, tasks):
|
638 |
+
start_time = time.time()
|
639 |
+
|
640 |
+
results_save_dir = os.path.join(args.result_save_path)
|
641 |
+
if not os.path.exists(results_save_dir):
|
642 |
+
os.makedirs(results_save_dir)
|
643 |
+
|
644 |
+
task_names = list(tasks.keys())
|
645 |
+
|
646 |
+
print(f"Evaluating tasks: {task_names}")
|
647 |
+
|
648 |
+
with torch.no_grad():
|
649 |
+
# for task_name in task_names:
|
650 |
+
for task_name in args.zero_shot_eval_tasks:
|
651 |
+
task_cfg = tasks[task_name]
|
652 |
+
print(f"Preparing dataloader for {task_name}...")
|
653 |
+
|
654 |
+
dataloader = get_task_dataloader(task_name, task_cfg, args)
|
655 |
+
|
656 |
+
generation_config = dict(max_new_tokens=1024, do_sample=False)
|
657 |
+
|
658 |
+
print("Start generating...")
|
659 |
+
generate_task_results(task_name, task_cfg, args, model, tokenizer, generation_config,
|
660 |
+
dataloader, results_save_dir)
|
661 |
+
|
662 |
+
print("Start calculating task metric...")
|
663 |
+
calc_task_metrics(args, task_name, task_cfg, results_save_dir)
|
664 |
+
|
665 |
+
end_time = time.time()
|
666 |
+
|
667 |
+
print(f"Evaluation takes {(end_time - start_time) / 60:.1f} minutes in total!")
|
668 |
+
|
669 |
+
|
670 |
+
if __name__ == '__main__':
|
671 |
+
path = "nvidia/NVLM-D-72B"
|
672 |
+
print("Loading model... from", path)
|
673 |
+
device_map = split_model()
|
674 |
+
|
675 |
+
start = time.time()
|
676 |
+
model = AutoModel.from_pretrained(
|
677 |
+
path,
|
678 |
+
torch_dtype=torch.bfloat16,
|
679 |
+
low_cpu_mem_usage=True,
|
680 |
+
use_flash_attn=True,
|
681 |
+
device_map=device_map,
|
682 |
+
trust_remote_code=True).eval()
|
683 |
+
end = time.time()
|
684 |
+
print("loading model takes:", end - start)
|
685 |
+
|
686 |
+
print(model)
|
687 |
+
|
688 |
+
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)
|
689 |
+
|
690 |
+
parser = argparse.ArgumentParser()
|
691 |
+
parser.add_argument('--config-path', type=str, required=True,
|
692 |
+
help='YAML file configuring evaluation datasets')
|
693 |
+
parser.add_argument('--result-save-path', type=str, default=os.path.join(path, '/eval_results'))
|
694 |
+
parser.add_argument('--zero-shot-eval-tasks', nargs='+', type=str, default=['mmmu'])
|
695 |
+
parser.add_argument('--start-idx', type=int, default=0)
|
696 |
+
parser.add_argument('--subset', type=int, default=None)
|
697 |
+
|
698 |
+
args = parser.parse_args()
|
699 |
+
|
700 |
+
tasks = yaml.safe_load(open(args.config_path))['datasets']
|
701 |
+
|
702 |
+
evaluate(args, model, tokenizer, tasks)
|