|
import gradio as gr |
|
from transformers import AutoProcessor, AutoModelForVision2Seq |
|
import re |
|
import time |
|
from PIL import Image |
|
import torch |
|
import spaces |
|
import subprocess |
|
|
|
|
|
|
|
processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-Instruct") |
|
model = AutoModelForVision2Seq.from_pretrained("HuggingFaceTB/SmolVLM-Instruct", |
|
torch_dtype=torch.bfloat16, |
|
|
|
).to("cuda") |
|
|
|
@spaces.GPU |
|
def model_inference( |
|
images, text, assistant_prefix, decoding_strategy, temperature, max_new_tokens, |
|
repetition_penalty, top_p |
|
): |
|
if text == "" and not images: |
|
gr.Error("Please input a query and optionally image(s).") |
|
|
|
if text == "" and images: |
|
gr.Error("Please input a text query along the image(s).") |
|
|
|
if isinstance(images, Image.Image): |
|
images = [images] |
|
|
|
|
|
resulting_messages = [ |
|
{ |
|
"role": "user", |
|
"content": [{"type": "image"}] + [ |
|
{"type": "text", "text": text} |
|
] |
|
} |
|
] |
|
|
|
if assistant_prefix: |
|
text = f"{assistant_prefix} {text}" |
|
|
|
|
|
prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True) |
|
inputs = processor(text=prompt, images=[images], return_tensors="pt") |
|
inputs = {k: v.to("cuda") for k, v in inputs.items()} |
|
|
|
generation_args = { |
|
"max_new_tokens": max_new_tokens, |
|
"repetition_penalty": repetition_penalty, |
|
|
|
} |
|
|
|
assert decoding_strategy in [ |
|
"Greedy", |
|
"Top P Sampling", |
|
] |
|
if decoding_strategy == "Greedy": |
|
generation_args["do_sample"] = False |
|
elif decoding_strategy == "Top P Sampling": |
|
generation_args["temperature"] = temperature |
|
generation_args["do_sample"] = True |
|
generation_args["top_p"] = top_p |
|
|
|
generation_args.update(inputs) |
|
|
|
|
|
generated_ids = model.generate(**generation_args) |
|
|
|
generated_texts = processor.batch_decode(generated_ids[:, generation_args["input_ids"].size(1):], skip_special_tokens=True) |
|
return generated_texts[0] |
|
|
|
|
|
with gr.Blocks(fill_height=False) as demo: |
|
gr.Markdown("## SmolVLM: Small yet Mighty 💫") |
|
gr.Markdown("Play with [HuggingFaceTB/SmolVLM-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM-Instruct) in this demo. To get started, upload an image and text or try one of the examples.") |
|
with gr.Column(): |
|
with gr.Row(): |
|
image_input = gr.Image(label="Upload your Image", type="pil") |
|
|
|
with gr.Column(): |
|
query_input = gr.Textbox(label="Prompt") |
|
assistant_prefix = gr.Textbox(label="Assistant Prefix", placeholder="Let's think step by step.") |
|
|
|
submit_btn = gr.Button("Submit") |
|
output = gr.Textbox(label="Output") |
|
|
|
|
|
with gr.Accordion(label="Advanced Generation Parameters", open=False): |
|
examples=[ |
|
["example_images/rococo.jpg", "What art era is this?", "", "Greedy", 0.4, 512, 1.2, 0.8], |
|
["example_images/examples_wat_arun.jpg", "I'm planning a visit to this temple, give me travel tips.", "", "Greedy", 0.4, 512, 1.2, 0.8], |
|
["example_images/examples_invoice.png", "What is the due date and the invoice date?", "", "Greedy", 0.4, 512, 1.2, 0.8], |
|
["example_images/s2w_example.png", "What is this UI about?", "", "Greedy", 0.4, 512, 1.2, 0.8], |
|
["example_images/examples_weather_events.png", "Where do the severe droughts happen according to this diagram?", "", "Greedy", 0.4, 512, 1.2, 0.8], |
|
] |
|
|
|
max_new_tokens = gr.Slider( |
|
minimum=8, |
|
maximum=1024, |
|
value=512, |
|
step=1, |
|
interactive=True, |
|
label="Maximum number of new tokens to generate", |
|
) |
|
repetition_penalty = gr.Slider( |
|
minimum=0.01, |
|
maximum=5.0, |
|
value=1.2, |
|
step=0.01, |
|
interactive=True, |
|
label="Repetition penalty", |
|
info="1.0 is equivalent to no penalty", |
|
) |
|
temperature = gr.Slider( |
|
minimum=0.0, |
|
maximum=5.0, |
|
value=0.4, |
|
step=0.1, |
|
interactive=True, |
|
label="Sampling temperature", |
|
info="Higher values will produce more diverse outputs.", |
|
) |
|
top_p = gr.Slider( |
|
minimum=0.01, |
|
maximum=0.99, |
|
value=0.8, |
|
step=0.01, |
|
interactive=True, |
|
label="Top P", |
|
info="Higher values is equivalent to sampling more low-probability tokens.", |
|
) |
|
decoding_strategy = gr.Radio( |
|
[ |
|
"Top P Sampling", |
|
"Greedy", |
|
|
|
], |
|
value="Greedy", |
|
label="Decoding strategy", |
|
interactive=True, |
|
info="Higher values is equivalent to sampling more low-probability tokens.", |
|
) |
|
decoding_strategy.change( |
|
fn=lambda selection: gr.Slider( |
|
visible=( |
|
selection in ["contrastive_sampling", "beam_sampling", "Top P Sampling", "sampling_top_k"] |
|
) |
|
), |
|
inputs=decoding_strategy, |
|
outputs=temperature, |
|
) |
|
|
|
decoding_strategy.change( |
|
fn=lambda selection: gr.Slider( |
|
visible=( |
|
selection in ["contrastive_sampling", "beam_sampling", "Top P Sampling", "sampling_top_k"] |
|
) |
|
), |
|
inputs=decoding_strategy, |
|
outputs=repetition_penalty, |
|
) |
|
decoding_strategy.change( |
|
fn=lambda selection: gr.Slider(visible=(selection in ["Top P Sampling"])), |
|
inputs=decoding_strategy, |
|
outputs=top_p, |
|
) |
|
gr.Examples( |
|
examples = examples, |
|
inputs=[image_input, query_input, assistant_prefix, decoding_strategy, temperature, |
|
max_new_tokens, repetition_penalty, top_p], |
|
outputs=output, |
|
fn=model_inference |
|
) |
|
|
|
|
|
submit_btn.click(model_inference, inputs = [image_input, query_input, assistant_prefix, decoding_strategy, temperature, |
|
max_new_tokens, repetition_penalty, top_p], outputs=output) |
|
|
|
|
|
demo.launch(debug=True) |