dtruong46me's picture
Upload 29 files
97e4014 verified
import sys, os
import argparse
from bert_score import BERTScorer
from transformers import (
T5Tokenizer,
T5ForConditionalGeneration,
AutoTokenizer
)
import warnings
warnings.filterwarnings("ignore")
from huggingface_hub import login
from datasets import load_dataset, Dataset
path = os.path.abspath(os.path.dirname(__file__))
sys.path.insert(0, path)
from preprocessing import *
def create_qds_triplet(datapath, split, start_index, end_index) -> Dataset:
data = load_dataset(datapath, split=split)
data = Dataset.from_dict(data[start_index:end_index])
scorer = BERTScorer(lang="en", rescale_with_baseline=True)
CHECKPOINT = "google/flan-t5-large"
tokenizer = T5Tokenizer.from_pretrained(CHECKPOINT)
model = T5ForConditionalGeneration.from_pretrained(CHECKPOINT)
qds_triplet = {
"query": [],
"dialogue": [],
"summary": []
}
dsp = DialogSumDataset(
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)
)
for dialogue, summary in zip(data["dialogue"], data["summary"]):
answerable_queries = []
while len(answerable_queries) < 1:
queries = dsp.generate_queries(model, tokenizer, summary, num_queries=5)
for query in queries:
## Text based filtering
output = dsp.text_based_filtering(model, tokenizer, query, summary)
if "yes" in output.lower():
answerable_queries.append(query)
n = len(answerable_queries)
print("Length of answerable queries:", n, end=" ### ")
if n == 1:
qds_triplet["query"].append(answerable_queries[0])
qds_triplet["dialogue"].append(dialogue)
qds_triplet["summary"].append(summary)
if n > 1:
filtered_queries = []
scores = [[0.0]*n for _ in range(n)]
for i in range(n):
for j in range(n):
if i > j:
scores[i][j] = dsp.semantic_filtering(scorer, answerable_queries[i], answerable_queries[j])
keep_indices = set(range(n))
for i in range(n):
for j in range(n):
if scores[i][j] > 0.7 and i > j:
keep_indices.discard(j)
for i in sorted(keep_indices):
filtered_queries.append(answerable_queries[i])
print("Length of filtered queries:", len(filtered_queries), end=" ### ")
for query in filtered_queries:
qds_triplet["query"].append(query)
qds_triplet["dialogue"].append(dialogue)
qds_triplet["summary"].append(summary)
print("Length of inputs:", len(qds_triplet["summary"]))
return Dataset.from_dict(qds_triplet)
if __name__=="__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--datapath", type=str, default="knkarthick/dialogsum")
parser.add_argument("--huggingface_hub_token", type=str, default="")
parser.add_argument("--split", type=str, default="train")
parser.add_argument("--start_index", type=int, default=0)
parser.add_argument("--end_index", type=int, default=-1)
args = parser.parse_args()
print("=========================================")
print('\n'.join(f' + {k}={v}' for k, v in vars(args).items()))
print("=========================================")
login(token=args.huggingface_hub_token)
print("Successfully logged in to Huggingface Hub")
qds_triplet = create_qds_triplet(args.datapath, args.split, args.start_index, args.end_index)
save_name = f"dialogsum-{args.split}-{args.start_index}-{args.end_index}"
qds_triplet.push_to_hub(save_name)
print(f"Saved to: {save_name}")