llavatesrun / app.py
teganmosi's picture
Update app.py
41b85fd
import gradio as gr
import textwrap
from io import BytesIO
import requests
import torch
from llava.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
from llava.conversation import SeparatorStyle, conv_templates
from llava.mm_utils import (
KeywordsStoppingCriteria,
get_model_name_from_path,
process_images,
tokenizer_image_token,
)
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from PIL import Image
from llava.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
from llava.conversation import SeparatorStyle, conv_templates
from llava.mm_utils import (
KeywordsStoppingCriteria,
get_model_name_from_path,
process_images,
tokenizer_image_token,
)
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from PIL import Image
import torch
# Disable PyTorch initialization
disable_torch_init()
# Load the pretrained model
MODEL = "4bit/llava-v1.5-13b-3GB"
model_name = get_model_name_from_path(MODEL)
tokenizer, model, image_processor, context_len = load_pretrained_model(
model_path=MODEL, model_base=None, model_name=model_name, load_4bit=True
)
# Define the prompt creation function
def create_prompt(prompt: str):
conv = conv_templates["llava_v0"].copy()
roles = conv.roles
prompt = DEFAULT_IMAGE_TOKEN + "\n" + prompt
conv.append_message(roles[0], prompt)
conv.append_message(roles[1], None)
return conv.get_prompt(), conv
# Define the image processing function
def process_image(image):
args = {"image_aspect_ratio": "pad"}
image_tensor = process_images([image], image_processor, args)
return image_tensor.to(model.device, dtype=torch.float16)
# Define the image description function
def describe_image(image_file):
image = Image.open(image_file)
image.resize((500, 500))
processed_image = process_image(image)
prompt, _ = create_prompt("Describe the image")
input_ids = (
tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
.unsqueeze(0)
.to(model.device)
)
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
stopping_criteria = KeywordsStoppingCriteria(
keywords=[stop_str], tokenizer=tokenizer, input_ids=input_ids
)
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=processed_image,
do_sample=True,
temperature=0.01,
max_new_tokens=512,
use_cache=True,
stopping_criteria=[stopping_criteria],
)
description = tokenizer.decode(
output_ids[0, input_ids.shape[1] :], skip_special_tokens=True
).strip()
return description
iface = gr.Interface(
fn=describe_image,
inputs=gr.Image(type="pil", label="Image"), # Specify the label for the input
outputs=gr.Textbox(),
live=True,
capture_session=True
)
# Launch the Gradio interface
iface.launch()