import gradio as gr import torch from PIL import Image from torchvision import transforms # from diffusers import StableDiffusionPipeline, StableDiffusionImageVariationPipeline, DiffusionPipeline import numpy as np import pandas as pd import math from transformers import CLIPTextModel, CLIPTokenizer import os from clip_retrieval.clip_client import ClipClient, Modality # clip_model_id = "openai/clip-vit-large-patch14-336" # clip_retrieval_indice_name, clip_model_id ="laion5B-L-14", "/laion/CLIP-ViT-L-14-laion2B-s32B-b82K" clip_retrieval_service_url = "https://knn.laion.ai/knn-service" # available models = ['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'] # clip_model="ViT-B/32" clip_model="ViT-L/14" clip_model_id ="laion5B-L-14" max_tabs = 10 input_images = [None for i in range(max_tabs)] input_prompts = [None for i in range(max_tabs)] embedding_plots = [None for i in range(max_tabs)] embedding_powers = [1. for i in range(max_tabs)] # global embedding_base64s embedding_base64s = [None for i in range(max_tabs)] # embedding_base64s = gr.State(value=[None for i in range(max_tabs)]) def image_to_embedding(input_im): input_im = Image.fromarray(input_im) prepro = preprocess(input_im).unsqueeze(0).to(device) with torch.no_grad(): image_embeddings = model.encode_image(prepro) image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True) image_embeddings_np = image_embeddings.cpu().to(torch.float32).detach().numpy() return image_embeddings_np def prompt_to_embedding(prompt): text = tokenizer([prompt]).to(device) with torch.no_grad(): prompt_embededdings = model.encode_text(text) prompt_embededdings /= prompt_embededdings.norm(dim=-1, keepdim=True) prompt_embededdings_np = prompt_embededdings.cpu().to(torch.float32).detach().numpy() return prompt_embededdings_np def embedding_to_image(embeddings): size = math.ceil(math.sqrt(embeddings.shape[0])) image_embeddings_square = np.pad(embeddings, (0, size**2 - embeddings.shape[0]), 'constant') image_embeddings_square.resize(size,size) embedding_image = Image.fromarray(image_embeddings_square, mode="L") return embedding_image def embedding_to_base64(embeddings): import base64 # ensure float32 embeddings = embeddings.astype(np.float32) embeddings_b64 = base64.urlsafe_b64encode(embeddings).decode() return embeddings_b64 def base64_to_embedding(embeddings_b64): import base64 embeddings = base64.urlsafe_b64decode(embeddings_b64) embeddings = np.frombuffer(embeddings, dtype=np.float32) # embeddings = torch.tensor(embeddings) return embeddings def safe_url(url): import urllib.parse url = urllib.parse.quote(url, safe=':/') # if url has two .jpg filenames, take the first one if url.count('.jpg') > 0: url = url.split('.jpg')[0] + '.jpg' return url def main( # input_im, embeddings, n_samples=4, ): embeddings = base64_to_embedding(embeddings) # convert to python array embeddings = embeddings.tolist() results = clip_retrieval_client.query(embedding_input=embeddings) images = [] for result in results: if len(images) >= n_samples: break url = safe_url(result["url"]) similarty = float("{:.4f}".format(result["similarity"])) title = str(similarty) + ' ' + result["caption"] # we could just return the url and the control would take care of the rest # however, if the url returns an error, the page crashes. # images.append((url, title)) # continue # dowload image import requests from io import BytesIO try: response = requests.get(url) if not response.ok: continue bytes = BytesIO(response.content) image = Image.open(bytes) if image.mode != 'RGB': image = image.convert('RGB') images.append((image, title)) except Exception as e: print(e) return images def on_image_load_update_embeddings(image_data): # image to embeddings if image_data is None: # embeddings = prompt_to_embedding('') # embeddings_b64 = embedding_to_base64(embeddings) # return gr.Text.update(embeddings_b64) return gr.Text.update('') embeddings = image_to_embedding(image_data) embeddings_b64 = embedding_to_base64(embeddings) return gr.Text.update(embeddings_b64) def on_prompt_change_update_embeddings(prompt): # prompt to embeddings if prompt is None or prompt == "": embeddings = prompt_to_embedding('') embeddings_b64 = embedding_to_base64(embeddings) return gr.Text.update(embedding_to_base64(embeddings)) embeddings = prompt_to_embedding(prompt) embeddings_b64 = embedding_to_base64(embeddings) return gr.Text.update(embeddings_b64) def update_average_embeddings(embedding_base64s_state, embedding_powers): final_embedding = None num_embeddings = 0 for i, embedding_base64 in enumerate(embedding_base64s_state): if embedding_base64 is None or embedding_base64 == "": continue embedding = base64_to_embedding(embedding_base64) embedding = embedding * embedding_powers[i] if final_embedding is None: final_embedding = embedding else: final_embedding = final_embedding + embedding num_embeddings += 1 if final_embedding is None: # embeddings = prompt_to_embedding('') # embeddings_b64 = embedding_to_base64(embeddings) # return gr.Text.update(embeddings_b64) return gr.Text.update('') # TODO toggle this to support average or sum # final_embedding = final_embedding / num_embeddings # normalize embeddings in numpy final_embedding /= np.linalg.norm(final_embedding) embeddings_b64 = embedding_to_base64(final_embedding) return embeddings_b64 def on_power_change_update_average_embeddings(embedding_base64s_state, embedding_power_state, power, idx): embedding_power_state[idx] = power embeddings_b64 = update_average_embeddings(embedding_base64s_state, embedding_power_state) return gr.Text.update(embeddings_b64) def on_embeddings_changed_update_average_embeddings(embedding_base64s_state, embedding_power_state, embedding_base64, idx): embedding_base64s_state[idx] = embedding_base64 if embedding_base64 != '' else None embeddings_b64 = update_average_embeddings(embedding_base64s_state, embedding_power_state) return gr.Text.update(embeddings_b64) def on_embeddings_changed_update_plot(embeddings_b64): # plot new embeddings if embeddings_b64 is None or embeddings_b64 == "": data = pd.DataFrame({ 'embedding': [], 'index': []}) return gr.LinePlot.update(data, x="index", y="embedding", # color="country", title="Embeddings", # stroke_dash="cluster", # x_lim=[1950, 2010], tooltip=['index', 'embedding'], # stroke_dash_legend_title="Country Cluster", # height=300, width=0) embeddings = base64_to_embedding(embeddings_b64) data = pd.DataFrame({ 'embedding': embeddings, 'index': [n for n in range(len(embeddings))]}) return gr.LinePlot.update(data, x="index", y="embedding", # color="country", title="Embeddings", # stroke_dash="cluster", # x_lim=[1950, 2010], tooltip=['index', 'embedding'], # stroke_dash_legend_title="Country Cluster", # height=300, width=embeddings.shape[0]) def on_example_image_click_set_image(input_image, image_url): input_image.value = image_url # device = torch.device("mps" if torch.backends.mps.is_available() else "cuda:0" if torch.cuda.is_available() else "cpu") device = "cuda:0" if torch.cuda.is_available() else "cpu" from clip_retrieval.load_clip import load_clip, get_tokenizer # model, preprocess = load_clip(clip_model, use_jit=True, device=device) model, preprocess = load_clip(clip_model, use_jit=True, device=device) tokenizer = get_tokenizer(clip_model) clip_retrieval_client = ClipClient( url=clip_retrieval_service_url, indice_name=clip_model_id, use_safety_model = False, use_violence_detector = False, ) # results = clip_retrieval_client.query(text="an image of a cat") # results[0] examples = [ ["SohoJoeEth.jpeg", "Ray-Liotta-Goodfellas.jpg", "SohoJoeEth + Ray.jpeg"], # ["SohoJoeEth.jpeg", "Donkey.jpg", "SohoJoeEth + Donkey.jpeg"], # ["SohoJoeEth.jpeg", "Snoop Dogg.jpg", "SohoJoeEth + Snoop Dogg.jpeg"], ] tile_size = 100 # image_folder = os.path.join("file", "images") image_folder ="images" # image_examples = { # "452650": "452650.jpeg", # "Prompt 1": "a college dorm with a desk and bunk beds", # "371739": "371739.jpeg", # "Prompt 2": "a large banana is placed before a stuffed monkey.", # "557922": "557922.jpeg", # "Prompt 3": "a person sitting on a bench using a cell phone", # } tabbed_examples = { "CoCo": { "452650": "452650.jpeg", "Prompt 1": "a college dorm with a desk and bunk beds", "371739": "371739.jpeg", "Prompt 2": "a large banana is placed before a stuffed monkey.", "557922": "557922.jpeg", "Prompt 3": "a person sitting on a bench using a cell phone", "540554": "540554.jpeg", "Prompt 4": "two trains are coming down the tracks, a steam engine and a modern train.", }, "Transforms": { "ColorWheel001": "ColorWheel001.jpg", "ColorWheel001 BW": "ColorWheel001 BW.jpg", "ColorWheel002": "ColorWheel002.jpg", "ColorWheel002 BW": "ColorWheel002 BW.jpg", }, "Portraits": { "Snoop": "Snoop Dogg.jpg", "Snoop Prompt": "Snoop Dogg", "Ray": "Ray-Liotta-Goodfellas.jpg", "Ray Prompt": "Ray Liotta, Goodfellas", "Anya": "Anya Taylor-Joy 003.jpg", "Anya Prompt": "Anya Taylor-Joy, The Queen's Gambit", "Billie": "billie eilish 004.jpeg", "Billie Prompt": "Billie Eilish, blonde hair", "Lizzo": "Lizzo 001.jpeg", "Lizzo Prompt": "Lizzo,", "Donkey": "Donkey.jpg", "Donkey Prompt": "Donkey, from Shrek", }, "NFT's": { "SohoJoe": "SohoJoeEth.jpeg", "SohoJoe Prompt": "SohoJoe.Eth", "Mirai": "Mirai.jpg", "Mirai Prompt": "Mirai from White Rabbit, @shibuyaxyz", "OnChainMonkey": "OnChainMonkey-2278.jpg", "OCM Prompt": "On Chain Monkey", "Wassie": "Wassie 4498.jpeg", "Wassie Prompt": "Wassie by Wassies", }, "Pups": { "Pup1": "pup1.jpg", "Prompt": "Teacup Yorkies", "Pup2": "pup2.jpg", "Pup3": "pup3.jpg", "Pup4": "pup4.jpeg", "Pup5": "pup5.jpg", }, } image_examples_tile_size = 50 with gr.Blocks() as demo: with gr.Row(): with gr.Column(scale=5): gr.Markdown( """ # Soho-Clip Embeddings Explorer A tool for exploring CLIP embedding space. Try uploading a few images and/or add some text prompts and click generate images. """) with gr.Column(scale=2, min_width=(tile_size+20)*3): with gr.Row(): with gr.Column(scale=1, min_width=tile_size): gr.Markdown("## Input 1") with gr.Column(scale=1, min_width=tile_size): gr.Markdown("## Input 2") with gr.Column(scale=1, min_width=tile_size): gr.Markdown("## Generates:") for example in examples: with gr.Row(): for example in example: with gr.Column(scale=1, min_width=tile_size): local_path = os.path.join(image_folder, example) gr.Image( value = local_path, shape=(tile_size,tile_size), show_label=False, interactive=False) \ .style(height=tile_size, width=tile_size) with gr.Row(): for i in range(max_tabs): with gr.Tab(f"Input {i+1}"): with gr.Row(): with gr.Column(scale=1, min_width=240): input_images[i] = gr.Image(label="Image Prompt", show_label=True) with gr.Column(scale=3, min_width=600): embedding_plots[i] = gr.LinePlot(show_label=False).style(container=False) # input_image.change(on_image_load, inputs= [input_image, plot]) with gr.Row(): with gr.Column(scale=2, min_width=240): input_prompts[i] = gr.Textbox(label="Text Prompt", show_label=True) with gr.Column(scale=3, min_width=600): with gr.Row(): # with gr.Slider(min=-5, max=5, value=1, label="Power", show_label=True): # embedding_powers[i] = gr.Slider.value embedding_powers[i] = gr.Slider(minimum=-3, maximum=3, value=1, label="Power", show_label=True, interactive=True) with gr.Row(): with gr.Accordion(f"Embeddings (base64)", open=False): embedding_base64s[i] = gr.Textbox(show_label=False) for idx, (tab_title, examples) in enumerate(tabbed_examples.items()): with gr.Tab(tab_title): with gr.Row(): for idx, (title, example) in enumerate(examples.items()): if example.endswith(".jpg") or example.endswith(".jpeg"): # add image example local_path = os.path.join(image_folder, example) with gr.Column(scale=1, min_width=image_examples_tile_size): gr.Examples( examples=[local_path], inputs=input_images[i], label=title, ) else: # add text example with gr.Column(scale=1, min_width=image_examples_tile_size*2): gr.Examples( examples=[example], inputs=input_prompts[i], label=title, ) with gr.Row(): average_embedding_plot = gr.LinePlot(show_label=True, label="Average Embeddings (base64)").style(container=False) with gr.Row(): with gr.Accordion(f"Avergage embeddings in base 64", open=False): average_embedding_base64 = gr.Textbox(show_label=False) with gr.Row(): with gr.Column(scale=1, min_width=200): n_samples = gr.Slider(1, 16, value=4, step=1, label="Number images") with gr.Column(scale=3, min_width=200): submit = gr.Button("Search embedding space") with gr.Row(): output = gr.Gallery(label="Closest images in Laion 5b using kNN", show_label=True) embedding_base64s_state = gr.State(value=[None for i in range(max_tabs)]) embedding_power_state = gr.State(value=[1. for i in range(max_tabs)]) for i in range(max_tabs): input_images[i].change(on_image_load_update_embeddings, input_images[i], [embedding_base64s[i]]) input_prompts[i].change(on_prompt_change_update_embeddings, input_prompts[i], [embedding_base64s[i]]) embedding_base64s[i].change(on_embeddings_changed_update_plot, embedding_base64s[i], [embedding_plots[i]]) idx_state = gr.State(value=i) embedding_base64s[i].change(on_embeddings_changed_update_average_embeddings, [embedding_base64s_state, embedding_power_state, embedding_base64s[i], idx_state], average_embedding_base64) embedding_powers[i].change(on_power_change_update_average_embeddings, [embedding_base64s_state, embedding_power_state, embedding_powers[i], idx_state], average_embedding_base64) average_embedding_base64.change(on_embeddings_changed_update_plot, average_embedding_base64, average_embedding_plot) # submit.click(main, inputs= [embedding_base64s[0], scale, n_samples, steps, seed], outputs=output) submit.click(main, inputs= [average_embedding_base64, n_samples], outputs=output) output.style(grid=[4], height="auto") with gr.Row(): gr.Markdown( """ My interest is to use CLIP for image/video understanding (see [CLIP_visual-spatial-reasoning](https://github.com/Sohojoe/CLIP_visual-spatial-reasoning).) ### Initial Features - Combine up to 10 Images and/or text inputs to create an average embedding space. - Search the laion 5b images via a kNN search ### Known limitations - ... ### Acknowledgements - I heavily build on [clip-retrieval](https://rom1504.github.io/clip-retrieval/) and use their API. Please [cite](https://github.com/rom1504/clip-retrieval#citation) the authors if you use this work. - [CLIP](https://openai.com/blog/clip/) - [Stable Diffusion](https://github.com/CompVis/stable-diffusion) """) # ![Alt Text](file/pup1.jpg) # # ![Alt Text](file/pup1.jpg){height=100 width=100} if __name__ == "__main__": demo.launch()