Spaces:
Runtime error
Runtime error
import streamlit as st | |
from sentence_transformers import SentenceTransformer, util | |
from pathlib import Path | |
import pickle | |
import requests | |
from PIL import Image | |
from io import BytesIO | |
import pandas as pd | |
from loguru import logger | |
import torch | |
T2I = "Text 2 Image" | |
I2I = "Image 2 Image" | |
def get_match(model, query, img_embs): | |
query_emb = model.encode([query], convert_to_tensor=True) | |
cosine_sim = util.pytorch_cos_sim(query_emb, img_embs) | |
return cosine_sim | |
def text_2_image(model, img_emb, img_names, img_urls, n_top_k_images): | |
st.title("Text to Image") | |
st.write("This is the text to image mode. Enter a text to be converted to an image") | |
text = st.text_input("Enter the text to be converted to an image") | |
if text: | |
if st.button("Convert"): | |
st.write("The image with the most similar embedding is:") | |
cosine_sim = get_match(model, text, img_emb) | |
top_k_images_indices = torch.topk(cosine_sim, n_top_k_images, 1).indices.squeeze() | |
if top_k_images_indices.nelement() == 1: | |
top_k_images_indices = [top_k_images_indices.tolist()] | |
else: | |
top_k_images_indices = top_k_images_indices.tolist() | |
images_found = [img_names[top_k_best_image] for top_k_best_image in top_k_images_indices] | |
cols = st.columns(n_top_k_images) | |
for i, image_found in enumerate(images_found): | |
logger.success(f"Image match found: {image_found}") | |
img_url_best_match = img_urls.loc[img_urls["photo_id"] == image_found] | |
logger.info(img_url_best_match.photo_url) | |
if len(img_url_best_match) >= 1: | |
response = requests.get(img_url_best_match.iloc[0]["photo_image_url"] + "?w=320") | |
image = Image.open(BytesIO(response.content)) | |
with cols[i]: | |
st.image(image, caption=f"{i+1}/{n_top_k_images} most similar") | |
else: | |
st.error("No image found") | |
def image_2_image(model, img_emb, img_names, img_urls,n_top_k_images): | |
st.title("Image to Image") | |
st.write("This is the image to image mode. Enter an image to be converted to an image") | |
image = st.file_uploader("Upload an image to be converted to an image", type=["jpg", "png", "jpeg"]) | |
if image is not None: | |
image = Image.open(BytesIO(image.getvalue())) | |
st.image(image, caption="Uploaded image") | |
if st.button("Convert"): | |
st.write("The image with the most similar embedding is:") | |
cosine_sim = get_match(model, image.convert("RGB"), img_emb) | |
top_k_images_indices = torch.topk(cosine_sim, n_top_k_images, 1).indices.squeeze() | |
if top_k_images_indices.nelement() == 1: | |
top_k_images_indices = [top_k_images_indices.tolist()] | |
else: | |
top_k_images_indices = top_k_images_indices.tolist() | |
images_found = [img_names[top_k_best_image] for top_k_best_image in top_k_images_indices] | |
cols = st.columns(n_top_k_images) | |
for i, image_found in enumerate(images_found): | |
logger.success(f"Image match found: {image_found}") | |
img_url_best_match = img_urls.loc[img_urls["photo_id"] == image_found] | |
logger.info(img_url_best_match.photo_url) | |
if len(img_url_best_match) >= 1: | |
response = requests.get(img_url_best_match.iloc[0]["photo_image_url"] + "?w=320") | |
image = Image.open(BytesIO(response.content)) | |
with cols[i]: | |
st.image(image, caption=f"{i+1}/{n_top_k_images} most similar") | |
else: | |
st.error("No image found") | |
def load_model(name): | |
# st.sidebar.info("Loading model") | |
model = SentenceTransformer(name) | |
# st.sidebar.success(f"Model {name} loaded") | |
return model | |
def load_embeddings(filename): | |
st.sidebar.info("Loading Unsplash-Lite image embeddings") | |
with open(filename, "rb") as fIn: | |
img_names, img_emb = pickle.load(fIn) | |
st.sidebar.success("Images embeddings loaded") | |
return img_names, img_emb | |
def load_image_url_list(filename): | |
url_list = pd.read_csv(filename, sep='\t', header=0) | |
return url_list | |
def main(): | |
st.title("CLIP Image Search") | |
model = load_model("clip-ViT-B-32") | |
st.write("Select the mode to search for a match in Unsplash (thumbnail size) dataset. text2image mode needs a text as input and outputs the image with the most similar embedding (following cosine similarity). The Image to image mode is similar, but an input image is used instead of a text query") | |
emb_filename = Path("unsplash-25k-photos-embeddings.pkl") | |
urls_file = "photos.tsv000" | |
img_urls = load_image_url_list(urls_file) | |
img_names, img_emb = load_embeddings(emb_filename) | |
# Convert list of image names to a dict matching image IDs and their embedding index | |
img_names = {img_number: img_name.split('.')[0] for img_number, img_name in enumerate(img_names)} | |
st.sidebar.title("Settings") | |
app_mode = st.sidebar.selectbox("Choose the app mode", | |
[T2I, I2I]) | |
n_images_to_search = st.sidebar.number_input("Select the number of images to search", min_value=1, max_value=6) | |
if app_mode == T2I: | |
st.sidebar.info("Text to image mode") | |
text_2_image(model, img_emb, img_names, img_urls,n_images_to_search) | |
elif app_mode == I2I: | |
st.sidebar.info("Image to image mode") | |
image_2_image(model, img_emb, img_names, img_urls, n_images_to_search) | |
if __name__ == "__main__": | |
main() |