hiyouga commited on
Commit
659f477
1 Parent(s): 5f6343c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -10
app.py CHANGED
@@ -30,27 +30,31 @@ model = AutoModelForVision2Seq.from_pretrained(model_id, torch_dtype="auto", dev
30
 
31
  @spaces.GPU
32
  def stream_chat(message: Dict[str, str], history: list):
33
- print(message)
 
 
 
 
 
34
  conversation = []
35
  for prompt, answer in history:
36
  conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
37
 
38
- conversation.append({"role": "user", "content": message})
 
 
 
 
 
39
 
40
- input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(
41
- model.device
42
- )
43
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
44
 
45
  generate_kwargs = dict(
46
  input_ids=input_ids,
 
47
  streamer=streamer,
48
- max_new_tokens=max_new_tokens,
49
- temperature=temperature,
50
- do_sample=True,
51
  )
52
- if temperature == 0:
53
- generate_kwargs["do_sample"] = False
54
 
55
  t = Thread(target=model.generate, kwargs=generate_kwargs)
56
  t.start()
 
30
 
31
  @spaces.GPU
32
  def stream_chat(message: Dict[str, str], history: list):
33
+ # {'text': 'what is this', 'files': ['image-xxx.jpg']}
34
+
35
+
36
+ image = Image.open(message["files"][0])
37
+ pixel_values = processor(images=[image], return_tensors="pt").to(model.device)["pixel_values"]
38
+
39
  conversation = []
40
  for prompt, answer in history:
41
  conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
42
 
43
+ conversation.append({"role": "user", "content": message["text"]})
44
+
45
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
46
+ image_token_id = tokenizer.convert_tokens_to_ids("<image>")
47
+ image_prefix = torch.empty((1, getattr(processor, "image_seq_length")), dtype=input_ids.dtype).fill_(image_token_id)
48
+ input_ids = torch.cat((image_prefix, input_ids), dim=-1).to(model.device)
49
 
 
 
 
50
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
51
 
52
  generate_kwargs = dict(
53
  input_ids=input_ids,
54
+ pixel_values=pixel_values
55
  streamer=streamer,
56
+ max_new_tokens=256,
 
 
57
  )
 
 
58
 
59
  t = Thread(target=model.generate, kwargs=generate_kwargs)
60
  t.start()