File size: 3,926 Bytes
97e4014
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115

import sys, os

import argparse

from bert_score import BERTScorer

from transformers import (
    T5Tokenizer, 
    T5ForConditionalGeneration,
    AutoTokenizer
)

import warnings
warnings.filterwarnings("ignore")

from huggingface_hub import login

from datasets import load_dataset, Dataset

path = os.path.abspath(os.path.dirname(__file__))
sys.path.insert(0, path)

from preprocessing import *

def create_qds_triplet(datapath, split, start_index, end_index) -> Dataset:
    data = load_dataset(datapath, split=split)
    data = Dataset.from_dict(data[start_index:end_index])

    scorer = BERTScorer(lang="en", rescale_with_baseline=True)

    CHECKPOINT = "google/flan-t5-large"
    tokenizer = T5Tokenizer.from_pretrained(CHECKPOINT)
    model = T5ForConditionalGeneration.from_pretrained(CHECKPOINT)

    qds_triplet = {
        "query": [],
        "dialogue": [],
        "summary": []
    }

    dsp = DialogSumDataset(
        tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)
    )

    for dialogue, summary in zip(data["dialogue"], data["summary"]):
        answerable_queries = []

        while len(answerable_queries) < 1:
            queries = dsp.generate_queries(model, tokenizer, summary, num_queries=5)

            for query in queries:
                ## Text based filtering
                output = dsp.text_based_filtering(model, tokenizer, query, summary)
                if "yes" in output.lower():
                    answerable_queries.append(query)

        n = len(answerable_queries)
        print("Length of answerable queries:", n, end="  ###  ")

        if n == 1:
            qds_triplet["query"].append(answerable_queries[0])
            qds_triplet["dialogue"].append(dialogue)
            qds_triplet["summary"].append(summary)

        if n > 1:
            filtered_queries = []
            scores = [[0.0]*n for _ in range(n)]

            for i in range(n):
                for j in range(n):
                    if i > j:
                        scores[i][j] = dsp.semantic_filtering(scorer, answerable_queries[i], answerable_queries[j])
            
            keep_indices = set(range(n))
            for i in range(n):
                for j in range(n):
                    if scores[i][j] > 0.7 and i > j:
                        keep_indices.discard(j)
            
            for i in sorted(keep_indices):
                filtered_queries.append(answerable_queries[i])
            
            print("Length of filtered queries:", len(filtered_queries), end="  ###  ")

            for query in filtered_queries:
                qds_triplet["query"].append(query)
                qds_triplet["dialogue"].append(dialogue)
                qds_triplet["summary"].append(summary)

        print("Length of inputs:", len(qds_triplet["summary"]))

    return Dataset.from_dict(qds_triplet)

if __name__=="__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--datapath", type=str, default="knkarthick/dialogsum")
    parser.add_argument("--huggingface_hub_token", type=str, default="")
    parser.add_argument("--split", type=str, default="train")
    parser.add_argument("--start_index", type=int, default=0)
    parser.add_argument("--end_index", type=int, default=-1)
    args = parser.parse_args()

    print("=========================================")
    print('\n'.join(f' + {k}={v}' for k, v in vars(args).items()))
    print("=========================================")

    login(token=args.huggingface_hub_token)
    print("Successfully logged in to Huggingface Hub")

    qds_triplet = create_qds_triplet(args.datapath, args.split, args.start_index, args.end_index)

    save_name = f"dialogsum-{args.split}-{args.start_index}-{args.end_index}"
    qds_triplet.push_to_hub(save_name)
    print(f"Saved to: {save_name}")