import json import os import sys import time import yaml import spacy import ast from PIL import Image from glob import glob from tqdm import tqdm from collections import defaultdict import pandas as pd from io import BytesIO import base64 from anls import anls_score import torch from torch.utils.data import Dataset, DataLoader, DistributedSampler import torchvision.transforms as T from eval import conversation as conversation_lib from eval.mmmu_utils import CAT_SHORT2LONG, DOMAIN_CAT2SUB_CAT, parse_multi_choice_response, parse_open_response, \ process_single_sample, construct_prompt, mmmu_main_eval, process_single_sample_pro, construct_prompt_pro from eval.mmmu_utils import evaluate as evaluate_mmmu from torchvision.transforms.functional import InterpolationMode from datasets import load_dataset, concatenate_datasets IMAGENET_MEAN = (0.485, 0.456, 0.406) IMAGENET_STD = (0.229, 0.224, 0.225) def build_transform(input_size): MEAN, STD = IMAGENET_MEAN, IMAGENET_STD transform = T.Compose([ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), T.ToTensor(), T.Normalize(mean=MEAN, std=STD) ]) return transform def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): best_ratio_diff = float('inf') best_ratio = (1, 1) area = width * height for ratio in target_ratios: target_aspect_ratio = ratio[0] / ratio[1] ratio_diff = abs(aspect_ratio - target_aspect_ratio) if ratio_diff < best_ratio_diff: best_ratio_diff = ratio_diff best_ratio = ratio elif ratio_diff == best_ratio_diff: if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: best_ratio = ratio return best_ratio def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False): orig_width, orig_height = image.size aspect_ratio = orig_width / orig_height # calculate the existing image aspect ratio target_ratios = set( (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if i * j <= max_num and i * j >= min_num) target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) # find the closest aspect ratio to the target target_aspect_ratio = find_closest_aspect_ratio( aspect_ratio, target_ratios, orig_width, orig_height, image_size) # calculate the target width and height target_width = image_size * target_aspect_ratio[0] target_height = image_size * target_aspect_ratio[1] blocks = target_aspect_ratio[0] * target_aspect_ratio[1] # resize the image resized_img = image.resize((target_width, target_height)) processed_images = [] for i in range(blocks): box = ( (i % (target_width // image_size)) * image_size, (i // (target_width // image_size)) * image_size, ((i % (target_width // image_size)) + 1) * image_size, ((i // (target_width // image_size)) + 1) * image_size ) # split the image split_img = resized_img.crop(box) processed_images.append(split_img) assert len(processed_images) == blocks if use_thumbnail and len(processed_images) != 1: thumbnail_img = image.resize((image_size, image_size)) processed_images.append(thumbnail_img) return processed_images def load_image(image, input_size=448, max_num=6, decoded=False): if not decoded: image = Image.open(image).convert('RGB') transform = build_transform(input_size=input_size) images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num) pixel_values = [transform(image) for image in images] pixel_values = torch.stack(pixel_values) return pixel_values def levenshtein_distance(s1, s2): if len(s1) > len(s2): s1, s2 = s2, s1 distances = range(len(s1) + 1) for i2, c2 in enumerate(s2): distances_ = [i2 + 1] for i1, c1 in enumerate(s1): if c1 == c2: distances_.append(distances[i1]) else: distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1]))) distances = distances_ return distances[-1] def get_anls_score(pred, gold_labels, threshold, llava_eval=False): values = [] for answer in gold_labels: # preprocess both the answers - gt and prediction gt_answer = ' '.join(answer.strip().lower().split()) det_answer = ' '.join(pred.strip().lower().split()) dist = levenshtein_distance(gt_answer, det_answer) length = max(len(answer.upper()), len(pred.upper())) values.append(0.0 if length == 0 else float(dist) / float(length)) question_result = 1 - min(values) if llava_eval: question_result = 1.0 if question_result >= threshold else 0.0 else: if (question_result < threshold): question_result = 0 return question_result def isNumber(n: str): try: float(n) return True except ValueError: return False class COCOEvalDataset(Dataset): def __init__(self, args, img_dir, subset=None): self.args = args self.img_files = sorted(glob(os.path.join(img_dir, "*"))) if subset: self.img_files = self.img_files[:subset] self.image_ids = [int(img_file.split("_")[-1].split(".")[0]) for img_file in self.img_files] def __len__(self): return len(self.img_files) def __getitem__(self, idx): img_path = self.img_files[idx] img = load_image(img_path, max_num=6).to(torch.bfloat16) return self.image_ids[idx], img class Flickr30KEvalDataset(Dataset): def __init__(self, args, img_dir, subset=None): self.args = args self.img_dir = img_dir self.test_samples = json.load(open(os.path.join(img_dir, "flickr30k_test.json"), encoding='utf-8')) if subset: self.test_samples = self.test_samples[:subset] def __len__(self): return len(self.test_samples) def __getitem__(self, idx): img_path = os.path.join(self.img_dir, self.test_samples[idx]["image"]) img = load_image(img_path, max_num=6).to(torch.bfloat16) image_id = int(self.test_samples[idx]["image"].split("/")[-1].replace(".jpg", "")) return image_id, img class VQAv2EvalDataset(Dataset): def __init__(self, args, img_dir, gt_path, subset=None): self.args = args self.img_dir = img_dir self.gt = json.load(open(gt_path, encoding='utf-8')) if subset: self.gt = self.gt[:subset] def __len__(self): return len(self.gt) def __getitem__(self, idx): img_path = os.path.join(self.img_dir, self.gt[idx]["image"]) img = load_image(img_path, max_num=6).to(torch.bfloat16) question_id = self.gt[idx]["question_id"] question = self.gt[idx]["question"] answer = self.gt[idx]["answer"] return img, question_id, question, answer class TextVQAEvalDataset(Dataset): def __init__(self, args, img_dir, gt_path, subset=None): self.args = args self.img_dir = img_dir self.gt = json.load(open(gt_path, encoding='utf-8'))['data'] if subset: self.gt = self.gt[:subset] def __len__(self): return len(self.gt) def __getitem__(self, idx): img_path = os.path.join(self.img_dir, self.gt[idx]["image_id"] + '.jpg') if not os.path.exists(img_path): img_path = img_path.replace('.jpg', '.png') img = load_image(img_path, max_num=6).to(torch.bfloat16) question_id = self.gt[idx]["question_id"] question = self.gt[idx]["question"] answer = self.gt[idx]["answers"] return img, question_id, question, answer class GQAEvalDataset(Dataset): def __init__(self, args, img_dir, gt_path, subset=None): self.args = args self.img_dir = img_dir self.gt = json.load(open(gt_path, encoding='utf-8')) self.gt = [{ "question_id": int(k), "image": v['imageId'] + ".jpg", "question": v['question'], "answer": v['answer'] } for k, v in self.gt.items()] if subset: self.gt = self.gt[:subset] def __len__(self): return len(self.gt) def __getitem__(self, idx): img_path = os.path.join(self.img_dir, self.gt[idx]["image"]) img = load_image(img_path, max_num=6).to(torch.bfloat16) question_id = self.gt[idx]["question_id"] question = self.gt[idx]["question"] answer = self.gt[idx]["answer"] return img, question_id, question, [answer] class ChartQAEvalDataset(Dataset): def __init__(self, args, img_dir, gt_path, subset=None): self.args = args self.img_dir = img_dir self.gt = json.load(open(gt_path, encoding='utf-8')) for i in range(len(self.gt)): self.gt[i]['question_id'] = i if subset: self.gt = self.gt[:subset] def __len__(self): return len(self.gt) def __getitem__(self, idx): img_path = os.path.join(self.img_dir, self.gt[idx]["imgname"]) img = load_image(img_path, max_num=6).to(torch.bfloat16) question_id = self.gt[idx]["question_id"] question = self.gt[idx]["query"] answer = self.gt[idx]["label"] return img, question_id, question, [answer] class OKVQAEvalDataset(Dataset): def __init__(self, args, img_dir, gt_path, question_path, subset=None): self.args = args self.img_dir = img_dir self.gt = json.load(open(gt_path, encoding='utf-8'))['annotations'] self.questions = json.load(open(question_path, 'r'))['questions'] if subset: self.gt = self.gt[:subset] qid2q = {q['question_id']: q['question'] for q in self.questions} for ann in self.gt: ann['answers'] = [ans['answer'] for ans in ann['answers']] ann['question'] = qid2q[ann['question_id']] def __len__(self): return len(self.gt) def __getitem__(self, idx): img_id = str(self.gt[idx]["image_id"]) img_id = '0' * (12 - len(img_id)) + img_id img_file_name = f"COCO_val2014_{img_id}.jpg" img_path = os.path.join(self.img_dir, img_file_name) img = load_image(img_path, max_num=6).to(torch.bfloat16) question_id = self.gt[idx]["question_id"] question = self.gt[idx]["question"] answer = self.gt[idx]["answers"] return img, question_id, question, answer class DocVQAEvalDataset(Dataset): def __init__(self, args, img_dir, gt_path, split='val', subset=None): self.args = args self.img_dir = img_dir self.gt = json.load(open(gt_path, encoding='utf-8'))['data'] if subset: self.gt = self.gt[:subset] self.split = split def __len__(self): return len(self.gt) def __getitem__(self, idx): img_path = os.path.join(self.img_dir, self.gt[idx]['image'].split('/')[-1]) img = load_image(img_path, max_num=6).to(torch.bfloat16) question_id = self.gt[idx]["questionId"] question = self.gt[idx]["question"] if self.split == 'val': answer = self.gt[idx]["answers"] else: answer = [''] return img, question_id, question, answer class OCRBenchEvalDataset(Dataset): def __init__(self, args, img_dir, gt_path, subset=None): self.args = args self.img_dir = img_dir self.gt = json.load(open(gt_path, encoding='utf-8')) if subset: self.gt = self.gt[:subset] def __len__(self): return len(self.gt) def __getitem__(self, idx): img_path = os.path.join(self.img_dir, self.gt[idx]['image_path']) img = load_image(img_path, max_num=6).to(torch.bfloat16) dataset_name = self.gt[idx]["dataset_name"] question_id = f"{idx}" question = self.gt[idx]["question"] answer = self.gt[idx]["answers"] data_type = self.gt[idx]["type"] return img, question_id, question, answer, dataset_name, data_type class AI2DiagramEvalDataset(Dataset): def __init__(self, args, img_dir, gt_path, subset=None): self.args = args self.img_dir = img_dir with open(gt_path, 'r') as json_file: json_list = list(json_file) self.gt = [json.loads(json_str) for json_str in json_list] if subset: self.gt = self.gt[:subset] def __len__(self): return len(self.gt) def __getitem__(self, idx): img_path = os.path.join(self.img_dir, self.gt[idx]['image']) img = load_image(img_path, max_num=6).to(torch.bfloat16) question_id = self.gt[idx]["question_id"] question = self.gt[idx]["question"] answer = self.gt[idx]["answer"] return img, question_id, question, answer class AI2DiagramNoMaskEvalDataset(Dataset): def __init__(self, args, img_dir, gt_path, subset=None): self.args = args self.img_dir = img_dir with open(gt_path, 'r') as json_file: json_list = list(json_file) self.gt = [json.loads(json_str) for json_str in json_list] if subset: self.gt = self.gt[:subset] def __len__(self): return len(self.gt) def __getitem__(self, idx): img_file_name = self.gt[idx]['image'].replace("AI2D_TEST", "AI2D_TEST_NO_MASK_IMAGES") img_path = os.path.join(self.img_dir, img_file_name) img = load_image(img_path, max_num=6).to(torch.bfloat16) question_id = self.gt[idx]["question_id"] question = self.gt[idx]["question"] answer = self.gt[idx]["answer"] return img, question_id, question, answer class RealworldQAEvalDataset(Dataset): def __init__(self, args, img_dir, gt_path, subset=None): self.args = args self.img_dir = img_dir self.gt = json.load(open(gt_path, encoding='utf-8')) if subset: self.gt = self.gt[:subset] def __len__(self): return len(self.gt) def __getitem__(self, idx): img_path = os.path.join(self.img_dir, self.gt[idx]['image']) img = load_image(img_path, max_num=6).to(torch.bfloat16) question_id = int(self.gt[idx]['image'].replace(".webp", "")) question = self.gt[idx]["question"] if self.gt[idx]['question_type'] == "multi-choice": choices = self.gt[idx]["choices"] start_chr = 'A' choices_str = '' index2ans = {} all_choices = [] for choice in choices: all_choices.append(start_chr) index2ans[start_chr] = choice choices_str += f"{start_chr}. {choice}\n" start_chr = chr(ord(start_chr) + 1) question = question + '\n' + choices_str question = question + "Answer with the option's letter from the given choices directly." answer = chr(ord('A') + self.gt[idx]['correct_choice_index']) else: question = question + "\nAnswer the question using a single word or phrase." answer = self.gt[idx]['answer'] return img, question_id, question, [answer] class MathVistaEvalDataset(Dataset): def __init__(self, args, task_cfg, gt_path=None): self.args = args self.task_cfg = task_cfg self.dataset = load_dataset("AI4Math/MathVista")['testmini'] def __len__(self): return len(self.dataset) def __getitem__(self, idx): img = self.dataset[idx]['decoded_image'] img = load_image(img.convert("RGB"), max_num=6, decoded=True).to(torch.bfloat16) question_id = self.dataset[idx]["pid"] question = self.dataset[idx]["question"] question_type = self.dataset[idx]["question_type"] # free_form or multi_choice query = self.dataset[idx]["query"] choices = self.dataset[idx]["choices"] answer = self.dataset[idx]["answer"] if question_type == 'multi_choice': start_chr = 'A' choices_str = '' index2ans = {} all_choices = [] for choice in choices: all_choices.append(start_chr) index2ans[start_chr] = choice choices_str += f"{start_chr}. {choice}\n" start_chr = chr(ord(start_chr) + 1) question = question + '\n' + choices_str question = question + "Answer with the option's letter from the given choices directly." answer = chr(ord('A') + choices.index(answer)) else: question = query.replace("Hint: ", "") index2ans = {} all_choices = [] return img, question_id, question_type, question, answer, str(index2ans), str(all_choices) def construct_prompt_for_fewshot(sample): config = { "task_instructions": "", "multi_choice_example_format": "{}\n{}Answer with the option's letter from the given choices directly.", "short_ans_example_format": "{}\nAnswer the question using a single word or phrase." } question = sample['question'].strip() options = eval(sample['options']) example = "" if sample['question_type'] == 'multiple-choice': start_chr = 'A' prediction_range = [] index2ans = {} for option in options: prediction_range.append(start_chr) example += f"({start_chr}) {option}\n" index2ans[start_chr] = option start_chr = chr(ord(start_chr) + 1) empty_prompt_sample_structure = config['multi_choice_example_format'] empty_prompt = empty_prompt_sample_structure.format(question, example) res_dict = {'type': 'multichoice'} res_dict['index2ans'] = index2ans res_dict['correct_choice'] = sample['answer'] res_dict['all_choices'] = prediction_range res_dict['empty_prompt'] = empty_prompt if config['task_instructions']: res_dict['final_input_prompt'] = config['task_instructions'].strip() + '\n\n' + empty_prompt else: res_dict['final_input_prompt'] = empty_prompt res_dict['gt_content'] = options[ord(sample['answer'].upper()) - ord('A')] else: empty_prompt_sample_structure = config['short_ans_example_format'] empty_prompt = empty_prompt_sample_structure.format(question) res_dict = {'type': 'open'} res_dict['empty_prompt'] = empty_prompt if config['task_instructions']: res_dict['final_input_prompt'] = config['task_instructions'].strip() + '\n\n' + empty_prompt else: res_dict['final_input_prompt'] = empty_prompt res_dict['gt_content'] = sample['answer'] res_dict.update(sample) return res_dict def process_image_tag(q): q = q.strip() # heuristic way of removing if q == '': q = 'Answer the question in the image.' elif ':' in q: q = q.replace(':', ' in the image. ') q = q.strip() elif ': ' in q: q = q.replace(': ', ' in the image. ') q = q.strip() elif '.' in q or '. ' in q: q_list = q.split('') q_list = [part.strip() for part in q_list if part.strip() != ''] q = ' '.join(q_list) elif q.startswith(' '): if q[10].isupper(): q = q.replace('', '') else: q = q.replace('', 'The image') q = q.strip() elif q.startswith(''): q = q.replace('', '') elif q.endswith('?'): q = q.replace('', 'the image') elif q.endswith('?') or q.endswith('? ') or q.endswith('\n'): q = q.replace('', '') q = q.strip() elif ' ' in q: q = q.replace('', 'the image') elif ' ' in q: q = q.replace('', 'the image') elif '()' in q: q = q.replace('()', '') elif '()' in q: q = q.replace('()', '') elif '.' in q: q = q.replace(".", ". ") else: q = q.replace("", ". ") q = q.strip() # remove to for i in range(2, 8): q = q.replace(f"", "") return q class MMMUProEvalDataset(Dataset): def __init__(self, args, task_cfg, subset=None): self.args = args self.task_cfg = task_cfg sub_dataset_list = [] # load_dataset will throw error if split is 'dev' # 'dev' is part of the 'validation' and we need to manually split them MMMU_path = "MMMU/MMMU_Pro" _split = "test" self.dataset = load_dataset(MMMU_path, "standard", split=_split) if subset: self.dataset = self.dataset[:subset] def __len__(self): return len(self.dataset) def __getitem__(self, idx): # ===== single-image ===== sample = self.dataset[idx] sample = process_single_sample_pro(sample) sample = construct_prompt_pro(sample, self.task_cfg) img = load_image(sample['image'].convert("RGB"), max_num=6, decoded=True).to(torch.bfloat16) # img = img.reshape(-1, 3, self.args.img_h, self.args.img_w) question_id = sample['id'] question = sample['final_input_prompt'] answer = sample['answer'] question = process_image_tag(question) question = self.task_cfg['default_image_token'] + '\n' + question if sample['question_type'] == 'multiple-choice': index2ans = sample['index2ans'] all_choices = sample['all_choices'] else: index2ans = {} all_choices = [] return img, question_id, sample['subfield'], sample['question_type'], question, answer, str(index2ans), str \ (all_choices) class MMMUEvalDataset(Dataset): def __init__(self, args, task_cfg, subset=None, start_idx=None): self.args = args self.task_cfg = task_cfg sub_dataset_list = [] # load_dataset will throw error if split is 'dev' # 'dev' is part of the 'validation' and we need to manually split them MMMU_path = "MMMU/MMMU" _split = "test" if task_cfg["split"] == "test" else "validation" for subject in CAT_SHORT2LONG.values(): sub_dataset = load_dataset( MMMU_path, subject, split=_split, ) sub_dataset_list.append(sub_dataset) dataset = concatenate_datasets(sub_dataset_list) if task_cfg["split"] != "test": dataset = [s for s in dataset if s['id'].startswith(task_cfg["split"])] # dataset = [s for s in dataset if s['image_2'] is not None][1:] self.dataset = dataset if subset: self.dataset = [dataset[i] for i in range(start_idx, min(start_idx + subset, len(dataset)))] print(f"Evaluating a subset of dataset: {len(self.dataset)} from {start_idx} to {start_idx + subset}") def __len__(self): return len(self.dataset) def __getitem__(self, idx): # ===== single-image ===== sample = self.dataset[idx] sample = process_single_sample(sample) sample = construct_prompt(sample, self.task_cfg) img = load_image(sample['image'].convert("RGB"), max_num=6, decoded=True).to(torch.bfloat16) question_id = sample['id'] question = sample['final_input_prompt'] answer = sample['answer'] question = process_image_tag(question) question = self.task_cfg['default_image_token'] + '\n' + question if sample['question_type'] == 'multiple-choice': index2ans = sample['index2ans'] all_choices = sample['all_choices'] else: index2ans = {} all_choices = [] return img, question_id, sample['subfield'], sample['question_type'], question, answer, str(index2ans), str \ (all_choices) class VizWizEvalDataset(Dataset): def __init__(self, args, img_dir, question_path, subset=None): self.args = args self.img_dir = img_dir self.questions = json.load(open(question_path, encoding='utf-8')) def __len__(self): return len(self.questions) def __getitem__(self, idx): img_path = os.path.join(self.img_dir, self.questions[idx]["image"]) img = load_image(img_path, max_num=6).to(torch.bfloat16) question = self.questions[idx]["question"] question_id = self.questions[idx]["image"] return img, question_id, question class MMBenchEvalDataset(Dataset): def __init__(self, args, gt_path, subset=None): self.args = args df = pd.read_csv(gt_path, sep='\t') self.dataset = [] for i, row in df.iterrows(): choices = [] for choice in ['A', 'B', 'C', 'D']: if str(row[choice]) != 'nan': choices.append(row[choice]) this_sample = { 'index': row['index'], 'question': row['question'], 'hint': row['hint'], 'category': row['category'], 'image': Image.open(BytesIO(base64.b64decode(row['image']))), 'choices': choices } # Only dev set gives the ground truth answer if 'answer' in row.keys(): this_sample['answer'] = row['answer'] else: this_sample['answer'] = '' self.dataset.append(this_sample) def __len__(self): return len(self.dataset) def __getitem__(self, idx): img = load_image(self.dataset[idx]["image"].convert("RGB"), max_num=6, decoded=True).to(torch.bfloat16) question = self.dataset[idx]["question"] hint = self.dataset[idx]["hint"] question_id = self.dataset[idx]["index"] choices = self.dataset[idx]["choices"] answer = self.dataset[idx]["answer"] start_chr = 'A' choices_str = '' index2ans = {} all_choices = [] for choice in choices: all_choices.append(start_chr) index2ans[start_chr] = choice choices_str += f"{start_chr}. {choice}\n" start_chr = chr(ord(start_chr) + 1) question = question + '\n' + choices_str return img, question_id, question, answer, str(index2ans), str(all_choices), self.dataset[idx]["question"] def get_task_dataloader(task_name, task_cfg, args): if "subset" in task_cfg.keys(): subset = task_cfg["subset"] else: subset = None if task_name == "coco_caption": dataset = COCOEvalDataset(args, task_cfg["image_dir"], subset) elif task_name == "flickr30k_caption": dataset = Flickr30KEvalDataset(args, task_cfg["image_dir"], subset) elif task_name == "vqav2": dataset = VQAv2EvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset) elif task_name == "textvqa": dataset = TextVQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset) elif task_name == "gqa": dataset = GQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset) elif task_name == "chartqa": dataset = ChartQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset) elif task_name == "okvqa": dataset = OKVQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], task_cfg["question_path"], subset) elif task_name == "vizwiz": dataset = VizWizEvalDataset(args, task_cfg["image_dir"], task_cfg["question_path"], subset) elif task_name == "docvqa": dataset = DocVQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], split='val', subset=subset) elif task_name == "docvqa_test": dataset = DocVQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], split='test', subset=subset) elif task_name == "realworldqa": dataset = RealworldQAEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset) elif task_name == "mmmu": dataset = MMMUEvalDataset(args, task_cfg, subset=args.subset, start_idx=args.start_idx) elif task_name == "mmmu_pro": dataset = MMMUProEvalDataset(args, task_cfg) elif task_name == "mathvista": dataset = MathVistaEvalDataset(args, task_cfg) elif task_name == "mmbench": dataset = MMBenchEvalDataset(args, task_cfg["gt_path"]) elif task_name == 'ocrbench': dataset = OCRBenchEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset) elif task_name == 'ai2diagram': dataset = AI2DiagramEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset) elif task_name == 'ai2diagram_nomask': dataset = AI2DiagramNoMaskEvalDataset(args, task_cfg["image_dir"], task_cfg["gt_path"], subset) else: raise NotImplementedError(f"Task {task_name} is not supported yet.") dataloader = DataLoader( dataset, batch_size=1, shuffle=False, pin_memory=True, ) return dataloader