yhavinga commited on
Commit
3f553b1
1 Parent(s): 8175a61

Remove unneeded generator settings. Get model generator settings if exist

Browse files
Files changed (2) hide show
  1. app.py +35 -133
  2. generator.py +53 -29
app.py CHANGED
@@ -1,18 +1,8 @@
1
- import json
2
- import os
3
  import time
4
- from random import randint
5
 
6
  import psutil
7
  import streamlit as st
8
- import torch
9
- from transformers import (
10
- AutoModelForCausalLM,
11
- AutoModelForSeq2SeqLM,
12
- AutoTokenizer,
13
- pipeline,
14
- set_seed,
15
- )
16
 
17
  from generator import GeneratorFactory
18
 
@@ -25,16 +15,19 @@ GENERATOR_LIST = [
25
  "model_name": "yhavinga/longt5-local-eff-large-nl8-voc8k-ddwn-512beta-512l-nedd-256ccmatrix-en-nl",
26
  "desc": "longT5 large nl8 256cc/512beta/512l en->nl",
27
  "task": TRANSLATION_NL_TO_EN,
 
28
  },
29
  {
30
  "model_name": "yhavinga/longt5-local-eff-large-nl8-voc8k-ddwn-512beta-512-nedd-en-nl",
31
  "desc": "longT5 large nl8 512beta/512l en->nl",
32
  "task": TRANSLATION_NL_TO_EN,
 
33
  },
34
  {
35
  "model_name": "yhavinga/t5-small-24L-ccmatrix-multi",
36
  "desc": "T5 small nl24 ccmatrix en->nl",
37
  "task": TRANSLATION_NL_TO_EN,
 
38
  },
39
  ]
40
 
@@ -49,18 +42,15 @@ def main():
49
 
50
  if "generators" not in st.session_state:
51
  st.session_state["generators"] = GeneratorFactory(GENERATOR_LIST)
52
-
53
  generators = st.session_state["generators"]
54
 
55
  with open("style.css") as f:
56
  st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
57
-
58
  st.sidebar.image("babel.png", width=200)
59
  st.sidebar.markdown(
60
  """# Babel
61
  Vertaal van en naar Engels"""
62
  )
63
- model_desc = st.sidebar.selectbox("Model", generators.gpt_descs(), index=1)
64
  st.sidebar.title("Parameters:")
65
  if "prompt_box" not in st.session_state:
66
  # Text is from https://www.gutenberg.org/files/35091/35091-h/35091-h.html
@@ -76,135 +66,47 @@ It was a quite young girl, unknown to me, with a hood over her head, and with la
76
  st.session_state["text"] = st.text_area(
77
  "Enter text", st.session_state.prompt_box, height=300
78
  )
79
- max_length = st.sidebar.number_input(
80
- "Lengte van de tekst",
81
- value=200,
82
- max_value=4096,
83
- )
84
- no_repeat_ngram_size = st.sidebar.number_input(
85
- "No-repeat NGram size", min_value=1, max_value=5, value=3
86
  )
87
- repetition_penalty = st.sidebar.number_input(
88
- "Repetition penalty", min_value=0.0, max_value=5.0, value=1.2, step=0.1
89
  )
90
- num_return_sequences = st.sidebar.number_input(
91
- "Num return sequences", min_value=1, max_value=5, value=1
92
- )
93
- seed_placeholder = st.sidebar.empty()
94
- if "seed" not in st.session_state:
95
- print(f"Session state does not contain seed")
96
- st.session_state["seed"] = 4162549114
97
- print(f"Seed is set to: {st.session_state['seed']}")
98
-
99
- seed = seed_placeholder.number_input(
100
- "Seed", min_value=0, max_value=2**32 - 1, value=st.session_state["seed"]
101
- )
102
-
103
- def set_random_seed():
104
- st.session_state["seed"] = randint(0, 2**32 - 1)
105
- seed = seed_placeholder.number_input(
106
- "Seed", min_value=0, max_value=2**32 - 1, value=st.session_state["seed"]
107
- )
108
- print(f"New random seed set to: {seed}")
109
-
110
- if st.button("Set new random seed"):
111
- set_random_seed()
112
-
113
- if sampling_mode := st.sidebar.selectbox(
114
- "select a Mode", index=0, options=["Top-k Sampling", "Beam Search"]
115
- ):
116
- if sampling_mode == "Beam Search":
117
- num_beams = st.sidebar.number_input(
118
- "Num beams", min_value=1, max_value=10, value=4
119
- )
120
- length_penalty = st.sidebar.number_input(
121
- "Length penalty", min_value=0.0, max_value=2.0, value=1.0, step=0.1
122
- )
123
- params = {
124
- "max_length": max_length,
125
- "no_repeat_ngram_size": no_repeat_ngram_size,
126
- "repetition_penalty": repetition_penalty,
127
- "num_return_sequences": num_return_sequences,
128
- "num_beams": num_beams,
129
- "early_stopping": True,
130
- "length_penalty": length_penalty,
131
- }
132
- else:
133
- top_k = st.sidebar.number_input(
134
- "Top K", min_value=0, max_value=100, value=50
135
- )
136
- top_p = st.sidebar.number_input(
137
- "Top P", min_value=0.0, max_value=1.0, value=0.95, step=0.05
138
- )
139
- temperature = st.sidebar.number_input(
140
- "Temperature", min_value=0.05, max_value=1.0, value=1.0, step=0.05
141
- )
142
- params = {
143
- "max_length": max_length,
144
- "no_repeat_ngram_size": no_repeat_ngram_size,
145
- "repetition_penalty": repetition_penalty,
146
- "num_return_sequences": num_return_sequences,
147
- "do_sample": True,
148
- "top_k": top_k,
149
- "top_p": top_p,
150
- "temperature": temperature,
151
- }
152
-
153
  st.sidebar.markdown(
154
  """For an explanation of the parameters, head over to the [Huggingface blog post about text generation](https://huggingface.co/blog/how-to-generate)
155
  and the [Huggingface text generation interface doc](https://huggingface.co/transformers/main_classes/model.html?highlight=generate#transformers.generation_utils.GenerationMixin.generate).
156
  """
157
  )
158
 
159
- def estimate_time():
160
- """Estimate the time it takes to generate the text."""
161
- estimate = max_length / 18
162
- if device == -1:
163
- ## cpu
164
- estimate = estimate * (1 + 0.7 * (num_return_sequences - 1))
165
- if sampling_mode == "Beam Search":
166
- estimate = estimate * (1.1 + 0.3 * (num_beams - 1))
167
- else:
168
- ## gpu
169
- estimate = estimate * (1 + 0.1 * (num_return_sequences - 1))
170
- estimate = 0.5 + estimate / 5
171
- if sampling_mode == "Beam Search":
172
- estimate = estimate * (1.0 + 0.1 * (num_beams - 1))
173
- return int(estimate)
174
 
175
  if st.button("Run"):
176
- estimate = estimate_time()
177
-
178
- with st.spinner(
179
- text=f"Please wait ~ {estimate} second{'s' if estimate != 1 else ''} while getting results ..."
180
- ):
181
- memory = psutil.virtual_memory()
182
-
183
- for generator in generators:
184
- st.subheader(f"Result from {generator}")
185
- set_seed(seed)
186
- time_start = time.time()
187
- result = generator.generate(text=st.session_state.text, **params)
188
- time_end = time.time()
189
- time_diff = time_end - time_start
190
-
191
- for text in result:
192
- st.write(text.replace("\n", " \n"))
193
- st.write(f"--- generated in {time_diff:.2f} seconds ---")
194
-
195
- info = f"""
196
- ---
197
- *Memory: {memory.total / 10**9:.2f}GB, used: {memory.percent}%, available: {memory.available / 10**9:.2f}GB*
198
- *Text generated using seed {seed}*
199
- """
200
- st.write(info)
201
-
202
- params["seed"] = seed
203
- params["prompt"] = st.session_state.text
204
- params["model"] = generator.model_name
205
- params_text = json.dumps(params)
206
- print(params_text)
207
- st.json(params_text)
208
 
209
 
210
  if __name__ == "__main__":
 
 
 
1
  import time
2
+ import torch
3
 
4
  import psutil
5
  import streamlit as st
 
 
 
 
 
 
 
 
6
 
7
  from generator import GeneratorFactory
8
 
 
15
  "model_name": "yhavinga/longt5-local-eff-large-nl8-voc8k-ddwn-512beta-512l-nedd-256ccmatrix-en-nl",
16
  "desc": "longT5 large nl8 256cc/512beta/512l en->nl",
17
  "task": TRANSLATION_NL_TO_EN,
18
+ "split_sentences": False,
19
  },
20
  {
21
  "model_name": "yhavinga/longt5-local-eff-large-nl8-voc8k-ddwn-512beta-512-nedd-en-nl",
22
  "desc": "longT5 large nl8 512beta/512l en->nl",
23
  "task": TRANSLATION_NL_TO_EN,
24
+ "split_sentences": False,
25
  },
26
  {
27
  "model_name": "yhavinga/t5-small-24L-ccmatrix-multi",
28
  "desc": "T5 small nl24 ccmatrix en->nl",
29
  "task": TRANSLATION_NL_TO_EN,
30
+ "split_sentences": True,
31
  },
32
  ]
33
 
 
42
 
43
  if "generators" not in st.session_state:
44
  st.session_state["generators"] = GeneratorFactory(GENERATOR_LIST)
 
45
  generators = st.session_state["generators"]
46
 
47
  with open("style.css") as f:
48
  st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
 
49
  st.sidebar.image("babel.png", width=200)
50
  st.sidebar.markdown(
51
  """# Babel
52
  Vertaal van en naar Engels"""
53
  )
 
54
  st.sidebar.title("Parameters:")
55
  if "prompt_box" not in st.session_state:
56
  # Text is from https://www.gutenberg.org/files/35091/35091-h/35091-h.html
 
66
  st.session_state["text"] = st.text_area(
67
  "Enter text", st.session_state.prompt_box, height=300
68
  )
69
+ num_beams = st.sidebar.number_input("Num beams", min_value=1, max_value=10, value=6)
70
+ num_beam_groups = st.sidebar.number_input(
71
+ "Num beam groups", min_value=1, max_value=10, value=3
 
 
 
 
72
  )
73
+ length_penalty = st.sidebar.number_input(
74
+ "Length penalty", min_value=0.0, max_value=2.0, value=1.2, step=0.1
75
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  st.sidebar.markdown(
77
  """For an explanation of the parameters, head over to the [Huggingface blog post about text generation](https://huggingface.co/blog/how-to-generate)
78
  and the [Huggingface text generation interface doc](https://huggingface.co/transformers/main_classes/model.html?highlight=generate#transformers.generation_utils.GenerationMixin.generate).
79
  """
80
  )
81
 
82
+ params = {
83
+ "num_beams": num_beams,
84
+ "num_beam_groups": num_beam_groups,
85
+ "length_penalty": length_penalty,
86
+ }
 
 
 
 
 
 
 
 
 
 
87
 
88
  if st.button("Run"):
89
+ memory = psutil.virtual_memory()
90
+
91
+ for generator in generators:
92
+ st.markdown(f"🧮 **Model `{generator}`**")
93
+ time_start = time.time()
94
+ result, params_used = generator.generate(
95
+ text=st.session_state.text, **params
96
+ )
97
+ time_end = time.time()
98
+ time_diff = time_end - time_start
99
+
100
+ st.write(result[0].replace("\n", " \n"))
101
+ text_line = ", ".join([f"{k}={v}" for k, v in params_used.items()])
102
+ st.markdown(f" 🕙 *generated in {time_diff:.2f}s, `{text_line}`*")
103
+
104
+ st.write(
105
+ f"""
106
+ ---
107
+ *Memory: {memory.total / 10**9:.2f}GB, used: {memory.percent}%, available: {memory.available / 10**9:.2f}GB*
108
+ """
109
+ )
 
 
 
 
 
 
 
 
 
 
 
110
 
111
 
112
  if __name__ == "__main__":
generator.py CHANGED
@@ -1,37 +1,37 @@
1
  import os
 
 
2
  import streamlit as st
3
  import torch
4
  from transformers import (
5
- AutoModelForCausalLM,
6
  AutoModelForSeq2SeqLM,
7
  AutoTokenizer,
8
  )
9
 
10
  device = torch.cuda.device_count() - 1
11
 
12
- TRANSLATION_NL_TO_EN = "translation_en_to_nl"
13
-
14
 
15
- @st.cache(suppress_st_warning=True, allow_output_mutation=True)
16
- def load_model(model_name, task):
17
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
18
  try:
19
  if not os.path.exists(".streamlit/secrets.toml"):
20
  raise FileNotFoundError
21
  access_token = st.secrets.get("babel")
22
  except FileNotFoundError:
23
  access_token = os.environ.get("HF_ACCESS_TOKEN", None)
 
 
 
 
 
 
24
  tokenizer = AutoTokenizer.from_pretrained(
25
- model_name, from_flax=True, use_auth_token=access_token
26
  )
27
  if tokenizer.pad_token is None:
28
  print("Adding pad_token to the tokenizer")
29
  tokenizer.pad_token = tokenizer.eos_token
30
- auto_model_class = (
31
- AutoModelForSeq2SeqLM if "translation" in task else AutoModelForCausalLM
32
- )
33
- model = auto_model_class.from_pretrained(
34
- model_name, from_flax=True, use_auth_token=access_token
35
  )
36
  if device != -1:
37
  model.to(f"cuda:{device}")
@@ -39,35 +39,61 @@ def load_model(model_name, task):
39
 
40
 
41
  class Generator:
42
- def __init__(self, model_name, task, desc):
43
  self.model_name = model_name
44
  self.task = task
45
  self.desc = desc
46
  self.tokenizer = None
47
  self.model = None
48
  self.prefix = ""
 
 
 
 
 
 
 
 
 
49
  self.load()
50
 
51
  def load(self):
52
  if not self.model:
53
  print(f"Loading model {self.model_name}")
54
- self.tokenizer, self.model = load_model(self.model_name, self.task)
55
-
 
 
 
 
 
 
 
 
 
 
 
56
  try:
57
  if self.task in self.model.config.task_specific_params:
58
  task_specific_params = self.model.config.task_specific_params[
59
  self.task
60
  ]
61
- if "prefix" in task_specific_params:
62
- self.prefix = task_specific_params["prefix"]
 
 
 
 
 
 
63
  except TypeError:
64
  pass
65
 
66
- def generate(self, text: str, **generate_kwargs) -> str:
67
- #
68
- # import pydevd_pycharm
69
- # pydevd_pycharm.settrace('10.1.0.144', port=12345, stdoutToServer=True, stderrToServer=True)
70
- #
71
  batch_encoded = self.tokenizer(
72
  self.prefix + text,
73
  max_length=generate_kwargs["max_length"],
@@ -89,9 +115,7 @@ class Generator:
89
  pred.replace("<pad> ", "").replace("<pad>", "").replace("</s>", "")
90
  for pred in decoded_preds
91
  ]
92
- return decoded_preds
93
-
94
- # return self.pipeline(text, **generate_kwargs)
95
 
96
  def __str__(self):
97
  return self.desc
@@ -104,10 +128,10 @@ class GeneratorFactory:
104
  with st.spinner(text=f"Loading the model {g['desc']} ..."):
105
  self.add_generator(**g)
106
 
107
- def add_generator(self, model_name, task, desc):
108
  # If the generator is not yet present, add it
109
  if not self.get_generator(model_name=model_name, task=task, desc=desc):
110
- g = Generator(model_name, task, desc)
111
  g.load()
112
  self.generators.append(g)
113
 
@@ -120,5 +144,5 @@ class GeneratorFactory:
120
  def __iter__(self):
121
  return iter(self.generators)
122
 
123
- def gpt_descs(self):
124
- return [g.desc for g in self.generators if g.task == TRANSLATION_NL_TO_EN]
 
1
  import os
2
+ import re
3
+
4
  import streamlit as st
5
  import torch
6
  from transformers import (
 
7
  AutoModelForSeq2SeqLM,
8
  AutoTokenizer,
9
  )
10
 
11
  device = torch.cuda.device_count() - 1
12
 
 
 
13
 
14
+ def get_access_token():
 
 
15
  try:
16
  if not os.path.exists(".streamlit/secrets.toml"):
17
  raise FileNotFoundError
18
  access_token = st.secrets.get("babel")
19
  except FileNotFoundError:
20
  access_token = os.environ.get("HF_ACCESS_TOKEN", None)
21
+ return access_token
22
+
23
+
24
+ @st.cache(suppress_st_warning=True, allow_output_mutation=True)
25
+ def load_model(model_name):
26
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
27
  tokenizer = AutoTokenizer.from_pretrained(
28
+ model_name, from_flax=True, use_auth_token=get_access_token()
29
  )
30
  if tokenizer.pad_token is None:
31
  print("Adding pad_token to the tokenizer")
32
  tokenizer.pad_token = tokenizer.eos_token
33
+ model = AutoModelForSeq2SeqLM.from_pretrained(
34
+ model_name, from_flax=True, use_auth_token=get_access_token()
 
 
 
35
  )
36
  if device != -1:
37
  model.to(f"cuda:{device}")
 
39
 
40
 
41
  class Generator:
42
+ def __init__(self, model_name, task, desc, split_sentences):
43
  self.model_name = model_name
44
  self.task = task
45
  self.desc = desc
46
  self.tokenizer = None
47
  self.model = None
48
  self.prefix = ""
49
+ self.gen_kwargs = {
50
+ "max_length": 128,
51
+ "num_beams": 6,
52
+ "num_beam_groups": 3,
53
+ "no_repeat_ngram_size": 0,
54
+ "early_stopping": True,
55
+ "num_return_sequences": 1,
56
+ "length_penalty": 1.0,
57
+ }
58
  self.load()
59
 
60
  def load(self):
61
  if not self.model:
62
  print(f"Loading model {self.model_name}")
63
+ self.tokenizer, self.model = load_model(self.model_name)
64
+
65
+ for key in self.gen_kwargs:
66
+ if key in self.model.config.__dict__:
67
+ self.gen_kwargs[key] = self.model.config.__dict__[key]
68
+ print(
69
+ "Setting",
70
+ key,
71
+ "to",
72
+ self.gen_kwargs[key],
73
+ "for model",
74
+ self.model_name,
75
+ )
76
  try:
77
  if self.task in self.model.config.task_specific_params:
78
  task_specific_params = self.model.config.task_specific_params[
79
  self.task
80
  ]
81
+ self.prefix = (
82
+ task_specific_params["prefix"]
83
+ if "prefix" in task_specific_params
84
+ else ""
85
+ )
86
+ for key in self.gen_kwargs:
87
+ if key in task_specific_params:
88
+ self.gen_kwargs[key] = task_specific_params[key]
89
  except TypeError:
90
  pass
91
 
92
+ def generate(self, text: str, **generate_kwargs) -> (str, dict):
93
+ # Replace two or more newlines with a single newline in text
94
+ text = re.sub(r"\n{2,}", "\n", text)
95
+
96
+ generate_kwargs = {**self.gen_kwargs, **generate_kwargs}
97
  batch_encoded = self.tokenizer(
98
  self.prefix + text,
99
  max_length=generate_kwargs["max_length"],
 
115
  pred.replace("<pad> ", "").replace("<pad>", "").replace("</s>", "")
116
  for pred in decoded_preds
117
  ]
118
+ return decoded_preds, generate_kwargs
 
 
119
 
120
  def __str__(self):
121
  return self.desc
 
128
  with st.spinner(text=f"Loading the model {g['desc']} ..."):
129
  self.add_generator(**g)
130
 
131
+ def add_generator(self, model_name, task, desc, split_sentences):
132
  # If the generator is not yet present, add it
133
  if not self.get_generator(model_name=model_name, task=task, desc=desc):
134
+ g = Generator(model_name, task, desc, split_sentences)
135
  g.load()
136
  self.generators.append(g)
137
 
 
144
  def __iter__(self):
145
  return iter(self.generators)
146
 
147
+ def get_descriptions(self, task=None):
148
+ return [g.desc for g in self.generators if task is None or task == g.task]