paralym commited on
Commit
ff17e6b
1 Parent(s): 738f600

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -53
app.py CHANGED
@@ -232,17 +232,17 @@ def bot(history):
232
  our_chatbot.conversation.append_message(our_chatbot.conversation.roles[1], None)
233
  prompt = our_chatbot.conversation.get_prompt()
234
 
235
- # input_ids = (
236
- # tokenizer_image_token(
237
- # prompt, our_chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
238
- # )
239
- # .unsqueeze(0)
240
- # .to(our_chatbot.model.device)
241
- # )
242
- input_ids = tokenizer_image_token(
243
  prompt, our_chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
244
- ).unsqueeze(0).to(our_chatbot.model.device)
245
- print("### input_id",input_ids)
 
 
 
 
 
 
246
  stop_str = (
247
  our_chatbot.conversation.sep
248
  if our_chatbot.conversation.sep_style != SeparatorStyle.TWO
@@ -252,58 +252,58 @@ def bot(history):
252
  stopping_criteria = KeywordsStoppingCriteria(
253
  keywords, our_chatbot.tokenizer, input_ids
254
  )
255
- # streamer = TextStreamer(
256
- # our_chatbot.tokenizer, skip_prompt=True, skip_special_tokens=True
257
- # )
258
- streamer = TextIteratorStreamer(
259
  our_chatbot.tokenizer, skip_prompt=True, skip_special_tokens=True
260
  )
 
 
 
261
  print(our_chatbot.model.device)
262
  print(input_ids.device)
263
  print(image_tensor.device)
264
  # import pdb;pdb.set_trace()
265
- # with torch.inference_mode():
266
- # output_ids = our_chatbot.model.generate(
267
- # input_ids,
268
- # images=image_tensor,
269
- # do_sample=True,
270
- # temperature=0.2,
271
- # max_new_tokens=1024,
272
- # streamer=streamer,
273
- # use_cache=False,
274
- # stopping_criteria=[stopping_criteria],
275
- # )
276
-
277
- # outputs = our_chatbot.tokenizer.decode(output_ids[0]).strip()
278
- # if outputs.endswith(stop_str):
279
- # outputs = outputs[: -len(stop_str)]
280
- # our_chatbot.conversation.messages[-1][-1] = outputs
281
-
282
- # history[-1] = [text, outputs]
283
-
284
- # return history
285
- generate_kwargs = dict(
286
- inputs=input_ids,
287
- streamer=streamer,
288
- images=image_tensor,
289
- max_new_tokens=1024,
290
- do_sample=True,
291
- temperature=0.2,
292
- num_beams=1,
293
- use_cache=False,
294
- stopping_criteria=[stopping_criteria],
295
- )
296
 
297
- t = Thread(target=our_chatbot.model.generate, kwargs=generate_kwargs)
298
- t.start()
299
 
300
- outputs = []
301
- for text in streamer:
302
- outputs.append(text)
303
- yield "".join(outputs)
304
 
305
- our_chatbot.conversation.messages[-1][-1] = "".join(outputs)
306
- history[-1] = [text, "".join(outputs)]
307
 
308
 
309
  txt = gr.Textbox(
 
232
  our_chatbot.conversation.append_message(our_chatbot.conversation.roles[1], None)
233
  prompt = our_chatbot.conversation.get_prompt()
234
 
235
+ input_ids = (
236
+ tokenizer_image_token(
 
 
 
 
 
 
237
  prompt, our_chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
238
+ )
239
+ .unsqueeze(0)
240
+ .to(our_chatbot.model.device)
241
+ )
242
+ # input_ids = tokenizer_image_token(
243
+ # prompt, our_chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
244
+ # ).unsqueeze(0).to(our_chatbot.model.device)
245
+ # print("### input_id",input_ids)
246
  stop_str = (
247
  our_chatbot.conversation.sep
248
  if our_chatbot.conversation.sep_style != SeparatorStyle.TWO
 
252
  stopping_criteria = KeywordsStoppingCriteria(
253
  keywords, our_chatbot.tokenizer, input_ids
254
  )
255
+ streamer = TextStreamer(
 
 
 
256
  our_chatbot.tokenizer, skip_prompt=True, skip_special_tokens=True
257
  )
258
+ # streamer = TextIteratorStreamer(
259
+ # our_chatbot.tokenizer, skip_prompt=True, skip_special_tokens=True
260
+ # )
261
  print(our_chatbot.model.device)
262
  print(input_ids.device)
263
  print(image_tensor.device)
264
  # import pdb;pdb.set_trace()
265
+ with torch.inference_mode():
266
+ output_ids = our_chatbot.model.generate(
267
+ input_ids,
268
+ images=image_tensor,
269
+ do_sample=True,
270
+ temperature=0.2,
271
+ max_new_tokens=1024,
272
+ streamer=streamer,
273
+ use_cache=False,
274
+ stopping_criteria=[stopping_criteria],
275
+ )
276
+
277
+ outputs = our_chatbot.tokenizer.decode(output_ids[0]).strip()
278
+ if outputs.endswith(stop_str):
279
+ outputs = outputs[: -len(stop_str)]
280
+ our_chatbot.conversation.messages[-1][-1] = outputs
281
+
282
+ history[-1] = [text, outputs]
283
+
284
+ return history
285
+ # generate_kwargs = dict(
286
+ # inputs=input_ids,
287
+ # streamer=streamer,
288
+ # images=image_tensor,
289
+ # max_new_tokens=1024,
290
+ # do_sample=True,
291
+ # temperature=0.2,
292
+ # num_beams=1,
293
+ # use_cache=False,
294
+ # stopping_criteria=[stopping_criteria],
295
+ # )
296
 
297
+ # t = Thread(target=our_chatbot.model.generate, kwargs=generate_kwargs)
298
+ # t.start()
299
 
300
+ # outputs = []
301
+ # for text in streamer:
302
+ # outputs.append(text)
303
+ # yield "".join(outputs)
304
 
305
+ # our_chatbot.conversation.messages[-1][-1] = "".join(outputs)
306
+ # history[-1] = [text, "".join(outputs)]
307
 
308
 
309
  txt = gr.Textbox(