MetaCLIP / app.py
SkalskiP's picture
debug
fcb4afd
from typing import List
import gradio as gr
import numpy as np
import torch
from transformers import CLIPProcessor, CLIPModel
IMAGENET_CLASSES_FILE = "imagenet-classes.txt"
EXAMPLES = ["dog.jpeg", "car.png"]
MARKDOWN = """
# Zero-Shot Image Classification with MetaCLIP
This is the demo for a zero-shot image classification model based on
[MetaCLIP](https://github.com/facebookresearch/MetaCLIP), described in the paper
[Demystifying CLIP Data](https://arxiv.org/abs/2309.16671) that formalizes CLIP data
curation as a simple algorithm.
"""
def load_text_lines(file_path: str) -> List[str]:
with open(file_path, 'r') as file:
lines = file.readlines()
return [line.rstrip() for line in lines]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CLIPModel.from_pretrained("facebook/metaclip-b32-400m").to(device)
processor = CLIPProcessor.from_pretrained("facebook/metaclip-b32-400m")
imagenet_classes = load_text_lines(IMAGENET_CLASSES_FILE)
def classify_image(input_image) -> str:
inputs = processor(
text=imagenet_classes,
images=input_image,
return_tensors="pt",
padding=True).to(device)
outputs = model(**inputs)
probs = outputs.logits_per_image.softmax(dim=1)
class_index = np.argmax(probs.detach().cpu().numpy())
return imagenet_classes[class_index]
with gr.Blocks() as demo:
gr.Markdown(MARKDOWN)
with gr.Row():
image = gr.Image(image_mode='RGB', type='pil')
output_text = gr.Textbox(label="Output")
submit_button = gr.Button("Submit")
submit_button.click(classify_image, inputs=[image], outputs=output_text)
gr.Examples(
examples=EXAMPLES,
fn=classify_image,
inputs=[image],
outputs=[output_text],
cache_examples=True,
run_on_click=True
)
demo.launch(debug=False)