nqtruong commited on
Commit
64d9fb2
·
verified ·
1 Parent(s): 2c991ee

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +136 -0
app.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from facenet_pytorch import InceptionResnetV1
3
+ import torch.nn as nn
4
+ import torchvision.transforms as tf
5
+ import numpy as np
6
+ import torch
7
+ import faiss
8
+ import h5py
9
+ import os
10
+ import random
11
+ from PIL import Image
12
+ import matplotlib.cm as cm
13
+ import matplotlib as mpl
14
+
15
+
16
+
17
+
18
+ img_names = []
19
+ with open('list_eval_partition.txt', 'r') as f:
20
+ for line in f:
21
+ img_name, dtype = line.rstrip().split(' ')
22
+ img_names.append(img_name)
23
+
24
+ # if dtype == '0':
25
+ # data['train'].append(img_name)
26
+ # elif dtype == '1':
27
+ # data['val'].append(img_name)
28
+ # else:
29
+ # data['test'].append(img_name)
30
+
31
+
32
+ # For a model pretrained on VGGFace2
33
+ print('Loading model weights ........')
34
+
35
+ class SiameseModel(nn.Module):
36
+ def __init__(self):
37
+ super().__init__()
38
+ self.backbone = InceptionResnetV1(pretrained='vggface2')
39
+ def forward(self, x):
40
+ x = self.backbone(x)
41
+ x = torch.nn.functional.normalize(x, dim=1)
42
+ return x
43
+
44
+ model = SiameseModel()
45
+ model.load_state_dict(torch.load('model_best_weights_1000.pt', map_location=torch.device('cpu')))
46
+ model.eval()
47
+
48
+
49
+ # Make FAISS index
50
+ print('Make index .............')
51
+ index = faiss.IndexFlatL2(512)
52
+
53
+ hf = h5py.File('face_vecs_full.h5', 'r')
54
+ for key in hf.keys():
55
+ vec = np.array(hf.get(key))
56
+ index.add(vec)
57
+
58
+ hf.close()
59
+
60
+ # Function to search image
61
+ def image_search(image, k=5):
62
+
63
+ transform = tf.Compose([
64
+ tf.Resize((160, 160)),
65
+ tf.ToTensor()
66
+ ])
67
+
68
+ query_img = transform(image)
69
+ query_img = torch.unsqueeze(query_img, 0)
70
+
71
+ model.eval()
72
+ query_vec = model(query_img).detach().numpy()
73
+
74
+ D, I = index.search(query_vec, k=k)
75
+
76
+ retrieval_imgs = []
77
+
78
+ FOLDER = 'img_align_celeba'
79
+ for idx in I[0]:
80
+ img_file_name = img_names[idx]
81
+ path = os.path.join(FOLDER, img_file_name)
82
+
83
+ image = Image.open(path)
84
+ retrieval_imgs.append((image, ''))
85
+
86
+ return retrieval_imgs
87
+
88
+ with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
89
+ gr.Markdown('''
90
+
91
+
92
+ # Face Image Retrieval with Content-based Retrieval Image (CBIR) & Saliency Map
93
+ --------
94
+
95
+
96
+ ''')
97
+
98
+ with gr.Row():
99
+ with gr.Column():
100
+ image = gr.Image(type='pil', scale=1)
101
+ slider = gr.Slider(1, 10, value=5, step=1, label='Number of retrieval image')
102
+ with gr.Row():
103
+ btn = gr.Button('Search')
104
+ clear_btn = gr.ClearButton()
105
+
106
+ gallery = gr.Gallery(label='Retrieval Images', columns=[5], show_label=True, scale=2)
107
+
108
+ img_dir = './img_align_celeba'
109
+ examples = random.choices(img_names, k=6)
110
+ examples = [os.path.join(img_dir, ex) for ex in examples]
111
+ examples = [Image.open(img) for img in examples]
112
+
113
+ with gr.Row():
114
+ gr.Examples(
115
+ examples = examples,
116
+ inputs = image
117
+ )
118
+
119
+
120
+ btn.click(image_search,
121
+ inputs= [image, slider],
122
+ outputs= [gallery])
123
+
124
+ def clear_image():
125
+ return None
126
+
127
+ clear_btn.click(
128
+ fn = clear_image,
129
+ inputs = [],
130
+ outputs = [image]
131
+ )
132
+
133
+
134
+ if __name__ == "__main__":
135
+ demo.launch()
136
+