hysts HF staff commited on
Commit
7df68a5
1 Parent(s): 7843da8
Files changed (2) hide show
  1. app.py +2 -78
  2. model.py +81 -0
app.py CHANGED
@@ -3,20 +3,11 @@
3
  from __future__ import annotations
4
 
5
  import argparse
6
- import os
7
- import pickle
8
- import sys
9
 
10
  import gradio as gr
11
- import numpy as np
12
- import torch
13
- import torch.nn as nn
14
  from huggingface_hub import hf_hub_download
15
 
16
- sys.path.insert(0, 'StyleGAN-Human')
17
-
18
- TOKEN = os.environ['TOKEN']
19
-
20
 
21
  def parse_args() -> argparse.Namespace:
22
  parser = argparse.ArgumentParser()
@@ -30,76 +21,9 @@ def parse_args() -> argparse.Namespace:
30
  return parser.parse_args()
31
 
32
 
33
- class App:
34
-
35
- def __init__(self, device: torch.device):
36
- self.device = device
37
- self.model = self.load_model('stylegan_human_v2_1024.pkl')
38
-
39
- def load_model(self, file_name: str) -> nn.Module:
40
- path = hf_hub_download('hysts/StyleGAN-Human',
41
- f'models/{file_name}',
42
- use_auth_token=TOKEN)
43
- with open(path, 'rb') as f:
44
- model = pickle.load(f)['G_ema']
45
- model.eval()
46
- model.to(self.device)
47
- with torch.inference_mode():
48
- z = torch.zeros((1, model.z_dim)).to(self.device)
49
- label = torch.zeros([1, model.c_dim], device=self.device)
50
- model(z, label, force_fp32=True)
51
- return model
52
-
53
- def generate_z(self, z_dim: int, seed: int) -> torch.Tensor:
54
- return torch.from_numpy(np.random.RandomState(seed).randn(
55
- 1, z_dim)).to(self.device).float()
56
-
57
- @torch.inference_mode()
58
- def generate_single_image(self, seed: int,
59
- truncation_psi: float) -> np.ndarray:
60
- seed = int(np.clip(seed, 0, np.iinfo(np.uint32).max))
61
-
62
- z = self.generate_z(self.model.z_dim, seed)
63
- label = torch.zeros([1, self.model.c_dim], device=self.device)
64
-
65
- out = self.model(z,
66
- label,
67
- truncation_psi=truncation_psi,
68
- force_fp32=True)
69
- out = (out.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(
70
- torch.uint8)
71
- return out[0].cpu().numpy()
72
-
73
- @torch.inference_mode()
74
- def generate_interpolated_images(
75
- self, seed0: int, psi0: float, seed1: int, psi1: float,
76
- num_intermediate: int) -> list[np.ndarray]:
77
- seed0 = int(np.clip(seed0, 0, np.iinfo(np.uint32).max))
78
- seed1 = int(np.clip(seed1, 0, np.iinfo(np.uint32).max))
79
-
80
- z0 = self.generate_z(self.model.z_dim, seed0)
81
- z1 = self.generate_z(self.model.z_dim, seed1)
82
- vec = z1 - z0
83
- dvec = vec / (num_intermediate + 1)
84
- zs = [z0 + dvec * i for i in range(num_intermediate + 2)]
85
- dpsi = (psi1 - psi0) / (num_intermediate + 1)
86
- psis = [psi0 + dpsi * i for i in range(num_intermediate + 2)]
87
-
88
- label = torch.zeros([1, self.model.c_dim], device=self.device)
89
-
90
- res = []
91
- for z, psi in zip(zs, psis):
92
- out = self.model(z, label, truncation_psi=psi, force_fp32=True)
93
- out = (out.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(
94
- torch.uint8)
95
- out = out[0].cpu().numpy()
96
- res.append(out)
97
- return res
98
-
99
-
100
  def main():
101
  args = parse_args()
102
- app = App(device=torch.device(args.device))
103
 
104
  with gr.Blocks(theme=args.theme) as demo:
105
  gr.Markdown('''<center><h1>StyleGAN-Human</h1></center>
 
3
  from __future__ import annotations
4
 
5
  import argparse
 
 
 
6
 
7
  import gradio as gr
 
 
 
8
  from huggingface_hub import hf_hub_download
9
 
10
+ from model import Model
 
 
 
11
 
12
  def parse_args() -> argparse.Namespace:
13
  parser = argparse.ArgumentParser()
 
21
  return parser.parse_args()
22
 
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  def main():
25
  args = parse_args()
26
+ app = Model(device=args.device)
27
 
28
  with gr.Blocks(theme=args.theme) as demo:
29
  gr.Markdown('''<center><h1>StyleGAN-Human</h1></center>
model.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import pickle
5
+ import sys
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ from huggingface_hub import hf_hub_download
11
+
12
+ sys.path.insert(0, 'StyleGAN-Human')
13
+
14
+ HF_TOKEN = os.environ['HF_TOKEN']
15
+
16
+
17
+ class Model:
18
+
19
+ def __init__(self, device: str | torch.device):
20
+ self.device = torch.device(device)
21
+ self.model = self.load_model('stylegan_human_v2_1024.pkl')
22
+
23
+ def load_model(self, file_name: str) -> nn.Module:
24
+ path = hf_hub_download('hysts/StyleGAN-Human',
25
+ f'models/{file_name}',
26
+ use_auth_token=HF_TOKEN)
27
+ with open(path, 'rb') as f:
28
+ model = pickle.load(f)['G_ema']
29
+ model.eval()
30
+ model.to(self.device)
31
+ with torch.inference_mode():
32
+ z = torch.zeros((1, model.z_dim)).to(self.device)
33
+ label = torch.zeros([1, model.c_dim], device=self.device)
34
+ model(z, label, force_fp32=True)
35
+ return model
36
+
37
+ def generate_z(self, z_dim: int, seed: int) -> torch.Tensor:
38
+ return torch.from_numpy(np.random.RandomState(seed).randn(
39
+ 1, z_dim)).to(self.device).float()
40
+
41
+ @torch.inference_mode()
42
+ def generate_single_image(self, seed: int,
43
+ truncation_psi: float) -> np.ndarray:
44
+ seed = int(np.clip(seed, 0, np.iinfo(np.uint32).max))
45
+
46
+ z = self.generate_z(self.model.z_dim, seed)
47
+ label = torch.zeros([1, self.model.c_dim], device=self.device)
48
+
49
+ out = self.model(z,
50
+ label,
51
+ truncation_psi=truncation_psi,
52
+ force_fp32=True)
53
+ out = (out.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(
54
+ torch.uint8)
55
+ return out[0].cpu().numpy()
56
+
57
+ @torch.inference_mode()
58
+ def generate_interpolated_images(
59
+ self, seed0: int, psi0: float, seed1: int, psi1: float,
60
+ num_intermediate: int) -> list[np.ndarray]:
61
+ seed0 = int(np.clip(seed0, 0, np.iinfo(np.uint32).max))
62
+ seed1 = int(np.clip(seed1, 0, np.iinfo(np.uint32).max))
63
+
64
+ z0 = self.generate_z(self.model.z_dim, seed0)
65
+ z1 = self.generate_z(self.model.z_dim, seed1)
66
+ vec = z1 - z0
67
+ dvec = vec / (num_intermediate + 1)
68
+ zs = [z0 + dvec * i for i in range(num_intermediate + 2)]
69
+ dpsi = (psi1 - psi0) / (num_intermediate + 1)
70
+ psis = [psi0 + dpsi * i for i in range(num_intermediate + 2)]
71
+
72
+ label = torch.zeros([1, self.model.c_dim], device=self.device)
73
+
74
+ res = []
75
+ for z, psi in zip(zs, psis):
76
+ out = self.model(z, label, truncation_psi=psi, force_fp32=True)
77
+ out = (out.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(
78
+ torch.uint8)
79
+ out = out[0].cpu().numpy()
80
+ res.append(out)
81
+ return res