hysts HF staff commited on
Commit
6e8417e
·
1 Parent(s): 9abd5b7
.gitattributes CHANGED
@@ -1,3 +1,4 @@
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
 
1
+ *.jpg filter=lfs diff=lfs merge=lfs -text
2
  *.7z filter=lfs diff=lfs merge=lfs -text
3
  *.arrow filter=lfs diff=lfs merge=lfs -text
4
  *.bin filter=lfs diff=lfs merge=lfs -text
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "projected_gan"]
2
+ path = projected_gan
3
+ url = https://github.com/autonomousvision/projected_gan
app.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import functools
7
+ import os
8
+ import pickle
9
+ import sys
10
+
11
+ sys.path.insert(0, 'projected_gan')
12
+
13
+ import gradio as gr
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn as nn
17
+ from huggingface_hub import hf_hub_download
18
+
19
+ ORIGINAL_REPO_URL = 'https://github.com/autonomousvision/projected_gan'
20
+ TITLE = 'autonomousvision/projected_gan'
21
+ DESCRIPTION = f'This is a demo for {ORIGINAL_REPO_URL}.'
22
+ SAMPLE_IMAGE_DIR = 'https://huggingface.co/spaces/hysts/projected_gan/resolve/main/samples'
23
+ ARTICLE = f'''## Generated images
24
+ - truncation: 0.7
25
+ - size: 256x256
26
+ - seed: 0-99
27
+ ### Art painting
28
+ ![Art painting samples]({SAMPLE_IMAGE_DIR}/art_painting.jpg)
29
+ ### Bedroom
30
+ ![Bedroom samples]({SAMPLE_IMAGE_DIR}/bedroom.jpg)
31
+ ### Church
32
+ ![Church samples]({SAMPLE_IMAGE_DIR}/church.jpg)
33
+ ### Cityscapes
34
+ ![Cityscapes samples]({SAMPLE_IMAGE_DIR}/cityscapes.jpg)
35
+ ### CLEVR
36
+ ![CLEVR samples]({SAMPLE_IMAGE_DIR}/clevr.jpg)
37
+ ### FFHQ
38
+ ![FFHQ samples]({SAMPLE_IMAGE_DIR}/ffhq.jpg)
39
+ ### Flowers
40
+ ![Flowers samples]({SAMPLE_IMAGE_DIR}/flowers.jpg)
41
+ ### Landscape
42
+ ![Landscape samples]({SAMPLE_IMAGE_DIR}/landscape.jpg)
43
+ ### Pokemon
44
+ ![Pokemon samples]({SAMPLE_IMAGE_DIR}/pokemon.jpg)
45
+ '''
46
+
47
+ TOKEN = os.environ['TOKEN']
48
+
49
+
50
+ def parse_args() -> argparse.Namespace:
51
+ parser = argparse.ArgumentParser()
52
+ parser.add_argument('--device', type=str, default='cpu')
53
+ parser.add_argument('--theme', type=str)
54
+ parser.add_argument('--live', action='store_true')
55
+ parser.add_argument('--share', action='store_true')
56
+ parser.add_argument('--port', type=int)
57
+ parser.add_argument('--disable-queue',
58
+ dest='enable_queue',
59
+ action='store_false')
60
+ parser.add_argument('--allow-flagging', type=str, default='never')
61
+ parser.add_argument('--allow-screenshot', action='store_true')
62
+ return parser.parse_args()
63
+
64
+
65
+ def generate_z(z_dim: int, seed: int, device: torch.device) -> torch.Tensor:
66
+ return torch.from_numpy(
67
+ np.random.RandomState(seed).randn(1,
68
+ z_dim).astype(np.float32)).to(device)
69
+
70
+
71
+ @torch.inference_mode()
72
+ def generate_image(model_name: str, seed: int, truncation_psi: float,
73
+ model_dict: dict[str, nn.Module],
74
+ device: torch.device) -> np.ndarray:
75
+ model = model_dict[model_name]
76
+ seed = int(np.clip(seed, 0, np.iinfo(np.uint32).max))
77
+
78
+ z = generate_z(model.z_dim, seed, device)
79
+ label = torch.zeros([1, model.c_dim], device=device)
80
+
81
+ out = model(z, label, truncation_psi=truncation_psi)
82
+ out = (out.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
83
+ return out[0].cpu().numpy()
84
+
85
+
86
+ def load_model(model_name: str, device: torch.device) -> nn.Module:
87
+ path = hf_hub_download('hysts/projected_gan',
88
+ f'models/{model_name}.pkl',
89
+ use_auth_token=TOKEN)
90
+ with open(path, 'rb') as f:
91
+ model = pickle.load(f)['G_ema']
92
+ model.eval()
93
+ model.to(device)
94
+ with torch.inference_mode():
95
+ z = torch.zeros((1, model.z_dim)).to(device)
96
+ label = torch.zeros([1, model.c_dim], device=device)
97
+ model(z, label)
98
+ return model
99
+
100
+
101
+ def main():
102
+ gr.close_all()
103
+
104
+ args = parse_args()
105
+ device = torch.device(args.device)
106
+
107
+ model_names = [
108
+ 'art_painting',
109
+ 'church',
110
+ 'bedroom',
111
+ 'cityscapes',
112
+ 'clevr',
113
+ 'ffhq',
114
+ 'flowers',
115
+ 'landscape',
116
+ 'pokemon',
117
+ ]
118
+
119
+ model_dict = {name: load_model(name, device) for name in model_names}
120
+
121
+ func = functools.partial(generate_image,
122
+ model_dict=model_dict,
123
+ device=device)
124
+ func = functools.update_wrapper(func, generate_image)
125
+
126
+ gr.Interface(
127
+ func,
128
+ [
129
+ gr.inputs.Radio(
130
+ model_names, type='value', default='pokemon', label='Model'),
131
+ gr.inputs.Number(default=0, label='Seed'),
132
+ gr.inputs.Slider(
133
+ 0, 2, step=0.05, default=0.7, label='Truncation psi'),
134
+ ],
135
+ gr.outputs.Image(type='numpy', label='Output'),
136
+ title=TITLE,
137
+ description=DESCRIPTION,
138
+ article=ARTICLE,
139
+ theme=args.theme,
140
+ allow_screenshot=args.allow_screenshot,
141
+ allow_flagging=args.allow_flagging,
142
+ live=args.live,
143
+ ).launch(
144
+ enable_queue=args.enable_queue,
145
+ server_port=args.port,
146
+ share=args.share,
147
+ )
148
+
149
+
150
+ if __name__ == '__main__':
151
+ main()
projected_gan ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit e1c246b8bdce4fac3c2bfcb69df309fc27df9b86
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ numpy==1.22.3
2
+ Pillow==9.0.1
3
+ scipy==1.8.0
4
+ torch==1.10.2
5
+ torchvision==0.11.3
samples/art_painting.jpg ADDED

Git LFS Details

  • SHA256: a7190a35437aec0e79e2b22ba476fe7d8efb110feced366f114456b19d91911b
  • Pointer size: 132 Bytes
  • Size of remote file: 3.64 MB
samples/bedroom.jpg ADDED

Git LFS Details

  • SHA256: 6d554d0fa414dfb5e599e226f69e4160a2a35691275abf8ea55c0ba167797dbc
  • Pointer size: 132 Bytes
  • Size of remote file: 2.44 MB
samples/church.jpg ADDED

Git LFS Details

  • SHA256: 9c00362027917a9a7facf7cee9691edeb81f76bddcfb77a3df58fb64f15f8104
  • Pointer size: 132 Bytes
  • Size of remote file: 2.95 MB
samples/cityscapes.jpg ADDED

Git LFS Details

  • SHA256: 205ecdd425ea02344478ae9a84d48cb574c1da8eb136301bc39f93ed30c8521c
  • Pointer size: 132 Bytes
  • Size of remote file: 2.45 MB
samples/clevr.jpg ADDED

Git LFS Details

  • SHA256: 8e98bccfe738243a3a16109151ca1c53d6f24f5c1dd46ab340c348ff081680c0
  • Pointer size: 131 Bytes
  • Size of remote file: 936 kB
samples/ffhq.jpg ADDED

Git LFS Details

  • SHA256: 091714c1ddfcc3d20d03fa8380723c78fb9370cdbd8a8ecc20d143f4405d2074
  • Pointer size: 132 Bytes
  • Size of remote file: 2.53 MB
samples/flowers.jpg ADDED

Git LFS Details

  • SHA256: 26586c30b2ccafb314f121cb5afd19579012ed6d10e4f0e4a8f482aa7b57394e
  • Pointer size: 132 Bytes
  • Size of remote file: 3.08 MB
samples/landscape.jpg ADDED

Git LFS Details

  • SHA256: 47fb12a15ce0716a6144f58f5608d0bef749b38c3875998d6ab4b1508b1fdb04
  • Pointer size: 132 Bytes
  • Size of remote file: 2.93 MB
samples/pokemon.jpg ADDED

Git LFS Details

  • SHA256: 7e3211f6867236522b8450ac1157ada3caf07a690d7e661c0e5478c1db9fe781
  • Pointer size: 132 Bytes
  • Size of remote file: 2.18 MB