nqtruong commited on
Commit
36da01d
·
verified ·
1 Parent(s): e5ecc25

Upload main.py

Browse files
Files changed (1) hide show
  1. main.py +127 -0
main.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from facenet_pytorch import InceptionResnetV1
3
+ import torch.nn as nn
4
+ import torchvision.transforms as tf
5
+ import numpy as np
6
+ import torch
7
+ import faiss
8
+ import h5py
9
+ import tqdm
10
+ import os
11
+ import random
12
+ from PIL import Image
13
+ import matplotlib.cm as cm
14
+ import matplotlib as mpl
15
+
16
+ img_names = []
17
+ with open('list_eval_partition.txt', 'r') as f:
18
+ for line in f:
19
+ img_name, dtype = line.rstrip().split(' ')
20
+ img_names.append(img_name)
21
+
22
+
23
+ # For a model pretrained on VGGFace2
24
+ print('Loading model weights ........')
25
+
26
+ class SiameseModel(nn.Module):
27
+ def __init__(self):
28
+ super().__init__()
29
+ self.backbone = InceptionResnetV1(pretrained='vggface2')
30
+ def forward(self, x):
31
+ x = self.backbone(x)
32
+ x = torch.nn.functional.normalize(x, dim=1)
33
+ return x
34
+
35
+ model = SiameseModel()
36
+ model.load_state_dict(torch.load('model.pt', map_location=torch.device('cpu')))
37
+ model.eval()
38
+
39
+
40
+ # Make FAISS index
41
+ print('Make index .............')
42
+ index = faiss.IndexFlatL2(512)
43
+
44
+ hf = h5py.File('face_vecs_full.h5', 'r')
45
+ for key in tqdm.tqdm(hf.keys()):
46
+ vec = np.array(hf.get(key))
47
+ index.add(vec)
48
+
49
+ hf.close()
50
+
51
+ print("Finished indexing")
52
+
53
+ # Function to search image
54
+ def image_search(image, k=5):
55
+
56
+ transform = tf.Compose([
57
+ tf.Resize((160, 160)),
58
+ tf.ToTensor()
59
+ ])
60
+
61
+ query_img = transform(image)
62
+ query_img = torch.unsqueeze(query_img, 0)
63
+
64
+ model.eval()
65
+ query_vec = model(query_img).detach().numpy()
66
+
67
+ D, I = index.search(query_vec, k=k)
68
+
69
+ retrieval_imgs = []
70
+
71
+ FOLDER = 'img_align_celeba'
72
+ for idx in I[0]:
73
+ img_file_name = img_names[idx]
74
+ path = os.path.join(FOLDER, img_file_name)
75
+
76
+ image = Image.open(path)
77
+ retrieval_imgs.append((image, ''))
78
+
79
+ return retrieval_imgs
80
+
81
+ with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
82
+ gr.Markdown('''
83
+
84
+
85
+ # Face Image Retrieval with Content-based Image Retrieval (CBIR)
86
+ --------
87
+
88
+
89
+ ''')
90
+
91
+ with gr.Row():
92
+ with gr.Column():
93
+ image = gr.Image(type='pil', scale=1)
94
+ slider = gr.Slider(1, 10, value=5, step=1, label='Number of retrieval image')
95
+ with gr.Row():
96
+ btn = gr.Button('Search')
97
+ clear_btn = gr.ClearButton()
98
+
99
+ gallery = gr.Gallery(label='Retrieval Images', columns=[5], show_label=True, scale=2)
100
+
101
+ img_dir = './img_align_celeba'
102
+ examples = random.choices(img_names, k=5)
103
+ examples = [os.path.join(img_dir, ex) for ex in examples]
104
+ examples = [Image.open(img) for img in examples]
105
+
106
+ with gr.Row():
107
+ gr.Examples(
108
+ examples = examples,
109
+ inputs = image
110
+ )
111
+
112
+
113
+ btn.click(image_search,
114
+ inputs= [image, slider],
115
+ outputs= [gallery])
116
+
117
+ def clear_image():
118
+ return None
119
+
120
+ clear_btn.click(
121
+ fn = clear_image,
122
+ inputs = [],
123
+ outputs = [image]
124
+ )
125
+
126
+ if __name__ == "__main__":
127
+ demo.launch(server_name = "0.0.0.0", server_port = 7860)