dtruong46me commited on
Commit
97e4014
·
verified ·
1 Parent(s): 559114d

Upload 29 files

Browse files
.DS_Store ADDED
Binary file (6.15 kB). View file
 
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Dinh Truong Phan
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,12 +1,109 @@
1
- ---
2
- title: Dialogue Text Summarization
3
- emoji:
4
- colorFrom: yellow
5
- colorTo: blue
6
- sdk: streamlit
7
- sdk_version: 1.35.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Problem Description
2
+
3
+ This project aims to develop a system capable of automatically **summarizing short dialogue text**. This addresses the challenge of extracting concise yet informative summaries from conversational exchanges, enabling users to **quickly grasp the information of the dialogues**.
4
+
5
+ Summarizing these conversations can be valuable for various applications, such as:
6
+ - Streamlining information retrieval in customer service interactions
7
+ - Condensing meeting discussions for efficient review
8
+ - Providing concise overviews of chat conversations on social media platforms
9
+
10
+ This project tackles the task of automatically generating concise summaries, saving users time and effort while improving comprehension.
11
+
12
+ ![](assets/image2.png)
13
+
14
+ <p align="center"><i>Source: Google Research</i></p>
15
+
16
+ **Input:** Dialogue text
17
+
18
+ Example:
19
+ ```
20
+ Matt: Do you want to go for date?
21
+ Agnes: Wow! You caught me out with this question Matt.
22
+ ...
23
+ Agnes: See you on saturday.
24
+ Matt: Yes, looking forward to it.
25
+ Agnes: Me too.
26
+ ```
27
+
28
+ **Output:** Summarized dialogue
29
+
30
+ Example:
31
+ ```
32
+ Matt invites Agnes for a date to get to know each other better. They'll go to the Georgian restaurant in Kazimierz on Saturday at 6 pm, and he'll pick her up on the way to the place.
33
+ ```
34
+
35
+ # Dataset
36
+
37
+ We'll utilize the `DialogSum` dataset accessible from 🤗**Hugging Face** (https://huggingface.co/datasets/knkarthick/dialogsum) and **Paper** (https://arxiv.org/pdf/2105.06762.pdf). This dataset comprises real-life dialogue scenarios paired with corresponding manually crafted summaries and dialogue topics.
38
+
39
+ `DialogSum` is a large-scale dialogue summarization dataset, consisting of **13,460** (Plus 100 holdout data for topic generation) dialogues with corresponding manually labeled summaries and topics.
40
+
41
+ Here's a sample of the `DialogSum` dataset structure:
42
+
43
+
44
+ |id|dialogue|summary|topic|
45
+ |-|-|-|-|
46
+ |train_3|#Person1#: Why didn't you tell me you had a girlfriend? #Person2#: Sorry, I thought you knew. ... #Person1#: Oh, you men! You are all the same.|#Person1#'s angry because #Person2# didn't tell #Person1# that #Person2# had a girlfriend and would marry her.|have a girl friend|
47
+ |train_16|#Person1#: Tell me something about your Valentine's Day. ...#Person2#: Yeah, that is what the holiday is for, isn't it?|#Person2# tells #Person1# their Valentine's Day. #Person1# feels it's romantic.|Valentine's Day|
48
+ |...|...|...|...|
49
+
50
+ **Distribution of dataset**
51
+
52
+ |Dialogue|Summary|Dialogue + Summary|
53
+ |:-:|:-:|:-:|
54
+ |![](assets/hist_dialogue.png)|![](assets/hist_summary.png)|![](assets/hist_dialogue+summary.png)|
55
+
56
+ # Method
57
+
58
+ ### Pre-trained Language Models:
59
+
60
+ This project explores two powerful LLMs well-suited for dialogue summarization:
61
+
62
+ - **FLAN-T5:** This model excels at understanding complex relationships within text, making it effective in summarizing the nuances of conversations.
63
+ - **BART:** This model boasts strong capabilities in text generation tasks, making it adept at generating informative and well-structured summaries.
64
+
65
+ ### Fine-tuning Techniques:
66
+
67
+ To tailor these LLMs specifically for dialogue summarization, we will investigate several fine-tuning approaches:
68
+
69
+ - Instruction Fine-tuning
70
+ - Parameter Efficient Fine Tuning (PEFT)
71
+ + Low-Rank Adaptation **(LoRA)**
72
+ + Quantized Low-Rank Adaptation **(QLoRA)**
73
+
74
+ # Installation
75
+
76
+ ```
77
+ !git clone "https://github.com/dtruong46me/dialogue-text-summarization.git"
78
+ ```
79
+
80
+ # Contributions
81
+
82
+ **Supervisor:** Prof. Le Thanh Huong
83
+
84
+ **Student Group:**
85
+
86
+ |No.|Name|Student ID|Email|
87
+ |:-:|-|:-:|-|
88
+ |1|Phan Dinh Truong (Leader)|20214937|truong.pd214937@sis.hust.edu.vn|
89
+ |2|Nguyen Tung Luong|20214913|luong.nt214913@sis.hust.edu.vn|
90
+ |3|Vu Tuan Minh|20210597|minh.vt210597@sis.hust.edu.vn|
91
+ |4|Hoang Tu Quyen|20214929|quyen.ht214929@sis.hust.edu.vn|
92
+
93
+ # [Bonus] How to run Streamlit on Kaggle
94
+
95
+ ```
96
+ !pip install -q streamlit
97
+ ```
98
+
99
+ ```
100
+ !wget -q -O - ipv4.icanhazip.com
101
+ ```
102
+
103
+ ```
104
+ !npm install -g localtunnel -q
105
+ ```
106
+
107
+ ```
108
+ !streamlit run "/kaggle/working/dialogue-text-summarization/streamlit_app.py" & npx localtunnel --port 8501
109
+ ```
app.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ from transformers import GenerationConfig, BartModel, BartTokenizer, AutoTokenizer, AutoModelForSeq2SeqLM
4
+ import torch
5
+ import time
6
+
7
+ import sys, os
8
+
9
+ path = os.path.abspath(os.path.dirname(__file__))
10
+ sys.path.insert(0, path)
11
+
12
+ from gen_summary import generate_summary
13
+
14
+
15
+ st.title("Dialogue Text Summarization")
16
+ st.caption("Natural Language Processing Project 20232")
17
+
18
+ st.write("---")
19
+
20
+ with st.sidebar:
21
+ checkpoint = st.selectbox("Model", options=[
22
+ "Choose model",
23
+ "dtruong46me/train-bart-base",
24
+ "dtruong46me/flant5-small",
25
+ "dtruong46me/flant5-base",
26
+ "dtruong46me/flan-t5-s",
27
+ "ntluongg/bart-base-luong"
28
+ ])
29
+ st.button("Model detail", use_container_width=True)
30
+ st.write("-----")
31
+ st.write("**Generate Options:**")
32
+ min_new_tokens = st.number_input("Min new tokens", min_value=1, max_value=64, value=10)
33
+ max_new_tokens = st.number_input("Max new tokens", min_value=64, max_value=128, value=64)
34
+ temperature = st.number_input("Temperature", min_value=0.0, max_value=1.0, value=0.9, step=0.05)
35
+ top_k = st.number_input("Top_k", min_value=1, max_value=50, step=1, value=20)
36
+ top_p = st.number_input("Top_p", min_value=0.01, max_value=1.00, step=0.01, value=1.0)
37
+
38
+
39
+ height = 200
40
+
41
+ input_text = st.text_area("Dialogue", height=height)
42
+
43
+ generation_config = GenerationConfig(
44
+ min_new_tokens=min_new_tokens,
45
+ max_new_tokens=320,
46
+ temperature=temperature,
47
+ top_p=top_p,
48
+ top_k=top_k
49
+ )
50
+
51
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
52
+
53
+ if checkpoint=="Choose model":
54
+ tokenizer = None
55
+ model = None
56
+
57
+ if checkpoint!="Choose model":
58
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint)
59
+ model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
60
+
61
+
62
+
63
+ if st.button("Submit"):
64
+ st.write("---")
65
+ st.write("## Summary")
66
+
67
+ if checkpoint=="Choose model":
68
+ st.error("Please selece a model!")
69
+
70
+ else:
71
+ if input_text=="":
72
+ st.error("Please enter a dialogue!")
73
+ st.write(generate_summary(model, " ".join(input_text.split()), generation_config, tokenizer))
assets/distribution.png ADDED
assets/hist_dialogue+summary.png ADDED
assets/hist_dialogue.png ADDED
assets/hist_summary.png ADDED
assets/image2.png ADDED
gen_summary.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, GenerationConfig, TextStreamer, AutoModelForSeq2SeqLM
3
+
4
+ import logging
5
+
6
+ import warnings
7
+ warnings.filterwarnings("ignore")
8
+
9
+ # = = = = = = = = = = = Logging Setup = = = = = = = = = = = = =
10
+ logger = logging.getLogger(__name__)
11
+ logging.basicConfig(
12
+ format = "%(asctime)s - %(levelname)s - %(name)s - %(message)s",
13
+ datefmt = "%m/%d/%Y %H:%M:%S",
14
+ level = logging.INFO,
15
+ )
16
+ # = = = = = = = = = = = = = = = = = = = = = = = = = = = = =
17
+
18
+ def generate_summary(model, input_text, generation_config, tokenizer, st_container=None) -> str:
19
+
20
+ try:
21
+ prefix = "Summarize the following conversation: \n###\n"
22
+ suffix = "\n### Summary:"
23
+
24
+ input_ids = tokenizer.encode(prefix + input_text + "The generated summary should be around " + str(0.15*len(input_text)) + " words." + suffix, return_tensors="pt")
25
+ output_ids = model.generate(input_ids, do_sample=True, generation_config=generation_config)
26
+
27
+ if "bart" in model.name_or_path and model.name_or_path != "dtruong46me/bart-base-qds":
28
+ output_ids[0][1] = 2
29
+
30
+ # streamer = TextStreamer(tokenizer, skip_special_tokens=True)
31
+ # model.generate(input_ids, streamer=streamer, do_sample=True, decoder_start_token_id=2, generation_config=generation_config)
32
+ # logger.info("\nComplete generate summary!")
33
+ output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
34
+ return output_text
35
+
36
+ except Exception as e:
37
+ print(f"Error while generating: {e}")
38
+ raise e
39
+
40
+ if __name__=="__main__":
41
+ input = "#Person1#: Ms. Dawson, I need you to take a dictation for me. #Person2#: Yes, sir... #Person1#: This should go out as an intra-office memorandum to all employees by this afternoon. Are you ready? #Person2#: Yes, sir. Go ahead. #Person1#: Attention all staff... Effective immediately, all office communications are restricted to email correspondence and official memos. The use of Instant Message programs by employees during working hours is strictly prohibited. #Person2#: Sir, does this apply to intra-office communications only? Or will it also restrict external communications? #Person1#: It should apply to all communications, not only in this office between employees, but also any outside communications. #Person2#: But sir, many employees use Instant Messaging to communicate with their clients. #Person1#: They will just have to change their communication methods. I don't want any - one using Instant Messaging in this office. It wastes too much time! Now, please continue with the memo. Where were we? #Person2#: This applies to internal and external communications. #Person1#: Yes. Any employee who persists in using Instant Messaging will first receive a warning and be placed on probation. At second offense, the employee will face termination. Any questions regarding this new policy may be directed to department heads. #Person2#: Is that all? #Person1#: Yes. Please get this memo typed up and distributed to all employees before 4 pm."
42
+ target1 = "Ms. Dawson helps #Person1# to write a memo to inform every employee that they have to change the communication method and should not use Instant Messaging anymore."
43
+ target2 = "In order to prevent employees from wasting time on Instant Message programs, #Person1# decides to terminate the use of those programs and asks Ms. Dawson to send out a memo to all employees by the afternoon."
44
+ target3 = "Ms. Dawson takes a dictation for #Person1# about prohibiting the use of Instant Message programs in the office. They argue about its reasonability but #Person1# still insists."
45
+
46
+ generation_config = GenerationConfig(
47
+ min_new_tokens=10,
48
+ max_new_tokens=256,
49
+ temperature=0.9,
50
+ top_p=1.0,
51
+ top_k=50
52
+ )
53
+
54
+ checkpoint = "dtruong46me/bart-base-qds2"
55
+
56
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
57
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint)
58
+
59
+ model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint).to(device)
60
+
61
+ generate_summary(model, input, generation_config, tokenizer)
62
+ print("\n==============\n")
63
+
64
+ print("Human base line:\n", target1, end="\n\n")
65
+ print("Human base line:\n", target2, end="\n\n")
66
+ print("Human base line:\n", target3, end="\n\n")
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datasets
2
+ huggingface_hub
3
+ nltk
4
+ numpy
5
+ pandas
6
+ peft
7
+ replicate
8
+ streamlit
9
+ torch
10
+ transformers==4.36.1
11
+ wandb
12
+ evaluate
13
+ rouge_score
14
+ bert_score
results/.gitignore ADDED
File without changes
results/rouge_score.csv ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ rouge1,rouge2,rougeL,rougeLsum,gen_len,checkpoint
2
+ 0.39233350039050524,0.1331263872944557,0.30561232240272806,0.305581876074012,25.568,dtruong46me/flant5-small
3
+ 0.42773411047439297,0.16070313389865537,0.33964372087731554,0.33971528751465496,24.633333333333333,dtruong46me/flant5-base
4
+ 0.4436612424628238,0.18215770435271772,0.3574836391515892,0.3575112795473217,25.358,dtruong46me/train-bart-base
5
+ 0.44596490799011734,0.1791041702437794,0.36099829444161424,0.3612203644902555,18.72,dtruong46me/bart-base-instructds2
6
+ 0.5335,0.2672,0.5084,0,0,human-annotated-summary
7
+ 0.4728,0.2118,0.4483,0,0,bart-large-in-paper
8
+ 0.5165,0.2981,0.4336,0.4337,23.187,dtruong46me/bart-base-qds
9
+ 0.4061788843274445,0.1588224274185049,0.3175643149646888,0.3207910509892517,26.058,dtruong46me/flan-t5-s
run_evaluation.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ warnings.filterwarnings("ignore")
3
+
4
+ from datasets import load_dataset
5
+
6
+ import os, sys
7
+
8
+ import pandas as pd
9
+ import argparse
10
+
11
+ path = os.path.abspath(os.path.join(os.path.dirname(__file__)))
12
+ sys.path.insert(0, path)
13
+
14
+ from src.model.model import load_model
15
+ from src.evaluate.evaluation import evaluation_rouge
16
+ from transformers import GenerationConfig
17
+
18
+
19
+ def save_metrics_to_csv(results, resultpath, checkpoint):
20
+
21
+ results["checkpoint"] = checkpoint
22
+
23
+ # Convert results to DataFrame
24
+ df = pd.DataFrame([results])
25
+
26
+ if not os.path.isfile(resultpath):
27
+ df.to_csv(resultpath, index=False)
28
+ else:
29
+ df.to_csv(resultpath, mode='a', header=False, index=False)
30
+
31
+
32
+ def main():
33
+ parser = argparse.ArgumentParser(description="Evaluation metric")
34
+ parser.add_argument("--datapath", type=str, default="knkarthick/dialogsum")
35
+ parser.add_argument("--checkpoint", type=str, default="google/flan-t5-base")
36
+ parser.add_argument("--resultpath", type=str, default="results/rouge_score.csv")
37
+
38
+ parser.add_argument("--min_new_tokens", type=int, default=10)
39
+ parser.add_argument("--max_new_tokens", type=int, default=256)
40
+ parser.add_argument("--temperature", type=float, default=0.9)
41
+ parser.add_argument("--top_p", type=float, default=1.0)
42
+ parser.add_argument("--top_k", type=int, default=50)
43
+
44
+ args = parser.parse_args()
45
+
46
+ print("=========================================")
47
+ print('\n'.join(f' + {k}={v}' for k, v in vars(args).items()))
48
+ print("=========================================")
49
+
50
+ datapath = args.datapath
51
+ checkpoint = args.checkpoint
52
+
53
+ generation_config = GenerationConfig(
54
+ min_new_tokens=args.min_new_tokens,
55
+ max_new_tokens=args.max_new_tokens,
56
+ temperature=args.temperature,
57
+ top_p=args.top_p,
58
+ top_k=args.top_k
59
+ )
60
+
61
+ data = load_dataset("binwang/InstructDS_datasets", "DialogSum", split="test")
62
+
63
+ model = load_model(checkpoint)
64
+ print(f"Loaded model from: {checkpoint}")
65
+
66
+ results = evaluation_rouge(model, data, generation_config)
67
+
68
+ print("--------------------------")
69
+ for k, v in results.items():
70
+ print(f"{k}: {v}")
71
+ print("--------------------------")
72
+
73
+ save_metrics_to_csv(results, args.resultpath, checkpoint)
74
+ print(f"Results saved to: {args.resultpath}")
75
+
76
+ if __name__ == "__main__":
77
+ main()
run_training.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import wandb
2
+ from huggingface_hub import login
3
+
4
+ import warnings
5
+ warnings.filterwarnings("ignore")
6
+
7
+ import os
8
+ import sys
9
+
10
+ path = os.path.abspath(os.path.join(os.path.dirname(__file__)))
11
+ sys.path.insert(0, path)
12
+
13
+ from src.pipelines.training_pipeline import training_pipeline
14
+ from src.utils import parse_args
15
+
16
+ def main():
17
+ # Load argument parser
18
+ args = parse_args()
19
+ print(f"\033[92mLoaded argument parsers\033[00m")
20
+
21
+ # Load token ID
22
+ huggingface_hub_token = args.huggingface_hub_token
23
+ wandb_token = args.wandb_token
24
+
25
+ if wandb_token:
26
+ os.environ["WANDB_PROJECT"] = "nlp_project"
27
+
28
+ # Login to Huggingface Hub and WandB
29
+ login(token=huggingface_hub_token)
30
+ print("\033[92mSuccessful login to Huggingface Hub\033[00m")
31
+
32
+ wandb.login(key=wandb_token)
33
+ print("\033[92mSuccessful login to WandB\033[00m")
34
+
35
+ training_pipeline(args)
36
+ print("\033[92mFinish training pipeline\033[00m")
37
+
38
+ if __name__=='__main__':
39
+ main()
setup.sh ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ echo "Hello"
2
+ echo "..."
3
+ pip install -q --upgrade pip
4
+ pip install -q -U datasets
5
+ pip install -q transformers
6
+ pip install -q -r "/kaggle/working/dialogue-text-summarization/requirements.txt"
7
+ echo "---------"
8
+ echo "Set up complete!"
src/.DS_Store ADDED
Binary file (6.15 kB). View file
 
src/data/create_dataset.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import sys, os
3
+
4
+ import argparse
5
+
6
+ from bert_score import BERTScorer
7
+
8
+ from transformers import (
9
+ T5Tokenizer,
10
+ T5ForConditionalGeneration,
11
+ AutoTokenizer
12
+ )
13
+
14
+ import warnings
15
+ warnings.filterwarnings("ignore")
16
+
17
+ from huggingface_hub import login
18
+
19
+ from datasets import load_dataset, Dataset
20
+
21
+ path = os.path.abspath(os.path.dirname(__file__))
22
+ sys.path.insert(0, path)
23
+
24
+ from preprocessing import *
25
+
26
+ def create_qds_triplet(datapath, split, start_index, end_index) -> Dataset:
27
+ data = load_dataset(datapath, split=split)
28
+ data = Dataset.from_dict(data[start_index:end_index])
29
+
30
+ scorer = BERTScorer(lang="en", rescale_with_baseline=True)
31
+
32
+ CHECKPOINT = "google/flan-t5-large"
33
+ tokenizer = T5Tokenizer.from_pretrained(CHECKPOINT)
34
+ model = T5ForConditionalGeneration.from_pretrained(CHECKPOINT)
35
+
36
+ qds_triplet = {
37
+ "query": [],
38
+ "dialogue": [],
39
+ "summary": []
40
+ }
41
+
42
+ dsp = DialogSumDataset(
43
+ tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)
44
+ )
45
+
46
+ for dialogue, summary in zip(data["dialogue"], data["summary"]):
47
+ answerable_queries = []
48
+
49
+ while len(answerable_queries) < 1:
50
+ queries = dsp.generate_queries(model, tokenizer, summary, num_queries=5)
51
+
52
+ for query in queries:
53
+ ## Text based filtering
54
+ output = dsp.text_based_filtering(model, tokenizer, query, summary)
55
+ if "yes" in output.lower():
56
+ answerable_queries.append(query)
57
+
58
+ n = len(answerable_queries)
59
+ print("Length of answerable queries:", n, end=" ### ")
60
+
61
+ if n == 1:
62
+ qds_triplet["query"].append(answerable_queries[0])
63
+ qds_triplet["dialogue"].append(dialogue)
64
+ qds_triplet["summary"].append(summary)
65
+
66
+ if n > 1:
67
+ filtered_queries = []
68
+ scores = [[0.0]*n for _ in range(n)]
69
+
70
+ for i in range(n):
71
+ for j in range(n):
72
+ if i > j:
73
+ scores[i][j] = dsp.semantic_filtering(scorer, answerable_queries[i], answerable_queries[j])
74
+
75
+ keep_indices = set(range(n))
76
+ for i in range(n):
77
+ for j in range(n):
78
+ if scores[i][j] > 0.7 and i > j:
79
+ keep_indices.discard(j)
80
+
81
+ for i in sorted(keep_indices):
82
+ filtered_queries.append(answerable_queries[i])
83
+
84
+ print("Length of filtered queries:", len(filtered_queries), end=" ### ")
85
+
86
+ for query in filtered_queries:
87
+ qds_triplet["query"].append(query)
88
+ qds_triplet["dialogue"].append(dialogue)
89
+ qds_triplet["summary"].append(summary)
90
+
91
+ print("Length of inputs:", len(qds_triplet["summary"]))
92
+
93
+ return Dataset.from_dict(qds_triplet)
94
+
95
+ if __name__=="__main__":
96
+ parser = argparse.ArgumentParser()
97
+ parser.add_argument("--datapath", type=str, default="knkarthick/dialogsum")
98
+ parser.add_argument("--huggingface_hub_token", type=str, default="")
99
+ parser.add_argument("--split", type=str, default="train")
100
+ parser.add_argument("--start_index", type=int, default=0)
101
+ parser.add_argument("--end_index", type=int, default=-1)
102
+ args = parser.parse_args()
103
+
104
+ print("=========================================")
105
+ print('\n'.join(f' + {k}={v}' for k, v in vars(args).items()))
106
+ print("=========================================")
107
+
108
+ login(token=args.huggingface_hub_token)
109
+ print("Successfully logged in to Huggingface Hub")
110
+
111
+ qds_triplet = create_qds_triplet(args.datapath, args.split, args.start_index, args.end_index)
112
+
113
+ save_name = f"dialogsum-{args.split}-{args.start_index}-{args.end_index}"
114
+ qds_triplet.push_to_hub(save_name)
115
+ print(f"Saved to: {save_name}")
src/data/ingest_data.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from datasets import load_dataset
3
+ from datasets import DatasetDict, Dataset
4
+ import random
5
+ from transformers import set_seed
6
+
7
+
8
+ def ingest_data(datapath: str) -> DatasetDict:
9
+ set_seed(42)
10
+
11
+ QDS_LIMIT = 6000
12
+ if "," in datapath:
13
+ datapaths = datapath.split(",")
14
+
15
+ datapath1 = "binwang/InstructDS_datasets"
16
+ datapath2 = "binwang/InstructDS_datasets"
17
+
18
+ all_train_data = []
19
+ origin_train_dialogsum = load_dataset(datapath1, "DialogSum", split="train")
20
+ qds_dialogsum = load_dataset(datapath2, "DialogSum_QDS", split="train")
21
+
22
+ new_data1 = []
23
+ for sample in origin_train_dialogsum:
24
+ new_sample = {
25
+ "instruction": "Please summarize the following dialogue.",
26
+ "input": sample["dialogue"],
27
+ "output": sample["summary"]
28
+ }
29
+ new_data1.append(new_sample)
30
+ origin_train_dialogsum = new_data1
31
+ all_train_data.extend(origin_train_dialogsum)
32
+
33
+ print("Len of origin_train_dialogsum: ", len(origin_train_dialogsum))
34
+ print("Len of all train data 1: ", len(all_train_data))
35
+
36
+ new_data2 = []
37
+ for sample in qds_dialogsum:
38
+ new_sample = {
39
+ "instruction": "Please answer the following question.",
40
+ "input": sample["dialogue"],
41
+ "output": sample["summary"]
42
+ }
43
+ new_data2.append(new_sample)
44
+ qds_dialogsum = new_data2
45
+ qds_dialogsum = random.sample(qds_dialogsum, QDS_LIMIT)
46
+ all_train_data.extend(qds_dialogsum)
47
+ print("Len of all train data 2: ", len(all_train_data))
48
+
49
+
50
+ naive_all_train_data_dict = {
51
+ "instruction": [item["instruction"] for item in all_train_data],
52
+ "input": [item["input"] for item in all_train_data],
53
+ "output": [item["output"] for item in all_train_data]
54
+ }
55
+
56
+ print("Len of naive_all_train_data_dict: ", len(naive_all_train_data_dict["instruction"]))
57
+
58
+ subset_train_data = all_train_data
59
+ with_len_train_data_dict = {
60
+ "instruction": [item["instruction"] + f" The output should be {len(item['output'].split())} words long." for item in subset_train_data],
61
+ "input": [item["input"] for item in subset_train_data],
62
+ "output": [item["output"] for item in subset_train_data]
63
+ }
64
+
65
+ print("Len of with_len_train_data_dict: ", len(with_len_train_data_dict["instruction"]))
66
+
67
+ all_train_data_dict = {
68
+ "instruction": naive_all_train_data_dict["instruction"] + with_len_train_data_dict["instruction"],
69
+ "input": naive_all_train_data_dict["input"] + with_len_train_data_dict["input"],
70
+ "output": naive_all_train_data_dict["output"] + with_len_train_data_dict["output"]
71
+ }
72
+
73
+ print("Len of all_train_data_dict: ", len(all_train_data_dict["instruction"]))
74
+
75
+ raw_train_data = Dataset.from_dict(all_train_data_dict)
76
+ train_data = raw_train_data.shuffle()
77
+
78
+ print(type(train_data))
79
+ print(train_data["instruction"][:10])
80
+ print(train_data["input"][:10])
81
+ print(train_data["output"][:10])
82
+
83
+ print("===================", len(train_data), "===================")
84
+
85
+ # Validation data
86
+ all_validation_data = []
87
+ origin_validation_dialogsum = load_dataset(datapath1, "DialogSum", split="validation")
88
+
89
+ new_data1 = []
90
+ for sample in origin_validation_dialogsum:
91
+ new_sample = {
92
+ "instruction": "Please summarize the following dialogue.",
93
+ "input": sample["dialogue"],
94
+ "output": sample["summary"]
95
+ }
96
+ new_data1.append(new_sample)
97
+
98
+ origin_validation_dialogsum = new_data1
99
+ all_validation_data.extend(origin_validation_dialogsum)
100
+
101
+ all_validation_data_dict = {
102
+ "instruction": [item["instruction"] for item in all_validation_data],
103
+ "input": [item["input"] for item in all_validation_data],
104
+ "output": [item["output"] for item in all_validation_data]
105
+ }
106
+
107
+ raw_validation_data = Dataset.from_dict(all_validation_data_dict)
108
+ validation_data = raw_validation_data.shuffle()
109
+
110
+ return DatasetDict({
111
+ "train": train_data,
112
+ "validation": validation_data
113
+ })
src/data/merge_dataset.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os, sys
3
+
4
+ import argparse
5
+
6
+ from datasets import load_dataset, concatenate_datasets, Dataset
7
+ from huggingface_hub import login
8
+
9
+ path = os.path.abspath(os.path.dirname(__file__))
10
+ sys.path.insert(0, path)
11
+
12
+ def merge_dataset(datapaths) -> Dataset:
13
+ datapaths = datapaths.split(",")
14
+ dataset = load_dataset(datapaths[0], split="train")
15
+
16
+ for i in range(1, len(datapaths)):
17
+ data = load_dataset(datapaths[i], split="train")
18
+ data = concatenate_datasets([dataset, data])
19
+
20
+ return dataset
21
+
22
+
23
+ if __name__=="__main__":
24
+ parser = argparse.ArgumentParser()
25
+ parser.add_argument("--datapaths", type=str, default="")
26
+ parser.add_argument("--huggingface_hub_token", type=str, default="")
27
+ parser.add_argument("--split", type=str, default="train")
28
+ args = parser.parse_args()
29
+
30
+ print("=========================================")
31
+ print('\n'.join(f' + {k}={v}' for k, v in vars(args).items()))
32
+ print("=========================================")
33
+
34
+ login(token=args.huggingface_hub_token)
35
+ print("Successfully logged in to Huggingface Hub")
36
+
37
+ dataset = merge_dataset(datapaths=args.datapaths)
38
+
39
+ DATASET_ID = "qds-triplet-dialogsum"
40
+ dataset.push_to_hub(DATASET_ID)
41
+ print(f"Successful push to Huggingface Hub: {DATASET_ID}")
src/data/preprocessing.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from datasets import DatasetDict, Dataset
3
+ import random
4
+ from bert_score import BERTScorer
5
+
6
+ from transformers import (
7
+ T5Tokenizer,
8
+ T5ForConditionalGeneration
9
+ )
10
+
11
+ class DialogSumDataset:
12
+ def __init__(self, tokenizer, use_contrastive_loss=False, tokenizing_strategy=1) -> None:
13
+ self.tokenizer = tokenizer
14
+ self.use_contrastive_loss = use_contrastive_loss
15
+ self.tokenizing_strategy = tokenizing_strategy
16
+
17
+ def handle_data(self, data: DatasetDict) -> DatasetDict:
18
+ try:
19
+ self.tokenizer.pad_token = self.tokenizer.eos_token
20
+ tokenized_dataset = data.map(self.preprocess_function, batched=True)
21
+ tokenized_dataset = tokenized_dataset.remove_columns([key for key in data["train"][0].keys()])
22
+
23
+ print("+++++++++++++++++++")
24
+ print(tokenized_dataset)
25
+ print("+++++++++++++++++++")
26
+
27
+ return tokenized_dataset
28
+
29
+ except Exception as e:
30
+ print(f"\033[31m\nError while tokenizing data: {e}\033[00m")
31
+ raise e
32
+
33
+ def preprocess_function(self, data: Dataset) -> Dataset:
34
+ ###
35
+ if self.tokenizing_strategy<=2:
36
+ prefix = "Summarize the following conversation:\n###\n"
37
+ suffix = "\n###\nSummary: "
38
+ inputs = [prefix + input + suffix for input in data["dialogue"]]
39
+ targets = data["summary"]
40
+
41
+ if self.tokenizing_strategy==1:
42
+ max_source_length = 1024
43
+ max_target_length = 176
44
+
45
+ if self.tokenizing_strategy==2:
46
+ max_source_length = 1224
47
+ max_target_length = 176
48
+
49
+ if self.tokenizing_strategy==3:
50
+ inputs = ["### Instruction: " + instruction + "\n### Input: " + input + "\n### Response: " for instruction, input in zip(data["instruction"], data["input"])]
51
+ targets = data["output"]
52
+
53
+ max_source_length = 1024
54
+ max_target_length = 176
55
+
56
+ data["input_ids"] = self.tokenizer(inputs, max_length=max_source_length, padding="max_length", truncation=True, return_tensors="pt").input_ids
57
+ # data["attention_mask"] = self.tokenizer(inputs, max_length=max_source_length, padding="max_length", truncation=True, return_tensors="pt").attention_mask
58
+ data["labels"] = self.tokenizer(targets, max_length=max_target_length, padding="max_length", truncation=True, return_tensors="pt").input_ids
59
+
60
+ # Generate negative examples:
61
+ if self.use_contrastive_loss==True:
62
+ negative_summaries = self.generate_negative_examples(data["summary"])
63
+ data["negative_labels"] = self.tokenizer(negative_summaries, max_length=max_target_length, padding="max_length", truncation=True, return_tensors="pt").input_ids
64
+ print("Complete generate negative examples!")
65
+
66
+ label_ignore_ids = []
67
+ for label in data["labels"]:
68
+ label_example = [l if l != 0 else -100 for l in label]
69
+ label_ignore_ids.append(label_example)
70
+
71
+ data["labels"] = label_ignore_ids
72
+
73
+ return data
74
+
75
+ ## Create Negetive Example for Contrastive Learning
76
+ def generate_negative_examples(self, summaries):
77
+ negative_summaries = []
78
+ for summary in summaries:
79
+ words = summary.split()
80
+ random.shuffle(words)
81
+ negative_summaries.append(" ".join(words))
82
+ return negative_summaries
83
+
84
+ ## Create Instruction Dataset
85
+ def generate_queries(self, model, tokenizer, summary, num_queries):
86
+ input_text = "Generate an answerable and specific question based on the following context:. ###\nContext: " + summary
87
+ input_ids = tokenizer(input_text, return_tensors="pt").input_ids
88
+ outputs = model.generate(input_ids, max_length=64, num_return_sequences=num_queries, do_sample=True)
89
+ queries = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
90
+ return queries
91
+
92
+ def text_based_filtering(self, model, tokenizer, query, summary):
93
+ input_text = "Is the question fully answerable from the context without any guessing, yes or no?###\nQuestion: " + query + "###\nContext: " + summary + "###Answer: "
94
+ input_ids = tokenizer(input_text, return_tensors="pt").input_ids
95
+ output_ids = model.generate(input_ids, num_return_sequences=1)
96
+ output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
97
+ return output_text
98
+
99
+ def semantic_filtering(self, scorer, query1, query2):
100
+ score = scorer.score([query1], [query2])[0]
101
+ return score
102
+
103
+
104
+ def preprocessing_data(data: DatasetDict, tokenizer, use_contrastive_loss=False, tokenizing_strategy=False) -> DatasetDict:
105
+ try:
106
+ dataset_ds = DialogSumDataset(tokenizer, use_contrastive_loss, tokenizing_strategy)
107
+ tokenized_data = dataset_ds.handle_data(data)
108
+
109
+ return tokenized_data
110
+
111
+ except Exception as e:
112
+ print(f"\nError while pre-processing data: {e}")
113
+ raise e
src/evaluate/evaluation.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ from datasets import Dataset
5
+
6
+ import evaluate
7
+ import torch
8
+
9
+ import logging
10
+
11
+ # = = = = = = = = = = = Logging Setup = = = = = = = = = = = = =
12
+ logger = logging.getLogger(__name__)
13
+ logging.basicConfig(
14
+ format = "%(asctime)s - %(levelname)s - %(name)s - %(message)s",
15
+ datefmt = "%m/%d/%Y %H:%M:%S",
16
+ level = logging.INFO,
17
+ )
18
+ # = = = = = = = = = = = = = = = = = = = = = = = = = = = = =
19
+
20
+ from transformers import AutoModelForSeq2SeqLM
21
+
22
+ path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
23
+ sys.path.insert(0, path)
24
+
25
+ from model.model import Model
26
+
27
+
28
+ class RougeEvaluation:
29
+ def __init__(self) -> None:
30
+ self.rouge_metric = evaluate.load("rouge")
31
+
32
+ def compute_rouge_metric(self, generated_summary, reference_summary) -> dict:
33
+ results = self.rouge_metric.compute(
34
+ predictions=generated_summary,
35
+ references=reference_summary,
36
+ use_aggregator=True,
37
+ use_stemmer=True
38
+ )
39
+ return results
40
+
41
+
42
+ def evaluation_rouge(model: Model, data: Dataset, generation_config) -> dict:
43
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
44
+ model.base_model = model.get_model()
45
+
46
+ dialogues = data["dialogue"]
47
+
48
+ human_summaries = [summary for summary in data["summary"]]
49
+
50
+ model_summaries = []
51
+
52
+ prefix = "Summarize the following dialogue:\n###\n"
53
+ suffix = "\n### Summary: "
54
+
55
+ # print("\n******************************")
56
+ # idx = 0
57
+ # for answer, dialogue in zip(data["answer"], data["dialogue"]):
58
+ # prefix = "Please summarize the following dialogue focused on the context query:"
59
+ # input = prefix + "\n### Queryr: " + answer + "\n### Dialogue: " + dialogue + "\n### The summary should be around " + str(int(0.2*len(dialogue.split()))) + " words." + "\n### Summary: "
60
+
61
+ for idx, dialogue in enumerate(dialogues):
62
+ input = prefix + dialogue + suffix
63
+
64
+ print(idx, end="# ")
65
+ output_text = model.generate_summary(input, generation_config, do_sample=False)
66
+
67
+ model_summaries.append(output_text)
68
+ idx += 1
69
+
70
+ logger.info("Evaluating summaries...")
71
+
72
+ rouge_evaluator = RougeEvaluation()
73
+
74
+ results = rouge_evaluator.compute_rouge_metric(model_summaries, human_summaries)
75
+
76
+ generated_lengths = [len(summary.split()) for summary in model_summaries]
77
+ average_gen_len = sum(generated_lengths) / len(generated_lengths) if generated_lengths else 0
78
+
79
+ results["gen_len"] = average_gen_len
80
+
81
+ return results
src/evaluate/rouge_metric.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import evaluate
2
+ import nltk
3
+ import numpy as np
4
+ from nltk.tokenize import sent_tokenize
5
+
6
+ from transformers import AutoTokenizer
7
+
8
+ import os
9
+ import sys
10
+
11
+ path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
12
+ sys.path.insert(0, path)
13
+
14
+
15
+ def postprocess_text(preds, labels):
16
+ nltk.download("punkt")
17
+
18
+ preds = [pred.strip() for pred in preds]
19
+ labels = [label.strip() for label in labels]
20
+
21
+ preds = ["\n".join(sent_tokenize(pred)) for pred in preds]
22
+ labels = ["\n".join(sent_tokenize(label)) for label in labels]
23
+
24
+ return preds, labels
25
+
26
+
27
+ def compute_metrics(eval_preds, tokenizer, metric):
28
+ preds, labels = eval_preds
29
+ if isinstance(preds, tuple):
30
+ preds = preds[0]
31
+
32
+ decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
33
+ # Replace -100 in the labels as we can't decode them.
34
+ labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
35
+ decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
36
+
37
+ # Some simple post-processing
38
+ decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
39
+
40
+ # metric = evaluate.load("rouge")
41
+ rouge_results = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
42
+ rouge_results = {k: round(v * 100, 4) for k, v in rouge_results.items()}
43
+
44
+ results = {
45
+ "rouge1": rouge_results["rouge1"],
46
+ "rouge2": rouge_results["rouge2"],
47
+ "rougeL": rouge_results["rougeL"],
48
+ "rougeLsum": rouge_results["rougeLsum"],
49
+ "gen_len": np.mean([np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds])
50
+ }
51
+
52
+ return results
src/model/model.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from transformers import (
4
+ AutoTokenizer,
5
+ AutoModelForSeq2SeqLM,
6
+ )
7
+
8
+ from peft import (
9
+ get_peft_model,
10
+ )
11
+
12
+ class Model:
13
+ def __init__(self, checkpoint):
14
+ self.checkpoint = checkpoint
15
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ self.tokenizer = AutoTokenizer.from_pretrained(self.checkpoint)
17
+ self.base_model = None
18
+
19
+ def get_model(self):
20
+ return AutoModelForSeq2SeqLM.from_pretrained(self.checkpoint)
21
+
22
+ def get_peft(self, lora_config):
23
+ return get_peft_model(self.base_model, lora_config)
24
+
25
+ def prepare_quantize(self, bnb_config):
26
+ return AutoModelForSeq2SeqLM.from_pretrained(self.checkpoint,
27
+ quantization_config=bnb_config,
28
+ device_map={"":0},
29
+ trust_remote_code=True)
30
+ # self.base_model.gradient_checkpointing_enable()
31
+ # self.base_model = prepare_model_for_kbit_training(self.base_model)
32
+
33
+
34
+ def generate_summary(self, input_text, generation_config, do_sample=True):
35
+ input_ids = self.tokenizer.encode(input_text, return_tensors="pt", max_length=1024, truncation=True, padding="max_length")
36
+ output_ids = self.base_model.generate(input_ids=input_ids, do_sample=do_sample, generation_config=generation_config)
37
+
38
+ if "bart" in self.checkpoint:
39
+ output_ids[0][1] = 2
40
+
41
+ output_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
42
+ print(f"\033[94mSummary: {output_text}\n\033[00m")
43
+ return output_text
44
+
45
+ class BartSum(Model):
46
+ def __init__(self, checkpoint):
47
+ super().__init__(checkpoint)
48
+ self.tokenizer = AutoTokenizer.from_pretrained(self.checkpoint)
49
+
50
+ def get_model(self):
51
+ return AutoModelForSeq2SeqLM.from_pretrained(self.checkpoint)
52
+
53
+
54
+ class FlanT5Sum(Model):
55
+ def __init__(self, checkpoint):
56
+ super().__init__(checkpoint)
57
+ self.tokenizer = AutoTokenizer.from_pretrained(self.checkpoint)
58
+
59
+ def get_model(self):
60
+ return AutoModelForSeq2SeqLM.from_pretrained(self.checkpoint)
61
+
62
+
63
+ def load_model(checkpoint):
64
+
65
+ try:
66
+ if "bart" in checkpoint:
67
+ print(f"\033[92mLoad Bart model from checkpoint: {checkpoint}\033[00m")
68
+ return BartSum(checkpoint)
69
+
70
+ if "flan" in checkpoint:
71
+ print(f"\033[92mLoad Flan-T5 model from checkpoint: {checkpoint}\033[00m")
72
+ return FlanT5Sum(checkpoint)
73
+
74
+ else:
75
+ print(f"\033[92mLoad general model from checkpoint: {checkpoint}\033[00m")
76
+ return Model(checkpoint)
77
+
78
+ except Exception as e:
79
+ print("Error while loading model: {e}")
80
+ raise e
src/pipelines/deploy_pipeline.py ADDED
File without changes
src/pipelines/training_pipeline.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import argparse
4
+ import numpy as np
5
+ import nltk
6
+
7
+ from nltk.tokenize import sent_tokenize
8
+ from transformers import (
9
+ Seq2SeqTrainer,
10
+ AutoTokenizer,
11
+ AutoModelForSeq2SeqLM
12
+ )
13
+
14
+ from peft import get_peft_model, prepare_model_for_kbit_training
15
+
16
+ path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
17
+ sys.path.insert(0, path)
18
+
19
+ from utils import *
20
+
21
+ # from model.models import load_model
22
+ from model.model import load_model
23
+ from data.preprocessing import preprocessing_data
24
+ from data.ingest_data import ingest_data
25
+
26
+ import evaluate
27
+
28
+
29
+ def training_pipeline(args: argparse.Namespace):
30
+ try:
31
+ print("=========================================")
32
+ print('\n'.join(f' + {k}={v}' for k, v in vars(args).items()))
33
+ print("=========================================")
34
+
35
+ import torch
36
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
+
38
+ model = load_model(args.checkpoint)
39
+ tokenizer = AutoTokenizer.from_pretrained(args.checkpoint)
40
+ print(type(tokenizer))
41
+
42
+ if (args.lora == False):
43
+ print("lora=Fasle, quantize=False")
44
+ base_model = AutoModelForSeq2SeqLM.from_pretrained(args.checkpoint).to(device)
45
+ # model.base_model = model.get_model()
46
+ # model.base_model.to(device)
47
+
48
+ else:
49
+ from peft import LoraConfig, TaskType
50
+ from transformers import BitsAndBytesConfig
51
+ import torch
52
+ # Define LoRA Config
53
+ lora_config = LoraConfig(
54
+ r=args.lora_rank,
55
+ lora_alpha=args.lora_alpha,
56
+ target_modules=args.target_modules.split(","),
57
+ lora_dropout=args.lora_dropout,
58
+ bias="none",
59
+ task_type=TaskType.SEQ_2_SEQ_LM
60
+ )
61
+
62
+ if (args.quantize == True):
63
+ print("Quantize=True, lora=True")
64
+ bnb_config = BitsAndBytesConfig(
65
+ load_in_4bit=True,
66
+ bnb_4bit_use_double_quant=True,
67
+ bnb_4bit_quant_type="nf4",
68
+ bnb_4bit_compute_dtype=torch.bfloat16
69
+ )
70
+
71
+ base_model = AutoModelForSeq2SeqLM.from_pretrained(args.checkpoint,
72
+ quantization_config=bnb_config,
73
+ device_map={"":0},
74
+ trust_remote_code=True)
75
+ base_model = prepare_model_for_kbit_training(base_model)
76
+
77
+ if (args.quantize==False):
78
+ print("Quantize=False, lora=True")
79
+ base_model = AutoModelForSeq2SeqLM.from_pretrained(args.checkpoint).to(device)
80
+
81
+ # add LoRA adaptor
82
+ print("Base model:", model.base_model)
83
+ base_model = get_peft_model(base_model, lora_config)
84
+ base_model.print_trainable_parameters()
85
+
86
+
87
+ # Load data from datapath
88
+ data = ingest_data(args.datapath)
89
+ print("\033[92m[+] Complete loading dataset!\033[00m")
90
+
91
+ # Pre-processing data
92
+ data = preprocessing_data(data, tokenizer, use_contrastive_loss=args.use_contrastive_loss, tokenizing_strategy=args.tokenizing_strategy)
93
+ print("\033[92m[+] Complete pre-processing dataset!\033[00m")
94
+
95
+ # Load training arguments
96
+ training_args = load_training_arguments(args)
97
+ print("\033[92m[+] Complete loading training arguments!\033[00m")
98
+
99
+ # Load metric
100
+ metric = evaluate.load("rouge")
101
+ nltk.download("punkt")
102
+
103
+ def postprocess_text(preds, labels):
104
+ preds = [pred.strip() for pred in preds]
105
+ labels = [label.strip() for label in labels]
106
+
107
+ preds = ["\n".join(sent_tokenize(pred)) for pred in preds]
108
+ labels = ["\n".join(sent_tokenize(label)) for label in labels]
109
+
110
+ return preds, labels
111
+
112
+ def compute_metric(eval_preds):
113
+ preds, labels = eval_preds
114
+ if isinstance(preds, tuple):
115
+ preds = preds[0]
116
+
117
+ decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
118
+ # Replace -100 in the labels as we can't decode them.
119
+ labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
120
+ decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
121
+
122
+ # Some simple post-processing
123
+ decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
124
+
125
+ # metric = evaluate.load("rouge")
126
+ rouge_results = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
127
+ rouge_results = {k: round(v * 100, 4) for k, v in rouge_results.items()}
128
+
129
+ results = {
130
+ "rouge1": rouge_results["rouge1"],
131
+ "rouge2": rouge_results["rouge2"],
132
+ "rougeL": rouge_results["rougeL"],
133
+ "rougeLsum": rouge_results["rougeLsum"],
134
+ "gen_len": np.mean([np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds])
135
+ }
136
+
137
+ return results
138
+
139
+ # Load trainer
140
+ if args.use_contrastive_loss==True:
141
+ trainer = ContrastiveLearningTrainer(model=base_model,
142
+ train_dataset=data["train"],
143
+ eval_dataset=data["validation"],
144
+ tokenizer=tokenizer,
145
+ compute_metrics=compute_metric)
146
+
147
+ if args.use_contrastive_loss==False:
148
+ trainer = Seq2SeqTrainer(model=base_model,
149
+ args=training_args,
150
+ train_dataset=data["train"],
151
+ eval_dataset=data["validation"],
152
+ tokenizer=tokenizer,
153
+ compute_metrics=compute_metric)
154
+
155
+ print("\033[92m[+] Complete loading trainer!\033[00m")
156
+
157
+ # Train model
158
+ trainer.train()
159
+ print("\033[92m[+] Complete training!\033[00m")
160
+
161
+ # Push to Huggingface Hub
162
+ trainer.push_to_hub()
163
+ print("\033[92m [+] Complete pushing model to hub!\033[00m")
164
+
165
+ except Exception as e:
166
+ print(f"\033[31m\nError while training: {e}\033[00m")
167
+ raise e
168
+
src/test/test_rouge.py ADDED
File without changes
src/utils.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import os
4
+ import sys
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from transformers import (
10
+ Seq2SeqTrainingArguments,
11
+ Seq2SeqTrainer,
12
+ )
13
+
14
+ path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
15
+ sys.path.insert(0, path)
16
+
17
+ # from src.evaluate.rouge_metric import compute_metrics
18
+
19
+ def parse_args() -> argparse.Namespace:
20
+ parser = argparse.ArgumentParser(description="Fine tuning LLM for Dialogue Text Summarization")
21
+ parser.add_argument("--huggingface_hub_token", type=str, default=None)
22
+ parser.add_argument("--wandb_token", type=str, default=None)
23
+
24
+ parser.add_argument("--checkpoint", type=str, default="google/flan-t5-base")
25
+ parser.add_argument("--datapath", type=str, default="knkarthick/dialogsum")
26
+
27
+ parser.add_argument("--output_dir", type=str, default="fine-tuned-flant5")
28
+ parser.add_argument("--overwrite_output_dir", action="store_true")
29
+
30
+ parser.add_argument("--num_train_epochs", type=int, default=3)
31
+ parser.add_argument("--per_device_train_batch_size", type=int, default=4)
32
+ parser.add_argument("--per_device_eval_batch_size", type=int, default=4)
33
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=2)
34
+
35
+ parser.add_argument("--learning_rate", type=float, default=0.00005)
36
+ parser.add_argument("--weight_decay", type=float, default=0.005)
37
+
38
+ parser.add_argument("--evaluation_strategy", type=str, default="no")
39
+ parser.add_argument("--save_strategy", type=str, default="no")
40
+
41
+ parser.add_argument("--logging_strategy", type=str, default="steps")
42
+ parser.add_argument("--logging_steps", type=int, default=1000)
43
+ parser.add_argument("--save_total_limit", type=int, default=1)
44
+
45
+ parser.add_argument("--report_to", type=str, default="wandb")
46
+ parser.add_argument("--run_name", type=str, default="flan-t5-base-model")
47
+
48
+ parser.add_argument("--predict_with_generate", action="store_true")
49
+
50
+ parser.add_argument("--min_new_tokens", type=int, default=10)
51
+ parser.add_argument("--max_new_tokens", type=int, default=256)
52
+ parser.add_argument("--temperature", type=float, default=0.9)
53
+ parser.add_argument("--top_p", type=float, default=1.0)
54
+ parser.add_argument("--top_k", type=int, default=50)
55
+
56
+ parser.add_argument("--lora", action="store_true")
57
+ parser.add_argument("--quantize", action="store_true")
58
+
59
+ parser.add_argument("--lora_rank", type=int, default=8)
60
+ parser.add_argument("--lora_alpha", type=int, default=16)
61
+ parser.add_argument("--target_modules", type=str, default="q,v")
62
+ parser.add_argument("--lora_dropout", type=float, default=0.05)
63
+
64
+ parser.add_argument("--use_contrastive_loss", action="store_true")
65
+ parser.add_argument("--tokenizing_strategy", type=int, default=1)
66
+
67
+ args = parser.parse_args()
68
+ return args
69
+
70
+
71
+ def load_training_arguments(args):
72
+ try:
73
+ training_args = Seq2SeqTrainingArguments(
74
+ output_dir=args.output_dir,
75
+ overwrite_output_dir=args.overwrite_output_dir,
76
+
77
+ num_train_epochs=args.num_train_epochs,
78
+ per_device_train_batch_size=args.per_device_train_batch_size,
79
+ per_device_eval_batch_size=args.per_device_eval_batch_size,
80
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
81
+
82
+ learning_rate=args.learning_rate,
83
+ weight_decay=args.weight_decay,
84
+
85
+ evaluation_strategy=args.evaluation_strategy,
86
+ save_strategy=args.save_strategy,
87
+
88
+ logging_strategy=args.logging_strategy,
89
+ logging_steps=args.logging_steps,
90
+ save_total_limit=args.save_total_limit,
91
+
92
+ report_to=args.report_to,
93
+ run_name=args.run_name,
94
+
95
+ predict_with_generate=args.predict_with_generate
96
+ )
97
+
98
+ return training_args
99
+
100
+ except Exception as e:
101
+ print(f"Error while loading training arguments: {e}")
102
+ raise e
103
+
104
+ class ContrastiveLoss(nn.Module):
105
+ def __init__(self, margin=1.0):
106
+ super(ContrastiveLoss, self).__init__()
107
+ self.margin = margin
108
+ self.cosine_similarity = nn.CosineSimilarity(dim=1, eps=1e-6)
109
+
110
+ def forward(self, dialgue_embeddings, pos_summary_embeddings, neg_summary_embeddings):
111
+ pos_sim = self.cosine_similarity(dialgue_embeddings, pos_summary_embeddings)
112
+ neg_sim = self.cosine_similarity(dialgue_embeddings, neg_summary_embeddings)
113
+ loss = torch.mean(1-pos_sim) + torch.clamp(neg_sim-self.margin, min=0.0)
114
+
115
+ return loss
116
+
117
+ class ContrastiveLearningTrainer(Seq2SeqTrainer):
118
+ def compute_loss(model, inputs, return_outputs=False):
119
+ output = model(**inputs)
120
+ lm_loss = output.loss
121
+
122
+ dialogue_embeddings = model.encoder(inputs["input_ids"]).last_hidden_state
123
+ pos_summary_embeddings = model.encoder(inputs["labels"]).last_hidden_state
124
+ neg_summary_embeddings = model.encoder(inputs["negative_labels"]).last_hidden_state
125
+
126
+ contrastive_loss = ContrastiveLoss(margin=1.0)(dialogue_embeddings, pos_summary_embeddings, neg_summary_embeddings)
127
+
128
+ # Combine losses
129
+ total_loss = lm_loss + contrastive_loss
130
+
131
+ return (total_loss, output) if return_outputs else total_loss
test_streaming.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import replicate
3
+ import os
4
+ from transformers import AutoTokenizer, GenerationConfig, AutoModelForSeq2SeqLM
5
+ import torch
6
+
7
+ # Set Replicate API token
8
+ with st.sidebar:
9
+ st.title('Dialogue Text Summarization')
10
+ if 'REPLICATE_API_TOKEN' in st.secrets:
11
+ replicate_api = st.secrets['REPLICATE_API_TOKEN']
12
+ else:
13
+ replicate_api = st.text_input('Enter Replicate API token:', type='password')
14
+ if not (replicate_api.startswith('r8_') and len(replicate_api) == 40):
15
+ st.warning('Please enter your Replicate API token.', icon='⚠️')
16
+ st.markdown("**Don't have an API token?** Head over to [Replicate](https://replicate.com) to sign up for one.")
17
+
18
+ os.environ['REPLICATE_API_TOKEN'] = replicate_api
19
+ st.subheader("Adjust model parameters")
20
+ min_new_tokens = st.slider('Min new tokens', min_value=1, max_value=256, step=1, value=10)
21
+ temperature = st.slider('Temperature', min_value=0.01, max_value=1.00, step=0.01, value=1.0)
22
+ top_k = st.slider('Top_k', min_value=1, max_value=50, step=1, value=20)
23
+ top_p = st.slider('Top_p', min_value=0.01, max_value=1.00, step=0.01, value=1.0)
24
+
25
+ # Initialize model and tokenizer
26
+ checkpoint = "dtruong46me/train-bart-base"
27
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint)
29
+ model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint).to(device)
30
+
31
+ st.title("Dialogue Text Summarization")
32
+ st.caption("Natural Language Processing Project 20232")
33
+ st.write("---")
34
+
35
+ input_text = st.text_area("Dialogue", height=200)
36
+
37
+ generation_config = GenerationConfig(
38
+ min_new_tokens=min_new_tokens,
39
+ max_new_tokens=320,
40
+ temperature=temperature,
41
+ top_p=top_p,
42
+ top_k=top_k
43
+ )
44
+
45
+ def generate_summary(model, input_text, generation_config, tokenizer):
46
+ prefix = "Summarize the following conversation: \n\n###"
47
+ suffix = "\n\nSummary:"
48
+ input_ids = tokenizer.encode(prefix + input_text + suffix, return_tensors="pt").to(model.device)
49
+ prompt_str = tokenizer.decode(input_ids[0], skip_special_tokens=True)
50
+ return prompt_str
51
+
52
+ def stream_summary(prompt_str, temperature, top_p):
53
+ for event in replicate.stream(
54
+ "snowflake/snowflake-arctic-instruct",
55
+ input={"prompt": prompt_str,
56
+ "prompt_template": r"{prompt}",
57
+ "temperature": temperature,
58
+ "top_p": top_p}):
59
+ yield str(event['output'])
60
+
61
+ if st.button("Submit"):
62
+ st.write("---")
63
+ st.write("## Summary")
64
+
65
+ if not replicate_api:
66
+ st.error("Please enter your Replicate API token!")
67
+ elif not input_text:
68
+ st.error("Please enter a dialogue!")
69
+ else:
70
+ prompt_str = generate_summary(model, input_text, generation_config, tokenizer)
71
+ summary_container = st.empty()
72
+
73
+ summary_text = ""
74
+ for output in stream_summary(prompt_str, temperature, top_p):
75
+ summary_text += output
76
+ summary_container.text(summary_text)