dtruong46me commited on
Commit
8fcd344
1 Parent(s): 45b0f52

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -73
app.py CHANGED
@@ -1,73 +1,132 @@
1
- import streamlit as st
2
-
3
- from transformers import GenerationConfig, BartModel, BartTokenizer, AutoTokenizer, AutoModelForSeq2SeqLM
4
- import torch
5
- import time
6
-
7
- import sys, os
8
-
9
- path = os.path.abspath(os.path.dirname(__file__))
10
- sys.path.insert(0, path)
11
-
12
- from gen_summary import generate_summary
13
-
14
-
15
- st.title("Dialogue Text Summarization")
16
- st.caption("Natural Language Processing Project 20232")
17
-
18
- st.write("---")
19
-
20
- with st.sidebar:
21
- checkpoint = st.selectbox("Model", options=[
22
- "Choose model",
23
- "dtruong46me/train-bart-base",
24
- "dtruong46me/flant5-small",
25
- "dtruong46me/flant5-base",
26
- "dtruong46me/flan-t5-s",
27
- "ntluongg/bart-base-luong"
28
- ])
29
- st.button("Model detail", use_container_width=True)
30
- st.write("-----")
31
- st.write("**Generate Options:**")
32
- min_new_tokens = st.number_input("Min new tokens", min_value=1, max_value=64, value=10)
33
- max_new_tokens = st.number_input("Max new tokens", min_value=64, max_value=128, value=64)
34
- temperature = st.number_input("Temperature", min_value=0.0, max_value=1.0, value=0.9, step=0.05)
35
- top_k = st.number_input("Top_k", min_value=1, max_value=50, step=1, value=20)
36
- top_p = st.number_input("Top_p", min_value=0.01, max_value=1.00, step=0.01, value=1.0)
37
-
38
-
39
- height = 200
40
-
41
- input_text = st.text_area("Dialogue", height=height)
42
-
43
- generation_config = GenerationConfig(
44
- min_new_tokens=min_new_tokens,
45
- max_new_tokens=320,
46
- temperature=temperature,
47
- top_p=top_p,
48
- top_k=top_k
49
- )
50
-
51
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
52
-
53
- if checkpoint=="Choose model":
54
- tokenizer = None
55
- model = None
56
-
57
- if checkpoint!="Choose model":
58
- tokenizer = AutoTokenizer.from_pretrained(checkpoint)
59
- model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
60
-
61
-
62
-
63
- if st.button("Submit"):
64
- st.write("---")
65
- st.write("## Summary")
66
-
67
- if checkpoint=="Choose model":
68
- st.error("Please selece a model!")
69
-
70
- else:
71
- if input_text=="":
72
- st.error("Please enter a dialogue!")
73
- st.write(generate_summary(model, " ".join(input_text.split()), generation_config, tokenizer))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+
4
+ from transformers import GenerationConfig, BartModel, BartTokenizer, AutoTokenizer, AutoModelForSeq2SeqLM, TextStreamer
5
+ import torch
6
+ import time
7
+
8
+ import sys, os
9
+
10
+ path = os.path.abspath(os.path.dirname(__file__))
11
+ sys.path.insert(0, path)
12
+
13
+ from gen_summary import generate_summary
14
+
15
+
16
+ st.title("Dialogue Text Summarization")
17
+ st.caption("Natural Language Processing Project 20232")
18
+
19
+ st.write("---")
20
+
21
+
22
+ class StreamlitTextStreamer(TextStreamer):
23
+ def __init__(self, tokenizer, st_container, st_info_container, skip_prompt=False, **decode_kwargs):
24
+ super().__init__(tokenizer, skip_prompt, **decode_kwargs)
25
+ self.st_container = st_container
26
+ self.st_info_container = st_info_container
27
+ self.text = ""
28
+ self.start_time = None
29
+ self.first_token_time = None
30
+ self.total_tokens = 0
31
+
32
+ def on_finalized_text(self, text: str, stream_end: bool=False):
33
+ if self.start_time is None:
34
+ self.start_time = time.time()
35
+
36
+ if self.first_token_time is None and len(text.strip()) > 0:
37
+ self.first_token_time = time.time()
38
+
39
+ self.text += text
40
+
41
+ self.total_tokens += len(text.split())
42
+ self.st_container.markdown("###### " + self.text)
43
+ time.sleep(0.03)
44
+
45
+ if stream_end:
46
+ total_time = time.time() - self.start_time
47
+ first_token_wait_time = self.first_token_time - self.start_time if self.first_token_time else None
48
+ tokens_per_second = self.total_tokens / total_time if total_time > 0 else None
49
+
50
+ df = pd.DataFrame(data={
51
+ "First token": [first_token_wait_time],
52
+ "Total tokens": [self.total_tokens],
53
+ "Time taken": [total_time],
54
+ "Token per second": [tokens_per_second]
55
+ })
56
+
57
+ self.st_info_container.table(df)
58
+
59
+ def generate_summary(model, input_text, generation_config, tokenizer, st_container, st_info_container) -> str:
60
+ try:
61
+ prefix = "Summarize the following conversation: \n###\n"
62
+ suffix = "\n### Summary:"
63
+ target_length = max(1, int(0.15 * len(input_text.split())))
64
+
65
+ input_ids = tokenizer.encode(prefix + input_text + f"The generated summary should be around {target_length} words." + suffix, return_tensors="pt")
66
+
67
+ # Initialize the Streamlit container and streamer
68
+ streamer = StreamlitTextStreamer(tokenizer, st_container, st_info_container, skip_special_tokens=True, decoder_start_token_id=3)
69
+
70
+ model.generate(input_ids, streamer=streamer, do_sample=True, generation_config=generation_config)
71
+
72
+ except Exception as e:
73
+ raise e
74
+
75
+
76
+ with st.sidebar:
77
+ checkpoint = st.selectbox("Model", options=[
78
+ "Choose model",
79
+ "dtruong46me/train-bart-base",
80
+ "dtruong46me/flant5-small",
81
+ "dtruong46me/flant5-base",
82
+ "dtruong46me/flan-t5-s",
83
+ "ntluongg/bart-base-luong"
84
+ ])
85
+ st.button("Model detail", use_container_width=True)
86
+ st.write("-----")
87
+ st.write("**Generate Options:**")
88
+ min_new_tokens = st.number_input("Min new tokens", min_value=1, max_value=64, value=10)
89
+ max_new_tokens = st.number_input("Max new tokens", min_value=64, max_value=128, value=64)
90
+ temperature = st.number_input("Temperature", min_value=0.0, max_value=1.0, value=0.9, step=0.05)
91
+ top_k = st.number_input("Top_k", min_value=1, max_value=50, step=1, value=20)
92
+ top_p = st.number_input("Top_p", min_value=0.01, max_value=1.00, step=0.01, value=1.0)
93
+
94
+
95
+ height = 200
96
+
97
+ input_text = st.text_area("Dialogue", height=height)
98
+
99
+ generation_config = GenerationConfig(
100
+ min_new_tokens=min_new_tokens,
101
+ max_new_tokens=320,
102
+ temperature=temperature,
103
+ top_p=top_p,
104
+ top_k=top_k
105
+ )
106
+
107
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
108
+
109
+ if checkpoint=="Choose model":
110
+ tokenizer = None
111
+ model = None
112
+
113
+ if checkpoint!="Choose model":
114
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint)
115
+ model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
116
+
117
+
118
+
119
+ if st.button("Submit"):
120
+ st.write("---")
121
+ st.write("## Summary")
122
+
123
+ if checkpoint=="Choose model":
124
+ st.error("Please selece a model!")
125
+
126
+ else:
127
+ if input_text=="":
128
+ st.error("Please enter a dialogue!")
129
+ # generate_summary(model, " ".join(input_text.split()), generation_config, tokenizer)
130
+ st_container = st.empty()
131
+ st_info_container = st.empty()
132
+ generate_summary(model, " ".join(input_text.split()), generation_config, tokenizer, st_container, st_info_container)