hysts HF staff commited on
Commit
ab445b8
1 Parent(s): fed7f36

Support LPIPS distance

Browse files
Files changed (3) hide show
  1. app.py +14 -0
  2. model.py +103 -14
  3. requirements.txt +1 -0
app.py CHANGED
@@ -53,6 +53,10 @@ def get_cluster_center_image_markdown(model_name: str) -> str:
53
  return f'![cluster center images]({url})'
54
 
55
 
 
 
 
 
56
  def main():
57
  args = parse_args()
58
 
@@ -83,6 +87,12 @@ def main():
83
  label='Truncation psi')
84
  multimodal_truncation = gr.Checkbox(
85
  label='Multi-modal Truncation', value=True)
 
 
 
 
 
 
86
  run_button = gr.Button('Run')
87
  with gr.Column():
88
  result = gr.Image(label='Result', elem_id='result')
@@ -106,12 +116,16 @@ def main():
106
  gr.Markdown(FOOTER)
107
 
108
  model_name.change(fn=model.set_model, inputs=model_name, outputs=None)
 
 
 
109
  run_button.click(fn=model.set_model_and_generate_image,
110
  inputs=[
111
  model_name,
112
  seed,
113
  psi,
114
  multimodal_truncation,
 
115
  ],
116
  outputs=result)
117
  model_name2.change(fn=get_sample_image_markdown,
 
53
  return f'![cluster center images]({url})'
54
 
55
 
56
+ def update_distance_type(multimodal_truncation: bool) -> dict:
57
+ return gr.Dropdown.update(visible=multimodal_truncation)
58
+
59
+
60
  def main():
61
  args = parse_args()
62
 
 
87
  label='Truncation psi')
88
  multimodal_truncation = gr.Checkbox(
89
  label='Multi-modal Truncation', value=True)
90
+ distance_type = gr.Dropdown([
91
+ 'lpips',
92
+ 'l2',
93
+ ],
94
+ value='lpips',
95
+ label='Distance Type')
96
  run_button = gr.Button('Run')
97
  with gr.Column():
98
  result = gr.Image(label='Result', elem_id='result')
 
116
  gr.Markdown(FOOTER)
117
 
118
  model_name.change(fn=model.set_model, inputs=model_name, outputs=None)
119
+ multimodal_truncation.change(fn=update_distance_type,
120
+ inputs=multimodal_truncation,
121
+ outputs=distance_type)
122
  run_button.click(fn=model.set_model_and_generate_image,
123
  inputs=[
124
  model_name,
125
  seed,
126
  psi,
127
  multimodal_truncation,
128
+ distance_type,
129
  ],
130
  outputs=result)
131
  model_name2.change(fn=get_sample_image_markdown,
model.py CHANGED
@@ -5,6 +5,7 @@ import pathlib
5
  import pickle
6
  import sys
7
 
 
8
  import numpy as np
9
  import torch
10
  import torch.nn as nn
@@ -17,6 +18,31 @@ sys.path.insert(0, submodule_dir.as_posix())
17
  HF_TOKEN = os.environ['HF_TOKEN']
18
 
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  class Model:
21
 
22
  MODEL_NAMES = [
@@ -33,10 +59,17 @@ class Model:
33
  self.device = torch.device(device)
34
  self._download_all_models()
35
  self._download_all_cluster_centers()
 
36
 
37
  self.model_name = self.MODEL_NAMES[0]
38
  self.model = self._load_model(self.model_name)
39
  self.cluster_centers = self._load_cluster_centers(self.model_name)
 
 
 
 
 
 
40
 
41
  def _load_model(self, model_name: str) -> nn.Module:
42
  path = hf_hub_download('hysts/Self-Distilled-StyleGAN',
@@ -56,12 +89,20 @@ class Model:
56
  centers = torch.from_numpy(centers).float().to(self.device)
57
  return centers
58
 
 
 
 
 
 
 
59
  def set_model(self, model_name: str) -> None:
60
  if model_name == self.model_name:
61
  return
62
  self.model_name = model_name
63
  self.model = self._load_model(model_name)
64
  self.cluster_centers = self._load_cluster_centers(model_name)
 
 
65
 
66
  def _download_all_models(self):
67
  for name in self.MODEL_NAMES:
@@ -71,6 +112,10 @@ class Model:
71
  for name in self.MODEL_NAMES:
72
  self._load_cluster_centers(name)
73
 
 
 
 
 
74
  def generate_z(self, seed: int) -> torch.Tensor:
75
  seed = int(np.clip(seed, 0, np.iinfo(np.uint32).max))
76
  return torch.from_numpy(
@@ -82,11 +127,6 @@ class Model:
82
  w = self.model.mapping(z, label)
83
  return w
84
 
85
- def find_nearest_cluster_center(self, w: torch.Tensor) -> int:
86
- # Here, Euclidean distance is used instead of LPIPS distance
87
- dist2 = ((self.cluster_centers - w)**2).sum(dim=1)
88
- return torch.argmin(dist2).item()
89
-
90
  @staticmethod
91
  def truncate_w(w_center: torch.Tensor, w: torch.Tensor,
92
  psi: float) -> torch.Tensor:
@@ -103,22 +143,71 @@ class Model:
103
  torch.uint8)
104
  return tensor.cpu().numpy()
105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  def generate_image(self, seed: int, truncation_psi: float,
107
- multimodal_truncation: bool) -> np.ndarray:
 
108
  z = self.generate_z(seed)
109
- w = self.compute_w(z)
110
  if multimodal_truncation:
111
- cluster_index = self.find_nearest_cluster_center(w[:, 0])
112
  w0 = self.cluster_centers[cluster_index]
113
  else:
114
  w0 = self.model.mapping.w_avg
115
- new_w = self.truncate_w(w0, w, truncation_psi)
116
- out = self.synthesize(new_w)
117
  out = self.postprocess(out)
118
  return out[0]
119
 
120
- def set_model_and_generate_image(
121
- self, model_name: str, seed: int, truncation_psi: float,
122
- multimodal_truncation: bool) -> np.ndarray:
 
123
  self.set_model(model_name)
124
- return self.generate_image(seed, truncation_psi, multimodal_truncation)
 
 
5
  import pickle
6
  import sys
7
 
8
+ import lpips
9
  import numpy as np
10
  import torch
11
  import torch.nn as nn
 
18
  HF_TOKEN = os.environ['HF_TOKEN']
19
 
20
 
21
+ class LPIPS(lpips.LPIPS):
22
+ @staticmethod
23
+ def preprocess(image: np.ndarray) -> torch.Tensor:
24
+ data = torch.from_numpy(image).float() / 255
25
+ data = data * 2 - 1
26
+ return data.permute(2, 0, 1).unsqueeze(0)
27
+
28
+ @torch.inference_mode()
29
+ def compute_features(self, data: torch.Tensor) -> list[torch.Tensor]:
30
+ data = self.scaling_layer(data)
31
+ data = self.net(data)
32
+ return [lpips.normalize_tensor(x) for x in data]
33
+
34
+ @torch.inference_mode()
35
+ def compute_distance(self, features0: list[torch.Tensor],
36
+ features1: list[torch.Tensor]) -> float:
37
+ res = 0
38
+ for lin, x0, x1 in zip(self.lins, features0, features1):
39
+ d = (x0 - x1)**2
40
+ y = lin(d)
41
+ y = lpips.lpips.spatial_average(y)
42
+ res += y.item()
43
+ return res
44
+
45
+
46
  class Model:
47
 
48
  MODEL_NAMES = [
 
59
  self.device = torch.device(device)
60
  self._download_all_models()
61
  self._download_all_cluster_centers()
62
+ self._download_all_cluster_center_images()
63
 
64
  self.model_name = self.MODEL_NAMES[0]
65
  self.model = self._load_model(self.model_name)
66
  self.cluster_centers = self._load_cluster_centers(self.model_name)
67
+ self.cluster_center_images = self._load_cluster_center_images(
68
+ self.model_name)
69
+
70
+ self.lpips = LPIPS()
71
+ self.cluster_center_lpips_feature_dict = self._compute_cluster_center_lpips_features(
72
+ )
73
 
74
  def _load_model(self, model_name: str) -> nn.Module:
75
  path = hf_hub_download('hysts/Self-Distilled-StyleGAN',
 
89
  centers = torch.from_numpy(centers).float().to(self.device)
90
  return centers
91
 
92
+ def _load_cluster_center_images(self, model_name: str) -> np.ndarray:
93
+ path = hf_hub_download('hysts/Self-Distilled-StyleGAN',
94
+ f'cluster_center_images/{model_name}.npy',
95
+ use_auth_token=HF_TOKEN)
96
+ return np.load(path)
97
+
98
  def set_model(self, model_name: str) -> None:
99
  if model_name == self.model_name:
100
  return
101
  self.model_name = model_name
102
  self.model = self._load_model(model_name)
103
  self.cluster_centers = self._load_cluster_centers(model_name)
104
+ self.cluster_center_images = self._load_cluster_center_images(
105
+ model_name)
106
 
107
  def _download_all_models(self):
108
  for name in self.MODEL_NAMES:
 
112
  for name in self.MODEL_NAMES:
113
  self._load_cluster_centers(name)
114
 
115
+ def _download_all_cluster_center_images(self):
116
+ for name in self.MODEL_NAMES:
117
+ self._load_cluster_center_images(name)
118
+
119
  def generate_z(self, seed: int) -> torch.Tensor:
120
  seed = int(np.clip(seed, 0, np.iinfo(np.uint32).max))
121
  return torch.from_numpy(
 
127
  w = self.model.mapping(z, label)
128
  return w
129
 
 
 
 
 
 
130
  @staticmethod
131
  def truncate_w(w_center: torch.Tensor, w: torch.Tensor,
132
  psi: float) -> torch.Tensor:
 
143
  torch.uint8)
144
  return tensor.cpu().numpy()
145
 
146
+ def compute_lpips_features(self, image: np.ndarray) -> list[torch.Tensor]:
147
+ data = self.lpips.preprocess(image)
148
+ return self.lpips.compute_features(data)
149
+
150
+ def _compute_cluster_center_lpips_features(
151
+ self) -> dict[str, list[list[torch.Tensor]]]:
152
+ res = dict()
153
+ for name in self.MODEL_NAMES:
154
+ images = self._load_cluster_center_images(name)
155
+ res[name] = [
156
+ self.compute_lpips_features(image) for image in images
157
+ ]
158
+ return res
159
+
160
+ def compute_distance_to_cluster_centers(
161
+ self, ws: torch.Tensor, distance_type: str) -> list[torch.Tensor]:
162
+ if distance_type == 'l2':
163
+ return self._compute_l2_distance_to_cluster_centers(ws)
164
+ elif distance_type == 'lpips':
165
+ return self._compute_lpips_distance_to_cluster_centers(ws)
166
+ else:
167
+ raise ValueError
168
+
169
+ def _compute_l2_distance_to_cluster_centers(
170
+ self, ws: torch.Tensor) -> np.ndarray:
171
+ dist2 = ((self.cluster_centers - ws[0, 0])**2).sum(dim=1)
172
+ return dist2.cpu().numpy()
173
+
174
+ def _compute_lpips_distance_to_cluster_centers(
175
+ self, ws: torch.Tensor) -> np.ndarray:
176
+ x = self.synthesize(ws)
177
+ x = self.postprocess(x)[0]
178
+ feat0 = self.compute_lpips_features(x)
179
+ cluster_center_features = self.cluster_center_lpips_feature_dict[
180
+ self.model_name]
181
+ distances = [
182
+ self.lpips.compute_distance(feat0, feat1)
183
+ for feat1 in cluster_center_features
184
+ ]
185
+ return np.asarray(distances)
186
+
187
+ def find_nearest_cluster_center(self, ws: torch.Tensor,
188
+ distance_type: str) -> int:
189
+ distances = self.compute_distance_to_cluster_centers(ws, distance_type)
190
+ return int(np.argmin(distances))
191
+
192
  def generate_image(self, seed: int, truncation_psi: float,
193
+ multimodal_truncation: bool,
194
+ distance_type: str) -> np.ndarray:
195
  z = self.generate_z(seed)
196
+ ws = self.compute_w(z)
197
  if multimodal_truncation:
198
+ cluster_index = self.find_nearest_cluster_center(ws, distance_type)
199
  w0 = self.cluster_centers[cluster_index]
200
  else:
201
  w0 = self.model.mapping.w_avg
202
+ new_ws = self.truncate_w(w0, ws, truncation_psi)
203
+ out = self.synthesize(new_ws)
204
  out = self.postprocess(out)
205
  return out[0]
206
 
207
+ def set_model_and_generate_image(self, model_name: str, seed: int,
208
+ truncation_psi: float,
209
+ multimodal_truncation: bool,
210
+ distance_type: str) -> np.ndarray:
211
  self.set_model(model_name)
212
+ return self.generate_image(seed, truncation_psi, multimodal_truncation,
213
+ distance_type)
requirements.txt CHANGED
@@ -1,3 +1,4 @@
 
1
  numpy==1.22.3
2
  Pillow==9.0.1
3
  scipy==1.8.0
 
1
+ lpips==0.1.4
2
  numpy==1.22.3
3
  Pillow==9.0.1
4
  scipy==1.8.0