hysts HF staff commited on
Commit
5691542
1 Parent(s): 6361ab3
.gitattributes CHANGED
@@ -1,3 +1,4 @@
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
 
1
+ *.jpg filter=lfs diff=lfs merge=lfs -text
2
  *.7z filter=lfs diff=lfs merge=lfs -text
3
  *.arrow filter=lfs diff=lfs merge=lfs -text
4
  *.bin filter=lfs diff=lfs merge=lfs -text
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "stylegan3"]
2
+ path = stylegan3
3
+ url = https://github.com/NVlabs/stylegan3
app.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import functools
7
+ import os
8
+ import pickle
9
+ import sys
10
+
11
+ sys.path.insert(0, 'stylegan3')
12
+
13
+ import gradio as gr
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn as nn
17
+ from huggingface_hub import hf_hub_download
18
+
19
+ ORIGINAL_REPO_URL = 'https://github.com/NVlabs/stylegan3'
20
+ TITLE = 'NVlabs/stylegan3'
21
+ DESCRIPTION = f'This is a demo for {ORIGINAL_REPO_URL}.'
22
+ SAMPLE_IMAGE_DIR = 'https://huggingface.co/spaces/hysts/StyleGAN3/resolve/main/samples'
23
+ ARTICLE = f'''## Generated images
24
+ - truncation: 0.7
25
+ ### AFHQv2
26
+ - size: 512x512
27
+ - seed: 0-99
28
+ ![AFHQv2 samples]({SAMPLE_IMAGE_DIR}/afhqv2.jpg)
29
+ ### FFHQ
30
+ - size: 1024x1024
31
+ - seed: 0-99
32
+ ![FFHQ samples]({SAMPLE_IMAGE_DIR}/ffhq.jpg)
33
+ ### FFHQ-U
34
+ - size: 1024x1024
35
+ - seed: 0-99
36
+ ![FFHQ-U samples]({SAMPLE_IMAGE_DIR}/ffhq-u.jpg)
37
+ ### MetFaces
38
+ - size: 1024x1024
39
+ - seed: 0-99
40
+ ![MetFaces samples]({SAMPLE_IMAGE_DIR}/metfaces.jpg)
41
+ ### MetFaces-U
42
+ - size: 1024x1024
43
+ - seed: 0-99
44
+ ![MetFaces-U samples]({SAMPLE_IMAGE_DIR}/metfaces-u.jpg)
45
+ '''
46
+
47
+ TOKEN = os.environ['TOKEN']
48
+
49
+
50
+ def parse_args() -> argparse.Namespace:
51
+ parser = argparse.ArgumentParser()
52
+ parser.add_argument('--device', type=str, default='cpu')
53
+ parser.add_argument('--theme', type=str)
54
+ parser.add_argument('--live', action='store_true')
55
+ parser.add_argument('--share', action='store_true')
56
+ parser.add_argument('--port', type=int)
57
+ parser.add_argument('--disable-queue',
58
+ dest='enable_queue',
59
+ action='store_false')
60
+ parser.add_argument('--allow-flagging', type=str, default='never')
61
+ parser.add_argument('--allow-screenshot', action='store_true')
62
+ return parser.parse_args()
63
+
64
+
65
+ def make_transform(translate: tuple[float, float], angle: float) -> np.ndarray:
66
+ mat = np.eye(3)
67
+ sin = np.sin(angle / 360 * np.pi * 2)
68
+ cos = np.cos(angle / 360 * np.pi * 2)
69
+ mat[0][0] = cos
70
+ mat[0][1] = sin
71
+ mat[0][2] = translate[0]
72
+ mat[1][0] = -sin
73
+ mat[1][1] = cos
74
+ mat[1][2] = translate[1]
75
+ return mat
76
+
77
+
78
+ def generate_z(z_dim: int, seed: int, device: torch.device) -> torch.Tensor:
79
+ return torch.from_numpy(np.random.RandomState(seed).randn(
80
+ 1, z_dim)).to(device)
81
+
82
+
83
+ @torch.inference_mode()
84
+ def generate_image(model_name: str, seed: int, truncation_psi: float,
85
+ tx: float, ty: float, angle: float,
86
+ model_dict: dict[str, nn.Module],
87
+ device: torch.device) -> np.ndarray:
88
+ model = model_dict[model_name]
89
+ seed = int(np.clip(seed, 0, np.iinfo(np.uint32).max))
90
+
91
+ z = generate_z(model.z_dim, seed, device)
92
+ label = torch.zeros([1, model.c_dim], device=device)
93
+
94
+ mat = make_transform((tx, ty), angle)
95
+ mat = np.linalg.inv(mat)
96
+ model.synthesis.input.transform.copy_(torch.from_numpy(mat))
97
+
98
+ out = model(z, label, truncation_psi=truncation_psi)
99
+ out = (out.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
100
+ return out[0].cpu().numpy()
101
+
102
+
103
+ def load_model(file_name: str, device: torch.device) -> nn.Module:
104
+ path = hf_hub_download('hysts/StyleGAN3',
105
+ f'models/{file_name}',
106
+ use_auth_token=TOKEN)
107
+ with open(path, 'rb') as f:
108
+ model = pickle.load(f)['G_ema']
109
+ model.eval()
110
+ model.to(device)
111
+ with torch.inference_mode():
112
+ z = torch.zeros((1, model.z_dim)).to(device)
113
+ label = torch.zeros([1, model.c_dim], device=device)
114
+ model(z, label)
115
+ return model
116
+
117
+
118
+ def main():
119
+ gr.close_all()
120
+
121
+ args = parse_args()
122
+ device = torch.device(args.device)
123
+
124
+ model_names = {
125
+ 'AFHQv2-512-R': 'stylegan3-r-afhqv2-512x512.pkl',
126
+ 'FFHQ-1024-R': 'stylegan3-r-ffhq-1024x1024.pkl',
127
+ 'FFHQ-U-256-R': 'stylegan3-r-ffhqu-256x256.pkl',
128
+ 'FFHQ-U-1024-R': 'stylegan3-r-ffhqu-1024x1024.pkl',
129
+ 'MetFaces-1024-R': 'stylegan3-r-metfaces-1024x1024.pkl',
130
+ 'MetFaces-U-1024-R': 'stylegan3-r-metfacesu-1024x1024.pkl',
131
+ 'AFHQv2-512-T': 'stylegan3-t-afhqv2-512x512.pkl',
132
+ 'FFHQ-1024-T': 'stylegan3-t-ffhq-1024x1024.pkl',
133
+ 'FFHQ-U-256-T': 'stylegan3-t-ffhqu-256x256.pkl',
134
+ 'FFHQ-U-1024-T': 'stylegan3-t-ffhqu-1024x1024.pkl',
135
+ 'MetFaces-1024-T': 'stylegan3-t-metfaces-1024x1024.pkl',
136
+ 'MetFaces-U-1024-T': 'stylegan3-t-metfacesu-1024x1024.pkl',
137
+ }
138
+
139
+ model_dict = {
140
+ name: load_model(file_name, device)
141
+ for name, file_name in model_names.items()
142
+ }
143
+
144
+ func = functools.partial(generate_image,
145
+ model_dict=model_dict,
146
+ device=device)
147
+ func = functools.update_wrapper(func, generate_image)
148
+
149
+ gr.Interface(
150
+ func,
151
+ [
152
+ gr.inputs.Radio(list(model_names.keys()),
153
+ type='value',
154
+ default='FFHQ-1024-R',
155
+ label='Model'),
156
+ gr.inputs.Number(default=0, label='Seed'),
157
+ gr.inputs.Slider(
158
+ 0, 2, step=0.05, default=0.7, label='Truncation psi'),
159
+ gr.inputs.Slider(-1, 1, step=0.05, default=0, label='Translate X'),
160
+ gr.inputs.Slider(-1, 1, step=0.05, default=0, label='Translate Y'),
161
+ gr.inputs.Slider(-180, 180, step=5, default=0, label='Angle'),
162
+ ],
163
+ gr.outputs.Image(type='numpy', label='Output'),
164
+ theme=args.theme,
165
+ title=TITLE,
166
+ description=DESCRIPTION,
167
+ article=ARTICLE,
168
+ allow_screenshot=args.allow_screenshot,
169
+ allow_flagging=args.allow_flagging,
170
+ live=args.live,
171
+ ).launch(
172
+ enable_queue=args.enable_queue,
173
+ server_port=args.port,
174
+ share=args.share,
175
+ )
176
+
177
+
178
+ if __name__ == '__main__':
179
+ main()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ numpy==1.22.3
2
+ Pillow==9.0.1
3
+ scipy==1.8.0
4
+ torch==1.11.0
5
+ torchvision==0.12.0
samples/afhqv2.jpg ADDED

Git LFS Details

  • SHA256: 794296f4754ffe9cb78ac15f9efce1aa831a318ac033312af260e8e8c5d25399
  • Pointer size: 133 Bytes
  • Size of remote file: 10.4 MB
samples/ffhq-u.jpg ADDED

Git LFS Details

  • SHA256: 499385e118437ce494ad5d3cb8a729c2ef8993bc6869a8e889c9f876c009ae12
  • Pointer size: 133 Bytes
  • Size of remote file: 27.4 MB
samples/ffhq.jpg ADDED

Git LFS Details

  • SHA256: 43fcb2ca2d82cda8b800913030af457b326a64d0ff6ff23304a80ad43ce3e53d
  • Pointer size: 133 Bytes
  • Size of remote file: 27.6 MB
samples/metfaces-u.jpg ADDED

Git LFS Details

  • SHA256: 7462609186728f84ef38307b7fcd02d2971a810b4a0436aaac9248e834f0bf22
  • Pointer size: 133 Bytes
  • Size of remote file: 27.9 MB
samples/metfaces.jpg ADDED

Git LFS Details

  • SHA256: 59bf6bd9579ff1cde11923ef5c03ab3fa3b6c2ab085f5a00bf4d0c932ace82a8
  • Pointer size: 133 Bytes
  • Size of remote file: 28.7 MB
stylegan3 ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit a5a69f58294509598714d1e88c9646c3d7c6ec94