Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -232,17 +232,17 @@ def bot(history):
|
|
232 |
our_chatbot.conversation.append_message(our_chatbot.conversation.roles[1], None)
|
233 |
prompt = our_chatbot.conversation.get_prompt()
|
234 |
|
235 |
-
|
236 |
-
|
237 |
-
# prompt, our_chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
|
238 |
-
# )
|
239 |
-
# .unsqueeze(0)
|
240 |
-
# .to(our_chatbot.model.device)
|
241 |
-
# )
|
242 |
-
input_ids = tokenizer_image_token(
|
243 |
prompt, our_chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
|
244 |
-
)
|
245 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
246 |
stop_str = (
|
247 |
our_chatbot.conversation.sep
|
248 |
if our_chatbot.conversation.sep_style != SeparatorStyle.TWO
|
@@ -252,58 +252,58 @@ def bot(history):
|
|
252 |
stopping_criteria = KeywordsStoppingCriteria(
|
253 |
keywords, our_chatbot.tokenizer, input_ids
|
254 |
)
|
255 |
-
|
256 |
-
# our_chatbot.tokenizer, skip_prompt=True, skip_special_tokens=True
|
257 |
-
# )
|
258 |
-
streamer = TextIteratorStreamer(
|
259 |
our_chatbot.tokenizer, skip_prompt=True, skip_special_tokens=True
|
260 |
)
|
|
|
|
|
|
|
261 |
print(our_chatbot.model.device)
|
262 |
print(input_ids.device)
|
263 |
print(image_tensor.device)
|
264 |
# import pdb;pdb.set_trace()
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
generate_kwargs = dict(
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
)
|
296 |
|
297 |
-
t = Thread(target=our_chatbot.model.generate, kwargs=generate_kwargs)
|
298 |
-
t.start()
|
299 |
|
300 |
-
outputs = []
|
301 |
-
for text in streamer:
|
302 |
-
|
303 |
-
|
304 |
|
305 |
-
our_chatbot.conversation.messages[-1][-1] = "".join(outputs)
|
306 |
-
history[-1] = [text, "".join(outputs)]
|
307 |
|
308 |
|
309 |
txt = gr.Textbox(
|
|
|
232 |
our_chatbot.conversation.append_message(our_chatbot.conversation.roles[1], None)
|
233 |
prompt = our_chatbot.conversation.get_prompt()
|
234 |
|
235 |
+
input_ids = (
|
236 |
+
tokenizer_image_token(
|
|
|
|
|
|
|
|
|
|
|
|
|
237 |
prompt, our_chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
|
238 |
+
)
|
239 |
+
.unsqueeze(0)
|
240 |
+
.to(our_chatbot.model.device)
|
241 |
+
)
|
242 |
+
# input_ids = tokenizer_image_token(
|
243 |
+
# prompt, our_chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
|
244 |
+
# ).unsqueeze(0).to(our_chatbot.model.device)
|
245 |
+
# print("### input_id",input_ids)
|
246 |
stop_str = (
|
247 |
our_chatbot.conversation.sep
|
248 |
if our_chatbot.conversation.sep_style != SeparatorStyle.TWO
|
|
|
252 |
stopping_criteria = KeywordsStoppingCriteria(
|
253 |
keywords, our_chatbot.tokenizer, input_ids
|
254 |
)
|
255 |
+
streamer = TextStreamer(
|
|
|
|
|
|
|
256 |
our_chatbot.tokenizer, skip_prompt=True, skip_special_tokens=True
|
257 |
)
|
258 |
+
# streamer = TextIteratorStreamer(
|
259 |
+
# our_chatbot.tokenizer, skip_prompt=True, skip_special_tokens=True
|
260 |
+
# )
|
261 |
print(our_chatbot.model.device)
|
262 |
print(input_ids.device)
|
263 |
print(image_tensor.device)
|
264 |
# import pdb;pdb.set_trace()
|
265 |
+
with torch.inference_mode():
|
266 |
+
output_ids = our_chatbot.model.generate(
|
267 |
+
input_ids,
|
268 |
+
images=image_tensor,
|
269 |
+
do_sample=True,
|
270 |
+
temperature=0.2,
|
271 |
+
max_new_tokens=1024,
|
272 |
+
streamer=streamer,
|
273 |
+
use_cache=False,
|
274 |
+
stopping_criteria=[stopping_criteria],
|
275 |
+
)
|
276 |
+
|
277 |
+
outputs = our_chatbot.tokenizer.decode(output_ids[0]).strip()
|
278 |
+
if outputs.endswith(stop_str):
|
279 |
+
outputs = outputs[: -len(stop_str)]
|
280 |
+
our_chatbot.conversation.messages[-1][-1] = outputs
|
281 |
+
|
282 |
+
history[-1] = [text, outputs]
|
283 |
+
|
284 |
+
return history
|
285 |
+
# generate_kwargs = dict(
|
286 |
+
# inputs=input_ids,
|
287 |
+
# streamer=streamer,
|
288 |
+
# images=image_tensor,
|
289 |
+
# max_new_tokens=1024,
|
290 |
+
# do_sample=True,
|
291 |
+
# temperature=0.2,
|
292 |
+
# num_beams=1,
|
293 |
+
# use_cache=False,
|
294 |
+
# stopping_criteria=[stopping_criteria],
|
295 |
+
# )
|
296 |
|
297 |
+
# t = Thread(target=our_chatbot.model.generate, kwargs=generate_kwargs)
|
298 |
+
# t.start()
|
299 |
|
300 |
+
# outputs = []
|
301 |
+
# for text in streamer:
|
302 |
+
# outputs.append(text)
|
303 |
+
# yield "".join(outputs)
|
304 |
|
305 |
+
# our_chatbot.conversation.messages[-1][-1] = "".join(outputs)
|
306 |
+
# history[-1] = [text, "".join(outputs)]
|
307 |
|
308 |
|
309 |
txt = gr.Textbox(
|