File size: 3,305 Bytes
269743b
d20ac21
280da27
 
 
 
269743b
8dbbaa9
a4b7e70
 
269743b
 
 
cc14163
269743b
d20ac21
280da27
 
cc14163
d20ac21
280da27
 
 
 
a4b7e70
280da27
 
a4b7e70
 
280da27
8dbbaa9
a4b7e70
280da27
 
a4b7e70
280da27
 
cc14163
 
280da27
 
 
cc14163
280da27
 
a4b7e70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280da27
936d897
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280da27
936d897
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import gradio as gr
import torch
from PIL import Image
from transformers import MllamaForConditionalGeneration, AutoProcessor
from peft import PeftModel
from huggingface_hub import login
import spaces
import json


# Login to Hugging Face
if "HF_TOKEN" not in os.environ:
    raise ValueError("Please set the HF_TOKEN environment variable with your Hugging Face token")
login(token=os.environ["HF_TOKEN"])

# Load model and processor (do this outside the inference function to avoid reloading)
base_model_path = "meta-llama/Llama-3.2-11B-Vision-Instruct"
lora_weights_path = "taesiri/BungsBunny-LLama-3.2-11B-Vision-Instruct-Medium"

processor = AutoProcessor.from_pretrained(base_model_path)
model = MllamaForConditionalGeneration.from_pretrained(
    base_model_path,
    torch_dtype=torch.bfloat16,
    device_map="cuda",
)
model = PeftModel.from_pretrained(model, lora_weights_path)
model.tie_weights()


@spaces.GPU
def inference(image):
    # Prepare input
    messages = [
        {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "Describe the image in JSON"}]}
    ]
    input_text = processor.apply_chat_template(messages, add_generation_prompt=True)
    inputs = processor(image, input_text, add_special_tokens=False, return_tensors="pt").to(model.device)
    
    # Run inference
    with torch.no_grad():
        output = model.generate(**inputs, max_new_tokens=2048)
    
    # Decode output
    result = processor.decode(output[0], skip_special_tokens=True)
    json_str = result.strip().split("assistant\n")[1].strip()
    
    try:
        # First JSON parse to handle escaped JSON string
        first_parse = json.loads(json_str)
        
        try:
            # Second JSON parse to get the actual JSON object
            json_object = json.loads(first_parse)
            # Return indented JSON string with 2 spaces
            return json.dumps(json_object, indent=2)
        except json.JSONDecodeError:
            # If second parse fails, return the result of first parse indented
            if isinstance(first_parse, (dict, list)):
                return json.dumps(first_parse, indent=2)
            return first_parse
            
    except json.JSONDecodeError:
        # If both JSON parses fail, return original string
        return json_str

    return None  # In case of unexpected errors

# Create Gradio interface using Blocks
with gr.Blocks() as demo:
    gr.Markdown("# BugsBunny-LLama-3.2-11B-Base-Medium Demo")

    with gr.Row():
        # Container for the image takes full width
        with gr.Column(scale=1):
            image_input = gr.Image(
                type="pil",
                label="Upload Image",
                elem_id="large-image",
                height=500,  # Increased height for larger display
            )

    with gr.Row():
        # Container for the text output takes full width
        with gr.Column(scale=1):
            text_output = gr.Textbox(
                label="Response",
                elem_id="response-text",
                lines=25,
                max_lines=10,
            )

    # Button to trigger the analysis
    submit_btn = gr.Button("Analyze Image", variant="primary")
    submit_btn.click(fn=inference, inputs=[image_input], outputs=[text_output])


demo.launch()