File size: 2,669 Bytes
97e4014
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os
import sys

from datasets import Dataset

import evaluate
import torch

import logging

# =  =  =  =  =  =  =  =  =  =  =  Logging Setup  =  =  =  =  =  =  =  =  =  =  =  =  = 
logger = logging.getLogger(__name__)
logging.basicConfig(
    format  = "%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt = "%m/%d/%Y %H:%M:%S",
    level   = logging.INFO,
)
# =  =  =  =  =  =  =  =  =  =  =  =  =  =  =  =  =  =  =  =  =  =  =  =  =  =  =  =  = 

from transformers import AutoModelForSeq2SeqLM

path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, path)

from model.model import Model


class RougeEvaluation:
    def __init__(self) -> None:
        self.rouge_metric = evaluate.load("rouge")
        
    def compute_rouge_metric(self, generated_summary, reference_summary) -> dict:
        results = self.rouge_metric.compute(
            predictions=generated_summary,
            references=reference_summary,
            use_aggregator=True,
            use_stemmer=True
        )
        return results
    

def evaluation_rouge(model: Model, data: Dataset, generation_config) -> dict:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.base_model = model.get_model()

    dialogues = data["dialogue"]

    human_summaries = [summary for summary in data["summary"]]

    model_summaries = []

    prefix = "Summarize the following dialogue:\n###\n"
    suffix = "\n### Summary: "

    # print("\n******************************")
    # idx = 0
    # for answer, dialogue in zip(data["answer"], data["dialogue"]):
    #     prefix = "Please summarize the following dialogue focused on the context query:"
    #     input = prefix + "\n### Queryr: " + answer + "\n### Dialogue: " + dialogue + "\n### The summary should be around " + str(int(0.2*len(dialogue.split()))) + " words." + "\n### Summary: "

    for idx, dialogue in enumerate(dialogues):
        input = prefix + dialogue + suffix
        
        print(idx, end="# ")
        output_text = model.generate_summary(input, generation_config, do_sample=False)

        model_summaries.append(output_text)
        idx += 1

    logger.info("Evaluating summaries...")

    rouge_evaluator = RougeEvaluation()

    results = rouge_evaluator.compute_rouge_metric(model_summaries, human_summaries)

    generated_lengths = [len(summary.split()) for summary in model_summaries]
    average_gen_len = sum(generated_lengths) / len(generated_lengths) if generated_lengths else 0

    results["gen_len"] = average_gen_len
    
    return results