Spaces:
Runtime error
Runtime error
File size: 1,639 Bytes
2217d55 6f32cf9 04bef66 2217d55 6f32cf9 04bef66 6f32cf9 04bef66 2217d55 5c93d70 04bef66 5c93d70 04bef66 5c93d70 04bef66 2217d55 04bef66 2217d55 5c93d70 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
import gradio as gr
import numpy as np
from PIL import Image
from pathlib import Path
import torch
from transformers import CLIPProcessor, CLIPModel
MODEL_NAME = "facebook/metaclip-b32-400m"
cache_path = Path('/app/cache')
if not cache_path.exists():
cache_path = None
def get_clip_model_and_processor(model_name: str, cache_path: Path = None):
device = "cuda" if torch.cuda.is_available() else "cpu"
if cache_path:
model = CLIPModel.from_pretrained(model_name, cache_dir=str(cache_path)).to(device)
processor = CLIPProcessor.from_pretrained(model_name, cache_dir=str(cache_path))
else:
model = CLIPModel.from_pretrained(model_name).to(device)
processor = CLIPProcessor.from_pretrained(model_name)
return model.eval(), processor
def image_to_embedding(img: np.ndarray = None, txt: str = None) -> np.ndarray:
if img is None and not txt:
return []
if img is not None:
embedding = CLIP_MODEL.get_image_features(
**CLIP_PROCESSOR(images=[Image.fromarray(img)], return_tensors="pt", padding=True).to(
CLIP_MODEL.device
)
)
else:
embedding = CLIP_MODEL.get_text_features(
**CLIP_PROCESSOR(text=[txt], return_tensors="pt", padding=True).to(
CLIP_MODEL.device
)
)
return embedding.detach().cpu().numpy()
CLIP_MODEL, CLIP_PROCESSOR = get_clip_model_and_processor(MODEL_NAME, cache_path=cache_path)
demo = gr.Interface(fn=image_to_embedding, inputs=["image", "textbox"], outputs="textbox", cache_examples=True)
demo.launch(server_name="0.0.0.0")
|