Spaces:
Running
Running
""" | |
Donut | |
Copyright (c) 2022-present NAVER Corp. | |
MIT License | |
""" | |
import argparse | |
import json | |
import os | |
import re | |
from pathlib import Path | |
import numpy as np | |
import torch | |
from datasets import load_dataset | |
from PIL import Image | |
from tqdm import tqdm | |
from donut import DonutModel, JSONParseEvaluator, load_json, save_json | |
def test(args): | |
pretrained_model = DonutModel.from_pretrained(args.pretrained_model_name_or_path) | |
if torch.cuda.is_available(): | |
pretrained_model.half() | |
pretrained_model.to("cuda") | |
pretrained_model.eval() | |
if args.save_path: | |
os.makedirs(os.path.dirname(args.save_path), exist_ok=True) | |
predictions = [] | |
ground_truths = [] | |
accs = [] | |
evaluator = JSONParseEvaluator() | |
dataset = load_dataset(args.dataset_name_or_path, split=args.split) | |
for idx, sample in tqdm(enumerate(dataset), total=len(dataset)): | |
ground_truth = json.loads(sample["ground_truth"]) | |
if args.task_name == "docvqa": | |
output = pretrained_model.inference( | |
image=sample["image"], | |
prompt=f"<s_{args.task_name}><s_question>{ground_truth['gt_parses'][0]['question'].lower()}</s_question><s_answer>", | |
)["predictions"][0] | |
else: | |
output = pretrained_model.inference(image=sample["image"], prompt=f"<s_{args.task_name}>")["predictions"][0] | |
if args.task_name == "rvlcdip": | |
gt = ground_truth["gt_parse"] | |
score = float(output["class"] == gt["class"]) | |
elif args.task_name == "docvqa": | |
# Note: we evaluated the model on the official website. | |
# In this script, an exact-match based score will be returned instead | |
gt = ground_truth["gt_parses"] | |
answers = set([qa_parse["answer"] for qa_parse in gt]) | |
score = float(output["answer"] in answers) | |
else: | |
gt = ground_truth["gt_parse"] | |
score = evaluator.cal_acc(output, gt) | |
accs.append(score) | |
predictions.append(output) | |
ground_truths.append(gt) | |
scores = { | |
"ted_accuracies": accs, | |
"ted_accuracy": np.mean(accs), | |
"f1_accuracy": evaluator.cal_f1(predictions, ground_truths), | |
} | |
print( | |
f"Total number of samples: {len(accs)}, Tree Edit Distance (TED) based accuracy score: {scores['ted_accuracy']}, F1 accuracy score: {scores['f1_accuracy']}" | |
) | |
if args.save_path: | |
scores["predictions"] = predictions | |
scores["ground_truths"] = ground_truths | |
save_json(args.save_path, scores) | |
return predictions | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--pretrained_model_name_or_path", type=str) | |
parser.add_argument("--dataset_name_or_path", type=str) | |
parser.add_argument("--split", type=str, default="test") | |
parser.add_argument("--task_name", type=str, default=None) | |
parser.add_argument("--save_path", type=str, default=None) | |
args, left_argv = parser.parse_known_args() | |
if args.task_name is None: | |
args.task_name = os.path.basename(args.dataset_name_or_path) | |
predictions = test(args) | |