Spaces:
Paused
Paused
from __future__ import annotations | |
import os | |
import string | |
import gradio as gr | |
import PIL.Image | |
import torch | |
from transformers import BitsAndBytesConfig, pipeline | |
import re | |
DESCRIPTION = "# LLaVA 🌋" | |
model_id = "llava-hf/llava-1.5-7b-hf" | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_compute_dtype=torch.float16 | |
) | |
pipe = pipeline("image-to-text", model=model_id, model_kwargs={"quantization_config": quantization_config}) | |
def extract_response_pairs(text): | |
pattern = re.compile(r'(USER:.*?)ASSISTANT:(.*?)(?:$|USER:)', re.DOTALL) | |
matches = pattern.findall(text) | |
pairs = [(user.strip(), assistant.strip()) for user, assistant in matches] | |
return pairs | |
def postprocess_output(output: str) -> str: | |
if output and output[-1] not in string.punctuation: | |
output += "." | |
return output | |
def chat(image, text, temperature, length_penalty, | |
repetition_penalty, max_length, min_length, num_beams, top_p, | |
history_chat): | |
prompt = " ".join(history_chat) | |
prompt = f"USER: <image>\n{text}\nASSISTANT:" | |
outputs = pipe(image, prompt=prompt, | |
generate_kwargs={"temperature":temperature, | |
"length_penalty":length_penalty, | |
"repetition_penalty":repetition_penalty, | |
"max_length":max_length, | |
"min_length":min_length, | |
"num_beams":num_beams, | |
"top_p":top_p}) | |
output = postprocess_output(outputs[0]["generated_text"]) | |
history_chat.append(output) | |
chat_val = extract_response_pairs(" ".join(history_chat)) | |
return chat_val, history_chat | |
css = """ | |
#mkd { | |
height: 500px; | |
overflow: auto; | |
border: 1px solid #ccc; | |
} | |
""" | |
with gr.Blocks(css="style.css") as demo: | |
gr.Markdown(DESCRIPTION) | |
gr.Markdown("**LLaVA, one of the greatest multimodal chat models is now available in transformers with 4-bit quantization! ⚡️ **") | |
gr.Markdown("**Try it in this demo 🤗 **") | |
chatbot = gr.Chatbot(label="Chat", show_label=False) | |
gr.Markdown("Input image and text and start chatting 👇") | |
with gr.Row(): | |
image = gr.Image(type="pil") | |
text_input = gr.Text(label="Chat Input", show_label=False, max_lines=3, container=False) | |
history_chat = gr.State(value=[]) | |
with gr.Row(): | |
clear_chat_button = gr.Button("Clear") | |
chat_button = gr.Button("Submit", variant="primary") | |
with gr.Accordion(label="Advanced settings", open=False): | |
temperature = gr.Slider( | |
label="Temperature", | |
info="Used with nucleus sampling.", | |
minimum=0.5, | |
maximum=1.0, | |
step=0.1, | |
value=1.0, | |
) | |
length_penalty = gr.Slider( | |
label="Length Penalty", | |
info="Set to larger for longer sequence, used with beam search.", | |
minimum=-1.0, | |
maximum=2.0, | |
step=0.2, | |
value=1.0, | |
) | |
repetition_penalty = gr.Slider( | |
label="Repetition Penalty", | |
info="Larger value prevents repetition.", | |
minimum=1.0, | |
maximum=5.0, | |
step=0.5, | |
value=1.5, | |
) | |
max_length = gr.Slider( | |
label="Max Length", | |
minimum=1, | |
maximum=512, | |
step=1, | |
value=50, | |
) | |
min_length = gr.Slider( | |
label="Minimum Length", | |
minimum=1, | |
maximum=100, | |
step=1, | |
value=1, | |
) | |
top_p = gr.Slider( | |
label="Top P", | |
info="Used with nucleus sampling.", | |
minimum=0.5, | |
maximum=1.0, | |
step=0.1, | |
value=0.9, | |
) | |
chat_output = [ | |
chatbot, | |
history_chat | |
] | |
chat_button.click(fn=chat, inputs=[image, | |
text_input, | |
temperature, | |
length_penalty, | |
repetition_penalty, | |
max_length, | |
min_length, | |
top_p, | |
history_chat], | |
outputs=chat_output, | |
api_name="Chat", | |
) | |
chat_inputs = [ | |
image, | |
text_input, | |
temperature, | |
length_penalty, | |
repetition_penalty, | |
max_length, | |
min_length, | |
top_p, | |
history_chat | |
] | |
text_input.submit( | |
fn=chat, | |
inputs=chat_inputs, | |
outputs=chat_output | |
).success( | |
fn=lambda: "", | |
outputs=chat_inputs, | |
queue=False, | |
api_name=False, | |
) | |
clear_chat_button.click( | |
fn=lambda: ([], []), | |
inputs=None, | |
outputs=[ | |
chatbot, | |
history_chat | |
], | |
queue=False, | |
api_name="clear", | |
) | |
image.change( | |
fn=lambda: ([], []), | |
inputs=None, | |
outputs=[ | |
chatbot, | |
history_chat | |
], | |
queue=False, | |
) | |
if __name__ == "__main__": | |
demo.queue(max_size=10).launch() |