Spaces:
Runtime error
Runtime error
import os | |
import gradio as gr | |
import torch | |
from transformers import Blip2Processor, Blip2ForConditionalGeneration | |
from PIL import Image # PIL should be imported separately for image handling | |
EXAMPLES_DIR = 'examples' | |
DEFAULT_PROMPT = "<image>" | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# Load the BLIP2 model using the AutoModel with trust_remote_code=True | |
model = Blip2ForConditionalGeneration.from_pretrained('Salesforce/blip2-flan-t5-xl', device_map="auto", torch_dtype=torch.float16) | |
model.to(device) | |
model.eval() | |
# Initialize processor | |
processor = Blip2Processor.from_pretrained('Salesforce/blip2-flan-t5-xl') | |
# Setup some example images | |
examples = [] | |
if os.path.isdir(EXAMPLES_DIR): | |
for file in os.listdir(EXAMPLES_DIR): | |
path = EXAMPLES_DIR + "/" + file | |
examples.append([path, DEFAULT_PROMPT]) | |
def predict_caption(image, prompt): | |
assert isinstance(prompt, str) | |
# Convert the PIL image to the format expected by the processor | |
inputs = processor(images=image, text=prompt, return_tensors="pt").to(device) | |
# Generate the caption | |
generated_ids = model.generate(**inputs, max_length=50) | |
caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
return caption | |
iface = gr.Interface( | |
fn=predict_caption, | |
inputs=[gr.Image(type="pil"), gr.Textbox(value=DEFAULT_PROMPT, label="Prompt")], | |
examples=examples, | |
outputs="text" | |
) | |
iface.launch(debug=True) | |