sohojoe's picture
Update app.py
fca5b2f
raw
history blame
30.8 kB
import gradio as gr
import torch
from PIL import Image
# from torchvision import transforms
# from diffusers import StableDiffusionPipeline, StableDiffusionImageVariationPipeline, DiffusionPipeline
import numpy as np
import pandas as pd
import math
# from transformers import CLIPTextModel, CLIPTokenizer
import os
# clip_model_id = "openai/clip-vit-large-patch14-336"
# clip_retrieval_indice_name, clip_model_id ="laion5B-L-14", "/laion/CLIP-ViT-L-14-laion2B-s32B-b82K"
clip_retrieval_service_url = "https://knn.laion.ai/knn-service"
# available models = ['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px']
# clip_model="ViT-B/32"
clip_model="ViT-L/14"
clip_model_id ="laion5B-L-14"
max_tabs = 10
input_images = [None for i in range(max_tabs)]
input_prompts = [None for i in range(max_tabs)]
embedding_plots = [None for i in range(max_tabs)]
embedding_powers = [1. for i in range(max_tabs)]
# global embedding_base64s
embedding_base64s = [None for i in range(max_tabs)]
# embedding_base64s = gr.State(value=[None for i in range(max_tabs)])
debug_print_on = False
def debug_print(*args, **kwargs):
if debug_print_on:
print(*args, **kwargs)
def image_to_embedding(input_im):
# debug_print("image_to_embedding")
input_im = Image.fromarray(input_im)
prepro = preprocess(input_im).unsqueeze(0).to(device)
with torch.no_grad():
image_embeddings = model.encode_image(prepro)
image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True)
image_embeddings_np = image_embeddings.cpu().to(torch.float32).detach().numpy()
return image_embeddings_np
def prompt_to_embedding(prompt):
# debug_print("prompt_to_embedding")
text = tokenizer([prompt]).to(device)
with torch.no_grad():
prompt_embededdings = model.encode_text(text)
prompt_embededdings /= prompt_embededdings.norm(dim=-1, keepdim=True)
prompt_embededdings_np = prompt_embededdings.cpu().to(torch.float32).detach().numpy()
return prompt_embededdings_np
def embedding_to_image(embeddings):
# debug_print("embedding_to_image")
size = math.ceil(math.sqrt(embeddings.shape[0]))
image_embeddings_square = np.pad(embeddings, (0, size**2 - embeddings.shape[0]), 'constant')
image_embeddings_square.resize(size,size)
embedding_image = Image.fromarray(image_embeddings_square, mode="L")
return embedding_image
def embedding_to_base64(embeddings):
# debug_print("embedding_to_base64")
import base64
# ensure float32
embeddings = embeddings.astype(np.float32)
embeddings_b64 = base64.urlsafe_b64encode(embeddings).decode()
return embeddings_b64
def base64_to_embedding(embeddings_b64):
# debug_print("base64_to_embedding")
import base64
embeddings = base64.urlsafe_b64decode(embeddings_b64)
embeddings = np.frombuffer(embeddings, dtype=np.float32)
# embeddings = torch.tensor(embeddings)
return embeddings
def is_prompt_embeddings(prompt):
if prompt is None or prompt == "":
return False
try:
embedding = base64_to_embedding(prompt)
return True
except Exception as e:
return False
def safe_url(url):
import urllib.parse
url = urllib.parse.quote(url, safe=':/')
# if url has two .jpg filenames, take the first one
if url.count('.jpg') > 0:
url = url.split('.jpg')[0] + '.jpg'
return url
def main(
# input_im,
embeddings,
n_samples=4,
):
debug_print("main")
images = []
for url in test_images_urls:
import requests
from io import BytesIO
from PIL import Image
try:
response = requests.get(url)
if not response.ok:
continue
bytes = BytesIO(response.content)
image = Image.open(bytes)
if image.mode != 'RGB':
image = image.convert('RGB')
# width = 336
# aspect_ratio = float(image.height) / float(image.width)
# height = int(width * aspect_ratio)
# image = image.resize((width, height), Image.Resampling.LANCZOS)
images.append((image, "title"))
except Exception as e:
print(e)
return images
embeddings = base64_to_embedding(embeddings)
# convert to python array
embeddings = embeddings.tolist()
results = clip_retrieval_client.query(embedding_input=embeddings)
images = []
for result in results:
if len(images) >= n_samples:
break
url = safe_url(result["url"])
similarty = float("{:.4f}".format(result["similarity"]))
title = str(similarty) + ' ' + result["caption"]
# we could just return the url and the control would take care of the rest
# however, if the url returns an error, the page crashes.
# images.append((url, title))
# continue
# dowload image
import requests
from io import BytesIO
try:
response = requests.get(url)
if not response.ok:
continue
bytes = BytesIO(response.content)
image = Image.open(bytes)
if image.mode != 'RGB':
image = image.convert('RGB')
# width = 336
# aspect_ratio = float(image.height) / float(image.width)
# height = int(width * aspect_ratio)
# image = image.resize((width, height), Image.Resampling.LANCZOS)
images.append((image, title))
except Exception as e:
print(e)
return images
def on_image_load_update_embeddings(image_data):
debug_print("on_image_load_update_embeddings")
# image to embeddings
if image_data is None:
# embeddings = prompt_to_embedding('')
# embeddings_b64 = embedding_to_base64(embeddings)
# return gr.Text.update(embeddings_b64)
# return gr.Text.update('')
return ''
embeddings = image_to_embedding(image_data)
embeddings_b64 = embedding_to_base64(embeddings)
# return gr.Text.update(embeddings_b64)
return embeddings_b64
def on_prompt_change_update_embeddings(prompt):
debug_print("on_prompt_change_update_embeddings")
# prompt to embeddings
if prompt is None or prompt == "":
embeddings = prompt_to_embedding('')
embeddings_b64 = embedding_to_base64(embeddings)
return gr.Text.update(embedding_to_base64(embeddings))
embeddings = prompt_to_embedding(prompt)
embeddings_b64 = embedding_to_base64(embeddings)
return embeddings_b64
def update_average_embeddings(embedding_base64s_state, embedding_powers):
debug_print("update_average_embeddings")
final_embedding = None
num_embeddings = 0
for i, embedding_base64 in enumerate(embedding_base64s_state):
if embedding_base64 is None or embedding_base64 == "":
continue
embedding = base64_to_embedding(embedding_base64)
embedding = embedding * embedding_powers[i]
if final_embedding is None:
final_embedding = embedding
else:
final_embedding = final_embedding + embedding
num_embeddings += 1
if final_embedding is None:
# embeddings = prompt_to_embedding('')
# embeddings_b64 = embedding_to_base64(embeddings)
# return gr.Text.update(embeddings_b64)
return ''
# TODO toggle this to support average or sum
# final_embedding = final_embedding / num_embeddings
# normalize embeddings in numpy
final_embedding /= np.linalg.norm(final_embedding)
embeddings_b64 = embedding_to_base64(final_embedding)
return embeddings_b64
def on_power_change_update_average_embeddings(embedding_base64s_state, embedding_power_state, power, idx):
debug_print("on_power_change_update_average_embeddings")
embedding_power_state[idx] = power
embeddings_b64 = update_average_embeddings(embedding_base64s_state, embedding_power_state)
return embeddings_b64
def on_embeddings_changed_update_average_embeddings(embedding_base64s_state, embedding_power_state, embedding_base64, idx):
debug_print("on_embeddings_changed_update_average_embeddings")
embedding_base64s_state[idx] = embedding_base64 if embedding_base64 != '' else None
embeddings_b64 = update_average_embeddings(embedding_base64s_state, embedding_power_state)
return embeddings_b64
def on_embeddings_changed_update_plot(embeddings_b64):
debug_print("on_embeddings_changed_update_plot")
# plot new embeddings
if embeddings_b64 is None or embeddings_b64 == "":
data = pd.DataFrame({
'embedding': [],
'index': []})
update = gr.LinePlot.update(data,
x="index",
y="embedding",
# color="country",
title="Embeddings",
# stroke_dash="cluster",
# x_lim=[1950, 2010],
tooltip=['index', 'embedding'],
# stroke_dash_legend_title="Country Cluster",
# height=300,
width=0)
return update
embeddings = base64_to_embedding(embeddings_b64)
data = pd.DataFrame({
'embedding': embeddings,
'index': [n for n in range(len(embeddings))]})
return gr.LinePlot.update(data,
x="index",
y="embedding",
# color="country",
title="Embeddings",
# stroke_dash="cluster",
# x_lim=[1950, 2010],
tooltip=['index', 'embedding'],
# stroke_dash_legend_title="Country Cluster",
# height=300,
width=embeddings.shape[0])
def on_example_image_click_set_image(input_image, image_url):
debug_print("on_example_image_click_set_image")
input_image.value = image_url
# device = torch.device("mps" if torch.backends.mps.is_available() else "cuda:0" if torch.cuda.is_available() else "cpu")
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# from clip_retrieval.load_clip import load_clip, get_tokenizer
# from clip_retrieval.clip_client import ClipClient, Modality
# model, preprocess = load_clip(clip_model, use_jit=True, device=device)
# tokenizer = get_tokenizer(clip_model)
# clip_retrieval_client = ClipClient(
# url=clip_retrieval_service_url,
# indice_name=clip_model_id,
# use_safety_model = False,
# use_violence_detector = False,
# # modality = Modality.TEXT,
# )
model, preprocess, tokenizer, clip_retrieval_client = None, None, None, None
examples = [
# ["SohoJoeEth.jpeg", "Ray-Liotta-Goodfellas.jpg", "SohoJoeEth + Ray.jpeg"],
# ["SohoJoeEth.jpeg", "Donkey.jpg", "SohoJoeEth + Donkey.jpeg"],
# ["SohoJoeEth.jpeg", "Snoop Dogg.jpg", "SohoJoeEth + Snoop Dogg.jpeg"],
["pup1.jpg", "", "Pup no teacup.jpg"],
]
# image_folder = os.path.join("file", "images")
image_folder ="images"
# image_examples = {
# "452650": "452650.jpeg",
# "Prompt 1": "a college dorm with a desk and bunk beds",
# "371739": "371739.jpeg",
# "Prompt 2": "a large banana is placed before a stuffed monkey.",
# "557922": "557922.jpeg",
# "Prompt 3": "a person sitting on a bench using a cell phone",
# }
test_images_urls = [
"https://www.mdig.com.br/imagens/bichos/caes_miniatura_chicara_07.jpg",
"https://i.pinimg.com/236x/04/ac/0e/04ac0e05964b75c9db59de94c571339e.jpg",
"https://i.pinimg.com/236x/a1/55/9e/a1559e56ae5fb6e4d19c43c6396b5940--teacup-yorkie-yorkie-puppy.jpg",
"https://i.pinimg.com/236x/11/c4/51/11c4518febb0869bd2b2d391d0753f44.jpg",
]
tabbed_examples = {
"Pups": {
"Pup1": "pup1.jpg",
"Prompt": "Teacup Yorkies",
"Pup2": "pup2.jpg",
"Pup3": "pup3.jpg",
"Pup4": "pup4.jpeg",
"Pup5": "pup5.jpg",
},
"Embeddings": {
"Black & White": "F0kxPHAqE7t3DoY79djWOwA6Cb2hjK88EkuIvXdEgTzS2yY93WXvOsKffL08qjU9oGVJvZtXD7wiQ-u7QTLhvGRqozpSFqo8fCMaOy42NDyyXCC9ls69Olk_A7zJ6Ik97AwLOyNjCryYr4W8kREmPfIOPb0xrde7137Fu3Jr5bwGKGU90T-lvI1pMT1ftz-9qy3vPMTnRDzx97C8fRWjPGbQU71d6f26ASZyPdg3Qrx-saS9FaAbu83DK732Ry-9WQ7HPPPiwTzY4gS97gc1PGXmRrzsZUS9kQwmPKDZvzw4F4a9zElPPQjmdj2Lqak9SHFXPJmPvDwRLPU7YvqHPV7OYDx7K-q8wfdIveWEXT1kYE289sfOOwH6YbuID129kMivu1uvCb35jGi9shisPNsXz7zb0xk99u_ivE68QL2ibjo8jCAmvPIz5Lyv9kU8rUIKPbCciD23RQG9P48tPfEpsLwbZkE8wtjHvHqvB72k_Au818IRO2pLHz2U2yq9M1_hvTG2FD05si275m4rvL85ojtDSCo9xWclvTOEgLrMt449xggSvFQzT72wtOw7mtgLvRe-N732MlW8bGIWvMi_m7p1XSA91Sq2PHDDxrxcMC29RPoJPVeIzTyCC448PVequz8rLz2Rsnk6yayDvVjAJjuJxqe9ivRpvVEKjTwwxoM9OKOtPUjlhT20uYO8ynh2PYLVZjwREFS8wWWIPLZGrz2oNxy7GR4gvaENaDzBxJu8ZdYGtw7B7rxG2sC7mP0QPYkKXbzbh4E7HrsqvJ64P73AL_Y8I6L-vB59jjyVE8O7jshZOw19JLwQD1A9NG3xu00JHr08iBC9WANYvBe-N73tpoS6dOZaPT1Btzy9kcI8MzNRvNH7hDuMnwK8o88MvZWU5rwDUhk9sBEPPGi_DL0BY8u8-cKPvGqIjbw90_c8y19CvVYT8bxLLxU920igvX_DgTyAzh-8q2lZOy-UmbxXnlU8or9cPJruuj3HJHy9a5-ZPUJ717w9frq88QG2vIwxQz2dIam8ET5Uu2ftYzsxRxI9nDwcPDQz5jxtrf079M5rPAVIrryU_U87fslzPSTkQrxpCK08VY1NvdsXTzuwsz48rvvavIsTY705S2G936llvTsk0bwQhEG6zoYovD0TXjoseXq9bt45PfQJqDzYp_I8z3WLvGTflLwWLui78YjDvAx4Trw2-yO9oXunvADWCj1wlQQ8NmF-PG1uszzCusW8jujMO3W637yShqy7O4loPHI4ybw4yr271thYPB9hAj0PwUQ8L0fRPLkzlD091oY88PvBOjv4wDyBMh69eUsePfOb6DyI4Zq8TxH-u0uCi702EZc8kqoQvQQkAbwz9AY841FFvebosTs2z708FpwnPUqOPb0zlQi9KNeWvBbeALwy70U91KD_PHPxSTwi8wO8aroMPYpCdb1QLDI9VgwAvL-JKz1I1Za8EshhO-EUbD2cpQW9UEeQvA8vBD0KvYq6ZJ6DvWZsVb1UdT07gHVPPUM-VLxbYZM8SoXXPFP6iDwC-AU9_5p4vIuPKD0d7x26nrBSvWaCxjn2aCY9iaLDvD8yYbzQPzo9BihlPaxUar2IeMa7zJniOsRbiD0zigS9P-TqvJMO7bsuR3u8KS_WvEL1XT3vWoG7OmGqvBsJdb3_vlE5js_2vP-wF70sbJq8XqWJPOOOM71wkJk9WJRqvA7XNzqVV-O8OjBuPW5_O71u-QK92baeOovezDw-NPS8G-dkvRx_UD1xzYe94dxoPM0_-TzqrSE8DUWhPLz6qzwr4TW9yOPIO5FmYzyM_oC9fj8gvWtdljuotQg8lwlDPPPRubw4Bxc8f6y5u_x2trxXH2Q81m2TPadef7yfNwc9AlIuvKNWXr2A2c28s6e1uspCsrzBHtk66N90PeUS-TqID109vydpu8vDVb1ZknC97JqRu18mLT1hIZg8CNmBt43v1DsezNw73krSPH_eo7yAPlA99-k0PNBamLw16jA90kSlu-pTIz0DPLs8sdQLvMP3Mz3_zCS6UOMmvX4mFr3OhtK7tjA8PT6wJL2DTy69bE17vCFNVr1LJVS8zIuouVlRzjy34jA9CV-QPD5j3DwrYu68CD4uvZAP57zhOQs8krApvOvEV707mAC7giNdPYjSbjxfdLg8gZacu3ritLz6NPI7KDTBvEEhxDzihwE8pX0vPWS4GTtZfk28YPBwPQMaFr3Jqxs8t5bsuw3moj0q2q07AuRuPDSLnbw1-zi961XVuq_qKL1ofAW8CwGBvVScoz3uDny9px98PA35cTxwaYk6jwUJPQUyZbx2mVM9SMqnPdPMejxuREA9dK0UPe4VCLrn6vg8LBTOO8c2v7yL1wW9PUWJPIqrybzN3h67_-sSvVaMHz2DKno8uDCSPN3ZiL2B5JK9-CsOPNdBrbzEIPY7gCjdvEAJCjv3LX-8NjupvTIxijw55-I7CgjIvOx1Hj33TUg81is6vLs8Gj1K7VA8nyetPXnXAbgJ1US92c7tPNAYP7zzeVi9eaVGvaT8izv9_pi92S3sPCywZD3ihPK8pJ2NPZe0BT3EYSG9ahv8uic4k7wNc-M8bX9QvTUsirzUYaC8LCVrPS5Y7jyqsnk4UYbaPLF8P70_x7C80qB_u5jBjDzQqCO9vqn8vCUcMb3Z6TY9ZLRcPYLJHz3ZPJg7OMW9O23YBr3YeZu89uKCPTdrqjuwB-O8ZL82vJ7ClbvHOq-76d4xvITNT7z2x6Q8NQAkvatTUb2PhlY8NCe0OzOLR71EliA8-W6APBeomjstTFG99YfouysDNT2R3b28y7K4vLt2-bvKf6C9bf6XvH8EBrzZMak8BNpuPLLJhz1G6gU9UouwPE-QRT1lKKA8bjaGO9JQGDwpjtQ7yi_BvLv_ATyE1Ju99_8nPRh1gr1Vyyk6Q_WzPc60fz2-XoC9nsTxvJfFDT2zS9m8kx-2u3mlxjyfqf68Tje9PcK_G7v_F467sdXOvIzZdr3LONw8ngu2vF8Rgr35KP87TzCEvQNZS7xGrkU976MhPTOLsr2woYg7OGAmvRPIzLxLtaM7ofWDPXREq7wSwsi7T7dVPEiO5zwaOtu8OudNvHa5nLxYLpC9h445OydP3rsAwMG7rV9ZvWnBfb0Gwoo8ZwPXvPa72ztudDc9l4gfPJTxsjxCioM90JIwPaFs5rxs46q8yGKQux4injyRLZ281PZivejwfLwwZGE8NTinvDfsdzwSkMk8hMtmPSTkrbyigau8L4kVvJntIrrwqTo8dNSPvctfQruCT0M8lm1Wu6mpPD2Mmpc971YEOY4sGT2Be1M9rewmPfCeS7yGn9Y8KXYaPVbF0DvtdQm9r2bhvPwGGz1jNOe8KQ0xPbEPOz38nIO8PEQaPKFsZj30PNW9Be4FvRVfzj2HUvq7aJiRPB-lNz1GIze87PZrPKefPz3g_ni8VfOPvYRFbb1j-7W8_Z0xvVLvrrz8kZQ9_2dLPbAC4zxthKa8fzwJvff_p7waPzG9FV_jvLKlDr3ED249QDy3PD1SabxxX0i9M0hAvVL1Rz0cGCG9coYuvY--bj1sKhO9d_jmPHowQL0Bm_g863DIPUM-VD0OVhQ8I5H2PLwD_rxb5zY7MJqdOy6wOrxhIS09yOU7PT74ljwRi_M9WYDLvAtKIb01Yf68WfwQPf2YsTv0-vu8bp_qPHKDrL1z1KS8Xf_bPDCQXD1oIZE9lGFjuzaN5LzWIXk8nJsvPSgdtTyqEEq9286EOyZm6rtF8DM95ugxugOlDz115YI86aumO388CT1SPRC8MtM5vB6liDwqxT06Q5agvJ-TAz0CikY9MxbBvP2_QTwxhsc8NLeDOzql9LxJxmq8MSWCPfVtkTzuo3W78-gSvenIvrxOofe8zr5APQBnCL3z4us7cKEhvQ4lWL0FTwK9cMMcPHpNZT3G1hK9XOKhvZF2nz1mRdq8dgeou_fTQTz43968VndaPNdZ_DsoxLo8n0gPvTcp0Tq17Js85EnDPWiyZ733Fsk8AFwuPGuUv7sYHcs8GGtBvZK7dju76bi9KHsavWsqEzwk0xA8",
"Fire": "VsgiPPjckjuCP-q8fs25PC1DT7z-1li7_bAkvPkCrDx4W9E7V9sYvf3_D72Oyau8vXGFPVCSFD00qhO9-mSEPGQ8kT1_pN67qvXxvDwZDb1nLnO8TbvSO8Gn0j3OdUc9dNaqPP2wJL170AQ9CX-0PJu1jTycyBW9JAVvvAtWPjyF1pe8o4awO12vBrvkJ0O8EtsBPUszvbzduJO8zghmPeIBezyBaKE8STaIPBUnkDq7D1q8swJCPUR1kLy9qmU7itcaPZ2flr2fOlg7gcpKvQh__DzNAEM8xqSJPVNTb7s7aMA7cJ3sPK8ujDx4W1G8-RVhvcra-jxRuD-9RTkAPdr0dDx9a7Q7Y1q9ux2WLT0QBJO8DWkHPUE8pby9Ioi8kFydPJX0Xbzwpme8YeVuvL0iiD0q2Ti9oE1OPGXPlL0qswS805g7vWb1pD1ivPi8oOskvAm7BT323P685LIju0N1WL0WYLq80sEWPShRx7zTSVC83qLkvF-_1bwZ1RG9833EvIRlVr0jo2m8xxmpvGXirjyLTDo9hRa-PKAnGj2tVwK9aGogvay5xjxvZB69MoSnPHSHmzxcc1k79hhQvfIufbyH7Zo9E2NNvRq_0Dz0o_i8Tx2iPJb0lTypgPY8g90cPV7VljuyyZc7gXvfPLjW97x5W6097CGdvJhT9Lk7yum8AJsbPP90izw3uj49qUeEvQE277xqQaE70UybPIw2-Tq1O2y8yD9dPD0_wbxZY6485v7DvN3LP7vsg8Y8OZE_PehzGzwNj1-86kouPW2Ni70FM9w7wDIzvHUibzwqO_Q7c4djvc2eh73TIwq9icSkvA7x0rzb9Ky9Lpp4vaSs0rwUdkM6WLK9PFhQOL2vZ348TUYPveAqVrw3HOi87JbyPPGmXj06foG6E1CPPMgD-rx7gT29kgpMPau5fLxaTe271DaSvUjUpr0v6a28hLRBvS-tbjznpHy8rbm0uic-UTwCwZm8UOEavCZnmrz9Ek69cU65vcUGabznc0g8oZwDPWVX8ry3YVi9MPwRPZqilzu_H-E8m7WNOs-INLy-5pI85ZzivLewQ7wlVLY9WnaAPOMBMz1j-JM8MF47Pfe2gjuxjWq5LWEqPRohejwumni84ow3vBoOvLzHyqu8Z7mvvB1t9jzE83I83LglvQRv4zxAx448KRUuvI3UGT0Ojxe9IJM-O-fCzjwKVv29DMtUPX6R6LwDIwQ9gwNRPaOG1LzCSsA8f20AvbUoLj1SyzU9o3MNPJd8cz0mBQO8A0kvPVhQFDx-axC9Ts7IPEtkxDyxjeq6O7fPPI1Jbz3VDS49UQcrOsPg_LwpUSM8VwFxve7P5rwOQCw61uTbu1WMdbxV22A8IM-PPKsxHz1U8SG9cmGvvD1Kgjx3cRI9WMVXvPigwbxR9JA9zoj8vJJGlDzWlXC9N7q-vFGlAT0Ywps6XoYrvD946zy3_668ZZNnu4OhSz3WqHi9dl4cuE1ZKb3fj4I82s4KvILKJr3hjFu8w35TPZDker20d2E9DS22PRBmmLxmaug8D_EKPRUnEDsbcB2911n7u0PEwzyl-yu9lWl9vFFWljzP_Xc7i5slvU5_yzzJtGo80RBKvcKUFLxQkgs8jXImvDUyRLxjWr07RxCcPfPfybw_eGu8HZaJu9VvPDzCIwK8XjecPC9LRT56bkc9tWSRvFD0Rr2QgtE77x7SvPbUkzyikR48vVv6PPSQujuS96i9G10Du50wKj0UdsO5NIFcvGxnjbtLRnu9ZvVIvDI1vLyDjo28URppPCegVj1oLk-9u5oEPWzcrLspApQ9W4mavJEzHrufOti7RHWQPF9warw1vZI8vaplugiojzxjHuw8LSXrPH0v9bwwmoM9Dt6CPPIufTwHu_G8XK8qve6psrrvHtI60l8Rvdwt6byZj7M8VVMDPI3nxboqd7M84gFXu4-rI71DE688wEXxPIxfsLycKq07rWrJvAowJb1mG_28M-YRvYUWtTxq1Da95v7DvJHRKj0VJxC7QU_sPEdyxbyw8hY9xUIWPaYhPDtVtay7faepvBg3Tb1Ai2E8Kuz2vNOrebygnF09X13QvEN1xjy4dM67XE0cvE67Lr0yhCe9rbmiPZvuf7yT2YW8_E5DvR8TXr1QVt46iLGuPJqXTb3Z7DY9LunjPNkw2Lw3uj499ZCEPbLJl7vefAw9TWznvJ8nPr0OLZI8Ff7Yva0IoLsEmIg9x__juwaoe7vmmgY97vgdPUIAuTwOQCw79-_0O_K5ubxN96M8GkqNPY2YbL0bcJ08y3gtPK0IoDucAfY8uLAfPbDceb0c0iK9ysdgOQgKFbtkCwq-NqfIvGD7Aj3cuEk8z8QFO7CjKzwNyzA9aslsPdRv8ryQq4i8e7LEPLaKD7v7KI-9fJSzPO6WmLy27Dg8TNGTvOSyIz0arLa9zRNdPW7G_TxUUzm958JOPbrWCz3L2kS9tU50PZ3ut7yenwS9rQggvH_1-LxCTxK8OeAqPSvsUr3Ijja9Tn9LPYixrjuD3Qq9HzyVPGgbNbzyV7S8UQerOj5l9bymITw8H4BIu3g1Hb3KePU8F9VQvZvIp7yN50W8poNlvB5alL3M7cy8IUSLO0HaH70NLba8PI70vGq2QLwRjMy8YF0aPZmPIbzVvqe9MHF5PcjdobuWuM08rQggvaMRiL1Tokg97-JcvMgsjb0-8A09bNysPbrpJTxQQyC9e-NUvF8OZTz_Xve85CdDPEXB-DyPb9s8XdVePVRAjbx3SFu8lM4pvHTWhjwgk746a2dDvF2vBjw00Gs9akEhPE3kCT2PmJK8D2bgPIS0Qb2xRY-8-jdbPQuloLvyzFO9uRKKPHaEdD12Xhy8Je0Uve-TcT1lMRo9PlI3PafSCLhy1nK8qLzrPA0E_zzkJ0M8eUiTPf1hubxkqYS9IhvUPPeNyzz7O006wmvdOxX-2L0EDcy8zcRxvTp-gT0f4nE7u3HfPB3lBj2Pq6w8FHZDvZlANj0OQKw7z3WavIVlMj1B2qi7jP0GvMBF8btDOWM8w-B8PX4cpb2tkP08YqnwPHLW8jzPJgs9-mSEvbf_rrzN7QS845-tvP3_j72M_Ya7ZkQ0vczajryGKTS9RMSNvS-Hnz3aQ_I8IIAkPbV3vTx8Mgq9z4i0vP_pTr1UZmU8lDDTOwl_ND1zdKU8G3AdPUlJIr0aIfo77x7Su1Zmi7wmK7e90oVFvSMuSj0WwuO85WvJPKa_ErwarD88mFN0OkJPpLrB0Am8hviavMhl_7xIIxK98sxTPKZwJzz8YW89lGwAvFIaIbx2QLi8sY3qvMTzcj37O006BPoxvVrFjzwPU6K81tFBPGzcrLqmg-W87ZY8vObYj7wEvjy9cZ0ku48Nsjtmpl09Z71gvcOnCrzPxAW9e24RPWUeAL3WIAm9YHDYPCTflrrcuEk8Q2KaPS-tyrze3hG936JAvGhqoDy_gWa9v1uyOw6PFz2QglG86DfuPEX9yTzazhw9sFSuPCEI3ru4_wq9QrHNPDwZDTz5UTK9DS02vYLw2jy_qh09v1syPNP6Ur0HMwu9HVpcvVgU57wdNJY8uvy_PO6AezxBPK68vHG7vEHahLvnr7S9DAcCvdWCerz9Es67lB25vO-T8bsUFJq84oyTvV8O5TsCIzG7gcrKPH5rkLvQrsQ8KaCOvCXJ1bsmBQO9XP4nPesle71QMJi8qm0Uu8YZTb3nwjw9C2lzvOSyI7xXPcI8sHp0PZZDk7xvUQS9-cZRPZCC0TvoERY9mrVVPOvS-Twez_u72jA0vSo74rwOQCy7nrIMvYcA4juuzKq8C0MJPXRLbj2yjU88FMUuvSv_bD3_1pC9yfAXva_yOj2tCKC769JePRF5DjzIixi9gBl-Pe-8KD0G5Ci9Cc7DPKxEJz1grJc9La4KvaERWb1nppU7rd_oPDdFn7yatag8HSnDvIxfDD0VTcQ83_EHPeBmJ709LCe99qOMvApDP72cAfa8YCHbPGpBITpXsuG8qfiGvKNzBD2QXAI9o5luvXGKijzaQ_K8R5h5PVJpjLug_mI9ZVfyPC_WAT0AI3m9",
},
"Portraits": {
"Snoop": "Snoop Dogg.jpg",
"Snoop Prompt": "Snoop Dogg",
"Ray": "Ray-Liotta-Goodfellas.jpg",
"Ray Prompt": "Ray Liotta, Goodfellas",
"Anya": "Anya Taylor-Joy 003.jpg",
"Anya Prompt": "Anya Taylor-Joy, The Queen's Gambit",
"Billie": "billie eilish 004.jpeg",
"Billie Prompt": "Billie Eilish, blonde hair",
"Lizzo": "Lizzo 001.jpeg",
"Lizzo Prompt": "Lizzo,",
"Donkey": "Donkey.jpg",
"Donkey Prompt": "Donkey, from Shrek",
},
"Transforms": {
"ColorWheel001": "ColorWheel001.jpg",
"ColorWheel001 BW": "ColorWheel001 BW.jpg",
"ColorWheel002": "ColorWheel002.jpg",
"ColorWheel002 BW": "ColorWheel002 BW.jpg",
},
"CoCo": {
"452650": "452650.jpeg",
"Prompt 1": "a college dorm with a desk and bunk beds",
"371739": "371739.jpeg",
"Prompt 2": "a large banana is placed before a stuffed monkey.",
"557922": "557922.jpeg",
"Prompt 3": "a person sitting on a bench using a cell phone",
"540554": "540554.jpeg",
"Prompt 4": "two trains are coming down the tracks, a steam engine and a modern train.",
},
# "NFT's": {
# "SohoJoe": "SohoJoeEth.jpeg",
# "SohoJoe Prompt": "SohoJoe.Eth",
# "Mirai": "Mirai.jpg",
# "Mirai Prompt": "Mirai from White Rabbit, @shibuyaxyz",
# "OnChainMonkey": "OnChainMonkey-2278.jpg",
# "OCM Prompt": "On Chain Monkey",
# "Wassie": "Wassie 4498.jpeg",
# "Wassie Prompt": "Wassie by Wassies",
# },
}
tile_size = 110
image_examples_tile_size = 50
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=5):
gr.Markdown(
"""
# Soho-Clip Embeddings Explorer
A tool for exploring CLIP embedding space.
**Example #1** - removing the Teacup from the image
* Add the image Pups->Pup1 on Input tab 1
* Add the text prompt "Teacup." on Input tab 2
* Make the Input 2 embeddings negative by setting the power to -1
* Click the "Search Embedding Space" to see the results
""")
with gr.Column(scale=2, min_width=(tile_size+20)*3):
with gr.Row():
with gr.Column(scale=1, min_width=tile_size):
gr.Markdown("#### Pup in cup:")
with gr.Column(scale=1, min_width=tile_size):
gr.Markdown("#### - 'Teacup'")
with gr.Column(scale=1, min_width=tile_size):
gr.Markdown("#### = Pup")
for example in examples:
with gr.Row():
for example in example:
with gr.Column(scale=1, min_width=tile_size):
if len(example):
local_path = os.path.join(image_folder, example)
gr.Image(
value = local_path, shape=(tile_size,tile_size),
show_label=False, interactive=False) \
.style(height=tile_size, width=tile_size)
with gr.Row():
for i in range(max_tabs):
with gr.Tab(f"Input {i+1}"):
with gr.Row():
with gr.Column(scale=1, min_width=240):
input_images[i] = gr.Image(label="Image Prompt", show_label=True)
with gr.Column(scale=3, min_width=600):
embedding_plots[i] = gr.LinePlot(show_label=False).style(container=False)
# input_image.change(on_image_load, inputs= [input_image, plot])
with gr.Row():
with gr.Column(scale=2, min_width=240):
input_prompts[i] = gr.Textbox(label="Text Prompt", show_label=True, max_lines=4)
with gr.Column(scale=3, min_width=600):
with gr.Row():
# with gr.Slider(min=-5, max=5, value=1, label="Power", show_label=True):
# embedding_powers[i] = gr.Slider.value
embedding_powers[i] = gr.Slider(minimum=-3, maximum=3, value=1, label="Power", show_label=True, interactive=True)
with gr.Row():
with gr.Accordion(f"Embeddings (base64)", open=False):
embedding_base64s[i] = gr.Textbox(show_label=False, live=True)
for idx, (tab_title, examples) in enumerate(tabbed_examples.items()):
with gr.Tab(tab_title):
with gr.Row():
for idx, (title, example) in enumerate(examples.items()):
if example.endswith(".jpg") or example.endswith(".jpeg"):
# add image example
local_path = os.path.join(image_folder, example)
with gr.Column(scale=1, min_width=image_examples_tile_size):
gr.Examples(
examples=[local_path],
inputs=input_images[i],
label=title,
)
else:
# add text example
with gr.Column(scale=1, min_width=image_examples_tile_size*2):
gr.Examples(
examples=[example],
inputs=input_prompts[i],
label=title,
)
with gr.Row():
average_embedding_plot = gr.LinePlot(show_label=True, label="Average Embeddings (base64)").style(container=False)
with gr.Row():
with gr.Accordion(f"Avergage embeddings in base 64", open=False):
average_embedding_base64 = gr.Textbox(show_label=False)
with gr.Row():
with gr.Column(scale=1, min_width=200):
n_samples = gr.Slider(1, 16, value=4, step=1, label="Number images")
with gr.Column(scale=3, min_width=200):
submit = gr.Button("Search embedding space")
# with gr.Row():
# output = gr.Gallery(label="Closest images in Laion 5b using kNN", show_label=True)\
# .style(grid=[4,4], height="auto")
output = gr.Gallery(label="Closest images in Laion 5b using kNN", show_label=True)\
.style(grid=[4,4], height="auto")
submit.click(main, inputs= [average_embedding_base64, n_samples], outputs=output)
embedding_base64s_state = gr.State(value=[None for i in range(max_tabs)])
embedding_power_state = gr.State(value=[1. for i in range(max_tabs)])
def on_image_load(input_image, idx_state, embedding_base64s_state, embedding_power_state):
debug_print("on_image_load")
embeddings_b64 = on_image_load_update_embeddings(input_image)
new_plot = on_embeddings_changed_update_plot(embeddings_b64)
average_embeddings_b64 = on_embeddings_changed_update_average_embeddings(embedding_base64s_state, embedding_power_state, embeddings_b64, idx_state)
new_average_plot = on_embeddings_changed_update_plot(average_embeddings_b64)
return embeddings_b64, new_plot, average_embeddings_b64, new_average_plot
def on_prompt_change(prompt, idx_state, embedding_base64s_state, embedding_power_state):
debug_print("on_prompt_change")
if is_prompt_embeddings(prompt):
embeddings_b64 = prompt
else:
embeddings_b64 = on_prompt_change_update_embeddings(prompt)
new_plot = on_embeddings_changed_update_plot(embeddings_b64)
average_embeddings_b64 = on_embeddings_changed_update_average_embeddings(embedding_base64s_state, embedding_power_state, embeddings_b64, idx_state)
new_average_plot = on_embeddings_changed_update_plot(average_embeddings_b64)
return embeddings_b64, new_plot, average_embeddings_b64, new_average_plot
def on_power_change(power, idx_state, embedding_base64s_state, embedding_power_state):
debug_print("on_power_change")
average_embeddings_b64 = on_power_change_update_average_embeddings(embedding_base64s_state, embedding_power_state, power, idx_state)
new_average_plot = on_embeddings_changed_update_plot(average_embeddings_b64)
return average_embeddings_b64, new_average_plot
for i in range(max_tabs):
idx_state = gr.State(value=i)
input_images[i].change(on_image_load,
[input_images[i], idx_state, embedding_base64s_state, embedding_power_state],
[embedding_base64s[i], embedding_plots[i], average_embedding_base64, average_embedding_plot])
input_prompts[i].change(on_prompt_change,
[input_prompts[i], idx_state, embedding_base64s_state, embedding_power_state],
[embedding_base64s[i], embedding_plots[i], average_embedding_base64, average_embedding_plot])
embedding_powers[i].change(on_power_change,
[embedding_powers[i], idx_state, embedding_base64s_state, embedding_power_state],
[average_embedding_base64, average_embedding_plot])
with gr.Row():
gr.Markdown(
"""
My interest is to use CLIP for image/video understanding (see [CLIP_visual-spatial-reasoning](https://github.com/Sohojoe/CLIP_visual-spatial-reasoning).)
**Example #2** - adding black & white embeddings
* Add the image Pups->Pup4 on Input tab 1
* Add Embeddings->Black&White on Input tab 2
* Set Input 2 embeddings power to 1.3
* Click the "Search Embedding Space" to see the results
* Note: You may need to play with the power with different source images
### Initial Features
- Combine up to 10 Images and/or text inputs to create an average embedding space.
- Search the laion 5b images via a kNN search
### Known limitations
- I'm getting formatting bugs when running on Huggingface (vs my Mac Book). This is impacting:
- The galary
- The Embeddings Tab
### Acknowledgements
- I heavily build on [clip-retrieval](https://rom1504.github.io/clip-retrieval/) and use their API. Please [cite](https://github.com/rom1504/clip-retrieval#citation) the authors if you use this work.
- [CLIP](https://openai.com/blog/clip/)
- [Stable Diffusion](https://github.com/CompVis/stable-diffusion)
""")
# ![Alt Text](file/pup1.jpg)
# <img src="file/pup1.jpg" width="100" height="100">
# ![Alt Text](file/pup1.jpg){height=100 width=100}
if __name__ == "__main__":
demo.launch(debug=True)