File size: 2,451 Bytes
97e4014
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import warnings
warnings.filterwarnings("ignore")

from datasets import load_dataset

import os, sys

import pandas as pd
import argparse

path = os.path.abspath(os.path.join(os.path.dirname(__file__)))
sys.path.insert(0, path)

from src.model.model import load_model
from src.evaluate.evaluation import evaluation_rouge
from transformers import GenerationConfig


def save_metrics_to_csv(results, resultpath, checkpoint):
    
    results["checkpoint"] = checkpoint

    # Convert results to DataFrame
    df = pd.DataFrame([results])

    if not os.path.isfile(resultpath):
        df.to_csv(resultpath, index=False)
    else:
        df.to_csv(resultpath, mode='a', header=False, index=False)


def main():
    parser = argparse.ArgumentParser(description="Evaluation metric")
    parser.add_argument("--datapath", type=str, default="knkarthick/dialogsum")
    parser.add_argument("--checkpoint", type=str, default="google/flan-t5-base")
    parser.add_argument("--resultpath", type=str, default="results/rouge_score.csv")
    
    parser.add_argument("--min_new_tokens", type=int, default=10)
    parser.add_argument("--max_new_tokens", type=int, default=256)
    parser.add_argument("--temperature", type=float, default=0.9)
    parser.add_argument("--top_p", type=float, default=1.0)
    parser.add_argument("--top_k", type=int, default=50)

    args = parser.parse_args()

    print("=========================================")
    print('\n'.join(f' + {k}={v}' for k, v in vars(args).items()))
    print("=========================================")
    
    datapath = args.datapath
    checkpoint = args.checkpoint

    generation_config = GenerationConfig(
        min_new_tokens=args.min_new_tokens,
        max_new_tokens=args.max_new_tokens,
        temperature=args.temperature,
        top_p=args.top_p,
        top_k=args.top_k
    )

    data = load_dataset("binwang/InstructDS_datasets", "DialogSum", split="test")

    model = load_model(checkpoint)
    print(f"Loaded model from: {checkpoint}")

    results = evaluation_rouge(model, data, generation_config)
    
    print("--------------------------")
    for k, v in results.items():
        print(f"{k}: {v}")
    print("--------------------------")

    save_metrics_to_csv(results, args.resultpath, checkpoint)
    print(f"Results saved to: {args.resultpath}")

if __name__ == "__main__":
    main()