Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -30,27 +30,31 @@ model = AutoModelForVision2Seq.from_pretrained(model_id, torch_dtype="auto", dev
|
|
30 |
|
31 |
@spaces.GPU
|
32 |
def stream_chat(message: Dict[str, str], history: list):
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
34 |
conversation = []
|
35 |
for prompt, answer in history:
|
36 |
conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
|
37 |
|
38 |
-
conversation.append({"role": "user", "content": message})
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
-
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(
|
41 |
-
model.device
|
42 |
-
)
|
43 |
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
|
44 |
|
45 |
generate_kwargs = dict(
|
46 |
input_ids=input_ids,
|
|
|
47 |
streamer=streamer,
|
48 |
-
max_new_tokens=
|
49 |
-
temperature=temperature,
|
50 |
-
do_sample=True,
|
51 |
)
|
52 |
-
if temperature == 0:
|
53 |
-
generate_kwargs["do_sample"] = False
|
54 |
|
55 |
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
56 |
t.start()
|
|
|
30 |
|
31 |
@spaces.GPU
|
32 |
def stream_chat(message: Dict[str, str], history: list):
|
33 |
+
# {'text': 'what is this', 'files': ['image-xxx.jpg']}
|
34 |
+
|
35 |
+
|
36 |
+
image = Image.open(message["files"][0])
|
37 |
+
pixel_values = processor(images=[image], return_tensors="pt").to(model.device)["pixel_values"]
|
38 |
+
|
39 |
conversation = []
|
40 |
for prompt, answer in history:
|
41 |
conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
|
42 |
|
43 |
+
conversation.append({"role": "user", "content": message["text"]})
|
44 |
+
|
45 |
+
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
|
46 |
+
image_token_id = tokenizer.convert_tokens_to_ids("<image>")
|
47 |
+
image_prefix = torch.empty((1, getattr(processor, "image_seq_length")), dtype=input_ids.dtype).fill_(image_token_id)
|
48 |
+
input_ids = torch.cat((image_prefix, input_ids), dim=-1).to(model.device)
|
49 |
|
|
|
|
|
|
|
50 |
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
|
51 |
|
52 |
generate_kwargs = dict(
|
53 |
input_ids=input_ids,
|
54 |
+
pixel_values=pixel_values
|
55 |
streamer=streamer,
|
56 |
+
max_new_tokens=256,
|
|
|
|
|
57 |
)
|
|
|
|
|
58 |
|
59 |
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
60 |
t.start()
|