Spaces:
Runtime error
Runtime error
File size: 5,314 Bytes
d7b8e7c 933f49e d7b8e7c 5b36a9d d7b8e7c 933f49e d7b8e7c 933f49e d7b8e7c 933f49e d7b8e7c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
#!/usr/bin/env python3
import gradio as gr
from PIL import Image
from transformers import AutoProcessor, AutoModel, AutoTokenizer
import torch
import faiss
import glob
import numpy as np
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = AutoModel.from_pretrained("google/siglip-base-patch16-256-multilingual").to(device)
processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-256-multilingual")
tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-256-multilingual")
num_dimensions = model.vision_model.config.hidden_size # 768
num_k = 30
text_examples = [
"Frog waiting on a rock",
"Bird with open mouth",
"Bridge and a ship",
"Bike for two people",
"Biene auf der Blume",
"Hesap makinesi"
]
def preprocess_images(pathname="images/*", index_file="index.faiss"):
print("Preprocessing images...")
index = faiss.IndexFlatIP(num_dimensions) # Build the index using Inner Product (IP) similarity.
image_filenames = []
image_features = []
for image_filename in glob.glob(pathname):
try:
image_raw = Image.open(image_filename)
image_rgb = image_raw.convert('RGB')
image_filenames.append(image_filename)
inputs = processor(images=image_rgb, return_tensors="pt").to(device)
with torch.no_grad():
image_embedding = model.get_image_features(**inputs).to("cpu")
image_embedding_n = image_embedding / image_embedding.norm(p=2, dim=-1, keepdim=True)
image_embedding_n = image_embedding_n.numpy()
image_features.append(image_embedding_n)
except Exception as e:
print(f"Error processing {image_filename}".format(image_filename))
print(e)
exit(1)
print("Indexing images...")
image_features = np.concatenate(image_features, axis=0)
index.add(image_features)
print("Saving index...")
faiss.write_index(index, index_file)
with open("image_filenames.txt", "w") as f:
for image_filename in image_filenames:
f.write(image_filename + "\n")
print("Preprocessing complete.")
return index, image_filenames
def load_processed_images(index_file="index.faiss", image_filenames_file="image_filenames.txt"):
print("Loading index...")
index = faiss.read_index(index_file)
with open(image_filenames_file) as f:
image_filenames = f.readlines()
image_filenames = [x.strip() for x in image_filenames]
return index, image_filenames
@torch.no_grad()
def search_using_text(text):
inputs = tokenizer(text, padding="max_length", return_tensors="pt").to(device)
text_features = model.get_text_features(**inputs).to("cpu")
text_features_n = text_features / text_features.norm(p=2, dim=-1, keepdim=True)
text_features_n = text_features_n.numpy()
D, I = index.search(text_features_n, num_k)
scale = model.logit_scale.exp().cpu().numpy()
bias = model.logit_bias.cpu().numpy()
result = []
for dist, idx in zip(D[0], I[0]):
score_logit = dist * scale + bias
score_probability = torch.sigmoid(torch.tensor(score_logit)).item()
found_image = Image.open(image_filenames[idx])
found_image.load()
result.append((found_image, "{:.2f}%".format(score_probability*100)))
return result
@torch.no_grad()
def search_using_image(image):
image = Image.fromarray(image)
image_rgb = image.convert('RGB')
inputs = processor(images=image_rgb, return_tensors="pt").to(device)
image_embedding = model.get_image_features(**inputs).to("cpu")
image_embedding_n = image_embedding / image_embedding.norm(p=2, dim=-1, keepdim=True)
image_embedding_n = image_embedding_n.numpy()
D, I = index.search(image_embedding_n, num_k)
result = []
for dist, idx in zip(D[0], I[0]):
found_image = Image.open(image_filenames[idx])
found_image.load()
result.append(found_image)
return result
if __name__ == "__main__":
#index, image_filenames = preprocess_images() # uncomment this line to preprocess images
index, image_filenames = load_processed_images()
with gr.Blocks() as demo:
gr.Markdown("# Image Search Engine Demo")
with gr.Row(equal_height=False):
with gr.Column():
gr.Markdown("This app is powered by [SigLIP](https://huggingface.co/google/siglip-base-patch16-256-multilingual) with multilingual support and [GPR1200 Dataset](https://www.kaggle.com/datasets/mathurinache/gpr1200-dataset) image contents. Enter your query in the text box or upload an image to search for similar images.")
with gr.Tab("Text-Image Search"):
text_input = gr.Textbox(label="Type a word or a sentence")
search_using_text_btn = gr.Button("Search with text", scale=0)
gr.Examples(
examples = text_examples,
inputs = [text_input]
)
with gr.Tab("Image-Image Search"):
image_input = gr.Image()
search_using_image_btn = gr.Button("Search with image", scale=0)
gallery = gr.Gallery(label="Generated images", show_label=False,
elem_id="gallery", columns=3,
object_fit="contain", interactive=False, scale=2.75)
search_using_text_btn.click(search_using_text, inputs=text_input, outputs=gallery)
search_using_image_btn.click(search_using_image, inputs=image_input, outputs=gallery)
demo.launch(share=False) |