Pixtral / app.py
sagar007's picture
Update app.py
8a25618 verified
import gradio as gr
import spaces
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from PIL import Image
# Load model and tokenizer
model_name = "mistral-community/pixtral-12b-240910"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto"
)
@spaces.GPU(duration=120)
def generate_description(image, detail_level):
if image is None:
return "Please upload an image to generate a description."
image = Image.open(image).convert("RGB")
detail_prompts = {
"Brief": "Provide a brief description of this image in 2-3 sentences.",
"Detailed": "Describe this image in detail, including main subjects, colors, composition, and any notable elements.",
"Comprehensive": "Provide a comprehensive analysis of this image, including subjects, colors, composition, mood, potential symbolism, and any other relevant details you can observe."
}
prompt = detail_prompts[detail_level]
messages = [
{"role": "system", "content": "You are a highly observant AI assistant specialized in describing images accurately and in detail."},
{"role": "user", "content": prompt}
]
formatted_prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
inputs = tokenizer(formatted_prompt, images=[image], return_tensors="pt", padding=True).to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=300,
do_sample=True,
temperature=0.7,
top_k=50,
top_p=0.95,
)
description = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
return description.strip()
# Custom CSS
css = """
body {
background-color: #f0f0f5;
font-family: 'Arial', sans-serif;
}
.container {
max-width: 900px;
margin: auto;
padding: 20px;
}
.gradio-container {
background-color: white;
border-radius: 15px;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
}
.header {
background-color: #4a90e2;
color: white;
padding: 20px;
border-radius: 15px 15px 0 0;
text-align: center;
margin-bottom: 20px;
}
.header h1 {
font-size: 2.5em;
margin-bottom: 10px;
}
.input-group, .output-group {
background-color: #f9f9f9;
padding: 20px;
border-radius: 10px;
margin-bottom: 20px;
}
.input-group label, .output-group label {
color: #4a90e2;
font-weight: bold;
}
.generate-btn {
background-color: #4a90e2 !important;
color: white !important;
border: none !important;
border-radius: 5px !important;
padding: 10px 20px !important;
font-size: 16px !important;
cursor: pointer !important;
transition: background-color 0.3s ease !important;
}
.generate-btn:hover {
background-color: #3a7bc8 !important;
}
"""
# Gradio interface
with gr.Blocks(css=css) as iface:
gr.HTML(
"""
<div class="header">
<h1>Pixtral Image Description Generator</h1>
<p>Upload an image and get a detailed description using the powerful Pixtral-12B model.</p>
</div>
"""
)
with gr.Group():
with gr.Group(elem_classes="input-group"):
image_input = gr.Image(type="filepath", label="Upload an image")
detail_level = gr.Radio(["Brief", "Detailed", "Comprehensive"], label="Description Detail Level", value="Detailed")
generate_btn = gr.Button("Generate Description", elem_classes="generate-btn")
with gr.Group(elem_classes="output-group"):
output = gr.Textbox(label="Generated Description", lines=10)
generate_btn.click(generate_description, inputs=[image_input, detail_level], outputs=output)
# Launch the app
iface.launch()