merve HF staff commited on
Commit
e4c787e
1 Parent(s): 9986eb7

Revert chatbot

Browse files
Files changed (1) hide show
  1. app.py +128 -99
app.py CHANGED
@@ -1,12 +1,11 @@
1
  import gradio as gr
2
- from transformers import AutoProcessor, AutoModelForVision2Seq, TextIteratorStreamer
3
- from threading import Thread
4
  import re
5
  import time
6
  from PIL import Image
7
  import torch
8
  import spaces
9
- #import subprocess
10
  #subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
11
 
12
 
@@ -16,39 +15,38 @@ model = AutoModelForVision2Seq.from_pretrained("HuggingFaceTB/SmolVLM-Instruct",
16
  #_attn_implementation="flash_attention_2"
17
  ).to("cuda")
18
 
19
- #@spaces.GPU
20
  def model_inference(
21
- input_dict, history, decoding_strategy, temperature, max_new_tokens,
22
  repetition_penalty, top_p
23
- ):
24
- text = input_dict["text"]
25
- print(input_dict["files"])
26
- if len(input_dict["files"]) > 1:
27
- images = [Image.open(image["path"]).convert("RGB") for image in input_dict["files"]]
28
- elif len(input_dict["files"]) == 1:
29
- images = [Image.open(input_dict["files"][0]["path"]).convert("RGB")]
30
-
31
-
32
  if text == "" and not images:
33
  gr.Error("Please input a query and optionally image(s).")
34
 
35
  if text == "" and images:
36
  gr.Error("Please input a text query along the image(s).")
37
 
38
-
 
39
 
40
 
41
  resulting_messages = [
42
  {
43
  "role": "user",
44
- "content": [{"type": "image"} for _ in range(len(images))] + [
45
  {"type": "text", "text": text}
46
  ]
47
  }
48
  ]
 
 
 
 
 
49
  prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True)
50
  inputs = processor(text=prompt, images=[images], return_tensors="pt")
51
  inputs = {k: v.to("cuda") for k, v in inputs.items()}
 
52
  generation_args = {
53
  "max_new_tokens": max_new_tokens,
54
  "repetition_penalty": repetition_penalty,
@@ -67,88 +65,119 @@ def model_inference(
67
  generation_args["top_p"] = top_p
68
 
69
  generation_args.update(inputs)
 
70
  # Generate
71
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens= True)
72
- generation_args = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
73
- generated_text = ""
74
-
75
- thread = Thread(target=model.generate, kwargs=generation_args)
76
- thread.start()
77
- thread.join()
78
-
79
- buffer = ""
80
- for new_text in streamer:
81
- try:
82
- print("Streamed text:", new_text)
83
- buffer += new_text
84
- except Exception as e:
85
- print("Error while streaming text:", e)
86
-
87
- for new_text in streamer:
88
-
89
- buffer += new_text
90
- generated_text_without_prompt = buffer#[len(ext_buffer):]
91
- time.sleep(0.01)
92
- yield buffer
93
-
94
-
95
- examples=[
96
- [{"text": "What art era do these artpieces belong to?", "files": ["example_images/rococo.jpg", "example_images/rococo_1.jpg"]}, "Greedy", 0.4, 512, 1.2, 0.8],
97
- [{"text": "I'm planning a visit to this temple, give me travel tips.", "files": ["example_images/examples_wat_arun.jpg"]}, "Greedy", 0.4, 512, 1.2, 0.8],
98
- [{"text": "What is the due date and the invoice date?", "files": ["example_images/examples_invoice.png"]}, "Greedy", 0.4, 512, 1.2, 0.8],
99
- [{"text": "What is this UI about?", "files": ["example_images/s2w_example.png"]}, "Greedy", 0.4, 512, 1.2, 0.8],
100
- [{"text": "Where do the severe droughts happen according to this diagram?", "files": ["example_images/examples_weather_events.png"]}, "Greedy", 0.4, 512, 1.2, 0.8],
101
- [{"text": "What are?", "files": ["example_images/examples_weather_events.png"]}, "Greedy", 0.4, 512, 1.2, 0.8],
102
- ]
103
- demo = gr.ChatInterface(fn=model_inference, title="SmolVLM: Small yet Mighty 💫",
104
- description="Play with [HuggingFaceTB/SmolVLM-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM-Instruct) in this demo. To get started, upload an image and text or try one of the examples. This checkpoint works best with single turn conversations, so clear the conversation after a single turn.",
105
- examples=examples,
106
- textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple"), stop_btn="Stop Generation", multimodal=True,
107
- additional_inputs=[gr.Radio(["Top P Sampling",
108
- "Greedy"],
109
- value="Greedy",
110
- label="Decoding strategy",
111
- #interactive=True,
112
- info="Higher values is equivalent to sampling more low-probability tokens.",
113
-
114
- ), gr.Slider(
115
- minimum=0.0,
116
- maximum=5.0,
117
- value=0.4,
118
- step=0.1,
119
- interactive=True,
120
- label="Sampling temperature",
121
- info="Higher values will produce more diverse outputs.",
122
- ),
123
- gr.Slider(
124
- minimum=8,
125
- maximum=1024,
126
- value=512,
127
- step=1,
128
- interactive=True,
129
- label="Maximum number of new tokens to generate",
130
- ), gr.Slider(
131
- minimum=0.01,
132
- maximum=5.0,
133
- value=1.2,
134
- step=0.01,
135
- interactive=True,
136
- label="Repetition penalty",
137
- info="1.0 is equivalent to no penalty",
138
- ),
139
- gr.Slider(
140
- minimum=0.01,
141
- maximum=0.99,
142
- value=0.8,
143
- step=0.01,
144
- interactive=True,
145
- label="Top P",
146
- info="Higher values is equivalent to sampling more low-probability tokens.",
147
- )],cache_examples=False
148
- )
149
-
150
-
151
-
152
-
153
- demo.launch(debug=True)
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from transformers import AutoProcessor, AutoModelForVision2Seq
 
3
  import re
4
  import time
5
  from PIL import Image
6
  import torch
7
  import spaces
8
+ import subprocess
9
  #subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
10
 
11
 
 
15
  #_attn_implementation="flash_attention_2"
16
  ).to("cuda")
17
 
18
+ @spaces.GPU
19
  def model_inference(
20
+ images, text, assistant_prefix, decoding_strategy, temperature, max_new_tokens,
21
  repetition_penalty, top_p
22
+ ):
 
 
 
 
 
 
 
 
23
  if text == "" and not images:
24
  gr.Error("Please input a query and optionally image(s).")
25
 
26
  if text == "" and images:
27
  gr.Error("Please input a text query along the image(s).")
28
 
29
+ if isinstance(images, Image.Image):
30
+ images = [images]
31
 
32
 
33
  resulting_messages = [
34
  {
35
  "role": "user",
36
+ "content": [{"type": "image"}] + [
37
  {"type": "text", "text": text}
38
  ]
39
  }
40
  ]
41
+
42
+ if assistant_prefix:
43
+ text = f"{assistant_prefix} {text}"
44
+
45
+
46
  prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True)
47
  inputs = processor(text=prompt, images=[images], return_tensors="pt")
48
  inputs = {k: v.to("cuda") for k, v in inputs.items()}
49
+
50
  generation_args = {
51
  "max_new_tokens": max_new_tokens,
52
  "repetition_penalty": repetition_penalty,
 
65
  generation_args["top_p"] = top_p
66
 
67
  generation_args.update(inputs)
68
+
69
  # Generate
70
+ generated_ids = model.generate(**generation_args)
71
+
72
+ generated_texts = processor.batch_decode(generated_ids[:, generation_args["input_ids"].size(1):], skip_special_tokens=True)
73
+ return generated_texts[0]
74
+
75
+
76
+ with gr.Blocks(fill_height=False) as demo:
77
+ gr.Markdown("## SmolVLM: Small yet Mighty 💫")
78
+ gr.Markdown("Play with [HuggingFaceTB/SmolVLM-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM-Instruct) in this demo. To get started, upload an image and text or try one of the examples.")
79
+ with gr.Column():
80
+ with gr.Row():
81
+ image_input = gr.Image(label="Upload your Image", type="pil")
82
+
83
+ with gr.Column():
84
+ query_input = gr.Textbox(label="Prompt")
85
+ assistant_prefix = gr.Textbox(label="Assistant Prefix", placeholder="Let's think step by step.")
86
+
87
+ submit_btn = gr.Button("Submit")
88
+ output = gr.Textbox(label="Output")
89
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
+ with gr.Accordion(label="Advanced Generation Parameters", open=False):
92
+ examples=[
93
+ ["example_images/rococo.jpg", "What art era is this?", "", "Top P Sampling", 0.4, 512, 1.2, 0.8],
94
+ ["example_images/examples_wat_arun.jpg", "I'm planning a visit to this temple, give me travel tips.", "", "Greedy", 0.4, 512, 1.2, 0.8],
95
+ ["example_images/examples_invoice.png", "What is the due date and the invoice date?", "", "Top P Sampling", 0.4, 512, 1.2, 0.8],
96
+ ["example_images/s2w_example.png", "What is this UI about?", "", "Top P Sampling", 0.4, 512, 1.2, 0.8],
97
+ ["example_images/examples_weather_events.png", "Where do the severe droughts happen according to this diagram?", "", "Top P Sampling", 0.4, 512, 1.2, 0.8],
98
+ ]
99
+ # Hyper-parameters for generation
100
+ max_new_tokens = gr.Slider(
101
+ minimum=8,
102
+ maximum=1024,
103
+ value=512,
104
+ step=1,
105
+ interactive=True,
106
+ label="Maximum number of new tokens to generate",
107
+ )
108
+ repetition_penalty = gr.Slider(
109
+ minimum=0.01,
110
+ maximum=5.0,
111
+ value=1.2,
112
+ step=0.01,
113
+ interactive=True,
114
+ label="Repetition penalty",
115
+ info="1.0 is equivalent to no penalty",
116
+ )
117
+ temperature = gr.Slider(
118
+ minimum=0.0,
119
+ maximum=5.0,
120
+ value=0.4,
121
+ step=0.1,
122
+ interactive=True,
123
+ label="Sampling temperature",
124
+ info="Higher values will produce more diverse outputs.",
125
+ )
126
+ top_p = gr.Slider(
127
+ minimum=0.01,
128
+ maximum=0.99,
129
+ value=0.8,
130
+ step=0.01,
131
+ interactive=True,
132
+ label="Top P",
133
+ info="Higher values is equivalent to sampling more low-probability tokens.",
134
+ )
135
+ decoding_strategy = gr.Radio(
136
+ [
137
+ "Top P Sampling",
138
+ "Greedy",
139
+
140
+ ],
141
+ value="Top P Sampling",
142
+ label="Decoding strategy",
143
+ interactive=True,
144
+ info="Higher values is equivalent to sampling more low-probability tokens.",
145
+ )
146
+ decoding_strategy.change(
147
+ fn=lambda selection: gr.Slider(
148
+ visible=(
149
+ selection in ["contrastive_sampling", "beam_sampling", "Top P Sampling", "sampling_top_k"]
150
+ )
151
+ ),
152
+ inputs=decoding_strategy,
153
+ outputs=temperature,
154
+ )
155
+
156
+ decoding_strategy.change(
157
+ fn=lambda selection: gr.Slider(
158
+ visible=(
159
+ selection in ["contrastive_sampling", "beam_sampling", "Top P Sampling", "sampling_top_k"]
160
+ )
161
+ ),
162
+ inputs=decoding_strategy,
163
+ outputs=repetition_penalty,
164
+ )
165
+ decoding_strategy.change(
166
+ fn=lambda selection: gr.Slider(visible=(selection in ["Top P Sampling"])),
167
+ inputs=decoding_strategy,
168
+ outputs=top_p,
169
+ )
170
+ gr.Examples(
171
+ examples = examples,
172
+ inputs=[image_input, query_input, assistant_prefix, decoding_strategy, temperature,
173
+ max_new_tokens, repetition_penalty, top_p],
174
+ outputs=output,
175
+ fn=model_inference
176
+ )
177
+
178
+
179
+ submit_btn.click(model_inference, inputs = [image_input, query_input, assistant_prefix, decoding_strategy, temperature,
180
+ max_new_tokens, repetition_penalty, top_p], outputs=output)
181
+
182
+
183
+ demo.launch(debug=True)