qnguyen3 commited on
Commit
fccbf81
·
verified ·
1 Parent(s): 69f9849

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -14
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration, TextIteratorStreamer
3
  from threading import Thread
4
  import re
5
  import time
@@ -7,9 +7,15 @@ from PIL import Image
7
  import torch
8
  import spaces
9
 
10
- processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
 
 
11
 
12
- model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype=torch.float16, low_cpu_mem_usage=True)
 
 
 
 
13
  model.to("cuda:0")
14
 
15
  @spaces.GPU
@@ -18,27 +24,46 @@ def bot_streaming(message, history):
18
  if message["files"]:
19
  image = message["files"][-1]["path"]
20
  else:
21
- for hist in history:
22
  if type(hist[0])==tuple:
23
  image = hist[0][0]
 
24
 
25
- if len(history) > 0 and image:
26
- chat_history.append({"role": "user", "content": f'<image>\n{message['text']}'})
27
- for human, assistant in history[1:]:
 
28
  chat_history.append({"role": "user", "content": human })
29
  chat_history.append({"role": "assistant", "content": assistant })
30
-
31
- if image is None:
32
- gr.Error("You need to upload an image for LLaVA to work.")
 
 
 
 
 
 
 
 
 
 
33
  prompt=f"[INST] <image>\n{message['text']} [/INST]"
34
  image = Image.open(image).convert("RGB")
35
- inputs = processor(prompt, image, return_tensors="pt").to("cuda:0")
36
- streamer = TextIteratorStreamer(processor, **{"skip_special_tokens": True})
37
- generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=100)
 
 
 
 
 
 
 
38
  generated_text = ""
39
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
40
  thread.start()
41
- text_prompt =f"[INST] \n{message['text']} [/INST]"
42
 
43
  buffer = ""
44
  for new_text in streamer:
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
  from threading import Thread
4
  import re
5
  import time
 
7
  import torch
8
  import spaces
9
 
10
+ tokenizer = AutoTokenizer.from_pretrained(
11
+ 'qnguyen3/nanoLLaVA',
12
+ trust_remote_code=True)
13
 
14
+ model = AutoModelForCausalLM.from_pretrained(
15
+ 'qnguyen3/nanoLLaVA',
16
+ torch_dtype=torch.float16,
17
+ device_map='auto',
18
+ trust_remote_code=True)
19
  model.to("cuda:0")
20
 
21
  @spaces.GPU
 
24
  if message["files"]:
25
  image = message["files"][-1]["path"]
26
  else:
27
+ for i, hist in enumerate(history):
28
  if type(hist[0])==tuple:
29
  image = hist[0][0]
30
+ image_turn = i
31
 
32
+ if len(history) > 0 and image is not None:
33
+ chat_history.append({"role": "user", "content": f'<image>\n{history[1][0]}'})
34
+ chat_history.append({"role": "assistant", "content": history[1][1] })
35
+ for human, assistant in history[2:]:
36
  chat_history.append({"role": "user", "content": human })
37
  chat_history.append({"role": "assistant", "content": assistant })
38
+ chat_history.append({"role": "user", "content": message['text']})
39
+ elif len(history) > 0 and image is None:
40
+ for human, assistant in history:
41
+ chat_history.append({"role": "user", "content": human })
42
+ chat_history.append({"role": "assistant", "content": assistant })
43
+ chat_history.append({"role": "user", "content": message['text']})
44
+ elif len(history) == 0 and image is not None:
45
+ chat_history.append({"role": "user", "content": f'<image>\n{message['text']}'})
46
+ elif len(history) == 0 and image is None:
47
+ chat_history.append({"role": "user", "content": message['text'] })
48
+
49
+ # if image is None:
50
+ # gr.Error("You need to upload an image for LLaVA to work.")
51
  prompt=f"[INST] <image>\n{message['text']} [/INST]"
52
  image = Image.open(image).convert("RGB")
53
+ text = tokenizer.apply_chat_template(
54
+ messages,
55
+ tokenize=False,
56
+ add_generation_prompt=True)
57
+ text_chunks = [tokenizer(chunk).input_ids for chunk in text.split('<image>')]
58
+ input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1], dtype=torch.long).unsqueeze(0)
59
+ streamer = TextIteratorStreamer(input_ids, **{"skip_special_tokens": True})
60
+ image = Image.open(image)
61
+ image_tensor = model.process_images([image], model.config).to(dtype=model.dtype)
62
+ generation_kwargs = dict(inputs, images=image_tensor, streamer=streamer, max_new_tokens=100)
63
  generated_text = ""
64
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
65
  thread.start()
66
+ text_prompt =f"<|im_start|>user\n{message['text']}<|im_end|>"
67
 
68
  buffer = ""
69
  for new_text in streamer: