taesiri's picture
Update app.py
936d897 verified
raw
history blame
3.31 kB
import os
import gradio as gr
import torch
from PIL import Image
from transformers import MllamaForConditionalGeneration, AutoProcessor
from peft import PeftModel
from huggingface_hub import login
import spaces
import json
# Login to Hugging Face
if "HF_TOKEN" not in os.environ:
raise ValueError("Please set the HF_TOKEN environment variable with your Hugging Face token")
login(token=os.environ["HF_TOKEN"])
# Load model and processor (do this outside the inference function to avoid reloading)
base_model_path = "meta-llama/Llama-3.2-11B-Vision-Instruct"
lora_weights_path = "taesiri/BungsBunny-LLama-3.2-11B-Vision-Instruct-Medium"
processor = AutoProcessor.from_pretrained(base_model_path)
model = MllamaForConditionalGeneration.from_pretrained(
base_model_path,
torch_dtype=torch.bfloat16,
device_map="cuda",
)
model = PeftModel.from_pretrained(model, lora_weights_path)
model.tie_weights()
@spaces.GPU
def inference(image):
# Prepare input
messages = [
{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "Describe the image in JSON"}]}
]
input_text = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(image, input_text, add_special_tokens=False, return_tensors="pt").to(model.device)
# Run inference
with torch.no_grad():
output = model.generate(**inputs, max_new_tokens=2048)
# Decode output
result = processor.decode(output[0], skip_special_tokens=True)
json_str = result.strip().split("assistant\n")[1].strip()
try:
# First JSON parse to handle escaped JSON string
first_parse = json.loads(json_str)
try:
# Second JSON parse to get the actual JSON object
json_object = json.loads(first_parse)
# Return indented JSON string with 2 spaces
return json.dumps(json_object, indent=2)
except json.JSONDecodeError:
# If second parse fails, return the result of first parse indented
if isinstance(first_parse, (dict, list)):
return json.dumps(first_parse, indent=2)
return first_parse
except json.JSONDecodeError:
# If both JSON parses fail, return original string
return json_str
return None # In case of unexpected errors
# Create Gradio interface using Blocks
with gr.Blocks() as demo:
gr.Markdown("# BugsBunny-LLama-3.2-11B-Base-Medium Demo")
with gr.Row():
# Container for the image takes full width
with gr.Column(scale=1):
image_input = gr.Image(
type="pil",
label="Upload Image",
elem_id="large-image",
height=500, # Increased height for larger display
)
with gr.Row():
# Container for the text output takes full width
with gr.Column(scale=1):
text_output = gr.Textbox(
label="Response",
elem_id="response-text",
lines=25,
max_lines=10,
)
# Button to trigger the analysis
submit_btn = gr.Button("Analyze Image", variant="primary")
submit_btn.click(fn=inference, inputs=[image_input], outputs=[text_output])
demo.launch()