Remove unneeded generator settings. Get model generator settings if exist
Browse files- app.py +35 -133
- generator.py +53 -29
app.py
CHANGED
@@ -1,18 +1,8 @@
|
|
1 |
-
import json
|
2 |
-
import os
|
3 |
import time
|
4 |
-
|
5 |
|
6 |
import psutil
|
7 |
import streamlit as st
|
8 |
-
import torch
|
9 |
-
from transformers import (
|
10 |
-
AutoModelForCausalLM,
|
11 |
-
AutoModelForSeq2SeqLM,
|
12 |
-
AutoTokenizer,
|
13 |
-
pipeline,
|
14 |
-
set_seed,
|
15 |
-
)
|
16 |
|
17 |
from generator import GeneratorFactory
|
18 |
|
@@ -25,16 +15,19 @@ GENERATOR_LIST = [
|
|
25 |
"model_name": "yhavinga/longt5-local-eff-large-nl8-voc8k-ddwn-512beta-512l-nedd-256ccmatrix-en-nl",
|
26 |
"desc": "longT5 large nl8 256cc/512beta/512l en->nl",
|
27 |
"task": TRANSLATION_NL_TO_EN,
|
|
|
28 |
},
|
29 |
{
|
30 |
"model_name": "yhavinga/longt5-local-eff-large-nl8-voc8k-ddwn-512beta-512-nedd-en-nl",
|
31 |
"desc": "longT5 large nl8 512beta/512l en->nl",
|
32 |
"task": TRANSLATION_NL_TO_EN,
|
|
|
33 |
},
|
34 |
{
|
35 |
"model_name": "yhavinga/t5-small-24L-ccmatrix-multi",
|
36 |
"desc": "T5 small nl24 ccmatrix en->nl",
|
37 |
"task": TRANSLATION_NL_TO_EN,
|
|
|
38 |
},
|
39 |
]
|
40 |
|
@@ -49,18 +42,15 @@ def main():
|
|
49 |
|
50 |
if "generators" not in st.session_state:
|
51 |
st.session_state["generators"] = GeneratorFactory(GENERATOR_LIST)
|
52 |
-
|
53 |
generators = st.session_state["generators"]
|
54 |
|
55 |
with open("style.css") as f:
|
56 |
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
|
57 |
-
|
58 |
st.sidebar.image("babel.png", width=200)
|
59 |
st.sidebar.markdown(
|
60 |
"""# Babel
|
61 |
Vertaal van en naar Engels"""
|
62 |
)
|
63 |
-
model_desc = st.sidebar.selectbox("Model", generators.gpt_descs(), index=1)
|
64 |
st.sidebar.title("Parameters:")
|
65 |
if "prompt_box" not in st.session_state:
|
66 |
# Text is from https://www.gutenberg.org/files/35091/35091-h/35091-h.html
|
@@ -76,135 +66,47 @@ It was a quite young girl, unknown to me, with a hood over her head, and with la
|
|
76 |
st.session_state["text"] = st.text_area(
|
77 |
"Enter text", st.session_state.prompt_box, height=300
|
78 |
)
|
79 |
-
|
80 |
-
|
81 |
-
value=
|
82 |
-
max_value=4096,
|
83 |
-
)
|
84 |
-
no_repeat_ngram_size = st.sidebar.number_input(
|
85 |
-
"No-repeat NGram size", min_value=1, max_value=5, value=3
|
86 |
)
|
87 |
-
|
88 |
-
"
|
89 |
)
|
90 |
-
num_return_sequences = st.sidebar.number_input(
|
91 |
-
"Num return sequences", min_value=1, max_value=5, value=1
|
92 |
-
)
|
93 |
-
seed_placeholder = st.sidebar.empty()
|
94 |
-
if "seed" not in st.session_state:
|
95 |
-
print(f"Session state does not contain seed")
|
96 |
-
st.session_state["seed"] = 4162549114
|
97 |
-
print(f"Seed is set to: {st.session_state['seed']}")
|
98 |
-
|
99 |
-
seed = seed_placeholder.number_input(
|
100 |
-
"Seed", min_value=0, max_value=2**32 - 1, value=st.session_state["seed"]
|
101 |
-
)
|
102 |
-
|
103 |
-
def set_random_seed():
|
104 |
-
st.session_state["seed"] = randint(0, 2**32 - 1)
|
105 |
-
seed = seed_placeholder.number_input(
|
106 |
-
"Seed", min_value=0, max_value=2**32 - 1, value=st.session_state["seed"]
|
107 |
-
)
|
108 |
-
print(f"New random seed set to: {seed}")
|
109 |
-
|
110 |
-
if st.button("Set new random seed"):
|
111 |
-
set_random_seed()
|
112 |
-
|
113 |
-
if sampling_mode := st.sidebar.selectbox(
|
114 |
-
"select a Mode", index=0, options=["Top-k Sampling", "Beam Search"]
|
115 |
-
):
|
116 |
-
if sampling_mode == "Beam Search":
|
117 |
-
num_beams = st.sidebar.number_input(
|
118 |
-
"Num beams", min_value=1, max_value=10, value=4
|
119 |
-
)
|
120 |
-
length_penalty = st.sidebar.number_input(
|
121 |
-
"Length penalty", min_value=0.0, max_value=2.0, value=1.0, step=0.1
|
122 |
-
)
|
123 |
-
params = {
|
124 |
-
"max_length": max_length,
|
125 |
-
"no_repeat_ngram_size": no_repeat_ngram_size,
|
126 |
-
"repetition_penalty": repetition_penalty,
|
127 |
-
"num_return_sequences": num_return_sequences,
|
128 |
-
"num_beams": num_beams,
|
129 |
-
"early_stopping": True,
|
130 |
-
"length_penalty": length_penalty,
|
131 |
-
}
|
132 |
-
else:
|
133 |
-
top_k = st.sidebar.number_input(
|
134 |
-
"Top K", min_value=0, max_value=100, value=50
|
135 |
-
)
|
136 |
-
top_p = st.sidebar.number_input(
|
137 |
-
"Top P", min_value=0.0, max_value=1.0, value=0.95, step=0.05
|
138 |
-
)
|
139 |
-
temperature = st.sidebar.number_input(
|
140 |
-
"Temperature", min_value=0.05, max_value=1.0, value=1.0, step=0.05
|
141 |
-
)
|
142 |
-
params = {
|
143 |
-
"max_length": max_length,
|
144 |
-
"no_repeat_ngram_size": no_repeat_ngram_size,
|
145 |
-
"repetition_penalty": repetition_penalty,
|
146 |
-
"num_return_sequences": num_return_sequences,
|
147 |
-
"do_sample": True,
|
148 |
-
"top_k": top_k,
|
149 |
-
"top_p": top_p,
|
150 |
-
"temperature": temperature,
|
151 |
-
}
|
152 |
-
|
153 |
st.sidebar.markdown(
|
154 |
"""For an explanation of the parameters, head over to the [Huggingface blog post about text generation](https://huggingface.co/blog/how-to-generate)
|
155 |
and the [Huggingface text generation interface doc](https://huggingface.co/transformers/main_classes/model.html?highlight=generate#transformers.generation_utils.GenerationMixin.generate).
|
156 |
"""
|
157 |
)
|
158 |
|
159 |
-
|
160 |
-
""
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
estimate = estimate * (1 + 0.7 * (num_return_sequences - 1))
|
165 |
-
if sampling_mode == "Beam Search":
|
166 |
-
estimate = estimate * (1.1 + 0.3 * (num_beams - 1))
|
167 |
-
else:
|
168 |
-
## gpu
|
169 |
-
estimate = estimate * (1 + 0.1 * (num_return_sequences - 1))
|
170 |
-
estimate = 0.5 + estimate / 5
|
171 |
-
if sampling_mode == "Beam Search":
|
172 |
-
estimate = estimate * (1.0 + 0.1 * (num_beams - 1))
|
173 |
-
return int(estimate)
|
174 |
|
175 |
if st.button("Run"):
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
*Memory: {memory.total / 10**9:.2f}GB, used: {memory.percent}%, available: {memory.available / 10**9:.2f}GB*
|
198 |
-
*Text generated using seed {seed}*
|
199 |
-
"""
|
200 |
-
st.write(info)
|
201 |
-
|
202 |
-
params["seed"] = seed
|
203 |
-
params["prompt"] = st.session_state.text
|
204 |
-
params["model"] = generator.model_name
|
205 |
-
params_text = json.dumps(params)
|
206 |
-
print(params_text)
|
207 |
-
st.json(params_text)
|
208 |
|
209 |
|
210 |
if __name__ == "__main__":
|
|
|
|
|
|
|
1 |
import time
|
2 |
+
import torch
|
3 |
|
4 |
import psutil
|
5 |
import streamlit as st
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
from generator import GeneratorFactory
|
8 |
|
|
|
15 |
"model_name": "yhavinga/longt5-local-eff-large-nl8-voc8k-ddwn-512beta-512l-nedd-256ccmatrix-en-nl",
|
16 |
"desc": "longT5 large nl8 256cc/512beta/512l en->nl",
|
17 |
"task": TRANSLATION_NL_TO_EN,
|
18 |
+
"split_sentences": False,
|
19 |
},
|
20 |
{
|
21 |
"model_name": "yhavinga/longt5-local-eff-large-nl8-voc8k-ddwn-512beta-512-nedd-en-nl",
|
22 |
"desc": "longT5 large nl8 512beta/512l en->nl",
|
23 |
"task": TRANSLATION_NL_TO_EN,
|
24 |
+
"split_sentences": False,
|
25 |
},
|
26 |
{
|
27 |
"model_name": "yhavinga/t5-small-24L-ccmatrix-multi",
|
28 |
"desc": "T5 small nl24 ccmatrix en->nl",
|
29 |
"task": TRANSLATION_NL_TO_EN,
|
30 |
+
"split_sentences": True,
|
31 |
},
|
32 |
]
|
33 |
|
|
|
42 |
|
43 |
if "generators" not in st.session_state:
|
44 |
st.session_state["generators"] = GeneratorFactory(GENERATOR_LIST)
|
|
|
45 |
generators = st.session_state["generators"]
|
46 |
|
47 |
with open("style.css") as f:
|
48 |
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
|
|
|
49 |
st.sidebar.image("babel.png", width=200)
|
50 |
st.sidebar.markdown(
|
51 |
"""# Babel
|
52 |
Vertaal van en naar Engels"""
|
53 |
)
|
|
|
54 |
st.sidebar.title("Parameters:")
|
55 |
if "prompt_box" not in st.session_state:
|
56 |
# Text is from https://www.gutenberg.org/files/35091/35091-h/35091-h.html
|
|
|
66 |
st.session_state["text"] = st.text_area(
|
67 |
"Enter text", st.session_state.prompt_box, height=300
|
68 |
)
|
69 |
+
num_beams = st.sidebar.number_input("Num beams", min_value=1, max_value=10, value=6)
|
70 |
+
num_beam_groups = st.sidebar.number_input(
|
71 |
+
"Num beam groups", min_value=1, max_value=10, value=3
|
|
|
|
|
|
|
|
|
72 |
)
|
73 |
+
length_penalty = st.sidebar.number_input(
|
74 |
+
"Length penalty", min_value=0.0, max_value=2.0, value=1.2, step=0.1
|
75 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
st.sidebar.markdown(
|
77 |
"""For an explanation of the parameters, head over to the [Huggingface blog post about text generation](https://huggingface.co/blog/how-to-generate)
|
78 |
and the [Huggingface text generation interface doc](https://huggingface.co/transformers/main_classes/model.html?highlight=generate#transformers.generation_utils.GenerationMixin.generate).
|
79 |
"""
|
80 |
)
|
81 |
|
82 |
+
params = {
|
83 |
+
"num_beams": num_beams,
|
84 |
+
"num_beam_groups": num_beam_groups,
|
85 |
+
"length_penalty": length_penalty,
|
86 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
|
88 |
if st.button("Run"):
|
89 |
+
memory = psutil.virtual_memory()
|
90 |
+
|
91 |
+
for generator in generators:
|
92 |
+
st.markdown(f"🧮 **Model `{generator}`**")
|
93 |
+
time_start = time.time()
|
94 |
+
result, params_used = generator.generate(
|
95 |
+
text=st.session_state.text, **params
|
96 |
+
)
|
97 |
+
time_end = time.time()
|
98 |
+
time_diff = time_end - time_start
|
99 |
+
|
100 |
+
st.write(result[0].replace("\n", " \n"))
|
101 |
+
text_line = ", ".join([f"{k}={v}" for k, v in params_used.items()])
|
102 |
+
st.markdown(f" 🕙 *generated in {time_diff:.2f}s, `{text_line}`*")
|
103 |
+
|
104 |
+
st.write(
|
105 |
+
f"""
|
106 |
+
---
|
107 |
+
*Memory: {memory.total / 10**9:.2f}GB, used: {memory.percent}%, available: {memory.available / 10**9:.2f}GB*
|
108 |
+
"""
|
109 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
|
112 |
if __name__ == "__main__":
|
generator.py
CHANGED
@@ -1,37 +1,37 @@
|
|
1 |
import os
|
|
|
|
|
2 |
import streamlit as st
|
3 |
import torch
|
4 |
from transformers import (
|
5 |
-
AutoModelForCausalLM,
|
6 |
AutoModelForSeq2SeqLM,
|
7 |
AutoTokenizer,
|
8 |
)
|
9 |
|
10 |
device = torch.cuda.device_count() - 1
|
11 |
|
12 |
-
TRANSLATION_NL_TO_EN = "translation_en_to_nl"
|
13 |
-
|
14 |
|
15 |
-
|
16 |
-
def load_model(model_name, task):
|
17 |
-
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
18 |
try:
|
19 |
if not os.path.exists(".streamlit/secrets.toml"):
|
20 |
raise FileNotFoundError
|
21 |
access_token = st.secrets.get("babel")
|
22 |
except FileNotFoundError:
|
23 |
access_token = os.environ.get("HF_ACCESS_TOKEN", None)
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
tokenizer = AutoTokenizer.from_pretrained(
|
25 |
-
model_name, from_flax=True, use_auth_token=
|
26 |
)
|
27 |
if tokenizer.pad_token is None:
|
28 |
print("Adding pad_token to the tokenizer")
|
29 |
tokenizer.pad_token = tokenizer.eos_token
|
30 |
-
|
31 |
-
|
32 |
-
)
|
33 |
-
model = auto_model_class.from_pretrained(
|
34 |
-
model_name, from_flax=True, use_auth_token=access_token
|
35 |
)
|
36 |
if device != -1:
|
37 |
model.to(f"cuda:{device}")
|
@@ -39,35 +39,61 @@ def load_model(model_name, task):
|
|
39 |
|
40 |
|
41 |
class Generator:
|
42 |
-
def __init__(self, model_name, task, desc):
|
43 |
self.model_name = model_name
|
44 |
self.task = task
|
45 |
self.desc = desc
|
46 |
self.tokenizer = None
|
47 |
self.model = None
|
48 |
self.prefix = ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
self.load()
|
50 |
|
51 |
def load(self):
|
52 |
if not self.model:
|
53 |
print(f"Loading model {self.model_name}")
|
54 |
-
self.tokenizer, self.model = load_model(self.model_name
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
try:
|
57 |
if self.task in self.model.config.task_specific_params:
|
58 |
task_specific_params = self.model.config.task_specific_params[
|
59 |
self.task
|
60 |
]
|
61 |
-
|
62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
except TypeError:
|
64 |
pass
|
65 |
|
66 |
-
def generate(self, text: str, **generate_kwargs) -> str:
|
67 |
-
#
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
batch_encoded = self.tokenizer(
|
72 |
self.prefix + text,
|
73 |
max_length=generate_kwargs["max_length"],
|
@@ -89,9 +115,7 @@ class Generator:
|
|
89 |
pred.replace("<pad> ", "").replace("<pad>", "").replace("</s>", "")
|
90 |
for pred in decoded_preds
|
91 |
]
|
92 |
-
return decoded_preds
|
93 |
-
|
94 |
-
# return self.pipeline(text, **generate_kwargs)
|
95 |
|
96 |
def __str__(self):
|
97 |
return self.desc
|
@@ -104,10 +128,10 @@ class GeneratorFactory:
|
|
104 |
with st.spinner(text=f"Loading the model {g['desc']} ..."):
|
105 |
self.add_generator(**g)
|
106 |
|
107 |
-
def add_generator(self, model_name, task, desc):
|
108 |
# If the generator is not yet present, add it
|
109 |
if not self.get_generator(model_name=model_name, task=task, desc=desc):
|
110 |
-
g = Generator(model_name, task, desc)
|
111 |
g.load()
|
112 |
self.generators.append(g)
|
113 |
|
@@ -120,5 +144,5 @@ class GeneratorFactory:
|
|
120 |
def __iter__(self):
|
121 |
return iter(self.generators)
|
122 |
|
123 |
-
def
|
124 |
-
return [g.desc for g in self.generators if
|
|
|
1 |
import os
|
2 |
+
import re
|
3 |
+
|
4 |
import streamlit as st
|
5 |
import torch
|
6 |
from transformers import (
|
|
|
7 |
AutoModelForSeq2SeqLM,
|
8 |
AutoTokenizer,
|
9 |
)
|
10 |
|
11 |
device = torch.cuda.device_count() - 1
|
12 |
|
|
|
|
|
13 |
|
14 |
+
def get_access_token():
|
|
|
|
|
15 |
try:
|
16 |
if not os.path.exists(".streamlit/secrets.toml"):
|
17 |
raise FileNotFoundError
|
18 |
access_token = st.secrets.get("babel")
|
19 |
except FileNotFoundError:
|
20 |
access_token = os.environ.get("HF_ACCESS_TOKEN", None)
|
21 |
+
return access_token
|
22 |
+
|
23 |
+
|
24 |
+
@st.cache(suppress_st_warning=True, allow_output_mutation=True)
|
25 |
+
def load_model(model_name):
|
26 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
27 |
tokenizer = AutoTokenizer.from_pretrained(
|
28 |
+
model_name, from_flax=True, use_auth_token=get_access_token()
|
29 |
)
|
30 |
if tokenizer.pad_token is None:
|
31 |
print("Adding pad_token to the tokenizer")
|
32 |
tokenizer.pad_token = tokenizer.eos_token
|
33 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(
|
34 |
+
model_name, from_flax=True, use_auth_token=get_access_token()
|
|
|
|
|
|
|
35 |
)
|
36 |
if device != -1:
|
37 |
model.to(f"cuda:{device}")
|
|
|
39 |
|
40 |
|
41 |
class Generator:
|
42 |
+
def __init__(self, model_name, task, desc, split_sentences):
|
43 |
self.model_name = model_name
|
44 |
self.task = task
|
45 |
self.desc = desc
|
46 |
self.tokenizer = None
|
47 |
self.model = None
|
48 |
self.prefix = ""
|
49 |
+
self.gen_kwargs = {
|
50 |
+
"max_length": 128,
|
51 |
+
"num_beams": 6,
|
52 |
+
"num_beam_groups": 3,
|
53 |
+
"no_repeat_ngram_size": 0,
|
54 |
+
"early_stopping": True,
|
55 |
+
"num_return_sequences": 1,
|
56 |
+
"length_penalty": 1.0,
|
57 |
+
}
|
58 |
self.load()
|
59 |
|
60 |
def load(self):
|
61 |
if not self.model:
|
62 |
print(f"Loading model {self.model_name}")
|
63 |
+
self.tokenizer, self.model = load_model(self.model_name)
|
64 |
+
|
65 |
+
for key in self.gen_kwargs:
|
66 |
+
if key in self.model.config.__dict__:
|
67 |
+
self.gen_kwargs[key] = self.model.config.__dict__[key]
|
68 |
+
print(
|
69 |
+
"Setting",
|
70 |
+
key,
|
71 |
+
"to",
|
72 |
+
self.gen_kwargs[key],
|
73 |
+
"for model",
|
74 |
+
self.model_name,
|
75 |
+
)
|
76 |
try:
|
77 |
if self.task in self.model.config.task_specific_params:
|
78 |
task_specific_params = self.model.config.task_specific_params[
|
79 |
self.task
|
80 |
]
|
81 |
+
self.prefix = (
|
82 |
+
task_specific_params["prefix"]
|
83 |
+
if "prefix" in task_specific_params
|
84 |
+
else ""
|
85 |
+
)
|
86 |
+
for key in self.gen_kwargs:
|
87 |
+
if key in task_specific_params:
|
88 |
+
self.gen_kwargs[key] = task_specific_params[key]
|
89 |
except TypeError:
|
90 |
pass
|
91 |
|
92 |
+
def generate(self, text: str, **generate_kwargs) -> (str, dict):
|
93 |
+
# Replace two or more newlines with a single newline in text
|
94 |
+
text = re.sub(r"\n{2,}", "\n", text)
|
95 |
+
|
96 |
+
generate_kwargs = {**self.gen_kwargs, **generate_kwargs}
|
97 |
batch_encoded = self.tokenizer(
|
98 |
self.prefix + text,
|
99 |
max_length=generate_kwargs["max_length"],
|
|
|
115 |
pred.replace("<pad> ", "").replace("<pad>", "").replace("</s>", "")
|
116 |
for pred in decoded_preds
|
117 |
]
|
118 |
+
return decoded_preds, generate_kwargs
|
|
|
|
|
119 |
|
120 |
def __str__(self):
|
121 |
return self.desc
|
|
|
128 |
with st.spinner(text=f"Loading the model {g['desc']} ..."):
|
129 |
self.add_generator(**g)
|
130 |
|
131 |
+
def add_generator(self, model_name, task, desc, split_sentences):
|
132 |
# If the generator is not yet present, add it
|
133 |
if not self.get_generator(model_name=model_name, task=task, desc=desc):
|
134 |
+
g = Generator(model_name, task, desc, split_sentences)
|
135 |
g.load()
|
136 |
self.generators.append(g)
|
137 |
|
|
|
144 |
def __iter__(self):
|
145 |
return iter(self.generators)
|
146 |
|
147 |
+
def get_descriptions(self, task=None):
|
148 |
+
return [g.desc for g in self.generators if task is None or task == g.task]
|