taesiri's picture
backup
f0a2db5
raw
history blame
7.19 kB
import os
import gradio as gr
import torch
from PIL import Image
from transformers import MllamaForConditionalGeneration, AutoProcessor
from huggingface_hub import login
import spaces
import json
import matplotlib.pyplot as plt
import io
import base64
def check_environment():
required_vars = ["HF_TOKEN"]
missing_vars = [var for var in required_vars if var not in os.environ]
if missing_vars:
raise ValueError(
f"Missing required environment variables: {', '.join(missing_vars)}\n"
"Please set the HF_TOKEN environment variable with your Hugging Face token"
)
# Login to Hugging Face
check_environment()
login(token=os.environ["HF_TOKEN"], add_to_git_credential=True)
# Load model and processor (do this outside the inference function to avoid reloading)
base_model_path = (
"taesiri/BugsBunny-LLama-3.2-11B-Vision-BaseCaptioner-Medium-FullModel"
)
processor = AutoProcessor.from_pretrained(base_model_path)
model = MllamaForConditionalGeneration.from_pretrained(
base_model_path,
torch_dtype=torch.bfloat16,
device_map="cuda",
)
# model = PeftModel.from_pretrained(model, lora_weights_path)
model.tie_weights()
def describe_image_in_JSON(json_string):
try:
# First JSON decode
first_decode = json.loads(json_string)
# Second JSON decode - parse the actual data
final_data = json.loads(first_decode)
return final_data
except json.JSONDecodeError as e:
return f"Error parsing JSON: {str(e)}"
def create_color_palette_image(colors):
if not colors or not isinstance(colors, list):
return None
try:
# Validate color format
for color in colors:
if not isinstance(color, str) or not color.startswith("#"):
return None
# Create figure and axis
fig, ax = plt.subplots(figsize=(10, 2))
# Create rectangles for each color
for i, color in enumerate(colors):
ax.add_patch(plt.Rectangle((i, 0), 1, 1, facecolor=color))
# Set the view limits and aspect ratio
ax.set_xlim(0, len(colors))
ax.set_ylim(0, 1)
ax.set_xticks([])
ax.set_yticks([])
return fig # Return the matplotlib figure directly
except Exception as e:
print(f"Error creating color palette: {e}")
return None
@spaces.GPU
def inference(image):
if image is None:
return ["Please provide an image"] * 8
if not isinstance(image, Image.Image):
try:
image = Image.fromarray(image)
except Exception as e:
print(f"Image conversion error: {e}")
return ["Invalid image format"] * 8
# Prepare input
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "Describe the image in JSON"},
],
}
]
input_text = processor.apply_chat_template(messages, add_generation_prompt=True)
try:
# Move inputs to the correct device
inputs = processor(
image, input_text, add_special_tokens=False, return_tensors="pt"
).to(model.device)
# Clear CUDA cache after inference
with torch.no_grad():
output = model.generate(**inputs, max_new_tokens=2048)
if torch.cuda.is_available():
torch.cuda.empty_cache()
except Exception as e:
print(f"Inference error: {e}")
return ["Error during inference"] * 8
# Decode output
result = processor.decode(output[0], skip_special_tokens=True)
print("DEBUG: Full decoded output:", result)
try:
json_str = result.strip().split("assistant\n")[1].strip()
print("DEBUG: Extracted JSON string after split:", json_str)
except Exception as e:
print("DEBUG: Error splitting response:", e)
return ["Error extracting JSON from response"] * 8 + [
"Failed to extract JSON",
"Error",
]
parsed_json = describe_image_in_JSON(json_str)
if parsed_json:
# Create color palette visualization
colors = parsed_json.get("color_palette", [])
color_image = create_color_palette_image(colors)
# Convert lists to proper format for Gradio JSON components
character_list = json.dumps(parsed_json.get("character_list", []))
object_list = json.dumps(parsed_json.get("object_list", []))
texture_details = json.dumps(parsed_json.get("texture_details", []))
return (
parsed_json.get("description", "Not available"),
parsed_json.get("scene_description", "Not available"),
character_list,
object_list,
texture_details,
parsed_json.get("lighting_details", "Not available"),
color_image,
json_str,
"", # Error box
"Analysis complete", # Status
)
return ["Error parsing response"] * 8 + ["Failed to parse JSON", "Error"]
# Update Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# BugsBunny-LLama-3.2-11B-Base-Medium Demo")
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(
type="pil",
label="Upload Image",
elem_id="large-image",
)
submit_btn = gr.Button("Analyze Image", variant="primary")
with gr.Tabs():
with gr.Tab("Structured Results"):
with gr.Column(scale=1):
description_output = gr.Textbox(
label="Description",
lines=4,
)
scene_output = gr.Textbox(
label="Scene Description",
lines=2,
)
characters_output = gr.JSON(
label="Characters",
)
objects_output = gr.JSON(
label="Objects",
)
textures_output = gr.JSON(
label="Texture Details",
)
lighting_output = gr.Textbox(
label="Lighting Details",
lines=2,
)
color_palette_output = gr.Plot(
label="Color Palette",
)
with gr.Tab("Raw Output"):
raw_output = gr.Textbox(
label="Raw JSON Response",
lines=25,
max_lines=30,
)
error_box = gr.Textbox(label="Error Messages", visible=False)
with gr.Row():
status_text = gr.Textbox(label="Status", value="Ready", interactive=False)
submit_btn.click(
fn=inference,
inputs=[image_input],
outputs=[
description_output,
scene_output,
characters_output,
objects_output,
textures_output,
lighting_output,
color_palette_output,
raw_output,
error_box,
status_text,
],
api_name="analyze",
)
demo.launch(share=True)