import os import numpy as np import torch import skimage from PIL import Image import open_clip import gradio as gr model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k') model.eval() # model in train mode by default, impacts some models with BatchNorm or stochastic depth active tokenizer = open_clip.get_tokenizer('ViT-B-32') target_labels = ["page","chelsea","astronaut","rocket", "motorcycle_right","camera","horse","coffee", 'logo'] original_images = [] images = [] file_names = [] for filename in [filename for filename in os.listdir(skimage.data_dir) if filename.endswith(".png") or filename.endswith(".jpg")]: name = os.path.splitext(filename)[0] if name not in target_labels: continue image = Image.open(os.path.join(skimage.data_dir, filename)).convert("RGB") original_images.append(image) images.append(preprocess(image)) file_names.append(filename) image_input = torch.tensor(np.stack(images)) with torch.no_grad(), torch.cuda.amp.autocast(): image_features = model.encode_image(image_input).float() image_features /= image_features.norm(dim=-1, keepdim=True) def identify_image(input_description): if input_description is None: return None text_tokens = tokenizer([input_description]) with torch.no_grad(), torch.cuda.amp.autocast(): text_features = model.encode_text(text_tokens).float() text_features /= text_features.norm(dim=-1, keepdim=True) text_probs = (100.0 * image_features @ text_features.T) top_probs, _ = text_probs.cpu().topk(1, dim=-1) return original_images[top_probs.argmax().item()] with gr.Blocks() as demo: gr.HTML("

Image Search

") gr.HTML("

Identify the most suitable image for description provided.

") with gr.Row(): with gr.Column(scale=2): gr.Gallery(value = original_images, label="Images to search from", show_label=True, elem_id="gallery" , columns=[3], rows=[3], object_fit="contain", height="auto") with gr.Column(scale=2): content = gr.Textbox(label = "Enter search text here") inputs = [ content, ] gr.Examples(["Page of text about segmentation", "Facial photo of a tabby cat", "Portrait of an astronaut with the American flag", "Rocket standing on a launchpad", "Red motorcycle standing in a garage", "Person looking at a camera on a tripod", "Black-and-white silhouette of a horse", "Cup of coffee on a saucer", "A snake in the background"], inputs = inputs) generate_btn = gr.Button(value = 'Identify') outputs = [gr.Image(label = "Is this the image you are referring to?", height = 512, width = 512)] generate_btn.click(fn = identify_image, inputs= inputs, outputs = outputs) ## for collab # demo.launch(debug=True) if __name__ == '__main__': demo.launch()