File size: 3,919 Bytes
ab097bb 8a25618 ab097bb 8a25618 ab097bb 8a25618 ab097bb 8a25618 ab097bb 8a25618 ab097bb 8a25618 ab097bb 8a25618 ab097bb 8a25618 ab097bb 8a25618 ab097bb 8a25618 ab097bb 8a25618 ab097bb 8a25618 ab097bb 8a25618 ab097bb 8a25618 ab097bb 8a25618 ab097bb 8a25618 ab097bb 8a25618 ab097bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import gradio as gr
import spaces
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from PIL import Image
# Load model and tokenizer
model_name = "mistral-community/pixtral-12b-240910"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto"
)
@spaces.GPU(duration=120)
def generate_description(image, detail_level):
if image is None:
return "Please upload an image to generate a description."
image = Image.open(image).convert("RGB")
detail_prompts = {
"Brief": "Provide a brief description of this image in 2-3 sentences.",
"Detailed": "Describe this image in detail, including main subjects, colors, composition, and any notable elements.",
"Comprehensive": "Provide a comprehensive analysis of this image, including subjects, colors, composition, mood, potential symbolism, and any other relevant details you can observe."
}
prompt = detail_prompts[detail_level]
messages = [
{"role": "system", "content": "You are a highly observant AI assistant specialized in describing images accurately and in detail."},
{"role": "user", "content": prompt}
]
formatted_prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
inputs = tokenizer(formatted_prompt, images=[image], return_tensors="pt", padding=True).to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=300,
do_sample=True,
temperature=0.7,
top_k=50,
top_p=0.95,
)
description = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
return description.strip()
# Custom CSS
css = """
body {
background-color: #f0f0f5;
font-family: 'Arial', sans-serif;
}
.container {
max-width: 900px;
margin: auto;
padding: 20px;
}
.gradio-container {
background-color: white;
border-radius: 15px;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
}
.header {
background-color: #4a90e2;
color: white;
padding: 20px;
border-radius: 15px 15px 0 0;
text-align: center;
margin-bottom: 20px;
}
.header h1 {
font-size: 2.5em;
margin-bottom: 10px;
}
.input-group, .output-group {
background-color: #f9f9f9;
padding: 20px;
border-radius: 10px;
margin-bottom: 20px;
}
.input-group label, .output-group label {
color: #4a90e2;
font-weight: bold;
}
.generate-btn {
background-color: #4a90e2 !important;
color: white !important;
border: none !important;
border-radius: 5px !important;
padding: 10px 20px !important;
font-size: 16px !important;
cursor: pointer !important;
transition: background-color 0.3s ease !important;
}
.generate-btn:hover {
background-color: #3a7bc8 !important;
}
"""
# Gradio interface
with gr.Blocks(css=css) as iface:
gr.HTML(
"""
<div class="header">
<h1>Pixtral Image Description Generator</h1>
<p>Upload an image and get a detailed description using the powerful Pixtral-12B model.</p>
</div>
"""
)
with gr.Group():
with gr.Group(elem_classes="input-group"):
image_input = gr.Image(type="filepath", label="Upload an image")
detail_level = gr.Radio(["Brief", "Detailed", "Comprehensive"], label="Description Detail Level", value="Detailed")
generate_btn = gr.Button("Generate Description", elem_classes="generate-btn")
with gr.Group(elem_classes="output-group"):
output = gr.Textbox(label="Generated Description", lines=10)
generate_btn.click(generate_description, inputs=[image_input, detail_level], outputs=output)
# Launch the app
iface.launch() |