boxinw@nvidia.com commited on
Commit
b925209
·
1 Parent(s): 07a7f16

Add benchmark evaluation scripts

Browse files
README.md CHANGED
@@ -318,6 +318,22 @@ response = model.chat(tokenizer, pixel_values, question, generation_config)
318
  print(f'User: {question}\nAssistant: {response}')
319
  ```
320
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
  ## Software Integration
322
  **Runtime Engine(s)**
323
  * PyTorch <br>
 
318
  print(f'User: {question}\nAssistant: {response}')
319
  ```
320
 
321
+ ### Benchmark Evaluation
322
+
323
+ To test our NVLM-1.0 model on the benchmark datasets, you can use the following code:
324
+
325
+ ```bash
326
+ python run_eval.py --config-path eval/full_eval.yaml \
327
+ --result-save-path path/to/eval_results/ \
328
+ --zero-shot-eval-tasks chartqa coco_caption flickr30k_caption vqav2 mmmu textvqa mathvista mmbench chartqa docvqa realworldqa ocrbench ai2diagram ai2diagram_nomask mmmu_pro docvqa_test
329
+ ```
330
+
331
+ Specifically,
332
+ - `--config-path eval/full_eval.yaml` file contains the evaluation configurations, including the evaluation prompt, the evaluation dataset paths, and generation hyper-parameters.
333
+ - `--result-save-path path/to/eval_results/` specifies the path to save the evaluation results.
334
+ - `--zero-shot-eval-tasks` specifies the tasks to evaluate on.
335
+
336
+
337
  ## Software Integration
338
  **Runtime Engine(s)**
339
  * PyTorch <br>
eval/conversation.py ADDED
@@ -0,0 +1,492 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # From https://github.com/haotian-liu/LLaVA/blob/main/llava/conversation.py
2
+
3
+ import dataclasses
4
+ from enum import auto, Enum
5
+ from typing import List, Tuple
6
+
7
+
8
+ class SeparatorStyle(Enum):
9
+ """Different separator style."""
10
+ SINGLE = auto()
11
+ TWO = auto()
12
+ MPT = auto()
13
+ PLAIN = auto()
14
+ LLAMA_2 = auto()
15
+
16
+
17
+ @dataclasses.dataclass
18
+ class Conversation:
19
+ """A class that keeps all conversation history."""
20
+ system: str
21
+ roles: List[str]
22
+ messages: List[List[str]]
23
+ offset: int
24
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
25
+ sep: str = "###"
26
+ sep2: str = None
27
+ real_sep2: str = None
28
+ version: str = "Unknown"
29
+
30
+ skip_next: bool = False
31
+
32
+ def get_prompt(self):
33
+ messages = self.messages
34
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
35
+ messages = self.messages.copy()
36
+ init_role, init_msg = messages[0].copy()
37
+ init_msg = init_msg[0].replace("<image>", "").strip()
38
+ if 'mmtag' in self.version:
39
+ messages[0] = (init_role, init_msg)
40
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
41
+ messages.insert(1, (self.roles[1], "Received."))
42
+ else:
43
+ messages[0] = (init_role, "<image>\n" + init_msg)
44
+
45
+ if self.sep_style == SeparatorStyle.SINGLE:
46
+ ret = self.system + self.sep
47
+ for role, message in messages:
48
+ if message:
49
+ if type(message) is tuple:
50
+ message, _, _ = message
51
+ ret += role + ": " + message + self.sep
52
+ else:
53
+ ret += role + ":"
54
+ elif self.sep_style == SeparatorStyle.TWO:
55
+ seps = [self.sep, self.sep2]
56
+ ret = self.system + seps[0]
57
+ for i, (role, message) in enumerate(messages):
58
+ if message:
59
+ if type(message) is tuple:
60
+ message, _, _ = message
61
+ ret += role + ": " + message + seps[i % 2]
62
+ else:
63
+ ret += role + ":"
64
+ elif self.sep_style == SeparatorStyle.MPT:
65
+ ret = self.system + self.sep
66
+ for role, message in messages:
67
+ if message:
68
+ if type(message) is tuple:
69
+ message, _, _ = message
70
+ ret += role + message + self.sep
71
+ else:
72
+ ret += role
73
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
74
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
75
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
76
+ ret = ""
77
+
78
+ for i, (role, message) in enumerate(messages):
79
+ if i == 0:
80
+ assert message, "first message should not be none"
81
+ assert role == self.roles[0], "first message should come from user"
82
+ if message:
83
+ if type(message) is tuple:
84
+ message, _, _ = message
85
+ if i == 0: message = wrap_sys(self.system) + message
86
+ if i % 2 == 0:
87
+ message = wrap_inst(message)
88
+ ret += self.sep + message
89
+ else:
90
+ ret += " " + message + " " + self.sep2
91
+ else:
92
+ ret += ""
93
+ ret = ret.lstrip(self.sep)
94
+ elif self.sep_style == SeparatorStyle.PLAIN:
95
+ seps = [self.sep, self.sep2]
96
+ ret = self.system
97
+ for i, (role, message) in enumerate(messages):
98
+ if message:
99
+ if type(message) is tuple:
100
+ message, _, _ = message
101
+ ret += message + seps[i % 2]
102
+ else:
103
+ ret += ""
104
+ else:
105
+ raise ValueError(f"Invalid style: {self.sep_style}")
106
+
107
+ return ret
108
+
109
+ def append_message(self, role, message):
110
+ self.messages.append([role, message])
111
+
112
+ def get_images(self, return_pil=False):
113
+ images = []
114
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
115
+ if i % 2 == 0:
116
+ if type(msg) is tuple:
117
+ import base64
118
+ from io import BytesIO
119
+ from PIL import Image
120
+ msg, image, image_process_mode = msg
121
+ if image_process_mode == "Pad":
122
+ def expand2square(pil_img, background_color=(122, 116, 104)):
123
+ width, height = pil_img.size
124
+ if width == height:
125
+ return pil_img
126
+ elif width > height:
127
+ result = Image.new(pil_img.mode, (width, width), background_color)
128
+ result.paste(pil_img, (0, (width - height) // 2))
129
+ return result
130
+ else:
131
+ result = Image.new(pil_img.mode, (height, height), background_color)
132
+ result.paste(pil_img, ((height - width) // 2, 0))
133
+ return result
134
+ image = expand2square(image)
135
+ elif image_process_mode in ["Default", "Crop"]:
136
+ pass
137
+ elif image_process_mode == "Resize":
138
+ image = image.resize((336, 336))
139
+ else:
140
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
141
+ max_hw, min_hw = max(image.size), min(image.size)
142
+ aspect_ratio = max_hw / min_hw
143
+ max_len, min_len = 800, 400
144
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
145
+ longest_edge = int(shortest_edge * aspect_ratio)
146
+ W, H = image.size
147
+ if longest_edge != max(image.size):
148
+ if H > W:
149
+ H, W = longest_edge, shortest_edge
150
+ else:
151
+ H, W = shortest_edge, longest_edge
152
+ image = image.resize((W, H))
153
+ if return_pil:
154
+ images.append(image)
155
+ else:
156
+ buffered = BytesIO()
157
+ image.save(buffered, format="PNG")
158
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
159
+ images.append(img_b64_str)
160
+ return images
161
+
162
+ def to_gradio_chatbot(self):
163
+ ret = []
164
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
165
+ if i % 2 == 0:
166
+ if type(msg) is tuple:
167
+ import base64
168
+ from io import BytesIO
169
+ msg, image, image_process_mode = msg
170
+ max_hw, min_hw = max(image.size), min(image.size)
171
+ aspect_ratio = max_hw / min_hw
172
+ max_len, min_len = 800, 400
173
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
174
+ longest_edge = int(shortest_edge * aspect_ratio)
175
+ W, H = image.size
176
+ if H > W:
177
+ H, W = longest_edge, shortest_edge
178
+ else:
179
+ H, W = shortest_edge, longest_edge
180
+ image = image.resize((W, H))
181
+ buffered = BytesIO()
182
+ image.save(buffered, format="JPEG")
183
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
184
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
185
+ msg = img_str + msg.replace('<image>', '').strip()
186
+ ret.append([msg, None])
187
+ else:
188
+ ret.append([msg, None])
189
+ else:
190
+ ret[-1][-1] = msg
191
+ return ret
192
+
193
+ def copy(self):
194
+ return Conversation(
195
+ system=self.system,
196
+ roles=self.roles,
197
+ messages=[[x, y] for x, y in self.messages],
198
+ offset=self.offset,
199
+ sep_style=self.sep_style,
200
+ sep=self.sep,
201
+ sep2=self.sep2,
202
+ real_sep2=self.real_sep2,
203
+ version=self.version)
204
+
205
+ def dict(self):
206
+ if len(self.get_images()) > 0:
207
+ return {
208
+ "system": self.system,
209
+ "roles": self.roles,
210
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
211
+ "offset": self.offset,
212
+ "sep": self.sep,
213
+ "sep2": self.sep2,
214
+ "real_sep2": self.real_sep2
215
+ }
216
+ return {
217
+ "system": self.system,
218
+ "roles": self.roles,
219
+ "messages": self.messages,
220
+ "offset": self.offset,
221
+ "sep": self.sep,
222
+ "sep2": self.sep2,
223
+ "real_sep2": self.real_sep2
224
+ }
225
+
226
+
227
+ conv_vicuna_v0 = Conversation(
228
+ system="A chat between a curious human and an artificial intelligence assistant. "
229
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
230
+ roles=("Human", "Assistant"),
231
+ messages=(
232
+ ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
233
+ ("Assistant",
234
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
235
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
236
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
237
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
238
+ "renewable and non-renewable energy sources:\n"
239
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
240
+ "energy sources are finite and will eventually run out.\n"
241
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
242
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
243
+ "and other negative effects.\n"
244
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
245
+ "have lower operational costs than non-renewable sources.\n"
246
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
247
+ "locations than non-renewable sources.\n"
248
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
249
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
250
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
251
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
252
+ ),
253
+ offset=2,
254
+ sep_style=SeparatorStyle.SINGLE,
255
+ sep="###",
256
+ )
257
+
258
+ ### Used for llava-instruction-tuning stage
259
+ conv_vicuna_v1 = Conversation(
260
+ system="A chat between a curious user and an artificial intelligence assistant. "
261
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
262
+ roles=("USER", "ASSISTANT"),
263
+ version="v1",
264
+ messages=(),
265
+ offset=0,
266
+ sep_style=SeparatorStyle.TWO,
267
+ sep=" ",
268
+ sep2="</s>",
269
+ )
270
+
271
+ conv_llama_2 = Conversation(
272
+ system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
273
+
274
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
275
+ roles=("USER", "ASSISTANT"),
276
+ version="llama_v2",
277
+ messages=(),
278
+ offset=0,
279
+ sep_style=SeparatorStyle.LLAMA_2,
280
+ sep="<s>",
281
+ sep2="</s>",
282
+ )
283
+
284
+ conv_llava_llama_2 = Conversation(
285
+ system="You are a helpful language and vision assistant. "
286
+ "You are able to understand the visual content that the user provides, "
287
+ "and assist the user with a variety of tasks using natural language.",
288
+ roles=("USER", "ASSISTANT"),
289
+ version="llama_v2",
290
+ messages=(),
291
+ offset=0,
292
+ sep_style=SeparatorStyle.LLAMA_2,
293
+ sep="<s>",
294
+ sep2="</s>",
295
+ )
296
+
297
+ conv_mpt = Conversation(
298
+ system="""<|im_start|>system
299
+ A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
300
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
301
+ version="mpt",
302
+ messages=(),
303
+ offset=0,
304
+ sep_style=SeparatorStyle.MPT,
305
+ sep="<|im_end|>",
306
+ )
307
+
308
+
309
+
310
+ ### Used for llava-pretraining
311
+ conv_llava_plain = Conversation(
312
+ system="",
313
+ roles=("", ""),
314
+ messages=(
315
+ ),
316
+ offset=0,
317
+ sep_style=SeparatorStyle.PLAIN,
318
+ sep="\n",
319
+ )
320
+
321
+ conv_llava_v0 = Conversation(
322
+ system="A chat between a curious human and an artificial intelligence assistant. "
323
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
324
+ roles=("Human", "Assistant"),
325
+ messages=(
326
+ ),
327
+ offset=0,
328
+ sep_style=SeparatorStyle.SINGLE,
329
+ sep="###",
330
+ )
331
+
332
+ conv_llava_v0_mmtag = Conversation(
333
+ system="A chat between a curious user and an artificial intelligence assistant. "
334
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
335
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
336
+ roles=("Human", "Assistant"),
337
+ messages=(
338
+ ),
339
+ offset=0,
340
+ sep_style=SeparatorStyle.SINGLE,
341
+ sep="###",
342
+ version="v0_mmtag",
343
+ )
344
+
345
+ conv_llava_v1 = Conversation(
346
+ system="A chat between a curious human and an artificial intelligence assistant. "
347
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
348
+ roles=("USER", "ASSISTANT"),
349
+ version="v1",
350
+ messages=(),
351
+ offset=0,
352
+ sep_style=SeparatorStyle.TWO,
353
+ sep=" ",
354
+ sep2="</s>",
355
+ )
356
+
357
+ conv_llava_v1_mmtag = Conversation(
358
+ system="A chat between a curious user and an artificial intelligence assistant. "
359
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
360
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
361
+ roles=("USER", "ASSISTANT"),
362
+ messages=(),
363
+ offset=0,
364
+ sep_style=SeparatorStyle.TWO,
365
+ sep=" ",
366
+ sep2="</s>",
367
+ version="v1_mmtag",
368
+ )
369
+
370
+
371
+ nvllm_8b_pretrain = Conversation(
372
+ system="",
373
+ roles=(),
374
+ version="nvllm_8b",
375
+ messages=(),
376
+ offset=0,
377
+ sep_style=SeparatorStyle.SINGLE,
378
+ sep="\n",
379
+ )
380
+
381
+ nvllm_8b_sft = Conversation(
382
+ system="System: This is a chat between a user and an artificial intelligence assistant. "
383
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
384
+ roles=("User", "Assistant"),
385
+ version="nvllm_8b",
386
+ messages=(),
387
+ offset=0,
388
+ sep_style=SeparatorStyle.TWO,
389
+ sep="\n\n",
390
+ sep2="\n\n\n",
391
+ real_sep2="\n\n"
392
+ )
393
+
394
+ chatqa_sft = Conversation(
395
+ system="System: This is a chat between a user and an artificial intelligence assistant. "
396
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
397
+ roles=("User", "Assistant"),
398
+ version="chatqa",
399
+ messages=(),
400
+ offset=0,
401
+ sep_style=SeparatorStyle.TWO,
402
+ sep="\n\n",
403
+ sep2="\n\n",
404
+ real_sep2="\n\n"
405
+ )
406
+
407
+ nvllm_8b_sft_noinstruction = Conversation(
408
+ system="",
409
+ roles=("User", "Assistant"),
410
+ version="nvllm_8b",
411
+ messages=(),
412
+ offset=0,
413
+ sep_style=SeparatorStyle.TWO,
414
+ sep="\n\n",
415
+ sep2="\n\n\n",
416
+ real_sep2="\n\n"
417
+ )
418
+
419
+ # conv_yi = Conversation(
420
+ # system="""<|im_start|>system
421
+ # A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
422
+ # roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
423
+ # version="mpt",
424
+ # messages=(),
425
+ # offset=0,
426
+ # sep_style=SeparatorStyle.MPT,
427
+ # sep="<|im_end|>",
428
+ # )
429
+
430
+ conv_chatml = Conversation(
431
+ system="""<|im_start|>system
432
+ Answer the questions.""",
433
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
434
+ version="mpt",
435
+ messages=(),
436
+ offset=0,
437
+ sep_style=SeparatorStyle.MPT,
438
+ sep="<|im_end|>",
439
+ )
440
+
441
+ llama3_instruct = Conversation(
442
+ system="<|start_header_id|>system<|end_header_id|>\n\nAnswer the questions.",
443
+ roles=("<|start_header_id|>user<|end_header_id|>\n\n", "<|start_header_id|>assistant<|end_header_id|>\n\n"),
444
+ version="mpt",
445
+ messages=(),
446
+ offset=0,
447
+ sep_style=SeparatorStyle.MPT,
448
+ sep="<|eot_id|>",
449
+ )
450
+
451
+ llama3_1_instruct = Conversation(
452
+ system="<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nAnswer the questions.",
453
+ roles=("<|start_header_id|>user<|end_header_id|>\n\n", "<|start_header_id|>assistant<|end_header_id|>\n\n"),
454
+ version="mpt",
455
+ messages=(),
456
+ offset=0,
457
+ sep_style=SeparatorStyle.MPT,
458
+ sep="<|eot_id|>",
459
+ )
460
+
461
+ # default_conversation = conv_vicuna_v0
462
+ default_conversation = nvllm_8b_sft
463
+
464
+ # original_llava_pretraining = conv_llava_plain
465
+ # original_llava_sft = conv_vicuna_v1
466
+
467
+ conv_templates = {
468
+ "default": conv_vicuna_v0,
469
+ "v0": conv_vicuna_v0,
470
+ "v1": conv_vicuna_v1,
471
+ "vicuna_v1": conv_vicuna_v1,
472
+ "llama_2": conv_llama_2,
473
+
474
+ "plain": conv_llava_plain,
475
+ "v0_plain": conv_llava_plain,
476
+ "llava_v0": conv_llava_v0,
477
+ "v0_mmtag": conv_llava_v0_mmtag,
478
+ "llava_v1": conv_llava_v1,
479
+ "v1_mmtag": conv_llava_v1_mmtag,
480
+ "llava_llama_2": conv_llava_llama_2,
481
+
482
+ "mpt": conv_mpt,
483
+ }
484
+
485
+
486
+ if __name__ == "__main__":
487
+
488
+ print(default_conversation)
489
+ print(default_conversation.roles[0])
490
+
491
+
492
+ # print(default_conversation.get_prompt())
eval/eval_dataset.py ADDED
@@ -0,0 +1,850 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import sys
4
+ import time
5
+ import yaml
6
+ import spacy
7
+ import ast
8
+ from PIL import Image
9
+ from glob import glob
10
+ from tqdm import tqdm
11
+ from collections import defaultdict
12
+ import pandas as pd
13
+ from io import BytesIO
14
+ import base64
15
+ from anls import anls_score
16
+ import torch
17
+ from torch.utils.data import Dataset, DataLoader, DistributedSampler
18
+ import torchvision.transforms as T
19
+ from eval import conversation as conversation_lib
20
+ from eval.mmmu_utils import CAT_SHORT2LONG, DOMAIN_CAT2SUB_CAT, parse_multi_choice_response, parse_open_response, \
21
+ process_single_sample, construct_prompt, mmmu_main_eval, process_single_sample_pro, construct_prompt_pro
22
+ from eval.mmmu_utils import evaluate as evaluate_mmmu
23
+ from torchvision.transforms.functional import InterpolationMode
24
+ from datasets import load_dataset, concatenate_datasets
25
+
26
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
27
+ IMAGENET_STD = (0.229, 0.224, 0.225)
28
+
29
+
30
+ def build_transform(input_size):
31
+ MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
32
+ transform = T.Compose([
33
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
34
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
35
+ T.ToTensor(),
36
+ T.Normalize(mean=MEAN, std=STD)
37
+ ])
38
+ return transform
39
+
40
+
41
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
42
+ best_ratio_diff = float('inf')
43
+ best_ratio = (1, 1)
44
+ area = width * height
45
+ for ratio in target_ratios:
46
+ target_aspect_ratio = ratio[0] / ratio[1]
47
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
48
+ if ratio_diff < best_ratio_diff:
49
+ best_ratio_diff = ratio_diff
50
+ best_ratio = ratio
51
+ elif ratio_diff == best_ratio_diff:
52
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
53
+ best_ratio = ratio
54
+ return best_ratio
55
+
56
+
57
+ def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False):
58
+ orig_width, orig_height = image.size
59
+ aspect_ratio = orig_width / orig_height
60
+
61
+ # calculate the existing image aspect ratio
62
+ target_ratios = set(
63
+ (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
64
+ i * j <= max_num and i * j >= min_num)
65
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
66
+
67
+ # find the closest aspect ratio to the target
68
+ target_aspect_ratio = find_closest_aspect_ratio(
69
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size)
70
+
71
+ # calculate the target width and height
72
+ target_width = image_size * target_aspect_ratio[0]
73
+ target_height = image_size * target_aspect_ratio[1]
74
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
75
+
76
+ # resize the image
77
+ resized_img = image.resize((target_width, target_height))
78
+ processed_images = []
79
+ for i in range(blocks):
80
+ box = (
81
+ (i % (target_width // image_size)) * image_size,
82
+ (i // (target_width // image_size)) * image_size,
83
+ ((i % (target_width // image_size)) + 1) * image_size,
84
+ ((i // (target_width // image_size)) + 1) * image_size
85
+ )
86
+ # split the image
87
+ split_img = resized_img.crop(box)
88
+ processed_images.append(split_img)
89
+ assert len(processed_images) == blocks
90
+ if use_thumbnail and len(processed_images) != 1:
91
+ thumbnail_img = image.resize((image_size, image_size))
92
+ processed_images.append(thumbnail_img)
93
+ return processed_images
94
+
95
+
96
+ def load_image(image, input_size=448, max_num=6, decoded=False):
97
+ if not decoded:
98
+ image = Image.open(image).convert('RGB')
99
+ transform = build_transform(input_size=input_size)
100
+ images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
101
+ pixel_values = [transform(image) for image in images]
102
+ pixel_values = torch.stack(pixel_values)
103
+ return pixel_values
104
+
105
+
106
+ def levenshtein_distance(s1, s2):
107
+ if len(s1) > len(s2):
108
+ s1, s2 = s2, s1
109
+
110
+ distances = range(len(s1) + 1)
111
+ for i2, c2 in enumerate(s2):
112
+ distances_ = [i2 + 1]
113
+ for i1, c1 in enumerate(s1):
114
+ if c1 == c2:
115
+ distances_.append(distances[i1])
116
+ else:
117
+ distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
118
+ distances = distances_
119
+ return distances[-1]
120
+
121
+
122
+ def get_anls_score(pred, gold_labels, threshold, llava_eval=False):
123
+ values = []
124
+ for answer in gold_labels:
125
+ # preprocess both the answers - gt and prediction
126
+ gt_answer = ' '.join(answer.strip().lower().split())
127
+ det_answer = ' '.join(pred.strip().lower().split())
128
+
129
+ dist = levenshtein_distance(gt_answer, det_answer)
130
+ length = max(len(answer.upper()), len(pred.upper()))
131
+ values.append(0.0 if length == 0 else float(dist) / float(length))
132
+
133
+ question_result = 1 - min(values)
134
+
135
+ if llava_eval:
136
+ question_result = 1.0 if question_result >= threshold else 0.0
137
+ else:
138
+ if (question_result < threshold):
139
+ question_result = 0
140
+
141
+ return question_result
142
+
143
+
144
+ def isNumber(n: str):
145
+ try:
146
+ float(n)
147
+ return True
148
+ except ValueError:
149
+ return False
150
+
151
+
152
+ class COCOEvalDataset(Dataset):
153
+ def __init__(self, args, img_dir, subset=None):
154
+ self.args = args
155
+ self.img_files = sorted(glob(os.path.join(img_dir, "*")))
156
+
157
+ if subset:
158
+ self.img_files = self.img_files[:subset]
159
+
160
+ self.image_ids = [int(img_file.split("_")[-1].split(".")[0]) for img_file in self.img_files]
161
+
162
+ def __len__(self):
163
+ return len(self.img_files)
164
+
165
+ def __getitem__(self, idx):
166
+ img_path = self.img_files[idx]
167
+ img = load_image(img_path, max_num=6).to(torch.bfloat16)
168
+
169
+ return self.image_ids[idx], img
170
+
171
+
172
+ class Flickr30KEvalDataset(Dataset):
173
+ def __init__(self, args, img_dir, subset=None):
174
+ self.args = args
175
+ self.img_dir = img_dir
176
+ self.test_samples = json.load(open(os.path.join(img_dir, "flickr30k_test.json"), encoding='utf-8'))
177
+
178
+ if subset:
179
+ self.test_samples = self.test_samples[:subset]
180
+
181
+ def __len__(self):
182
+ return len(self.test_samples)
183
+
184
+ def __getitem__(self, idx):
185
+ img_path = os.path.join(self.img_dir, self.test_samples[idx]["image"])
186
+ img = load_image(img_path, max_num=6).to(torch.bfloat16)
187
+
188
+ image_id = int(self.test_samples[idx]["image"].split("/")[-1].replace(".jpg", ""))
189
+
190
+ return image_id, img
191
+
192
+
193
+ class VQAv2EvalDataset(Dataset):
194
+ def __init__(self, args, img_dir, gt_path, subset=None):
195
+ self.args = args
196
+ self.img_dir = img_dir
197
+ self.gt = json.load(open(gt_path, encoding='utf-8'))
198
+
199
+ if subset:
200
+ self.gt = self.gt[:subset]
201
+
202
+ def __len__(self):
203
+ return len(self.gt)
204
+
205
+ def __getitem__(self, idx):
206
+ img_path = os.path.join(self.img_dir, self.gt[idx]["image"])
207
+ img = load_image(img_path, max_num=6).to(torch.bfloat16)
208
+
209
+ question_id = self.gt[idx]["question_id"]
210
+ question = self.gt[idx]["question"]
211
+ answer = self.gt[idx]["answer"]
212
+
213
+ return img, question_id, question, answer
214
+
215
+
216
+ class TextVQAEvalDataset(Dataset):
217
+ def __init__(self, args, img_dir, gt_path, subset=None):
218
+ self.args = args
219
+ self.img_dir = img_dir
220
+ self.gt = json.load(open(gt_path, encoding='utf-8'))['data']
221
+
222
+ if subset:
223
+ self.gt = self.gt[:subset]
224
+
225
+ def __len__(self):
226
+ return len(self.gt)
227
+
228
+ def __getitem__(self, idx):
229
+ img_path = os.path.join(self.img_dir, self.gt[idx]["image_id"] + '.jpg')
230
+ if not os.path.exists(img_path):
231
+ img_path = img_path.replace('.jpg', '.png')
232
+ img = load_image(img_path, max_num=6).to(torch.bfloat16)
233
+
234
+ question_id = self.gt[idx]["question_id"]
235
+ question = self.gt[idx]["question"]
236
+ answer = self.gt[idx]["answers"]
237
+
238
+ return img, question_id, question, answer
239
+
240
+
241
+ class GQAEvalDataset(Dataset):
242
+ def __init__(self, args, img_dir, gt_path, subset=None):
243
+ self.args = args
244
+ self.img_dir = img_dir
245
+ self.gt = json.load(open(gt_path, encoding='utf-8'))
246
+ self.gt = [{
247
+ "question_id": int(k),
248
+ "image": v['imageId'] + ".jpg",
249
+ "question": v['question'],
250
+ "answer": v['answer']
251
+ } for k, v in self.gt.items()]
252
+
253
+ if subset:
254
+ self.gt = self.gt[:subset]
255
+
256
+ def __len__(self):
257
+ return len(self.gt)
258
+
259
+ def __getitem__(self, idx):
260
+ img_path = os.path.join(self.img_dir, self.gt[idx]["image"])
261
+ img = load_image(img_path, max_num=6).to(torch.bfloat16)
262
+
263
+ question_id = self.gt[idx]["question_id"]
264
+ question = self.gt[idx]["question"]
265
+ answer = self.gt[idx]["answer"]
266
+
267
+ return img, question_id, question, [answer]
268
+
269
+
270
+ class ChartQAEvalDataset(Dataset):
271
+ def __init__(self, args, img_dir, gt_path, subset=None):
272
+ self.args = args
273
+ self.img_dir = img_dir
274
+ self.gt = json.load(open(gt_path, encoding='utf-8'))
275
+ for i in range(len(self.gt)):
276
+ self.gt[i]['question_id'] = i
277
+
278
+ if subset:
279
+ self.gt = self.gt[:subset]
280
+
281
+ def __len__(self):
282
+ return len(self.gt)
283
+
284
+ def __getitem__(self, idx):
285
+ img_path = os.path.join(self.img_dir, self.gt[idx]["imgname"])
286
+ img = load_image(img_path, max_num=6).to(torch.bfloat16)
287
+
288
+ question_id = self.gt[idx]["question_id"]
289
+ question = self.gt[idx]["query"]
290
+ answer = self.gt[idx]["label"]
291
+
292
+ return img, question_id, question, [answer]
293
+
294
+
295
+ class OKVQAEvalDataset(Dataset):
296
+ def __init__(self, args, img_dir, gt_path, question_path, subset=None):
297
+ self.args = args
298
+ self.img_dir = img_dir
299
+ self.gt = json.load(open(gt_path, encoding='utf-8'))['annotations']
300
+ self.questions = json.load(open(question_path, 'r'))['questions']
301
+
302
+ if subset:
303
+ self.gt = self.gt[:subset]
304
+
305
+ qid2q = {q['question_id']: q['question'] for q in self.questions}
306
+
307
+ for ann in self.gt:
308
+ ann['answers'] = [ans['answer'] for ans in ann['answers']]
309
+ ann['question'] = qid2q[ann['question_id']]
310
+
311
+ def __len__(self):
312
+ return len(self.gt)
313
+
314
+ def __getitem__(self, idx):
315
+ img_id = str(self.gt[idx]["image_id"])
316
+ img_id = '0' * (12 - len(img_id)) + img_id
317
+ img_file_name = f"COCO_val2014_{img_id}.jpg"
318
+ img_path = os.path.join(self.img_dir, img_file_name)
319
+ img = load_image(img_path, max_num=6).to(torch.bfloat16)
320
+
321
+ question_id = self.gt[idx]["question_id"]
322
+ question = self.gt[idx]["question"]
323
+ answer = self.gt[idx]["answers"]
324
+
325
+ return img, question_id, question, answer
326
+
327
+
328
+ class DocVQAEvalDataset(Dataset):
329
+ def __init__(self, args, img_dir, gt_path, split='val', subset=None):
330
+ self.args = args
331
+ self.img_dir = img_dir
332
+ self.gt = json.load(open(gt_path, encoding='utf-8'))['data']
333
+
334
+ if subset:
335
+ self.gt = self.gt[:subset]
336
+
337
+ self.split = split
338
+
339
+ def __len__(self):
340
+ return len(self.gt)
341
+
342
+ def __getitem__(self, idx):
343
+ img_path = os.path.join(self.img_dir, self.gt[idx]['image'].split('/')[-1])
344
+ img = load_image(img_path, max_num=6).to(torch.bfloat16)
345
+
346
+ question_id = self.gt[idx]["questionId"]
347
+ question = self.gt[idx]["question"]
348
+
349
+ if self.split == 'val':
350
+ answer = self.gt[idx]["answers"]
351
+ else:
352
+ answer = ['']
353
+
354
+ return img, question_id, question, answer
355
+
356
+
357
+ class OCRBenchEvalDataset(Dataset):
358
+ def __init__(self, args, img_dir, gt_path, subset=None):
359
+ self.args = args
360
+ self.img_dir = img_dir
361
+ self.gt = json.load(open(gt_path, encoding='utf-8'))
362
+
363
+ if subset:
364
+ self.gt = self.gt[:subset]
365
+
366
+ def __len__(self):
367
+ return len(self.gt)
368
+
369
+ def __getitem__(self, idx):
370
+ img_path = os.path.join(self.img_dir, self.gt[idx]['image_path'])
371
+ img = load_image(img_path, max_num=6).to(torch.bfloat16)
372
+
373
+ dataset_name = self.gt[idx]["dataset_name"]
374
+ question_id = f"{idx}"
375
+ question = self.gt[idx]["question"]
376
+ answer = self.gt[idx]["answers"]
377
+ data_type = self.gt[idx]["type"]
378
+
379
+ return img, question_id, question, answer, dataset_name, data_type
380
+
381
+
382
+ class AI2DiagramEvalDataset(Dataset):
383
+ def __init__(self, args, img_dir, gt_path, subset=None):
384
+ self.args = args
385
+ self.img_dir = img_dir
386
+
387
+ with open(gt_path, 'r') as json_file:
388
+ json_list = list(json_file)
389
+ self.gt = [json.loads(json_str) for json_str in json_list]
390
+
391
+ if subset:
392
+ self.gt = self.gt[:subset]
393
+
394
+ def __len__(self):
395
+ return len(self.gt)
396
+
397
+ def __getitem__(self, idx):
398
+ img_path = os.path.join(self.img_dir, self.gt[idx]['image'])
399
+ img = load_image(img_path, max_num=6).to(torch.bfloat16)
400
+
401
+ question_id = self.gt[idx]["question_id"]
402
+ question = self.gt[idx]["question"]
403
+ answer = self.gt[idx]["answer"]
404
+
405
+ return img, question_id, question, answer
406
+
407
+
408
+ class AI2DiagramNoMaskEvalDataset(Dataset):
409
+ def __init__(self, args, img_dir, gt_path, subset=None):
410
+ self.args = args
411
+ self.img_dir = img_dir
412
+
413
+ with open(gt_path, 'r') as json_file:
414
+ json_list = list(json_file)
415
+ self.gt = [json.loads(json_str) for json_str in json_list]
416
+
417
+ if subset:
418
+ self.gt = self.gt[:subset]
419
+
420
+ def __len__(self):
421
+ return len(self.gt)
422
+
423
+ def __getitem__(self, idx):
424
+ img_file_name = self.gt[idx]['image'].replace("AI2D_TEST", "AI2D_TEST_NO_MASK_IMAGES")
425
+ img_path = os.path.join(self.img_dir, img_file_name)
426
+ img = load_image(img_path, max_num=6).to(torch.bfloat16)
427
+
428
+ question_id = self.gt[idx]["question_id"]
429
+ question = self.gt[idx]["question"]
430
+ answer = self.gt[idx]["answer"]
431
+
432
+ return img, question_id, question, answer
433
+
434
+
435
+ class RealworldQAEvalDataset(Dataset):
436
+ def __init__(self, args, img_dir, gt_path, subset=None):
437
+ self.args = args
438
+ self.img_dir = img_dir
439
+ self.gt = json.load(open(gt_path, encoding='utf-8'))
440
+
441
+ if subset:
442
+ self.gt = self.gt[:subset]
443
+
444
+ def __len__(self):
445
+ return len(self.gt)
446
+
447
+ def __getitem__(self, idx):
448
+ img_path = os.path.join(self.img_dir, self.gt[idx]['image'])
449
+ img = load_image(img_path, max_num=6).to(torch.bfloat16)
450
+
451
+ question_id = int(self.gt[idx]['image'].replace(".webp", ""))
452
+ question = self.gt[idx]["question"]
453
+
454
+ if self.gt[idx]['question_type'] == "multi-choice":
455
+ choices = self.gt[idx]["choices"]
456
+ start_chr = 'A'
457
+ choices_str = ''
458
+ index2ans = {}
459
+ all_choices = []
460
+ for choice in choices:
461
+ all_choices.append(start_chr)
462
+ index2ans[start_chr] = choice
463
+ choices_str += f"{start_chr}. {choice}\n"
464
+ start_chr = chr(ord(start_chr) + 1)
465
+
466
+ question = question + '\n' + choices_str
467
+ question = question + "Answer with the option's letter from the given choices directly."
468
+ answer = chr(ord('A') + self.gt[idx]['correct_choice_index'])
469
+ else:
470
+ question = question + "\nAnswer the question using a single word or phrase."
471
+ answer = self.gt[idx]['answer']
472
+
473
+ return img, question_id, question, [answer]
474
+
475
+
476
+ class MathVistaEvalDataset(Dataset):
477
+ def __init__(self, args, task_cfg, gt_path=None):
478
+ self.args = args
479
+ self.task_cfg = task_cfg
480
+ self.dataset = load_dataset("AI4Math/MathVista")['testmini']
481
+
482
+ def __len__(self):
483
+ return len(self.dataset)
484
+
485
+ def __getitem__(self, idx):
486
+ img = self.dataset[idx]['decoded_image']
487
+ img = load_image(img.convert("RGB"), max_num=6, decoded=True).to(torch.bfloat16)
488
+
489
+ question_id = self.dataset[idx]["pid"]
490
+ question = self.dataset[idx]["question"]
491
+ question_type = self.dataset[idx]["question_type"] # free_form or multi_choice
492
+ query = self.dataset[idx]["query"]
493
+ choices = self.dataset[idx]["choices"]
494
+ answer = self.dataset[idx]["answer"]
495
+
496
+ if question_type == 'multi_choice':
497
+ start_chr = 'A'
498
+ choices_str = ''
499
+ index2ans = {}
500
+ all_choices = []
501
+ for choice in choices:
502
+ all_choices.append(start_chr)
503
+ index2ans[start_chr] = choice
504
+ choices_str += f"{start_chr}. {choice}\n"
505
+ start_chr = chr(ord(start_chr) + 1)
506
+
507
+ question = question + '\n' + choices_str
508
+ question = question + "Answer with the option's letter from the given choices directly."
509
+ answer = chr(ord('A') + choices.index(answer))
510
+ else:
511
+ question = query.replace("Hint: ", "")
512
+ index2ans = {}
513
+ all_choices = []
514
+
515
+ return img, question_id, question_type, question, answer, str(index2ans), str(all_choices)
516
+
517
+
518
+ def construct_prompt_for_fewshot(sample):
519
+ config = {
520
+ "task_instructions": "",
521
+ "multi_choice_example_format": "{}\n{}Answer with the option's letter from the given choices directly.",
522
+ "short_ans_example_format": "{}\nAnswer the question using a single word or phrase."
523
+ }
524
+
525
+ question = sample['question'].strip()
526
+
527
+
528
+ options = eval(sample['options'])
529
+ example = ""
530
+ if sample['question_type'] == 'multiple-choice':
531
+ start_chr = 'A'
532
+ prediction_range = []
533
+ index2ans = {}
534
+ for option in options:
535
+ prediction_range.append(start_chr)
536
+ example += f"({start_chr}) {option}\n"
537
+ index2ans[start_chr] = option
538
+ start_chr = chr(ord(start_chr) + 1)
539
+ empty_prompt_sample_structure = config['multi_choice_example_format']
540
+ empty_prompt = empty_prompt_sample_structure.format(question, example)
541
+ res_dict = {'type': 'multichoice'}
542
+ res_dict['index2ans'] = index2ans
543
+ res_dict['correct_choice'] = sample['answer']
544
+ res_dict['all_choices'] = prediction_range
545
+ res_dict['empty_prompt'] = empty_prompt
546
+ if config['task_instructions']:
547
+ res_dict['final_input_prompt'] = config['task_instructions'].strip() + '\n\n' + empty_prompt
548
+ else:
549
+ res_dict['final_input_prompt'] = empty_prompt
550
+
551
+ res_dict['gt_content'] = options[ord(sample['answer'].upper()) - ord('A')]
552
+ else:
553
+ empty_prompt_sample_structure = config['short_ans_example_format']
554
+ empty_prompt = empty_prompt_sample_structure.format(question)
555
+ res_dict = {'type': 'open'}
556
+ res_dict['empty_prompt'] = empty_prompt
557
+ if config['task_instructions']:
558
+ res_dict['final_input_prompt'] = config['task_instructions'].strip() + '\n\n' + empty_prompt
559
+ else:
560
+ res_dict['final_input_prompt'] = empty_prompt
561
+ res_dict['gt_content'] = sample['answer']
562
+
563
+ res_dict.update(sample)
564
+ return res_dict
565
+
566
+
567
+ def process_image_tag(q):
568
+ q = q.strip()
569
+
570
+ # heuristic way of removing <image 1>
571
+ if q == '<image 1>':
572
+ q = 'Answer the question in the image.'
573
+ elif ':<image 1>' in q:
574
+ q = q.replace(':<image 1>', ' in the image. ')
575
+ q = q.strip()
576
+ elif ': <image 1>' in q:
577
+ q = q.replace(': <image 1>', ' in the image. ')
578
+ q = q.strip()
579
+ elif '.<image 1>' in q or '. <image 1>' in q:
580
+ q_list = q.split('<image 1>')
581
+ q_list = [part.strip() for part in q_list if part.strip() != '']
582
+ q = ' '.join(q_list)
583
+ elif q.startswith('<image 1> '):
584
+ if q[10].isupper():
585
+ q = q.replace('<image 1>', '')
586
+ else:
587
+ q = q.replace('<image 1>', 'The image')
588
+ q = q.strip()
589
+ elif q.startswith('<image 1>'):
590
+ q = q.replace('<image 1>', '')
591
+ elif q.endswith('<image 1>?'):
592
+ q = q.replace('<image 1>', 'the image')
593
+ elif q.endswith('?<image 1>') or q.endswith('? <image 1>') or q.endswith('\n<image 1>'):
594
+ q = q.replace('<image 1>', '')
595
+ q = q.strip()
596
+ elif ' <image 1> ' in q:
597
+ q = q.replace('<image 1>', 'the image')
598
+ elif ' <image 1>' in q:
599
+ q = q.replace('<image 1>', 'the image')
600
+ elif '()<image 1>' in q:
601
+ q = q.replace('()<image 1>', '')
602
+ elif '(<image 1>)' in q:
603
+ q = q.replace('(<image 1>)', '')
604
+ elif '<image 1>.' in q:
605
+ q = q.replace("<image 1>.", ". ")
606
+ else:
607
+ q = q.replace("<image 1>", ". ")
608
+ q = q.strip()
609
+
610
+ # remove <image 2> to <image 8>
611
+ for i in range(2, 8):
612
+ q = q.replace(f"<image {i}>", "")
613
+
614
+ return q
615
+
616
+
617
+ class MMMUProEvalDataset(Dataset):
618
+ def __init__(self, args, task_cfg, subset=None):
619
+ self.args = args
620
+ self.task_cfg = task_cfg
621
+ sub_dataset_list = []
622
+ # load_dataset will throw error if split is 'dev'
623
+ # 'dev' is part of the 'validation' and we need to manually split them
624
+
625
+ MMMU_path = "MMMU/MMMU_Pro"
626
+
627
+ _split = "test"
628
+
629
+ self.dataset = load_dataset(MMMU_path, "standard", split=_split)
630
+ if subset:
631
+ self.dataset = self.dataset[:subset]
632
+
633
+ def __len__(self):
634
+ return len(self.dataset)
635
+
636
+ def __getitem__(self, idx):
637
+ # ===== single-image =====
638
+ sample = self.dataset[idx]
639
+ sample = process_single_sample_pro(sample)
640
+ sample = construct_prompt_pro(sample, self.task_cfg)
641
+ img = load_image(sample['image'].convert("RGB"), max_num=6, decoded=True).to(torch.bfloat16)
642
+
643
+ # img = img.reshape(-1, 3, self.args.img_h, self.args.img_w)
644
+
645
+ question_id = sample['id']
646
+ question = sample['final_input_prompt']
647
+ answer = sample['answer']
648
+
649
+ question = process_image_tag(question)
650
+ question = self.task_cfg['default_image_token'] + '\n' + question
651
+
652
+ if sample['question_type'] == 'multiple-choice':
653
+ index2ans = sample['index2ans']
654
+ all_choices = sample['all_choices']
655
+ else:
656
+ index2ans = {}
657
+ all_choices = []
658
+
659
+ return img, question_id, sample['subfield'], sample['question_type'], question, answer, str(index2ans), str \
660
+ (all_choices)
661
+
662
+
663
+ class MMMUEvalDataset(Dataset):
664
+ def __init__(self, args, task_cfg, subset=None, start_idx=None):
665
+ self.args = args
666
+ self.task_cfg = task_cfg
667
+ sub_dataset_list = []
668
+ # load_dataset will throw error if split is 'dev'
669
+ # 'dev' is part of the 'validation' and we need to manually split them
670
+
671
+ MMMU_path = "MMMU/MMMU"
672
+
673
+ _split = "test" if task_cfg["split"] == "test" else "validation"
674
+ for subject in CAT_SHORT2LONG.values():
675
+ sub_dataset = load_dataset(
676
+ MMMU_path, subject,
677
+ split=_split,
678
+ )
679
+ sub_dataset_list.append(sub_dataset)
680
+
681
+ dataset = concatenate_datasets(sub_dataset_list)
682
+
683
+ if task_cfg["split"] != "test":
684
+ dataset = [s for s in dataset if s['id'].startswith(task_cfg["split"])]
685
+
686
+ # dataset = [s for s in dataset if s['image_2'] is not None][1:]
687
+
688
+ self.dataset = dataset
689
+
690
+ if subset:
691
+ self.dataset = [dataset[i] for i in range(start_idx, min(start_idx + subset, len(dataset)))]
692
+ print(f"Evaluating a subset of dataset: {len(self.dataset)} from {start_idx} to {start_idx + subset}")
693
+
694
+ def __len__(self):
695
+ return len(self.dataset)
696
+
697
+ def __getitem__(self, idx):
698
+ # ===== single-image =====
699
+ sample = self.dataset[idx]
700
+ sample = process_single_sample(sample)
701
+ sample = construct_prompt(sample, self.task_cfg)
702
+
703
+ img = load_image(sample['image'].convert("RGB"), max_num=6, decoded=True).to(torch.bfloat16)
704
+
705
+ question_id = sample['id']
706
+ question = sample['final_input_prompt']
707
+ answer = sample['answer']
708
+
709
+ question = process_image_tag(question)
710
+ question = self.task_cfg['default_image_token'] + '\n' + question
711
+
712
+
713
+ if sample['question_type'] == 'multiple-choice':
714
+ index2ans = sample['index2ans']
715
+ all_choices = sample['all_choices']
716
+ else:
717
+ index2ans = {}
718
+ all_choices = []
719
+
720
+ return img, question_id, sample['subfield'], sample['question_type'], question, answer, str(index2ans), str \
721
+ (all_choices)
722
+
723
+
724
+
725
+ class VizWizEvalDataset(Dataset):
726
+ def __init__(self, args, img_dir, question_path, subset=None):
727
+ self.args = args
728
+ self.img_dir = img_dir
729
+ self.questions = json.load(open(question_path, encoding='utf-8'))
730
+
731
+ def __len__(self):
732
+ return len(self.questions)
733
+
734
+ def __getitem__(self, idx):
735
+ img_path = os.path.join(self.img_dir, self.questions[idx]["image"])
736
+ img = load_image(img_path, max_num=6).to(torch.bfloat16)
737
+ question = self.questions[idx]["question"]
738
+ question_id = self.questions[idx]["image"]
739
+
740
+ return img, question_id, question
741
+
742
+
743
+ class MMBenchEvalDataset(Dataset):
744
+ def __init__(self, args, gt_path, subset=None):
745
+ self.args = args
746
+ df = pd.read_csv(gt_path, sep='\t')
747
+ self.dataset = []
748
+ for i, row in df.iterrows():
749
+ choices = []
750
+ for choice in ['A', 'B', 'C', 'D']:
751
+ if str(row[choice]) != 'nan':
752
+ choices.append(row[choice])
753
+
754
+ this_sample = {
755
+ 'index': row['index'],
756
+ 'question': row['question'],
757
+ 'hint': row['hint'],
758
+ 'category': row['category'],
759
+ 'image': Image.open(BytesIO(base64.b64decode(row['image']))),
760
+ 'choices': choices
761
+ }
762
+
763
+ # Only dev set gives the ground truth answer
764
+ if 'answer' in row.keys():
765
+ this_sample['answer'] = row['answer']
766
+ else:
767
+ this_sample['answer'] = ''
768
+
769
+ self.dataset.append(this_sample)
770
+
771
+ def __len__(self):
772
+ return len(self.dataset)
773
+
774
+ def __getitem__(self, idx):
775
+ img = load_image(self.dataset[idx]["image"].convert("RGB"), max_num=6, decoded=True).to(torch.bfloat16)
776
+
777
+ question = self.dataset[idx]["question"]
778
+ hint = self.dataset[idx]["hint"]
779
+ question_id = self.dataset[idx]["index"]
780
+ choices = self.dataset[idx]["choices"]
781
+ answer = self.dataset[idx]["answer"]
782
+
783
+ start_chr = 'A'
784
+ choices_str = ''
785
+ index2ans = {}
786
+ all_choices = []
787
+ for choice in choices:
788
+ all_choices.append(start_chr)
789
+ index2ans[start_chr] = choice
790
+ choices_str += f"{start_chr}. {choice}\n"
791
+ start_chr = chr(ord(start_chr) + 1)
792
+
793
+ question = question + '\n' + choices_str
794
+
795
+ return img, question_id, question, answer, str(index2ans), str(all_choices), self.dataset[idx]["question"]
796
+
797
+
798
+ def get_task_dataloader(task_name, task_cfg, args):
799
+ if "subset" in task_cfg.keys():
800
+ subset = task_cfg["subset"]
801
+ else:
802
+ subset = None
803
+
804
+ if task_name == "coco_caption":
805
+ dataset = COCOEvalDataset(args, task_cfg["image_dir"], subset)
806
+ elif task_name == "flickr30k_caption":
807
+ dataset = Flickr30KEvalDataset(args, task_cfg["image_dir"], subset)
808
+ elif task_name == "vqav2":
809
+ dataset = VQAv2EvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset)
810
+ elif task_name == "textvqa":
811
+ dataset = TextVQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset)
812
+ elif task_name == "gqa":
813
+ dataset = GQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset)
814
+ elif task_name == "chartqa":
815
+ dataset = ChartQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset)
816
+ elif task_name == "okvqa":
817
+ dataset = OKVQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], task_cfg["question_path"], subset)
818
+ elif task_name == "vizwiz":
819
+ dataset = VizWizEvalDataset(args, task_cfg["image_dir"], task_cfg["question_path"], subset)
820
+ elif task_name == "docvqa":
821
+ dataset = DocVQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], split='val', subset=subset)
822
+ elif task_name == "docvqa_test":
823
+ dataset = DocVQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], split='test', subset=subset)
824
+ elif task_name == "realworldqa":
825
+ dataset = RealworldQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset)
826
+ elif task_name == "mmmu":
827
+ dataset = MMMUEvalDataset(args, task_cfg, subset=args.subset, start_idx=args.start_idx)
828
+ elif task_name == "mmmu_pro":
829
+ dataset = MMMUProEvalDataset(args, task_cfg)
830
+ elif task_name == "mathvista":
831
+ dataset = MathVistaEvalDataset(args, task_cfg)
832
+ elif task_name == "mmbench":
833
+ dataset = MMBenchEvalDataset(args, task_cfg["gt_path"])
834
+ elif task_name == 'ocrbench':
835
+ dataset = OCRBenchEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset)
836
+ elif task_name == 'ai2diagram':
837
+ dataset = AI2DiagramEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset)
838
+ elif task_name == 'ai2diagram_nomask':
839
+ dataset = AI2DiagramNoMaskEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset)
840
+ else:
841
+ raise NotImplementedError(f"Task {task_name} is not supported yet.")
842
+
843
+ dataloader = DataLoader(
844
+ dataset,
845
+ batch_size=1,
846
+ shuffle=False,
847
+ pin_memory=True,
848
+ )
849
+
850
+ return dataloader
eval/full_eval.yaml ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datasets:
2
+ coco_caption:
3
+ image_dir: "path/to/image"
4
+ gt_path: "path/to/ground_truth"
5
+ prompt: "<|im_start|>system\nFollow the user's instruction and answer questions.<|im_end|><|im_start|>user\n<image>\nGive a brief description of this image in one sentence.<|im_end|><|im_start|>assistant\n"
6
+ beam_search: True
7
+ beam_size: 1
8
+ output_max_len: 30
9
+ top_k: 3
10
+ temperature: 1.0
11
+
12
+ flickr30k_caption:
13
+ image_dir: "path/to/image"
14
+ gt_path: "path/to/ground_truth"
15
+ prompt: "<|im_start|>system\nFollow the user's instruction and answer questions.<|im_end|><|im_start|>user\n<image>\nGive a brief description of this image in one sentence.<|im_end|><|im_start|>assistant\n"
16
+ beam_search: True
17
+ beam_size: 1
18
+ output_max_len: 30
19
+ top_k: 3
20
+ temperature: 1.0
21
+
22
+ vqav2:
23
+ image_dir: "path/to/image"
24
+ gt_path: "path/to/ground_truth"
25
+ prompt: "<|im_start|>system\nFollow the user's instruction and answer questions.<|im_end|><|im_start|>user\n<image>\n{}\nAnswer the question using a single word or phrase.<|im_end|><|im_start|>assistant\n"
26
+ beam_search: True
27
+ beam_size: 1
28
+ top_k: 1
29
+ top_p: 0.0
30
+ output_max_len: 8
31
+ temperature: 1.0
32
+
33
+ mmmu:
34
+ split: "validation"
35
+ beam_search: True
36
+ beam_size: 1
37
+ top_k: 1
38
+ top_p: 0.0
39
+ output_max_len: 1024
40
+ temperature: 1.0
41
+ apply_lemmatizer: False
42
+ task_instructions: ""
43
+ multi_choice_example_format: "{}\n{}\nAnswer with the option's letter from the given choices directly."
44
+ short_ans_example_format: "{}\nAnswer the question using a single word or phrase."
45
+ use_chat_format: True
46
+ conv_format: "yi_nous_sft"
47
+ default_image_token: "<image>"
48
+ prompt_offset: 4
49
+ answer_dict: "path/to/answer_dict_val.json"
50
+
51
+ textvqa:
52
+ split: "val"
53
+ image_dir: "path/to/image"
54
+ gt_path: "path/to/ground_truth"
55
+ prompt: "<|im_start|>system\nFollow the user's instruction and answer questions.<|im_end|><|im_start|>user\n<image>\n{}\nAnswer the question using a single word, phrase, or number.<|im_end|><|im_start|>assistant\n"
56
+ beam_search: True
57
+ beam_size: 1
58
+ top_k: 1
59
+ top_p: 0.0
60
+ output_max_len: 10
61
+ temperature: 1.0
62
+
63
+ mathvista:
64
+ split: "testmini"
65
+ prompt: "<|im_start|>system\nYou are math expert. Use your math knowledge to calculate the answer.<|im_end|><|im_start|>user\n<image>\n{}\nAnswer the question using a single word, phrase, or number.<|im_end|><|im_start|>assistant\n"
66
+ beam_search: True
67
+ beam_size: 1
68
+ top_k: 1
69
+ top_p: 0.0
70
+ output_max_len: 1024
71
+ temperature: 1.0
72
+
73
+ mmbench:
74
+ split: "dev"
75
+ gt_path: "path/to/ground_truth"
76
+ prompt: "<|im_start|>system\nFollow the user's instruction and answer questions.<|im_end|><|im_start|>user\n<image>\n{}Answer with the option's letter from the given choices directly.<|im_end|><|im_start|>assistant\n"
77
+ beam_search: True
78
+ beam_size: 1
79
+ top_k: 1
80
+ top_p: 0.0
81
+ output_max_len: 10
82
+ temperature: 1.0
83
+ submission: False
84
+
85
+ chartqa:
86
+ split: "test"
87
+ image_dir: "path/to/image"
88
+ gt_path: "path/to/ground_truth"
89
+ prompt: "<|im_start|>system\nFollow the user's instruction and answer questions.<|im_end|><|im_start|>user\n<image>\n{}<|im_end|><|im_start|>assistant\n"
90
+
91
+ beam_search: True
92
+ beam_size: 1
93
+ top_k: 1
94
+ top_p: 0.0
95
+ output_max_len: 20
96
+ temperature: 1.0
97
+
98
+ docvqa:
99
+ split: "val"
100
+ image_dir: "path/to/image"
101
+ gt_path: "path/to/ground_truth"
102
+ prompt: "<|im_start|>system\nFollow the user's instruction and answer questions.<|im_end|><|im_start|>user\n<image>\n{}<|im_end|><|im_start|>assistant\n"
103
+ beam_search: True
104
+ beam_size: 1
105
+ top_k: 1
106
+ top_p: 0.0
107
+ output_max_len: 20
108
+ temperature: 1.0
109
+
110
+ realworldqa:
111
+ split: "test"
112
+ image_dir: "path/to/image"
113
+ gt_path: "path/to/ground_truth"
114
+ prompt: "<|im_start|>system\nFollow the user's instruction and answer questions.<|im_end|><|im_start|>user\n<image>\n{}<|im_end|><|im_start|>assistant\n"
115
+ beam_search: True
116
+ beam_size: 1
117
+ top_k: 1
118
+ top_p: 0.0
119
+ output_max_len: 20
120
+ temperature: 1.0
121
+ submission: False
122
+
123
+ ocrbench:
124
+ split: "test"
125
+ image_dir: "path/to/image"
126
+ gt_path: "path/to/ground_truth"
127
+ prompt: "<|im_start|>system\nFollow the user's instruction and answer questions.<|im_end|><|im_start|>user\n<image>\n{}<|im_end|><|im_start|>assistant\n"
128
+ beam_search: True
129
+ beam_size: 1
130
+ top_k: 1
131
+ top_p: 0.0
132
+ output_max_len: 70
133
+ temperature: 1.0
134
+ submission: False
135
+
136
+ ai2diagram:
137
+ split: "test"
138
+ image_dir: "path/to/image"
139
+ gt_path: "path/to/ground_truth"
140
+ prompt: "<|im_start|>system\nFollow the user's instruction and answer questions.<|im_end|><|im_start|>user\n<image>\n{}\nAnswer the question using a single word, phrase, or number.<|im_end|><|im_start|>assistant\n"
141
+ beam_search: True
142
+ beam_size: 1
143
+ top_k: 1
144
+ top_p: 0.0
145
+ output_max_len: 20
146
+ temperature: 1.0
147
+
148
+ ai2diagram_nomask:
149
+ split: "test"
150
+ image_dir: "path/to/image"
151
+ gt_path: "path/to/ground_truth"
152
+ prompt: "<|im_start|>system\nFollow the user's instruction and answer questions.<|im_end|><|im_start|>user\n<image>\n{}\nAnswer the question using a single word, phrase, or number.<|im_end|><|im_start|>assistant\n"
153
+ beam_search: True
154
+ beam_size: 1
155
+ top_k: 1
156
+ top_p: 0.0
157
+ output_max_len: 20
158
+ temperature: 1.0
159
+
160
+ mmmu_pro:
161
+ split: "validation"
162
+ beam_search: True
163
+ beam_size: 1
164
+ top_k: 1
165
+ top_p: 0.0
166
+ output_max_len: 10
167
+ temperature: 1.0
168
+ apply_lemmatizer: False
169
+ task_instructions: ""
170
+ multi_choice_example_format: "{}\n{}\nAnswer with the option's letter from the given choices directly."
171
+ short_ans_example_format: "{}\nAnswer the question using a single word or phrase."
172
+ use_chat_format: True
173
+ conv_format: "yi_nous_sft"
174
+ default_image_token: "<image>"
175
+ prompt_offset: 4
176
+ answer_dict: "path/to/answer_dict.json"
177
+
178
+ docvqa_test:
179
+ split: "test"
180
+ image_dir: "path/to/image"
181
+ gt_path: "path/to/ground_truth"
182
+ prompt: "<|im_start|>system\nFollow the user's instruction and answer questions.<|im_end|>\n<|im_start|>user\n<image>\n{}\nAnswer this question using the text in the image directly.<|im_end|>\n<|im_start|>assistant\n"
183
+ beam_search: True
184
+ beam_size: 1
185
+ top_k: 1
186
+ top_p: 0.0
187
+ output_max_len: 20
188
+ temperature: 1.0
eval/mmmu_utils.py ADDED
@@ -0,0 +1,663 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/MMMU-Benchmark/MMMU/blob/main/mmmu/utils/data_utils.py
2
+
3
+ """Utils for data load, save, and process (e.g., prompt construction)"""
4
+
5
+ import os
6
+ import json
7
+ import yaml
8
+ import re
9
+
10
+ DOMAIN_CAT2SUB_CAT = {
11
+ 'Art and Design': ['Art', 'Art_Theory', 'Design', 'Music'],
12
+ 'Business': ['Accounting', 'Economics', 'Finance', 'Manage', 'Marketing'],
13
+ 'Science': ['Biology', 'Chemistry', 'Geography', 'Math', 'Physics', ],
14
+ 'Health and Medicine': ['Basic_Medical_Science', 'Clinical_Medicine', 'Diagnostics_and_Laboratory_Medicine',
15
+ 'Pharmacy', 'Public_Health'],
16
+ 'Humanities and Social Science': ['History', 'Literature', 'Sociology', 'Psychology'],
17
+ 'Tech and Engineering': ['Agriculture', 'Architecture_and_Engineering', 'Computer_Science', 'Electronics',
18
+ 'Energy_and_Power', 'Materials', 'Mechanical_Engineering'],
19
+ }
20
+
21
+ CAT_SHORT2LONG = {
22
+ 'acc': 'Accounting',
23
+ 'agri': 'Agriculture',
24
+ 'arch': 'Architecture_and_Engineering',
25
+ 'art': 'Art',
26
+ 'art_theory': 'Art_Theory',
27
+ 'bas_med': 'Basic_Medical_Science',
28
+ 'bio': 'Biology',
29
+ 'chem': 'Chemistry',
30
+ 'cli_med': 'Clinical_Medicine',
31
+ 'cs': 'Computer_Science',
32
+ 'design': 'Design',
33
+ 'diag_med': 'Diagnostics_and_Laboratory_Medicine',
34
+ 'econ': 'Economics',
35
+ 'elec': 'Electronics',
36
+ 'ep': 'Energy_and_Power',
37
+ 'fin': 'Finance',
38
+ 'geo': 'Geography',
39
+ 'his': 'History',
40
+ 'liter': 'Literature',
41
+ 'manage': 'Manage',
42
+ 'mark': 'Marketing',
43
+ 'mate': 'Materials',
44
+ 'math': 'Math',
45
+ 'mech': 'Mechanical_Engineering',
46
+ 'music': 'Music',
47
+ 'phar': 'Pharmacy',
48
+ 'phys': 'Physics',
49
+ 'psy': 'Psychology',
50
+ 'pub_health': 'Public_Health',
51
+ 'socio': 'Sociology'
52
+ }
53
+
54
+
55
+ # DATA SAVING
56
+ def save_json(filename, ds):
57
+ with open(filename, 'w') as f:
58
+ json.dump(ds, f, indent=4)
59
+
60
+
61
+ def get_multi_choice_info(options):
62
+ """
63
+ Given the list of options for multiple choice question
64
+ Return the index2ans and all_choices
65
+ """
66
+
67
+ start_chr = 'A'
68
+ all_choices = []
69
+ index2ans = {}
70
+ for i, option in enumerate(options):
71
+ index2ans[chr(ord(start_chr) + i)] = option
72
+ all_choices.append(chr(ord(start_chr) + i))
73
+
74
+ return index2ans, all_choices
75
+
76
+
77
+ def load_yaml(file_path):
78
+ with open(file_path, 'r') as stream:
79
+ try:
80
+ yaml_dict = yaml.safe_load(stream)
81
+ except yaml.YAMLError as exc:
82
+ print(exc)
83
+
84
+ return yaml_dict
85
+
86
+
87
+ def parse_img_path(text):
88
+ matches = re.findall("<img='(.*?)'>", text)
89
+ return matches
90
+
91
+
92
+ def process_single_sample(data):
93
+ question = data['question']
94
+ o_imgs_paths = []
95
+ for option in data['options']:
96
+ current_o_imgs_paths = parse_img_path(option)
97
+ for img_path in current_o_imgs_paths:
98
+ o_imgs_paths.append(img_path)
99
+
100
+ if len(o_imgs_paths) > 1: # multiple images in options, used for random selection
101
+ return {'id': data['id'], 'question': question, 'options': data['options'], 'answer': data['answer'],
102
+ 'image': None, 'question_type': data['question_type'], 'subfield': data['subfield']}
103
+ else:
104
+ return {'id': data['id'], 'question': question, 'options': data['options'], 'answer': data['answer'],
105
+ 'image': data['image_1'], 'question_type': data['question_type'], 'subfield': data['subfield']}
106
+
107
+
108
+ def process_single_sample_pro(data):
109
+ question = data['question']
110
+ o_imgs_paths = []
111
+ for option in data['options']:
112
+ current_o_imgs_paths = parse_img_path(option)
113
+ for img_path in current_o_imgs_paths:
114
+ o_imgs_paths.append(img_path)
115
+
116
+ if len(o_imgs_paths) > 1: # multiple images in options, used for random selection
117
+ return {'id': data['id'], 'question': question, 'options': data['options'], 'answer': data['answer'],
118
+ 'image': None, 'question_type': 'multiple-choice', 'subfield': data['subject']}
119
+ else:
120
+ return {'id': data['id'], 'question': question, 'options': data['options'], 'answer': data['answer'],
121
+ 'image': data['image_1'], 'question_type': 'multiple-choice', 'subfield': data['subject']}
122
+
123
+
124
+ # DATA SAVING
125
+ def save_json(filename, ds):
126
+ with open(filename, 'w') as f:
127
+ json.dump(ds, f, indent=4)
128
+
129
+
130
+ def save_jsonl(filename, data):
131
+ """
132
+ Save a dictionary of data to a JSON Lines file with the filename as key and caption as value.
133
+
134
+ Args:
135
+ filename (str): The path to the file where the data should be saved.
136
+ data (dict): The dictionary containing the data to save where key is the image path and value is the caption.
137
+ """
138
+ with open(filename, 'w', encoding='utf-8') as f:
139
+ for img_path, caption in data.items():
140
+ # Extract the base filename without the extension
141
+ base_filename = os.path.basename(img_path)
142
+ # Create a JSON object with the filename as the key and caption as the value
143
+ json_record = json.dumps({base_filename: caption}, ensure_ascii=False)
144
+ # Write the JSON object to the file, one per line
145
+ f.write(json_record + '\n')
146
+
147
+
148
+ def save_args(args, path_dir):
149
+ argsDict = args.__dict__
150
+ with open(path_dir + 'setting.txt', 'w') as f:
151
+ f.writelines('------------------ start ------------------' + '\n')
152
+ for eachArg, value in argsDict.items():
153
+ f.writelines(eachArg + ' : ' + str(value) + '\n')
154
+ f.writelines('------------------- end -------------------')
155
+
156
+
157
+ # DATA PROCESSING
158
+ def construct_prompt(sample, config):
159
+ question = sample['question'].strip()
160
+
161
+ # for i in range(8):
162
+ # question = question.replace(f" <image {i}> ", " ")
163
+ # question = question.replace(f" <image {i}>", " ")
164
+ # question = question.replace(f"<image {i}> ", " ")
165
+ # question = question.replace(f"<image {i}>", " ")
166
+ # question = question.strip()
167
+
168
+ options = eval(sample['options'])
169
+ example = ""
170
+ if sample['question_type'] == 'multiple-choice':
171
+ start_chr = 'A'
172
+ prediction_range = []
173
+ index2ans = {}
174
+ for option in options:
175
+ prediction_range.append(start_chr)
176
+ example += f"({start_chr}) {option}\n"
177
+ # example += f"{start_chr}. {option}\n"
178
+ index2ans[start_chr] = option
179
+ start_chr = chr(ord(start_chr) + 1)
180
+ # example = example.rstrip()
181
+ empty_prompt_sample_structure = config['multi_choice_example_format']
182
+ empty_prompt = empty_prompt_sample_structure.format(question, example)
183
+ res_dict = {'type': 'multichoice'}
184
+ res_dict['index2ans'] = index2ans
185
+ res_dict['correct_choice'] = sample['answer']
186
+ res_dict['all_choices'] = prediction_range
187
+ res_dict['empty_prompt'] = empty_prompt
188
+ if config['task_instructions']:
189
+ res_dict['final_input_prompt'] = config['task_instructions'].strip() + '\n\n' + empty_prompt
190
+ else:
191
+ res_dict['final_input_prompt'] = empty_prompt
192
+
193
+ res_dict['gt_content'] = options[ord(sample['answer'].upper()) - ord('A')]
194
+ else:
195
+ empty_prompt_sample_structure = config['short_ans_example_format']
196
+ empty_prompt = empty_prompt_sample_structure.format(question)
197
+ res_dict = {'type': 'open'}
198
+ res_dict['empty_prompt'] = empty_prompt
199
+ if config['task_instructions']:
200
+ res_dict['final_input_prompt'] = config['task_instructions'].strip() + '\n\n' + empty_prompt
201
+ else:
202
+ res_dict['final_input_prompt'] = empty_prompt
203
+ res_dict['gt_content'] = sample['answer']
204
+
205
+ res_dict.update(sample)
206
+ return res_dict
207
+
208
+
209
+ def construct_prompt_pro(sample, config):
210
+ question = sample['question'].strip()
211
+
212
+ # for i in range(8):
213
+ # question = question.replace(f" <image {i}> ", " ")
214
+ # question = question.replace(f" <image {i}>", " ")
215
+ # question = question.replace(f"<image {i}> ", " ")
216
+ # question = question.replace(f"<image {i}>", " ")
217
+ # question = question.strip()
218
+
219
+ options = eval(sample['options'])
220
+
221
+ if len(options) == 1:
222
+ print("This is wrongly formated. We correct to options[0].")
223
+ options = options[0]
224
+
225
+ example = ""
226
+ if sample['question_type'] == 'multiple-choice':
227
+ start_chr = 'A'
228
+ prediction_range = []
229
+ index2ans = {}
230
+ for option in options:
231
+ prediction_range.append(start_chr)
232
+ example += f"({start_chr}) {option}\n"
233
+ # example += f"{start_chr}. {option}\n"
234
+ index2ans[start_chr] = option
235
+ start_chr = chr(ord(start_chr) + 1)
236
+ # example = example.rstrip()
237
+ empty_prompt_sample_structure = config['multi_choice_example_format']
238
+ empty_prompt = empty_prompt_sample_structure.format(question, example)
239
+ res_dict = {'type': 'multichoice'}
240
+ res_dict['index2ans'] = index2ans
241
+ res_dict['correct_choice'] = sample['answer']
242
+ res_dict['all_choices'] = prediction_range
243
+ res_dict['empty_prompt'] = empty_prompt
244
+ if config['task_instructions']:
245
+ res_dict['final_input_prompt'] = config['task_instructions'].strip() + '\n\n' + empty_prompt
246
+ else:
247
+ res_dict['final_input_prompt'] = empty_prompt
248
+
249
+ res_dict['gt_content'] = options[ord(sample['answer'].upper()) - ord('A')]
250
+ else:
251
+ empty_prompt_sample_structure = config['short_ans_example_format']
252
+ empty_prompt = empty_prompt_sample_structure.format(question)
253
+ res_dict = {'type': 'open'}
254
+ res_dict['empty_prompt'] = empty_prompt
255
+ if config['task_instructions']:
256
+ res_dict['final_input_prompt'] = config['task_instructions'].strip() + '\n\n' + empty_prompt
257
+ else:
258
+ res_dict['final_input_prompt'] = empty_prompt
259
+ res_dict['gt_content'] = sample['answer']
260
+
261
+ res_dict.update(sample)
262
+ return res_dict
263
+
264
+ """Response Parsing and Evaluation for various models"""
265
+ from typing import Dict
266
+
267
+ import re
268
+ import random
269
+
270
+ import numpy as np
271
+
272
+
273
+ # ----------- Process Multi-choice -------------
274
+ def parse_multi_choice_response(response, all_choices, index2ans):
275
+ """
276
+ Parse the prediction from the generated response.
277
+ Return the predicted index e.g., A, B, C, D.
278
+ """
279
+ for char in [',', '.', '!', '?', ';', ':', "'"]:
280
+ response = response.strip(char)
281
+ response = " " + response + " " # add space to avoid partial match
282
+
283
+ index_ans = True
284
+ ans_with_brack = False
285
+ candidates = []
286
+ for choice in all_choices: # e.g., (A) (B) (C) (D) A) B) C) D)
287
+ if f'({choice})' in response or f'{choice})' in response:
288
+ candidates.append(choice)
289
+ ans_with_brack = True
290
+
291
+ if len(candidates) == 0:
292
+ for choice in all_choices: # e.g., A B C D
293
+ if f' {choice} ' in response:
294
+ candidates.append(choice)
295
+
296
+ # if all above doesn't get candidates, check if the content is larger than 5 tokens and try to parse the example
297
+ if len(candidates) == 0 and len(response.split()) > 5:
298
+ for index, ans in index2ans.items():
299
+ if ans.lower() in response.lower():
300
+ candidates.append(index)
301
+ index_ans = False # it's content ans.
302
+
303
+ if len(candidates) == 0: # still not get answer, randomly choose one.
304
+ pred_index = all_choices[0]
305
+ elif len(candidates) > 1:
306
+ start_indexes = []
307
+ if index_ans:
308
+ if ans_with_brack:
309
+ for can in candidates:
310
+ index = response.rfind(f'({can})')
311
+ start_indexes.append(index) # -1 will be ignored anyway
312
+ # start_indexes = [generated_response.index(f'({can})') for can in candidates]
313
+ else:
314
+ for can in candidates:
315
+ index = response.rfind(f" {can} ")
316
+ start_indexes.append(index)
317
+ else:
318
+ for can in candidates:
319
+ index = response.lower().rfind(index2ans[can].lower())
320
+ start_indexes.append(index)
321
+ # get the last one
322
+ pred_index = candidates[np.argmax(start_indexes)]
323
+ else: # if only one candidate, use it.
324
+ pred_index = candidates[0]
325
+
326
+ return pred_index
327
+
328
+
329
+ # ----------- Process Open -------------
330
+ def check_is_number(string):
331
+ """
332
+ Check if the given string a number.
333
+ """
334
+ try:
335
+ float(string.replace(',', ''))
336
+ return True
337
+ except ValueError:
338
+ # check if there's comma inside
339
+ return False
340
+
341
+
342
+ def normalize_str(string):
343
+ """
344
+ Normalize the str to lower case and make them float numbers if possible.
345
+ """
346
+ # check if characters in the string
347
+
348
+ # if number, numerize it.
349
+ string = string.strip()
350
+
351
+ is_number = check_is_number(string)
352
+
353
+ if is_number:
354
+ string = string.replace(',', '')
355
+ string = float(string)
356
+ # leave 2 decimal
357
+ string = round(string, 2)
358
+ return [string]
359
+ else: # it's likely to be a string
360
+ # lower it
361
+ string = string.lower()
362
+ if len(string) == 1:
363
+ return [" " + string, string + " "] # avoid trivial matches
364
+ return [string]
365
+
366
+
367
+ def extract_numbers(string):
368
+ """
369
+ Exact all forms of numbers from a string with regex.
370
+ """
371
+ # Pattern for numbers with commas
372
+ pattern_commas = r'-?\b\d{1,3}(?:,\d{3})+\b'
373
+ # Pattern for scientific notation
374
+ pattern_scientific = r'-?\d+(?:\.\d+)?[eE][+-]?\d+'
375
+ # Pattern for simple numbers without commas
376
+ pattern_simple = r'-?(?:\d+\.\d+|\.\d+|\d+\b)(?![eE][+-]?\d+)(?![,\d])'
377
+
378
+ # Extract numbers with commas
379
+ numbers_with_commas = re.findall(pattern_commas, string)
380
+ # Extract numbers in scientific notation
381
+ numbers_scientific = re.findall(pattern_scientific, string)
382
+ # Extract simple numbers without commas
383
+ numbers_simple = re.findall(pattern_simple, string)
384
+
385
+ # Combine all extracted numbers
386
+ all_numbers = numbers_with_commas + numbers_scientific + numbers_simple
387
+ return all_numbers
388
+
389
+
390
+ def parse_open_response(response):
391
+ """
392
+ Parse the prediction from the generated response.
393
+ Return a list of predicted strings or numbers.
394
+ """
395
+
396
+ # content = content.strip("\n").strip(".").strip(" ")
397
+ def get_key_subresponses(response):
398
+ key_responses = []
399
+ response = response.strip().strip(".").lower()
400
+ sub_responses = re.split(r'\.\s(?=[A-Z])|\n', response)
401
+ indicators_of_keys = ['could be ', 'so ', 'is ',
402
+ 'thus ', 'therefore ', 'final ', 'answer ', 'result ']
403
+ key_responses = []
404
+ for index, resp in enumerate(sub_responses):
405
+ # if last one, accept it's an equation (the entire response can be just one sentence with equation)
406
+ if index == len(sub_responses) - 1:
407
+ indicators_of_keys.extend(['='])
408
+ shortest_key_response = None # the shortest response that may contain the answer (tail part of the response)
409
+ for indicator in indicators_of_keys:
410
+ if indicator in resp:
411
+ if not shortest_key_response:
412
+ shortest_key_response = resp.split(indicator)[-1].strip()
413
+ else:
414
+ if len(resp.split(indicator)[-1].strip()) < len(shortest_key_response):
415
+ shortest_key_response = resp.split(indicator)[-1].strip()
416
+ # key_responses.append(resp.split(indicator)[1].strip())
417
+
418
+ if shortest_key_response:
419
+ # and it's not trivial
420
+ if shortest_key_response.strip() not in [":", ",", ".", "!", "?", ";", ":", "'"]:
421
+ key_responses.append(shortest_key_response)
422
+ if len(key_responses) == 0: # did not found any
423
+ return [response]
424
+ return key_responses
425
+
426
+ # pdb.set_trace()
427
+ key_responses = get_key_subresponses(response)
428
+
429
+ pred_list = key_responses.copy() # keep the original string response
430
+ for resp in key_responses:
431
+ pred_list.extend(extract_numbers(resp))
432
+
433
+ tmp_pred_list = []
434
+ for i in range(len(pred_list)):
435
+ tmp_pred_list.extend(normalize_str(pred_list[i]))
436
+ pred_list = tmp_pred_list
437
+
438
+ # remove duplicates
439
+ pred_list = list(set(pred_list))
440
+
441
+ return pred_list
442
+
443
+
444
+ # ----------- Evaluation -------------
445
+
446
+ def eval_multi_choice(gold_i, pred_i):
447
+ """
448
+ Evaluate a multiple choice instance.
449
+ """
450
+ correct = False
451
+ # only they are exactly the same, we consider it as correct
452
+ if isinstance(gold_i, list):
453
+ for answer in gold_i:
454
+ if answer == pred_i:
455
+ correct = True
456
+ break
457
+ else: # gold_i is a string
458
+ if gold_i == pred_i:
459
+ correct = True
460
+ return correct
461
+
462
+
463
+ def eval_open(gold_i, pred_i):
464
+ """
465
+ Evaluate an open question instance
466
+ """
467
+ correct = False
468
+ if isinstance(gold_i, list):
469
+ # use float to avoid trivial matches
470
+ norm_answers = []
471
+ for answer in gold_i:
472
+ norm_answers.extend(normalize_str(answer))
473
+ else:
474
+ norm_answers = normalize_str(gold_i)
475
+ for pred in pred_i: # pred is already normalized in parse response phase
476
+ if isinstance(pred, str): # if it's a string, then find if ans in the pred_i
477
+ for norm_ans in norm_answers:
478
+ # only see if the string answer in the string pred
479
+ if isinstance(norm_ans, str) and norm_ans in pred:
480
+ if not correct:
481
+ correct = True
482
+ break
483
+ else: # it's a float number
484
+ if pred in norm_answers:
485
+ if not correct:
486
+ correct = True
487
+ break
488
+ return correct
489
+
490
+
491
+ # ----------- Batch Evaluation -------------
492
+ def evaluate(samples):
493
+ """
494
+ Batch evaluation for multiple choice and open questions.
495
+ """
496
+ pred_correct = 0
497
+ judge_dict = dict()
498
+ for sample in samples:
499
+ gold_i = sample['answer']
500
+ pred_i = sample['parsed_pred']
501
+ if sample['question_type'] == 'multiple-choice':
502
+ correct = eval_multi_choice(gold_i, pred_i)
503
+ else: # open question
504
+ correct = eval_open(gold_i, pred_i)
505
+
506
+ if correct:
507
+ judge_dict[sample['id']] = 'Correct'
508
+ pred_correct += 1
509
+ else:
510
+ judge_dict[sample['id']] = 'Wrong'
511
+
512
+ if len(samples) == 0:
513
+ return {'acc': 0}
514
+ return judge_dict, {'acc': pred_correct / len(samples)}
515
+
516
+
517
+ # ----------- Calculate Accuracy -------------
518
+ def calculate_ins_level_acc(results: Dict):
519
+ """Calculate the instruction level accuracy for given Subject results"""
520
+ acc = 0
521
+ ins_num = 0
522
+ for cat_results in results.values():
523
+ acc += cat_results['acc'] * cat_results['num_example']
524
+ ins_num += cat_results['num_example']
525
+ if ins_num == 0:
526
+ return 0
527
+ return acc / ins_num
528
+
529
+
530
+ def mmmu_main_eval(output_dict, task_cfg):
531
+ answer_dict = json.load(open(task_cfg["answer_dict"]))
532
+
533
+ # group by category
534
+ output_dict_w_cat = {}
535
+ for data_id, parsed_pred in output_dict.items():
536
+ category = "_".join(data_id.split("_")[1:-1])
537
+ if category not in output_dict_w_cat:
538
+ output_dict_w_cat.update({category: {}})
539
+ output_dict_w_cat[category].update({data_id: parsed_pred})
540
+
541
+ # group by category
542
+ answer_dict_w_cat = {}
543
+ for data_id, parsed_pred in answer_dict.items():
544
+ category = "_".join(data_id.split("_")[1:-1])
545
+ if category not in answer_dict_w_cat:
546
+ answer_dict_w_cat.update({category: {}})
547
+ answer_dict_w_cat[category].update({data_id: parsed_pred})
548
+
549
+ evaluation_result = {}
550
+
551
+ for category in CAT_SHORT2LONG.values():
552
+ # print("Evaluating: {}".format(category))
553
+ # get cat_outputs and cat_answers
554
+ try:
555
+ cat_outputs = output_dict_w_cat[category]
556
+ cat_answers = answer_dict_w_cat[category]
557
+ except KeyError:
558
+ print("Skipping {} for not found".format(category))
559
+ continue
560
+
561
+ exampels_to_eval = []
562
+ for data_id, parsed_pred in cat_outputs.items():
563
+ question_type = cat_answers[data_id]['question_type']
564
+ if question_type != 'multiple-choice':
565
+ parsed_pred = parse_open_response(parsed_pred) # mainly for type consistency (make it number, etc.)
566
+ else:
567
+ parsed_pred = parsed_pred
568
+
569
+ exampels_to_eval.append({
570
+ "id": data_id,
571
+ "question_type": question_type,
572
+ "answer": cat_answers[data_id]['ground_truth'],
573
+ "parsed_pred": parsed_pred
574
+ })
575
+
576
+ judge_dict, metric_dict = evaluate(exampels_to_eval)
577
+ metric_dict.update({"num_example": len(exampels_to_eval)})
578
+
579
+ evaluation_result[category] = metric_dict
580
+
581
+ printable_results = {}
582
+ # pdb.set_trace()
583
+ # add domain Subject
584
+ for domain, in_domain_cats in DOMAIN_CAT2SUB_CAT.items():
585
+ in_domain_cat_results = {}
586
+ for cat_name in in_domain_cats: # use the order in DOMAIN_CAT2SUB_CAT
587
+ if cat_name in evaluation_result.keys():
588
+ in_domain_cat_results[cat_name] = evaluation_result[cat_name]
589
+ else:
590
+ pass
591
+ in_domain_ins_acc = calculate_ins_level_acc(in_domain_cat_results)
592
+ in_domain_data_num = sum([cat_results['num_example'] for cat_results in in_domain_cat_results.values()])
593
+ printable_results['Overall-' + domain] = {"num": int(in_domain_data_num),
594
+ "acc": round(in_domain_ins_acc, 4)
595
+ }
596
+ # add sub category
597
+ for cat_name, cat_results in in_domain_cat_results.items():
598
+ printable_results[cat_name] = {"num": int(cat_results['num_example']),
599
+ "acc": round(cat_results['acc'], 4)
600
+ }
601
+
602
+ # table.append(["-----------------------------", "-----", "----"])
603
+ all_ins_acc = calculate_ins_level_acc(evaluation_result)
604
+ printable_results['Overall'] = {
605
+ "num": sum([cat_results['num_example'] for cat_results in evaluation_result.values()]),
606
+ "acc": round(all_ins_acc, 4)
607
+ }
608
+
609
+ # print(printable_results)
610
+ return printable_results
611
+
612
+
613
+ if __name__ == '__main__':
614
+ # tasks = yaml.safe_load(open("eval_config/eval_mmmu_yi_oci.yaml"))['datasets']
615
+ tasks = yaml.safe_load(open("eval_config/eval_mmmu_yi.yaml"))['datasets']
616
+ print(tasks)
617
+
618
+ # with open("/lustre/fs4/portfolios/adlr/users/boxinw/llava-megatron-gen/checkpoints/test/eval_mmmu_iter500_merged.4node.json") as f:
619
+ with open("/lustre/fsw/portfolios/llmservice/users/boxinw/eval_mmmu_iter6000_merged.0.53.json") as f:
620
+ merged_results = json.load(f)
621
+
622
+ eval_samples = []
623
+ eval_output_dict = {}
624
+ for res in merged_results:
625
+ pred_ans = res["answer"].upper()
626
+ gt_ans = res['gt_answer']
627
+ if res['question_type'] == 'multiple-choice':
628
+ parsed_pred = parse_multi_choice_response(pred_ans, res['all_choices'], res['index2ans'])
629
+ if pred_ans != parsed_pred:
630
+ print(f"MC: Original: {pred_ans}, Parsed: {parsed_pred}")
631
+ eval_samples.append(
632
+ {
633
+ 'id': res['question_id'],
634
+ 'question_type': res['question_type'],
635
+ 'answer': res['gt_answer'], # the content in option, not answer index.
636
+ 'response': pred_ans,
637
+ 'parsed_pred': parsed_pred,
638
+ 'index2ans': res['index2ans'],
639
+ }
640
+ )
641
+ eval_output_dict[res['question_id']] = parsed_pred
642
+ else:
643
+ parsed_pred = parse_open_response(pred_ans)
644
+ if pred_ans != parsed_pred:
645
+ print(f"Open: Original: {pred_ans}, Parsed: {parsed_pred}")
646
+ eval_samples.append(
647
+ {
648
+ 'id': res['question_id'],
649
+ 'question_type': res['question_type'],
650
+ 'answer': res['gt_answer'],
651
+ 'response': pred_ans,
652
+ 'parsed_pred': parsed_pred,
653
+ }
654
+ )
655
+ eval_output_dict[res['question_id']] = pred_ans
656
+
657
+ json.dump(eval_output_dict, open("validation_mmmu_iter6000_merged.0.53.sorted.json", "w"), indent=4, sort_keys=True)
658
+
659
+
660
+ x = mmmu_main_eval(eval_output_dict,
661
+ task_cfg=tasks['mmmu'])
662
+
663
+ print(x)
eval/requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ anls
2
+ datasets
3
+ pycocoevalcap
eval/vqa_utils.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+
3
+ __author__ = "aagrawal"
4
+
5
+ # This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
6
+ # (https://github.com/tylin/coco-caption/blob/master/pycocoevalcap/eval.py).
7
+ import sys
8
+ import re
9
+
10
+
11
+ class VQAEval:
12
+ def __init__(self, vqa=None, vqaRes=None, n=2):
13
+ self.n = n
14
+ self.accuracy = {}
15
+ self.evalQA = {}
16
+ self.evalQuesType = {}
17
+ self.evalAnsType = {}
18
+ self.vqa = vqa
19
+ self.vqaRes = vqaRes
20
+ if vqa is not None:
21
+ self.params = {"question_id": vqa.getQuesIds()}
22
+ self.contractions = {
23
+ "aint": "ain't",
24
+ "arent": "aren't",
25
+ "cant": "can't",
26
+ "couldve": "could've",
27
+ "couldnt": "couldn't",
28
+ "couldn'tve": "couldn't've",
29
+ "couldnt've": "couldn't've",
30
+ "didnt": "didn't",
31
+ "doesnt": "doesn't",
32
+ "dont": "don't",
33
+ "hadnt": "hadn't",
34
+ "hadnt've": "hadn't've",
35
+ "hadn'tve": "hadn't've",
36
+ "hasnt": "hasn't",
37
+ "havent": "haven't",
38
+ "hed": "he'd",
39
+ "hed've": "he'd've",
40
+ "he'dve": "he'd've",
41
+ "hes": "he's",
42
+ "howd": "how'd",
43
+ "howll": "how'll",
44
+ "hows": "how's",
45
+ "Id've": "I'd've",
46
+ "I'dve": "I'd've",
47
+ "Im": "I'm",
48
+ "Ive": "I've",
49
+ "isnt": "isn't",
50
+ "itd": "it'd",
51
+ "itd've": "it'd've",
52
+ "it'dve": "it'd've",
53
+ "itll": "it'll",
54
+ "let's": "let's",
55
+ "maam": "ma'am",
56
+ "mightnt": "mightn't",
57
+ "mightnt've": "mightn't've",
58
+ "mightn'tve": "mightn't've",
59
+ "mightve": "might've",
60
+ "mustnt": "mustn't",
61
+ "mustve": "must've",
62
+ "neednt": "needn't",
63
+ "notve": "not've",
64
+ "oclock": "o'clock",
65
+ "oughtnt": "oughtn't",
66
+ "ow's'at": "'ow's'at",
67
+ "'ows'at": "'ow's'at",
68
+ "'ow'sat": "'ow's'at",
69
+ "shant": "shan't",
70
+ "shed've": "she'd've",
71
+ "she'dve": "she'd've",
72
+ "she's": "she's",
73
+ "shouldve": "should've",
74
+ "shouldnt": "shouldn't",
75
+ "shouldnt've": "shouldn't've",
76
+ "shouldn'tve": "shouldn't've",
77
+ "somebody'd": "somebodyd",
78
+ "somebodyd've": "somebody'd've",
79
+ "somebody'dve": "somebody'd've",
80
+ "somebodyll": "somebody'll",
81
+ "somebodys": "somebody's",
82
+ "someoned": "someone'd",
83
+ "someoned've": "someone'd've",
84
+ "someone'dve": "someone'd've",
85
+ "someonell": "someone'll",
86
+ "someones": "someone's",
87
+ "somethingd": "something'd",
88
+ "somethingd've": "something'd've",
89
+ "something'dve": "something'd've",
90
+ "somethingll": "something'll",
91
+ "thats": "that's",
92
+ "thered": "there'd",
93
+ "thered've": "there'd've",
94
+ "there'dve": "there'd've",
95
+ "therere": "there're",
96
+ "theres": "there's",
97
+ "theyd": "they'd",
98
+ "theyd've": "they'd've",
99
+ "they'dve": "they'd've",
100
+ "theyll": "they'll",
101
+ "theyre": "they're",
102
+ "theyve": "they've",
103
+ "twas": "'twas",
104
+ "wasnt": "wasn't",
105
+ "wed've": "we'd've",
106
+ "we'dve": "we'd've",
107
+ "weve": "we've",
108
+ "werent": "weren't",
109
+ "whatll": "what'll",
110
+ "whatre": "what're",
111
+ "whats": "what's",
112
+ "whatve": "what've",
113
+ "whens": "when's",
114
+ "whered": "where'd",
115
+ "wheres": "where's",
116
+ "whereve": "where've",
117
+ "whod": "who'd",
118
+ "whod've": "who'd've",
119
+ "who'dve": "who'd've",
120
+ "wholl": "who'll",
121
+ "whos": "who's",
122
+ "whove": "who've",
123
+ "whyll": "why'll",
124
+ "whyre": "why're",
125
+ "whys": "why's",
126
+ "wont": "won't",
127
+ "wouldve": "would've",
128
+ "wouldnt": "wouldn't",
129
+ "wouldnt've": "wouldn't've",
130
+ "wouldn'tve": "wouldn't've",
131
+ "yall": "y'all",
132
+ "yall'll": "y'all'll",
133
+ "y'allll": "y'all'll",
134
+ "yall'd've": "y'all'd've",
135
+ "y'alld've": "y'all'd've",
136
+ "y'all'dve": "y'all'd've",
137
+ "youd": "you'd",
138
+ "youd've": "you'd've",
139
+ "you'dve": "you'd've",
140
+ "youll": "you'll",
141
+ "youre": "you're",
142
+ "youve": "you've",
143
+ }
144
+ self.manualMap = {
145
+ "none": "0",
146
+ "zero": "0",
147
+ "one": "1",
148
+ "two": "2",
149
+ "three": "3",
150
+ "four": "4",
151
+ "five": "5",
152
+ "six": "6",
153
+ "seven": "7",
154
+ "eight": "8",
155
+ "nine": "9",
156
+ "ten": "10",
157
+ }
158
+ self.articles = ["a", "an", "the"]
159
+
160
+ self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)")
161
+ self.commaStrip = re.compile("(\d)(,)(\d)")
162
+ self.punct = [
163
+ ";",
164
+ r"/",
165
+ "[",
166
+ "]",
167
+ '"',
168
+ "{",
169
+ "}",
170
+ "(",
171
+ ")",
172
+ "=",
173
+ "+",
174
+ "\\",
175
+ "_",
176
+ "-",
177
+ ">",
178
+ "<",
179
+ "@",
180
+ "`",
181
+ ",",
182
+ "?",
183
+ "!",
184
+ ]
185
+
186
+ def evaluate(self, quesIds=None):
187
+ if quesIds == None:
188
+ quesIds = [quesId for quesId in self.params["question_id"]]
189
+ gts = {}
190
+ res = {}
191
+ for quesId in quesIds:
192
+ gts[quesId] = self.vqa.qa[quesId]
193
+ res[quesId] = self.vqaRes.qa[quesId]
194
+
195
+ # =================================================
196
+ # Compute accuracy
197
+ # =================================================
198
+ accQA = []
199
+ accQuesType = {}
200
+ accAnsType = {}
201
+ print("computing accuracy")
202
+ step = 0
203
+ for quesId in quesIds:
204
+ resAns = res[quesId]["answer"]
205
+ resAns = resAns.replace("\n", " ")
206
+ resAns = resAns.replace("\t", " ")
207
+ resAns = resAns.strip()
208
+ resAns = self.processPunctuation(resAns)
209
+ resAns = self.processDigitArticle(resAns)
210
+ gtAcc = []
211
+ gtAnswers = [ans["answer"] for ans in gts[quesId]["answers"]]
212
+ if len(set(gtAnswers)) > 1:
213
+ for ansDic in gts[quesId]["answers"]:
214
+ ansDic["answer"] = self.processPunctuation(ansDic["answer"])
215
+ for gtAnsDatum in gts[quesId]["answers"]:
216
+ otherGTAns = [
217
+ item for item in gts[quesId]["answers"] if item != gtAnsDatum
218
+ ]
219
+ matchingAns = [item for item in otherGTAns if item["answer"] == resAns]
220
+ acc = min(1, float(len(matchingAns)) / 3)
221
+ gtAcc.append(acc)
222
+ quesType = gts[quesId]["question_type"]
223
+ ansType = gts[quesId]["answer_type"]
224
+ avgGTAcc = float(sum(gtAcc)) / len(gtAcc)
225
+ accQA.append(avgGTAcc)
226
+ if quesType not in accQuesType:
227
+ accQuesType[quesType] = []
228
+ accQuesType[quesType].append(avgGTAcc)
229
+ if ansType not in accAnsType:
230
+ accAnsType[ansType] = []
231
+ accAnsType[ansType].append(avgGTAcc)
232
+ self.setEvalQA(quesId, avgGTAcc)
233
+ self.setEvalQuesType(quesId, quesType, avgGTAcc)
234
+ self.setEvalAnsType(quesId, ansType, avgGTAcc)
235
+ if step % 100 == 0:
236
+ self.updateProgress(step / float(len(quesIds)))
237
+ step = step + 1
238
+
239
+ self.setAccuracy(accQA, accQuesType, accAnsType)
240
+ print("Done computing accuracy")
241
+
242
+ def processPunctuation(self, inText):
243
+ outText = inText
244
+ for p in self.punct:
245
+ if (p + " " in inText or " " + p in inText) or (
246
+ re.search(self.commaStrip, inText) != None
247
+ ):
248
+ outText = outText.replace(p, "")
249
+ else:
250
+ outText = outText.replace(p, " ")
251
+ outText = self.periodStrip.sub("", outText, re.UNICODE)
252
+ return outText
253
+
254
+ def processDigitArticle(self, inText):
255
+ outText = []
256
+ tempText = inText.lower().split()
257
+ for word in tempText:
258
+ word = self.manualMap.setdefault(word, word)
259
+ if word not in self.articles:
260
+ outText.append(word)
261
+ else:
262
+ pass
263
+ for wordId, word in enumerate(outText):
264
+ if word in self.contractions:
265
+ outText[wordId] = self.contractions[word]
266
+ outText = " ".join(outText)
267
+ return outText
268
+
269
+ def setAccuracy(self, accQA, accQuesType, accAnsType):
270
+ self.accuracy["overall"] = round(100 * float(sum(accQA)) / len(accQA), self.n)
271
+ self.accuracy["perQuestionType"] = {
272
+ quesType: round(
273
+ 100 * float(sum(accQuesType[quesType])) / len(accQuesType[quesType]),
274
+ self.n,
275
+ )
276
+ for quesType in accQuesType
277
+ }
278
+ self.accuracy["perAnswerType"] = {
279
+ ansType: round(
280
+ 100 * float(sum(accAnsType[ansType])) / len(accAnsType[ansType]), self.n
281
+ )
282
+ for ansType in accAnsType
283
+ }
284
+
285
+ def setEvalQA(self, quesId, acc):
286
+ self.evalQA[quesId] = round(100 * acc, self.n)
287
+
288
+ def setEvalQuesType(self, quesId, quesType, acc):
289
+ if quesType not in self.evalQuesType:
290
+ self.evalQuesType[quesType] = {}
291
+ self.evalQuesType[quesType][quesId] = round(100 * acc, self.n)
292
+
293
+ def setEvalAnsType(self, quesId, ansType, acc):
294
+ if ansType not in self.evalAnsType:
295
+ self.evalAnsType[ansType] = {}
296
+ self.evalAnsType[ansType][quesId] = round(100 * acc, self.n)
297
+
298
+ def updateProgress(self, progress):
299
+ barLength = 20
300
+ status = ""
301
+ if isinstance(progress, int):
302
+ progress = float(progress)
303
+ if not isinstance(progress, float):
304
+ progress = 0
305
+ status = "error: progress var must be float\r\n"
306
+ if progress < 0:
307
+ progress = 0
308
+ status = "Halt...\r\n"
309
+ if progress >= 1:
310
+ progress = 1
311
+ status = "Done...\r\n"
312
+ block = int(round(barLength * progress))
313
+ text = "\rFinshed Percent: [{0}] {1}% {2}".format(
314
+ "#" * block + "-" * (barLength - block), int(progress * 100), status
315
+ )
316
+ sys.stdout.write(text)
317
+ sys.stdout.flush()
run_eval.py ADDED
@@ -0,0 +1,702 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import json
3
+
4
+ import torch
5
+ from transformers import AutoTokenizer, AutoModel
6
+ import math
7
+ import time
8
+ import argparse
9
+ import yaml
10
+ import os
11
+
12
+ from eval.eval_dataset import get_task_dataloader, isNumber, get_anls_score
13
+ from eval.mmmu_utils import CAT_SHORT2LONG, DOMAIN_CAT2SUB_CAT, parse_multi_choice_response, parse_open_response, \
14
+ process_single_sample, construct_prompt, mmmu_main_eval, process_single_sample_pro, construct_prompt_pro
15
+ from eval.mmmu_utils import evaluate as evaluate_mmmu
16
+ from pycocotools.coco import COCO
17
+ from pycocoevalcap.eval import COCOEvalCap
18
+ from anls import anls_score
19
+ import pandas as pd
20
+ from eval.vqa_utils import VQAEval
21
+
22
+
23
+ def split_model():
24
+ device_map = {}
25
+ world_size = torch.cuda.device_count()
26
+ num_layers = 80
27
+ # Since the first GPU will be used for ViT, treat it as half a GPU.
28
+ num_layers_per_gpu = math.ceil(num_layers / (world_size - 0.5))
29
+ num_layers_per_gpu = [num_layers_per_gpu] * world_size
30
+ num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * 0.5)
31
+ layer_cnt = 0
32
+ for i, num_layer in enumerate(num_layers_per_gpu):
33
+ for j in range(num_layer):
34
+ device_map[f'language_model.model.layers.{layer_cnt}'] = i
35
+ layer_cnt += 1
36
+ device_map['vision_model'] = 0
37
+ device_map['mlp1'] = 0
38
+ device_map['language_model.model.tok_embeddings'] = 0
39
+ device_map['language_model.model.embed_tokens'] = 0
40
+ device_map['language_model.output'] = 0
41
+ device_map['language_model.model.norm'] = 0
42
+ device_map['language_model.lm_head'] = 0
43
+ device_map[f'language_model.model.layers.{num_layers - 1}'] = 0
44
+
45
+ return device_map
46
+
47
+
48
+ def generate_task_results(
49
+ task_name,
50
+ task_cfg,
51
+ args,
52
+ model,
53
+ tokenizer,
54
+ generation_config,
55
+ dataloader,
56
+ results_save_dir,
57
+ ):
58
+ results = []
59
+
60
+ if "prompt" in task_cfg:
61
+ prompt = task_cfg["prompt"]
62
+
63
+ if task_name == "mmmu" or task_name == 'mmmu_pro':
64
+ for i, (img, question_id, subfield, question_type, question, answer, index2ans, all_choices) in enumerate(
65
+ dataloader):
66
+ if img.dim() == 5:
67
+ img = img.squeeze(0)
68
+ question_id = question_id[0]
69
+ question_type = question_type[0]
70
+ question = question[0]
71
+ answer = answer[0]
72
+ subfield = subfield[0]
73
+ index2ans = ast.literal_eval(index2ans[0])
74
+ all_choices = ast.literal_eval(all_choices[0])
75
+
76
+ this_prompt = question
77
+
78
+ if task_name == 'mmmu':
79
+ categories = list(CAT_SHORT2LONG.values())
80
+ new_sys_msg = "You are expert in {} ({}). Read the image and use your knowledge in {} ({}) to answer the question."
81
+ for c in categories:
82
+ if c in question_id:
83
+ cat = c.lower().replace('_', ' ')
84
+ new_sys_msg = new_sys_msg.format(cat, subfield, cat, subfield)
85
+ break
86
+ else:
87
+ new_sys_msg = f"You are expert in {subfield}. Read the image and use your knowledge in {subfield} to answer the question."
88
+ model.system_message = new_sys_msg
89
+
90
+ generated_answer = model.chat(tokenizer, img, question, generation_config)
91
+ print("generated_answer", generated_answer)
92
+
93
+ results.append({
94
+ "question_id": question_id,
95
+ "question_type": question_type,
96
+ "answer": generated_answer,
97
+ "gt_answer": answer,
98
+ "index2ans": index2ans,
99
+ "all_choices": all_choices
100
+ })
101
+ print("idx:", i, results[-1])
102
+
103
+
104
+ elif task_name == "coco_caption" or task_name == "flickr30k_caption":
105
+
106
+ for i, (image_id, img,) in enumerate(dataloader):
107
+
108
+ if img.dim() == 5:
109
+ img = img.squeeze(0)
110
+
111
+ generated_answer = model.chat_without_chat_prompt(tokenizer, img, prompt, generation_config)
112
+ print("generated_answer", generated_answer)
113
+
114
+ results.append({"image_id": image_id.item(), "caption": generated_answer})
115
+ print("idx:", i, results[-1])
116
+
117
+ elif task_name in ["vqav2", "gqa", "okvqa", "textvqa", "chartqa", "docvqa", "realworldqa",
118
+ "ai2diagram", "ai2diagram_nomask", "docvqa_test"]:
119
+ for i, (img, question_id, question, answer,) in enumerate(dataloader):
120
+ if img.dim() == 5:
121
+ img = img.squeeze(0)
122
+
123
+ # Only works for batch size = 1
124
+ # need to implement collate function when we use bs > 1 in the future
125
+ question = question[0]
126
+ if task_name == 'ai2diagram' or task_name == 'ai2diagram_nomask':
127
+ question_id = question_id[0]
128
+ else:
129
+ question_id = question_id[0].item()
130
+
131
+ if type(answer) == list:
132
+ answer = [ans[0] for ans in answer]
133
+ else:
134
+ answer = [answer[0]]
135
+
136
+ # Need to change if using batch size > 1 in the future
137
+ this_prompt = prompt.format(question)
138
+
139
+ generated_answer = model.chat_without_chat_prompt(tokenizer, img, this_prompt, generation_config)
140
+ print("generated_answer", generated_answer)
141
+
142
+ results.append({"question_id": question_id, "answer": generated_answer, "gt_answer": answer})
143
+ print("idx:", i, results[-1])
144
+
145
+
146
+ elif task_name == "ocrbench":
147
+ for i, (img, question_id, question, answer, dataset_name, data_type,) in enumerate(
148
+ dataloader):
149
+ if img.dim() == 5:
150
+ img = img.squeeze(0)
151
+
152
+ question_id = question_id[0]
153
+ question = question[0]
154
+ answer = answer[0]
155
+ dataset_name = dataset_name[0]
156
+ data_type = data_type[0]
157
+
158
+ this_prompt = prompt.format(question)
159
+
160
+ generated_answer = model.chat_without_chat_prompt(tokenizer, img, this_prompt, generation_config)
161
+ print("generated_answer", generated_answer)
162
+
163
+ results.append({"question_id": question_id, "answer": generated_answer, "gt_answer": answer,
164
+ "dataset_name": dataset_name, "type": data_type})
165
+ print("idx:", i, results[-1])
166
+
167
+
168
+ elif task_name == "mathvista":
169
+ for i, (
170
+ img, question_id, question_type, question, answer, index2ans, all_choices,) in enumerate(
171
+ dataloader):
172
+ if img.dim() == 5:
173
+ img = img.squeeze(0)
174
+
175
+ question_id = question_id[0]
176
+ question = question[0]
177
+ question_type = question_type[0]
178
+ answer = answer[0]
179
+
180
+ index2ans = ast.literal_eval(index2ans[0])
181
+ all_choices = ast.literal_eval(all_choices[0])
182
+
183
+ this_prompt = prompt.format(question)
184
+
185
+ generated_answer = model.chat_without_chat_prompt(tokenizer, img, this_prompt, generation_config)
186
+ print("generated_answer", generated_answer)
187
+
188
+ results.append({
189
+ "question_id": question_id,
190
+ "question_type": question_type,
191
+ "answer": generated_answer,
192
+ "gt_answer": answer,
193
+ "index2ans": index2ans,
194
+ "all_choices": all_choices
195
+ })
196
+ print("idx:", i, results[-1])
197
+
198
+
199
+ elif task_name == "mmbench":
200
+ for i, (
201
+ img, question_id, question, answer, index2ans, all_choices, original_q,) in enumerate(
202
+ dataloader):
203
+ if img.dim() == 5:
204
+ img = img.squeeze(0)
205
+
206
+ question_id = question_id[0]
207
+ question = question[0]
208
+ answer = answer[0]
209
+ original_q = original_q[0]
210
+
211
+ index2ans = ast.literal_eval(index2ans[0])
212
+ all_choices = ast.literal_eval(all_choices[0])
213
+
214
+ this_prompt = prompt.format(question)
215
+
216
+ generated_answer = model.chat_without_chat_prompt(tokenizer, img, this_prompt, generation_config)
217
+ print("generated_answer", generated_answer)
218
+
219
+ results.append({
220
+ "question_id": question_id.item(),
221
+ "question": original_q,
222
+ "answer": generated_answer,
223
+ "gt_answer": answer,
224
+ "index2ans": index2ans,
225
+ "all_choices": all_choices
226
+ })
227
+ print("idx:", i, results[-1])
228
+
229
+
230
+ elif task_name == "vizwiz":
231
+ for i, (img, question_id, question,) in enumerate(dataloader):
232
+ if img.dim() == 5:
233
+ img = img.squeeze(0)
234
+
235
+ question = question[0]
236
+ question_id = question_id[0]
237
+
238
+ this_prompt = prompt.format(question)
239
+
240
+ generated_answer = model.chat_without_chat_prompt(tokenizer, img, this_prompt, generation_config)
241
+ print("generated_answer", generated_answer)
242
+
243
+ results.append({"image": question_id, "answer": generated_answer})
244
+ print("idx:", i, results[-1])
245
+
246
+ else:
247
+ raise NotImplementedError(f"Task {task_name} is not supported yet.")
248
+
249
+ os.makedirs(results_save_dir, exist_ok=True)
250
+ if args.subset is not None:
251
+ results_save_path = os.path.join(results_save_dir,
252
+ f"eval_{task_name}_subset_{args.subset}_start_{args.start_idx}.json")
253
+ else:
254
+ results_save_path = os.path.join(results_save_dir,
255
+ f"eval_{task_name}.json")
256
+ print("Saving to ", results_save_path)
257
+ json.dump(results, open(results_save_path, 'w'))
258
+
259
+
260
+ def calc_task_metrics(args, task_name, task_cfg, results_save_dir):
261
+ if args.subset is not None:
262
+ merged_results_path = os.path.join(results_save_dir,
263
+ f"eval_{task_name}_subset_{args.subset}_start_{args.start_idx}.json")
264
+ else:
265
+ merged_results_path = os.path.join(results_save_dir,
266
+ f"eval_{task_name}.json")
267
+ merged_results = json.load(open(merged_results_path, "r"))
268
+
269
+ if task_name == "coco_caption" or task_name == "flickr30k_caption":
270
+ # Calculate scores
271
+ coco = COCO(task_cfg["gt_path"])
272
+ coco_result = coco.loadRes(merged_results_path)
273
+ coco_eval = COCOEvalCap(coco, coco_result)
274
+ coco_eval.params["image_id"] = coco_result.getImgIds()
275
+ coco_eval.evaluate()
276
+
277
+ # Print and save scores
278
+ print(f"====== {task_name} scores ======")
279
+ with open(os.path.join(results_save_dir, "scores.txt"), "a") as f:
280
+ f.write(f"{task_name} scores:\n")
281
+ for k, v in coco_eval.eval.items():
282
+ msg = f"{k} = {v * 100:.3f}"
283
+ print(msg)
284
+ f.write(msg + "\n")
285
+ f.write("\n")
286
+ return coco_eval.eval["CIDEr"]
287
+ elif task_name == "vqav2" or task_name == "okvqa":
288
+ vqa_tool = VQAEval()
289
+ all_acc = []
290
+ for res in merged_results:
291
+ pred_ans = res["answer"]
292
+ pred_ans = vqa_tool.processPunctuation(pred_ans)
293
+ pred_ans = vqa_tool.processDigitArticle(pred_ans)
294
+
295
+ gt_ans = res["gt_answer"]
296
+ gt_ans = [vqa_tool.processPunctuation(ans) for ans in gt_ans]
297
+ gt_ans = [vqa_tool.processDigitArticle(ans) for ans in gt_ans]
298
+
299
+ num_match = sum([pred_ans == ans for ans in gt_ans])
300
+ acc = min(1.0, num_match / 3.0)
301
+ all_acc.append(acc)
302
+ acc_avg = sum(all_acc) / len(all_acc) * 100
303
+ print(f"===== {task_name} Accuracy {acc_avg:.2f}% =====")
304
+ with open(os.path.join(results_save_dir, "scores.txt"), "a") as f:
305
+ f.write(f"{task_name} Accuracy = {acc_avg:.2f}%\n\n")
306
+ return acc_avg
307
+
308
+ elif task_name == "textvqa":
309
+ vqa_tool = VQAEval()
310
+ all_acc = []
311
+ for res in merged_results:
312
+ pred_ans = res["answer"]
313
+ pred_ans = vqa_tool.processPunctuation(pred_ans)
314
+ pred_ans = vqa_tool.processDigitArticle(pred_ans)
315
+
316
+ gt_ans = res["gt_answer"]
317
+ gt_ans = [vqa_tool.processPunctuation(ans) for ans in gt_ans]
318
+ gt_ans = [vqa_tool.processDigitArticle(ans) for ans in gt_ans]
319
+
320
+ num_match = sum([pred_ans == ans for ans in gt_ans])
321
+ acc = min(1.0, num_match / 3.0)
322
+ all_acc.append(acc)
323
+ acc_avg = sum(all_acc) / len(all_acc) * 100
324
+ print(
325
+ f"===== {task_name} Accuracy {acc_avg:.2f}% (need to submit to EvalAI for the accurate accuracy) =====")
326
+ with open(os.path.abspath(os.path.join(results_save_dir, "scores.txt")), "a") as f:
327
+ f.write(
328
+ f"{task_name} Accuracy = {acc_avg:.2f}% (need to submit to EvalAI for the accurate accuracy)\n\n")
329
+ return acc_avg
330
+
331
+ elif task_name in ["gqa", "realworldqa", "ai2diagram", "ai2diagram_nomask"]:
332
+ vqa_tool = VQAEval()
333
+ acc = 0
334
+ for res in merged_results:
335
+ pred_ans = res["answer"]
336
+ pred_ans = vqa_tool.processPunctuation(pred_ans)
337
+ pred_ans = vqa_tool.processDigitArticle(pred_ans)
338
+
339
+ gt_ans = res["gt_answer"][0]
340
+ gt_ans = vqa_tool.processPunctuation(gt_ans)
341
+ gt_ans = vqa_tool.processDigitArticle(gt_ans)
342
+
343
+ if pred_ans == gt_ans:
344
+ acc += 1
345
+ acc = acc / len(merged_results) * 100
346
+ print(f"===== {task_name} Accuracy {acc:.2f}% =====")
347
+ with open(os.path.abspath(os.path.join(results_save_dir, "scores.txt")), "a") as f:
348
+ f.write(f"{task_name} Accuracy = {acc:.2f}%\n\n")
349
+ return acc
350
+
351
+ elif task_name == "docvqa" or task_name == "docvqa_test":
352
+ vqa_tool = VQAEval()
353
+ anls = 0
354
+ for res in merged_results:
355
+ pred_ans = res["answer"]
356
+ gt_ans = res["gt_answer"]
357
+ anls += get_anls_score(pred=pred_ans, gold_labels=gt_ans, threshold=0.5)
358
+ anls = anls / len(merged_results) * 100
359
+ print(f"===== {task_name} ANLS {anls:.2f}% =====")
360
+ with open(os.path.abspath(os.path.join(results_save_dir, "scores.txt")), "a") as f:
361
+ f.write(f" {task_name} ANLS = {anls:.2f}%\n\n")
362
+ return anls
363
+
364
+ elif task_name == "chartqa":
365
+ vqa_tool = VQAEval()
366
+ acc = 0
367
+ for res in merged_results:
368
+ pred_ans = res["answer"]
369
+ pred_ans = vqa_tool.processPunctuation(pred_ans)
370
+ pred_ans = vqa_tool.processDigitArticle(pred_ans)
371
+
372
+ gt_ans = res["gt_answer"][0]
373
+ gt_ans = vqa_tool.processPunctuation(gt_ans)
374
+ gt_ans = vqa_tool.processDigitArticle(gt_ans)
375
+
376
+ # ChartQA uses relaxed accuracy:
377
+ # "We consider an answer to be correct if it is within 5% of the gold answer.
378
+ # For non-numeric answers, we still need an exact match to consider an answer to be correct."
379
+ if isNumber(pred_ans) and isNumber(gt_ans):
380
+ pred_ans = float(pred_ans)
381
+ gt_ans = float(gt_ans)
382
+ if pred_ans >= (gt_ans * 0.95) and pred_ans <= (gt_ans * 1.05):
383
+ acc += 1
384
+ elif pred_ans == gt_ans:
385
+ acc += 1
386
+ acc = acc / len(merged_results) * 100
387
+ print(f"===== {task_name} Accuracy {acc:.2f}% =====")
388
+ with open(os.path.abspath(os.path.join(results_save_dir, "scores.txt")), "a") as f:
389
+ f.write(f"{task_name} Accuracy = {acc:.2f}%\n\n")
390
+ return acc
391
+
392
+ elif task_name == 'ocrbench':
393
+
394
+ from collections import defaultdict
395
+ OCRBench_score = {"Regular Text Recognition": 0, "Irregular Text Recognition": 0,
396
+ "Artistic Text Recognition": 0, "Handwriting Recognition": 0,
397
+ "Digit String Recognition": 0, "Non-Semantic Text Recognition": 0,
398
+ "Scene Text-centric VQA": 0, "Doc-oriented VQA": 0, "Doc-oriented VQA": 0,
399
+ "Key Information Extraction": 0, "Handwritten Mathematical Expression Recognition": 0}
400
+
401
+ for res in merged_results:
402
+ predict = res["answer"]
403
+ answers = res["gt_answer"]
404
+
405
+ dataset_name = res["dataset_name"]
406
+ ocr_type = res["type"]
407
+
408
+ # data[i]['result'] = 0
409
+ if dataset_name == "HME100k":
410
+ if type(answers) == list:
411
+ for j in range(len(answers)):
412
+ answer = answers[j].strip().replace("\n", " ").replace(" ", "")
413
+ predict = predict.strip().replace("\n", " ").replace(" ", "")
414
+ if answer in predict:
415
+ OCRBench_score[ocr_type] += 1
416
+ else:
417
+ answers = answers.strip().replace("\n", " ").replace(" ", "")
418
+ predict = predict.strip().replace("\n", " ").replace(" ", "")
419
+ if answers in predict:
420
+ OCRBench_score[ocr_type] += 1
421
+ else:
422
+ if type(answers) == list:
423
+ for j in range(len(answers)):
424
+ answer = answers[j].lower().strip().replace("\n", " ")
425
+ predict = predict.lower().strip().replace("\n", " ")
426
+ if answer in predict:
427
+ OCRBench_score[ocr_type] += 1
428
+ else:
429
+ answers = answers.lower().strip().replace("\n", " ")
430
+ predict = predict.lower().strip().replace("\n", " ")
431
+ if answers in predict:
432
+ OCRBench_score[ocr_type] += 1
433
+
434
+ recognition_score = OCRBench_score['Regular Text Recognition'] + OCRBench_score['Irregular Text Recognition'] + \
435
+ OCRBench_score['Artistic Text Recognition'] + OCRBench_score['Handwriting Recognition'] + \
436
+ OCRBench_score['Digit String Recognition'] + OCRBench_score['Non-Semantic Text Recognition']
437
+ Final_score = recognition_score + OCRBench_score['Scene Text-centric VQA'] + OCRBench_score[
438
+ 'Doc-oriented VQA'] + OCRBench_score['Key Information Extraction'] + OCRBench_score[
439
+ 'Handwritten Mathematical Expression Recognition']
440
+ result_log = f"\n###########################OCRBench##############################\n\
441
+ Text Recognition(Total 300):{recognition_score}\n\
442
+ ------------------Details of Recognition Score-------------------\n\
443
+ Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}\n\
444
+ Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}\n\
445
+ Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}\n\
446
+ Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}\n\
447
+ Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}\n\
448
+ Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}\n\
449
+ ----------------------------------------------------------------\n\
450
+ Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}\n\
451
+ ----------------------------------------------------------------\n\
452
+ Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}\n\
453
+ ----------------------------------------------------------------\n\
454
+ Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}\n\
455
+ ----------------------------------------------------------------\n\
456
+ Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}\n\
457
+ ----------------------Final Score-------------------------------\n\
458
+ Final Score(Total 1000): {Final_score}\n"
459
+
460
+ print(f"===== {task_name} Final_score(Total 1000): {Final_score:.2f}% ===== {result_log}")
461
+ with open(os.path.abspath(os.path.join(results_save_dir, "scores.txt")), "a") as f:
462
+ f.write(f" {task_name} Accuracy = {Final_score:.2f}% {result_log}\n\n")
463
+ return Final_score
464
+
465
+
466
+ elif task_name == "mathvista":
467
+ import re
468
+ def extra_processing(text):
469
+
470
+ ## max decimal point capped to 2 decimal point
471
+ regex = re.compile(r'^\d+\.\d+$')
472
+ decimal = regex.findall(text)
473
+
474
+ if len(decimal) > 0:
475
+ non_decimal = len(decimal[0].split(".")[0])
476
+
477
+ # if decimal values are all 0, trim them
478
+
479
+ decimal_digits = [int(d) for d in decimal[0].split(".")[1]]
480
+ if sum(decimal_digits) == 0:
481
+ text = decimal[0][:non_decimal]
482
+ else:
483
+ text = decimal[0][:non_decimal + 3]
484
+
485
+ ## remove %
486
+ text = text.replace("%", "")
487
+ try:
488
+ if text[-1] == ".":
489
+ text = text[:-1]
490
+ except:
491
+ print(text)
492
+
493
+ return text
494
+
495
+ def extract_answer(text):
496
+
497
+ alphabet = re.findall(r'[a-zA-Z]+', text)
498
+ if len(alphabet) > 0 and "e+" not in text:
499
+
500
+ template_1 = re.findall(r'answer is -*\d+\.*\d*', text)
501
+ if len(template_1) > 0:
502
+ text = template_1[0]
503
+
504
+ numbers = re.findall(r'-*\d+\.*\d*', text)
505
+ text = numbers[0] if len(numbers) > 0 else text
506
+
507
+ return text
508
+
509
+ vqa_tool = VQAEval()
510
+ acc = 0
511
+ for res in merged_results:
512
+ pred_ans = res["answer"]
513
+ if res['question_type'] == 'multi_choice':
514
+ pred_ans = parse_multi_choice_response(pred_ans, res['all_choices'], res['index2ans'])
515
+ else:
516
+ pred_ans = vqa_tool.processPunctuation(pred_ans)
517
+ pred_ans = vqa_tool.processDigitArticle(pred_ans)
518
+
519
+ pred_ans = extra_processing(pred_ans)
520
+ pred_ans = extract_answer(pred_ans)
521
+
522
+ gt_ans = res["gt_answer"]
523
+ if res['question_type'] != 'multi_choice':
524
+ gt_ans = vqa_tool.processPunctuation(gt_ans)
525
+ gt_ans = vqa_tool.processDigitArticle(gt_ans)
526
+ gt_ans = extra_processing(gt_ans)
527
+
528
+ if pred_ans == gt_ans:
529
+ acc += 1
530
+ acc = acc / len(merged_results) * 100
531
+ print(f"===== {task_name} Accuracy {acc:.2f}% =====")
532
+ with open(os.path.abspath(os.path.join(results_save_dir, "scores.txt")), "a") as f:
533
+ f.write(f" {task_name} Accuracy = {acc:.2f}%\n\n")
534
+ return acc
535
+ elif task_name == "mmbench":
536
+ if task_cfg['split'] == 'dev':
537
+ acc = 0
538
+ for res in merged_results:
539
+ gt_ans = res["gt_answer"]
540
+ pred_ans = res["answer"]
541
+ pred_ans = parse_multi_choice_response(pred_ans, res['all_choices'], res['index2ans'])
542
+
543
+ if pred_ans == gt_ans:
544
+ acc += 1
545
+ acc = acc / len(merged_results) * 100
546
+ print(f"===== {task_name} Accuracy {acc:.2f}% =====")
547
+ with open(os.path.abspath(os.path.join(results_save_dir, "scores.txt")), "a") as f:
548
+ f.write(f"{task_name} Accuracy = {acc:.2f}%\n\n")
549
+
550
+ return acc
551
+
552
+ # Generate submission.xlsx for test set online evaluation, "https://mmbench.opencompass.org.cn/mmbench-submission"
553
+ if task_cfg['submission']:
554
+ res_df = pd.DataFrame(
555
+ {
556
+ "index": [r['question_id'] for r in merged_results],
557
+ "question": [r['question'] for r in merged_results],
558
+ "A": [r['index2ans']['A'] if 'A' in r['index2ans'] else None for r in merged_results],
559
+ "B": [r['index2ans']['B'] if 'B' in r['index2ans'] else None for r in merged_results],
560
+ "C": [r['index2ans']['C'] if 'C' in r['index2ans'] else None for r in merged_results],
561
+ "D": [r['index2ans']['D'] if 'D' in r['index2ans'] else None for r in merged_results],
562
+ "prediction": [parse_multi_choice_response(r['answer'], r['all_choices'], r['index2ans']) for r in
563
+ merged_results],
564
+ },
565
+ columns=['index', 'question', 'A', 'B', 'C', 'D', 'prediction']
566
+ )
567
+ res_df.to_excel(os.path.join(results_save_dir, "submission.xlsx"))
568
+
569
+
570
+ elif task_name == "vizwiz":
571
+ print(
572
+ f"VizWiz result file is saved at: {merged_results_path}\n"
573
+ f"Upload manually or use this CLI `python evaluation/upload_vizwiz.py --result {merged_results_path} --token <your_eval_user_token>`."
574
+ )
575
+
576
+ elif task_name == "mmmu" or task_name == 'mmmu_pro':
577
+ def extract_answer(text):
578
+ import re
579
+ # Regular expression to find content inside \answer{xxx}
580
+ match = re.search(r'\\answer\{(.*?)\}', text)
581
+ if match:
582
+ return match.group(1) # Return the content inside the braces
583
+ return text # Return the original string if no match is found
584
+
585
+ eval_samples = []
586
+ eval_output_dict = {}
587
+ for res in merged_results:
588
+ pred_ans = res["answer"]
589
+ gt_ans = res['gt_answer']
590
+ if res['question_type'] == 'multiple-choice':
591
+ parsed_pred = parse_multi_choice_response(pred_ans, res['all_choices'], res['index2ans'])
592
+ eval_samples.append(
593
+ {
594
+ 'id': res['question_id'],
595
+ 'question_type': res['question_type'],
596
+ 'answer': res['gt_answer'], # the content in option, not answer index.
597
+ 'response': pred_ans,
598
+ 'parsed_pred': parsed_pred,
599
+ 'index2ans': res['index2ans'],
600
+ }
601
+ )
602
+ eval_output_dict[res['question_id']] = parsed_pred
603
+ else:
604
+ pred_ans = extract_answer(pred_ans) # for sft v9prov3, we observe answers are within \answer{xxx}
605
+ parsed_pred = parse_open_response(pred_ans)
606
+ eval_samples.append(
607
+ {
608
+ 'id': res['question_id'],
609
+ 'question_type': res['question_type'],
610
+ 'answer': res['gt_answer'],
611
+ 'response': pred_ans,
612
+ 'parsed_pred': parsed_pred,
613
+ }
614
+ )
615
+ eval_output_dict[res['question_id']] = pred_ans
616
+ if args.subset is not None:
617
+ eval_output_dict_path = os.path.join(results_save_dir,
618
+ f"eval_output_dict_{task_name}_subset_{args.subset}_start_{args.start_idx}.json")
619
+ else:
620
+ eval_output_dict_path = os.path.join(results_save_dir, f"eval_output_dict_{task_name}.json")
621
+ json.dump(eval_output_dict, open(eval_output_dict_path, "w"), indent=4, sort_keys=True)
622
+
623
+ mmmu_results = mmmu_main_eval(eval_output_dict, task_cfg)
624
+ with open(os.path.join(results_save_dir, "scores.txt"), "a") as f:
625
+ f.write(f"{task_name} {task_cfg['split']}:\n")
626
+ for cat, cat_val in mmmu_results.items():
627
+ if 'Overall' in cat:
628
+ cat = cat.replace("Overall-", "")
629
+ print(f'{cat}: {cat_val["acc"] * 100:.2f}')
630
+ f.write(f'{cat}: {cat_val["acc"] * 100:.2f}\n')
631
+ return mmmu_results['Overall']['acc']
632
+
633
+ else:
634
+ raise NotImplementedError(f"Task {task_name} is not supported yet.")
635
+
636
+
637
+ def evaluate(args, model, tokenizer, tasks):
638
+ start_time = time.time()
639
+
640
+ results_save_dir = os.path.join(args.result_save_path)
641
+ if not os.path.exists(results_save_dir):
642
+ os.makedirs(results_save_dir)
643
+
644
+ task_names = list(tasks.keys())
645
+
646
+ print(f"Evaluating tasks: {task_names}")
647
+
648
+ with torch.no_grad():
649
+ # for task_name in task_names:
650
+ for task_name in args.zero_shot_eval_tasks:
651
+ task_cfg = tasks[task_name]
652
+ print(f"Preparing dataloader for {task_name}...")
653
+
654
+ dataloader = get_task_dataloader(task_name, task_cfg, args)
655
+
656
+ generation_config = dict(max_new_tokens=1024, do_sample=False)
657
+
658
+ print("Start generating...")
659
+ generate_task_results(task_name, task_cfg, args, model, tokenizer, generation_config,
660
+ dataloader, results_save_dir)
661
+
662
+ print("Start calculating task metric...")
663
+ calc_task_metrics(args, task_name, task_cfg, results_save_dir)
664
+
665
+ end_time = time.time()
666
+
667
+ print(f"Evaluation takes {(end_time - start_time) / 60:.1f} minutes in total!")
668
+
669
+
670
+ if __name__ == '__main__':
671
+ path = "nvidia/NVLM-D-72B"
672
+ print("Loading model... from", path)
673
+ device_map = split_model()
674
+
675
+ start = time.time()
676
+ model = AutoModel.from_pretrained(
677
+ path,
678
+ torch_dtype=torch.bfloat16,
679
+ low_cpu_mem_usage=True,
680
+ use_flash_attn=True,
681
+ device_map=device_map,
682
+ trust_remote_code=True).eval()
683
+ end = time.time()
684
+ print("loading model takes:", end - start)
685
+
686
+ print(model)
687
+
688
+ tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)
689
+
690
+ parser = argparse.ArgumentParser()
691
+ parser.add_argument('--config-path', type=str, required=True,
692
+ help='YAML file configuring evaluation datasets')
693
+ parser.add_argument('--result-save-path', type=str, default=os.path.join(path, '/eval_results'))
694
+ parser.add_argument('--zero-shot-eval-tasks', nargs='+', type=str, default=['mmmu'])
695
+ parser.add_argument('--start-idx', type=int, default=0)
696
+ parser.add_argument('--subset', type=int, default=None)
697
+
698
+ args = parser.parse_args()
699
+
700
+ tasks = yaml.safe_load(open(args.config_path))['datasets']
701
+
702
+ evaluate(args, model, tokenizer, tasks)