|
import gradio as gr |
|
import torch |
|
from PIL import Image |
|
from torchvision import transforms |
|
|
|
import numpy as np |
|
import pandas as pd |
|
import math |
|
from transformers import CLIPTextModel, CLIPTokenizer |
|
import os |
|
|
|
from clip_retrieval.clip_client import ClipClient, Modality |
|
|
|
|
|
|
|
|
|
clip_retrieval_service_url = "https://knn.laion.ai/knn-service" |
|
|
|
|
|
clip_model="ViT-L/14" |
|
clip_model_id ="laion5B-L-14" |
|
|
|
|
|
|
|
max_tabs = 10 |
|
input_images = [None for i in range(max_tabs)] |
|
input_prompts = [None for i in range(max_tabs)] |
|
embedding_plots = [None for i in range(max_tabs)] |
|
embedding_powers = [1. for i in range(max_tabs)] |
|
|
|
embedding_base64s = [None for i in range(max_tabs)] |
|
|
|
|
|
|
|
def image_to_embedding(input_im): |
|
input_im = Image.fromarray(input_im) |
|
prepro = preprocess(input_im).unsqueeze(0).to(device) |
|
with torch.no_grad(): |
|
image_embeddings = model.encode_image(prepro) |
|
image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True) |
|
image_embeddings_np = image_embeddings.cpu().to(torch.float32).detach().numpy() |
|
return image_embeddings_np |
|
|
|
def prompt_to_embedding(prompt): |
|
text = tokenizer([prompt]).to(device) |
|
with torch.no_grad(): |
|
prompt_embededdings = model.encode_text(text) |
|
prompt_embededdings /= prompt_embededdings.norm(dim=-1, keepdim=True) |
|
prompt_embededdings_np = prompt_embededdings.cpu().to(torch.float32).detach().numpy() |
|
return prompt_embededdings_np |
|
|
|
def embedding_to_image(embeddings): |
|
size = math.ceil(math.sqrt(embeddings.shape[0])) |
|
image_embeddings_square = np.pad(embeddings, (0, size**2 - embeddings.shape[0]), 'constant') |
|
image_embeddings_square.resize(size,size) |
|
embedding_image = Image.fromarray(image_embeddings_square, mode="L") |
|
return embedding_image |
|
|
|
def embedding_to_base64(embeddings): |
|
import base64 |
|
|
|
embeddings = embeddings.astype(np.float32) |
|
embeddings_b64 = base64.urlsafe_b64encode(embeddings).decode() |
|
return embeddings_b64 |
|
|
|
def base64_to_embedding(embeddings_b64): |
|
import base64 |
|
embeddings = base64.urlsafe_b64decode(embeddings_b64) |
|
embeddings = np.frombuffer(embeddings, dtype=np.float32) |
|
|
|
return embeddings |
|
|
|
def safe_url(url): |
|
import urllib.parse |
|
url = urllib.parse.quote(url, safe=':/') |
|
|
|
if url.count('.jpg') > 0: |
|
url = url.split('.jpg')[0] + '.jpg' |
|
return url |
|
|
|
def main( |
|
|
|
embeddings, |
|
n_samples=4, |
|
): |
|
|
|
embeddings = base64_to_embedding(embeddings) |
|
|
|
embeddings = embeddings.tolist() |
|
results = clip_retrieval_client.query(embedding_input=embeddings) |
|
images = [] |
|
for result in results: |
|
if len(images) >= n_samples: |
|
break |
|
url = safe_url(result["url"]) |
|
similarty = float("{:.4f}".format(result["similarity"])) |
|
title = str(similarty) + ' ' + result["caption"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
import requests |
|
from io import BytesIO |
|
try: |
|
response = requests.get(url) |
|
if not response.ok: |
|
continue |
|
bytes = BytesIO(response.content) |
|
image = Image.open(bytes) |
|
if image.mode != 'RGB': |
|
image = image.convert('RGB') |
|
images.append((image, title)) |
|
except Exception as e: |
|
print(e) |
|
return images |
|
|
|
def on_image_load_update_embeddings(image_data): |
|
|
|
if image_data is None: |
|
|
|
|
|
|
|
return gr.Text.update('') |
|
embeddings = image_to_embedding(image_data) |
|
embeddings_b64 = embedding_to_base64(embeddings) |
|
return gr.Text.update(embeddings_b64) |
|
|
|
def on_prompt_change_update_embeddings(prompt): |
|
|
|
if prompt is None or prompt == "": |
|
embeddings = prompt_to_embedding('') |
|
embeddings_b64 = embedding_to_base64(embeddings) |
|
return gr.Text.update(embedding_to_base64(embeddings)) |
|
embeddings = prompt_to_embedding(prompt) |
|
embeddings_b64 = embedding_to_base64(embeddings) |
|
return gr.Text.update(embeddings_b64) |
|
|
|
def update_average_embeddings(embedding_base64s_state, embedding_powers): |
|
final_embedding = None |
|
num_embeddings = 0 |
|
for i, embedding_base64 in enumerate(embedding_base64s_state): |
|
if embedding_base64 is None or embedding_base64 == "": |
|
continue |
|
embedding = base64_to_embedding(embedding_base64) |
|
embedding = embedding * embedding_powers[i] |
|
if final_embedding is None: |
|
final_embedding = embedding |
|
else: |
|
final_embedding = final_embedding + embedding |
|
num_embeddings += 1 |
|
if final_embedding is None: |
|
|
|
|
|
|
|
return gr.Text.update('') |
|
|
|
|
|
|
|
|
|
|
|
final_embedding /= np.linalg.norm(final_embedding) |
|
|
|
embeddings_b64 = embedding_to_base64(final_embedding) |
|
return embeddings_b64 |
|
|
|
def on_power_change_update_average_embeddings(embedding_base64s_state, embedding_power_state, power, idx): |
|
embedding_power_state[idx] = power |
|
embeddings_b64 = update_average_embeddings(embedding_base64s_state, embedding_power_state) |
|
return gr.Text.update(embeddings_b64) |
|
|
|
def on_embeddings_changed_update_average_embeddings(embedding_base64s_state, embedding_power_state, embedding_base64, idx): |
|
embedding_base64s_state[idx] = embedding_base64 if embedding_base64 != '' else None |
|
embeddings_b64 = update_average_embeddings(embedding_base64s_state, embedding_power_state) |
|
return gr.Text.update(embeddings_b64) |
|
|
|
def on_embeddings_changed_update_plot(embeddings_b64): |
|
|
|
if embeddings_b64 is None or embeddings_b64 == "": |
|
data = pd.DataFrame({ |
|
'embedding': [], |
|
'index': []}) |
|
return gr.LinePlot.update(data, |
|
x="index", |
|
y="embedding", |
|
|
|
title="Embeddings", |
|
|
|
|
|
tooltip=['index', 'embedding'], |
|
|
|
|
|
width=0) |
|
|
|
embeddings = base64_to_embedding(embeddings_b64) |
|
data = pd.DataFrame({ |
|
'embedding': embeddings, |
|
'index': [n for n in range(len(embeddings))]}) |
|
return gr.LinePlot.update(data, |
|
x="index", |
|
y="embedding", |
|
|
|
title="Embeddings", |
|
|
|
|
|
tooltip=['index', 'embedding'], |
|
|
|
|
|
width=embeddings.shape[0]) |
|
|
|
def on_example_image_click_set_image(input_image, image_url): |
|
input_image.value = image_url |
|
|
|
|
|
device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
|
|
from clip_retrieval.load_clip import load_clip, get_tokenizer |
|
|
|
model, preprocess = load_clip(clip_model, use_jit=True, device=device) |
|
tokenizer = get_tokenizer(clip_model) |
|
|
|
clip_retrieval_client = ClipClient( |
|
url=clip_retrieval_service_url, |
|
indice_name=clip_model_id, |
|
use_safety_model = False, |
|
use_violence_detector = False, |
|
) |
|
|
|
|
|
|
|
examples = [ |
|
["SohoJoeEth.jpeg", "Ray-Liotta-Goodfellas.jpg", "SohoJoeEth + Ray.jpeg"], |
|
|
|
|
|
] |
|
tile_size = 100 |
|
|
|
image_folder ="images" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tabbed_examples = { |
|
"CoCo": { |
|
"452650": "452650.jpeg", |
|
"Prompt 1": "a college dorm with a desk and bunk beds", |
|
"371739": "371739.jpeg", |
|
"Prompt 2": "a large banana is placed before a stuffed monkey.", |
|
"557922": "557922.jpeg", |
|
"Prompt 3": "a person sitting on a bench using a cell phone", |
|
"540554": "540554.jpeg", |
|
"Prompt 4": "two trains are coming down the tracks, a steam engine and a modern train.", |
|
}, |
|
"Transforms": { |
|
"ColorWheel001": "ColorWheel001.jpg", |
|
"ColorWheel001 BW": "ColorWheel001 BW.jpg", |
|
"ColorWheel002": "ColorWheel002.jpg", |
|
"ColorWheel002 BW": "ColorWheel002 BW.jpg", |
|
}, |
|
"Portraits": { |
|
"Snoop": "Snoop Dogg.jpg", |
|
"Snoop Prompt": "Snoop Dogg", |
|
"Ray": "Ray-Liotta-Goodfellas.jpg", |
|
"Ray Prompt": "Ray Liotta, Goodfellas", |
|
"Anya": "Anya Taylor-Joy 003.jpg", |
|
"Anya Prompt": "Anya Taylor-Joy, The Queen's Gambit", |
|
"Billie": "billie eilish 004.jpeg", |
|
"Billie Prompt": "Billie Eilish, blonde hair", |
|
"Lizzo": "Lizzo 001.jpeg", |
|
"Lizzo Prompt": "Lizzo,", |
|
"Donkey": "Donkey.jpg", |
|
"Donkey Prompt": "Donkey, from Shrek", |
|
}, |
|
"NFT's": { |
|
"SohoJoe": "SohoJoeEth.jpeg", |
|
"SohoJoe Prompt": "SohoJoe.Eth", |
|
"Mirai": "Mirai.jpg", |
|
"Mirai Prompt": "Mirai from White Rabbit, @shibuyaxyz", |
|
"OnChainMonkey": "OnChainMonkey-2278.jpg", |
|
"OCM Prompt": "On Chain Monkey", |
|
"Wassie": "Wassie 4498.jpeg", |
|
"Wassie Prompt": "Wassie by Wassies", |
|
}, |
|
"Pups": { |
|
"Pup1": "pup1.jpg", |
|
"Prompt": "Teacup Yorkies", |
|
"Pup2": "pup2.jpg", |
|
"Pup3": "pup3.jpg", |
|
"Pup4": "pup4.jpeg", |
|
"Pup5": "pup5.jpg", |
|
}, |
|
} |
|
|
|
|
|
image_examples_tile_size = 50 |
|
|
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
with gr.Column(scale=5): |
|
gr.Markdown( |
|
""" |
|
# Soho-Clip Embeddings Explorer |
|
|
|
A tool for exploring CLIP embedding space. |
|
|
|
Try uploading a few images and/or add some text prompts and click generate images. |
|
""") |
|
with gr.Column(scale=2, min_width=(tile_size+20)*3): |
|
with gr.Row(): |
|
with gr.Column(scale=1, min_width=tile_size): |
|
gr.Markdown("## Input 1") |
|
with gr.Column(scale=1, min_width=tile_size): |
|
gr.Markdown("## Input 2") |
|
with gr.Column(scale=1, min_width=tile_size): |
|
gr.Markdown("## Generates:") |
|
for example in examples: |
|
with gr.Row(): |
|
for example in example: |
|
with gr.Column(scale=1, min_width=tile_size): |
|
local_path = os.path.join(image_folder, example) |
|
gr.Image( |
|
value = local_path, shape=(tile_size,tile_size), |
|
show_label=False, interactive=False) \ |
|
.style(height=tile_size, width=tile_size) |
|
|
|
with gr.Row(): |
|
for i in range(max_tabs): |
|
with gr.Tab(f"Input {i+1}"): |
|
with gr.Row(): |
|
with gr.Column(scale=1, min_width=240): |
|
input_images[i] = gr.Image(label="Image Prompt", show_label=True) |
|
with gr.Column(scale=3, min_width=600): |
|
embedding_plots[i] = gr.LinePlot(show_label=False).style(container=False) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2, min_width=240): |
|
input_prompts[i] = gr.Textbox(label="Text Prompt", show_label=True) |
|
with gr.Column(scale=3, min_width=600): |
|
with gr.Row(): |
|
|
|
|
|
embedding_powers[i] = gr.Slider(minimum=-3, maximum=3, value=1, label="Power", show_label=True, interactive=True) |
|
with gr.Row(): |
|
with gr.Accordion(f"Embeddings (base64)", open=False): |
|
embedding_base64s[i] = gr.Textbox(show_label=False) |
|
for idx, (tab_title, examples) in enumerate(tabbed_examples.items()): |
|
with gr.Tab(tab_title): |
|
with gr.Row(): |
|
for idx, (title, example) in enumerate(examples.items()): |
|
if example.endswith(".jpg") or example.endswith(".jpeg"): |
|
|
|
local_path = os.path.join(image_folder, example) |
|
with gr.Column(scale=1, min_width=image_examples_tile_size): |
|
gr.Examples( |
|
examples=[local_path], |
|
inputs=input_images[i], |
|
label=title, |
|
) |
|
else: |
|
|
|
with gr.Column(scale=1, min_width=image_examples_tile_size*2): |
|
gr.Examples( |
|
examples=[example], |
|
inputs=input_prompts[i], |
|
label=title, |
|
) |
|
|
|
with gr.Row(): |
|
average_embedding_plot = gr.LinePlot(show_label=True, label="Average Embeddings (base64)").style(container=False) |
|
with gr.Row(): |
|
with gr.Accordion(f"Avergage embeddings in base 64", open=False): |
|
average_embedding_base64 = gr.Textbox(show_label=False) |
|
with gr.Row(): |
|
with gr.Column(scale=1, min_width=200): |
|
n_samples = gr.Slider(1, 16, value=4, step=1, label="Number images") |
|
with gr.Column(scale=3, min_width=200): |
|
submit = gr.Button("Search embedding space") |
|
with gr.Row(): |
|
output = gr.Gallery(label="Closest images in Laion 5b using kNN", show_label=True) |
|
|
|
embedding_base64s_state = gr.State(value=[None for i in range(max_tabs)]) |
|
embedding_power_state = gr.State(value=[1. for i in range(max_tabs)]) |
|
for i in range(max_tabs): |
|
input_images[i].change(on_image_load_update_embeddings, input_images[i], [embedding_base64s[i]]) |
|
input_prompts[i].change(on_prompt_change_update_embeddings, input_prompts[i], [embedding_base64s[i]]) |
|
embedding_base64s[i].change(on_embeddings_changed_update_plot, embedding_base64s[i], [embedding_plots[i]]) |
|
idx_state = gr.State(value=i) |
|
embedding_base64s[i].change(on_embeddings_changed_update_average_embeddings, [embedding_base64s_state, embedding_power_state, embedding_base64s[i], idx_state], average_embedding_base64) |
|
embedding_powers[i].change(on_power_change_update_average_embeddings, [embedding_base64s_state, embedding_power_state, embedding_powers[i], idx_state], average_embedding_base64) |
|
|
|
average_embedding_base64.change(on_embeddings_changed_update_plot, average_embedding_base64, average_embedding_plot) |
|
|
|
|
|
submit.click(main, inputs= [average_embedding_base64, n_samples], outputs=output) |
|
output.style(grid=[4], height="auto") |
|
|
|
with gr.Row(): |
|
gr.Markdown( |
|
""" |
|
My interest is to use CLIP for image/video understanding (see [CLIP_visual-spatial-reasoning](https://github.com/Sohojoe/CLIP_visual-spatial-reasoning).) |
|
|
|
|
|
### Initial Features |
|
|
|
- Combine up to 10 Images and/or text inputs to create an average embedding space. |
|
- Search the laion 5b images via a kNN search |
|
|
|
### Known limitations |
|
|
|
- ... |
|
|
|
### Acknowledgements |
|
|
|
- I heavily build on [clip-retrieval](https://rom1504.github.io/clip-retrieval/) and use their API. Please [cite](https://github.com/rom1504/clip-retrieval#citation) the authors if you use this work. |
|
- [CLIP](https://openai.com/blog/clip/) |
|
- [Stable Diffusion](https://github.com/CompVis/stable-diffusion) |
|
|
|
""") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |