Iker commited on
Commit
58aa7b8
1 Parent(s): 8c7833a

Use GPU only for generate

Browse files
Files changed (1) hide show
  1. app_zero.py +43 -38
app_zero.py CHANGED
@@ -118,6 +118,48 @@ def finish_generation(text: str) -> str:
118
 
119
 
120
  @spaces.GPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  def generate_text(
122
  url: str, mode: int, progress=gr.Progress(track_tqdm=False)
123
  ) -> (str, str):
@@ -181,44 +223,7 @@ def generate_text(
181
  progress(0.5, desc="🤖 Leyendo noticia")
182
 
183
  try:
184
- if mode == 0:
185
- prompt = summarize_prompt(title, text)
186
- elif mode == 50:
187
- prompt = summarize_clickbait_short_prompt(title, text)
188
- elif mode == 100:
189
- prompt = summarize_clickbait_large_prompt(title, text)
190
- else:
191
- raise ValueError("Mode not supported")
192
-
193
- formatted_prompt = tokenizer.apply_chat_template(
194
- [{"role": "user", "content": prompt}],
195
- tokenize=False,
196
- add_generation_prompt=True,
197
- )
198
-
199
- model_inputs = tokenizer(
200
- [formatted_prompt], return_tensors="pt", add_special_tokens=False
201
- )
202
-
203
- streamer = TextStreamer(
204
- tokenizer=tokenizer, skip_prompt=True, skip_special_tokens=True
205
- )
206
-
207
- model_output = model.generate(
208
- **model_inputs.to(model.device),
209
- streamer=streamer,
210
- generation_config=generation_config,
211
- logits_processor=stop_criteria,
212
- )
213
-
214
- yield title, streamer, text
215
-
216
- temp = tokenizer.batch_decode(
217
- model_output[:, model_inputs["input_ids"].shape[-1] :],
218
- skip_special_tokens=True,
219
- clean_up_tokenization_spaces=True,
220
- )[0]
221
-
222
  yield title, temp, text
223
 
224
  except Exception as e:
 
118
 
119
 
120
  @spaces.GPU
121
+ def run_model(mode, title, text):
122
+ if mode == 0:
123
+ prompt = summarize_prompt(title, text)
124
+ elif mode == 50:
125
+ prompt = summarize_clickbait_short_prompt(title, text)
126
+ elif mode == 100:
127
+ prompt = summarize_clickbait_large_prompt(title, text)
128
+ else:
129
+ raise ValueError("Mode not supported")
130
+
131
+ formatted_prompt = tokenizer.apply_chat_template(
132
+ [{"role": "user", "content": prompt}],
133
+ tokenize=False,
134
+ add_generation_prompt=True,
135
+ )
136
+
137
+ model_inputs = tokenizer(
138
+ [formatted_prompt], return_tensors="pt", add_special_tokens=False
139
+ )
140
+
141
+ streamer = TextStreamer(
142
+ tokenizer=tokenizer, skip_prompt=True, skip_special_tokens=True
143
+ )
144
+
145
+ model_output = model.generate(
146
+ **model_inputs.to(model.device),
147
+ streamer=streamer,
148
+ generation_config=generation_config,
149
+ logits_processor=stop_criteria,
150
+ )
151
+
152
+ yield title, streamer, text
153
+
154
+ temp = tokenizer.batch_decode(
155
+ model_output[:, model_inputs["input_ids"].shape[-1] :],
156
+ skip_special_tokens=True,
157
+ clean_up_tokenization_spaces=True,
158
+ )[0]
159
+
160
+ return temp
161
+
162
+
163
  def generate_text(
164
  url: str, mode: int, progress=gr.Progress(track_tqdm=False)
165
  ) -> (str, str):
 
223
  progress(0.5, desc="🤖 Leyendo noticia")
224
 
225
  try:
226
+ temp = run_model(mode, title, text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
  yield title, temp, text
228
 
229
  except Exception as e: