import os import numpy as np import torch import skimage from PIL import Image import open_clip import gradio as gr model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k') model.eval() # model in train mode by default, impacts some models with BatchNorm or stochastic depth active tokenizer = open_clip.get_tokenizer('ViT-B-32') target_labels = ["page","chelsea","astronaut","rocket", "motorcycle_right","camera","horse","coffee", 'logo'] original_images = [] images = [] file_names = [] for filename in [filename for filename in os.listdir(skimage.data_dir) if filename.endswith(".png") or filename.endswith(".jpg")]: name = os.path.splitext(filename)[0] if name not in target_labels: continue image = Image.open(os.path.join(skimage.data_dir, filename)).convert("RGB") original_images.append(image) images.append(preprocess(image)) file_names.append(filename) image_input = torch.tensor(np.stack(images)) with torch.no_grad(), torch.cuda.amp.autocast(): image_features = model.encode_image(image_input).float() image_features /= image_features.norm(dim=-1, keepdim=True) def identify_image(input_description): if input_description is None: return None text_tokens = tokenizer([input_description]) with torch.no_grad(), torch.cuda.amp.autocast(): text_features = model.encode_text(text_tokens).float() text_features /= text_features.norm(dim=-1, keepdim=True) text_probs = (100.0 * image_features @ text_features.T) top_probs, _ = text_probs.cpu().topk(1, dim=-1) return original_images[top_probs.argmax().item()] with gr.Blocks() as demo: gr.HTML("