File size: 5,314 Bytes
d7b8e7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
933f49e
 
 
 
 
 
 
 
 
d7b8e7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b36a9d
d7b8e7c
 
 
 
 
 
 
 
933f49e
d7b8e7c
933f49e
 
 
 
 
d7b8e7c
 
 
 
 
 
933f49e
d7b8e7c
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#!/usr/bin/env python3
import gradio as gr
from PIL import Image
from transformers import AutoProcessor, AutoModel, AutoTokenizer
import torch
import faiss
import glob
import numpy as np

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model     = AutoModel.from_pretrained("google/siglip-base-patch16-256-multilingual").to(device)
processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-256-multilingual")
tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-256-multilingual")

num_dimensions = model.vision_model.config.hidden_size # 768
num_k = 30

text_examples = [
    "Frog waiting on a rock",
    "Bird with open mouth",
    "Bridge and a ship",
    "Bike for two people",
    "Biene auf der Blume",
    "Hesap makinesi"
]

def preprocess_images(pathname="images/*", index_file="index.faiss"):
  print("Preprocessing images...")
  index = faiss.IndexFlatIP(num_dimensions) # Build the index using Inner Product (IP) similarity.
  image_filenames = []
  image_features = []
  for image_filename in glob.glob(pathname):
    try:
      image_raw = Image.open(image_filename)
      image_rgb = image_raw.convert('RGB')
      image_filenames.append(image_filename)
      inputs = processor(images=image_rgb, return_tensors="pt").to(device)
      with torch.no_grad():
        image_embedding = model.get_image_features(**inputs).to("cpu")
        image_embedding_n = image_embedding / image_embedding.norm(p=2, dim=-1, keepdim=True)
        image_embedding_n = image_embedding_n.numpy()
      image_features.append(image_embedding_n)
    except Exception as e:
      print(f"Error processing {image_filename}".format(image_filename))
      print(e)
      exit(1)

  print("Indexing images...")
  image_features = np.concatenate(image_features, axis=0)
  index.add(image_features)

  print("Saving index...")
  faiss.write_index(index, index_file)
  with open("image_filenames.txt", "w") as f:
    for image_filename in image_filenames:
      f.write(image_filename + "\n")

  print("Preprocessing complete.")
  return index, image_filenames

def load_processed_images(index_file="index.faiss", image_filenames_file="image_filenames.txt"):
  print("Loading index...")
  index = faiss.read_index(index_file)
  with open(image_filenames_file) as f:
    image_filenames = f.readlines()
  image_filenames = [x.strip() for x in image_filenames]
  return index, image_filenames

@torch.no_grad()
def search_using_text(text):
  inputs = tokenizer(text, padding="max_length", return_tensors="pt").to(device)
  text_features = model.get_text_features(**inputs).to("cpu")
  text_features_n = text_features / text_features.norm(p=2, dim=-1, keepdim=True)
  text_features_n = text_features_n.numpy()

  D, I = index.search(text_features_n, num_k)

  scale = model.logit_scale.exp().cpu().numpy()
  bias = model.logit_bias.cpu().numpy()
  result = []
  for dist, idx in zip(D[0], I[0]):
    score_logit = dist * scale + bias
    score_probability = torch.sigmoid(torch.tensor(score_logit)).item()
    found_image = Image.open(image_filenames[idx])
    found_image.load()
    result.append((found_image, "{:.2f}%".format(score_probability*100)))

  return result

@torch.no_grad()
def search_using_image(image):
  image = Image.fromarray(image)
  image_rgb = image.convert('RGB')
  inputs = processor(images=image_rgb, return_tensors="pt").to(device)

  image_embedding = model.get_image_features(**inputs).to("cpu")
  image_embedding_n = image_embedding / image_embedding.norm(p=2, dim=-1, keepdim=True)
  image_embedding_n = image_embedding_n.numpy()

  D, I = index.search(image_embedding_n, num_k)

  result = []
  for dist, idx in zip(D[0], I[0]):
    found_image = Image.open(image_filenames[idx])
    found_image.load()
    result.append(found_image)

  return result

if __name__ == "__main__":
  #index, image_filenames = preprocess_images() # uncomment this line to preprocess images
  index, image_filenames = load_processed_images()

  with gr.Blocks() as demo:
    gr.Markdown("# Image Search Engine Demo")
    with gr.Row(equal_height=False):
      with gr.Column():
        gr.Markdown("This app is powered by [SigLIP](https://huggingface.co/google/siglip-base-patch16-256-multilingual) with multilingual support and [GPR1200 Dataset](https://www.kaggle.com/datasets/mathurinache/gpr1200-dataset) image contents. Enter your query in the text box or upload an image to search for similar images.")
        with gr.Tab("Text-Image Search"):
          text_input = gr.Textbox(label="Type a word or a sentence")
          search_using_text_btn = gr.Button("Search with text", scale=0)
          gr.Examples(
            examples = text_examples,
            inputs = [text_input]
          )

        with gr.Tab("Image-Image Search"):
          image_input = gr.Image()
          search_using_image_btn = gr.Button("Search with image", scale=0)

      gallery = gr.Gallery(label="Generated images", show_label=False,
                          elem_id="gallery", columns=3,
                          object_fit="contain", interactive=False, scale=2.75)

    search_using_text_btn.click(search_using_text, inputs=text_input, outputs=gallery)
    search_using_image_btn.click(search_using_image, inputs=image_input, outputs=gallery)
    demo.launch(share=False)