File size: 3,919 Bytes
ab097bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a25618
 
 
 
 
 
 
 
 
 
 
 
 
 
ab097bb
8a25618
ab097bb
 
 
 
8a25618
ab097bb
 
 
 
8a25618
ab097bb
8a25618
 
ab097bb
 
 
8a25618
 
ab097bb
 
 
 
8a25618
ab097bb
 
 
 
 
 
 
 
8a25618
ab097bb
 
 
 
8a25618
 
ab097bb
 
 
 
 
 
 
 
 
 
8a25618
ab097bb
 
 
 
 
8a25618
ab097bb
 
 
8a25618
ab097bb
 
 
 
 
 
 
 
 
8a25618
ab097bb
 
 
 
 
 
 
 
8a25618
 
ab097bb
 
 
 
 
 
8a25618
 
 
ab097bb
 
8a25618
ab097bb
8a25618
ab097bb
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import gradio as gr
import spaces
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from PIL import Image

# Load model and tokenizer
model_name = "mistral-community/pixtral-12b-240910"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map="auto"
)

@spaces.GPU(duration=120)
def generate_description(image, detail_level):
    if image is None:
        return "Please upload an image to generate a description."
    
    image = Image.open(image).convert("RGB")
    
    detail_prompts = {
        "Brief": "Provide a brief description of this image in 2-3 sentences.",
        "Detailed": "Describe this image in detail, including main subjects, colors, composition, and any notable elements.",
        "Comprehensive": "Provide a comprehensive analysis of this image, including subjects, colors, composition, mood, potential symbolism, and any other relevant details you can observe."
    }
    
    prompt = detail_prompts[detail_level]
    
    messages = [
        {"role": "system", "content": "You are a highly observant AI assistant specialized in describing images accurately and in detail."},
        {"role": "user", "content": prompt}
    ]
    formatted_prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)

    inputs = tokenizer(formatted_prompt, images=[image], return_tensors="pt", padding=True).to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=300,
            do_sample=True,
            temperature=0.7,
            top_k=50,
            top_p=0.95,
        )
    
    description = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
    return description.strip()

# Custom CSS
css = """
body {
    background-color: #f0f0f5;
    font-family: 'Arial', sans-serif;
}
.container {
    max-width: 900px;
    margin: auto;
    padding: 20px;
}
.gradio-container {
    background-color: white;
    border-radius: 15px;
    box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
}
.header {
    background-color: #4a90e2;
    color: white;
    padding: 20px;
    border-radius: 15px 15px 0 0;
    text-align: center;
    margin-bottom: 20px;
}
.header h1 {
    font-size: 2.5em;
    margin-bottom: 10px;
}
.input-group, .output-group {
    background-color: #f9f9f9;
    padding: 20px;
    border-radius: 10px;
    margin-bottom: 20px;
}
.input-group label, .output-group label {
    color: #4a90e2;
    font-weight: bold;
}
.generate-btn {
    background-color: #4a90e2 !important;
    color: white !important;
    border: none !important;
    border-radius: 5px !important;
    padding: 10px 20px !important;
    font-size: 16px !important;
    cursor: pointer !important;
    transition: background-color 0.3s ease !important;
}
.generate-btn:hover {
    background-color: #3a7bc8 !important;
}
"""

# Gradio interface
with gr.Blocks(css=css) as iface:
    gr.HTML(
        """
        <div class="header">
            <h1>Pixtral Image Description Generator</h1>
            <p>Upload an image and get a detailed description using the powerful Pixtral-12B model.</p>
        </div>
        """
    )

    with gr.Group():
        with gr.Group(elem_classes="input-group"):
            image_input = gr.Image(type="filepath", label="Upload an image")
            detail_level = gr.Radio(["Brief", "Detailed", "Comprehensive"], label="Description Detail Level", value="Detailed")
            generate_btn = gr.Button("Generate Description", elem_classes="generate-btn")

        with gr.Group(elem_classes="output-group"):
            output = gr.Textbox(label="Generated Description", lines=10)

    generate_btn.click(generate_description, inputs=[image_input, detail_level], outputs=output)

# Launch the app
iface.launch()