File size: 2,319 Bytes
7bd3343
589ceac
7bd3343
 
589ceac
 
 
 
 
815e124
589ceac
ae4445c
 
589ceac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7bd3343
589ceac
 
 
 
7bd3343
589ceac
 
 
 
 
62802d5
589ceac
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import gradio as gr
import os
import numpy as np
import torch
import pickle
import types

from huggingface_hub import hf_hub_url, cached_download

#TOKEN = os.environ['TOKEN']

with open(cached_download(hf_hub_url('CorvaeOboro/gen_ability_icon', 'gen_ability_icon_stylegan2ada_20220819.pkl')), 'rb') as f:
    G = pickle.load(f)['G_ema']# torch.nn.Module

device = torch.device("cpu")
if torch.cuda.is_available():
    device = torch.device("cuda")
    G = G.to(device)
else:
    _old_forward = G.forward

    def _new_forward(self, *args, **kwargs):
        kwargs["force_fp32"] = True
        return _old_forward(*args, **kwargs)

    G.forward = types.MethodType(_new_forward, G)

    _old_synthesis_forward = G.synthesis.forward

    def _new_synthesis_forward(self, *args, **kwargs):
        kwargs["force_fp32"] = True
        return _old_synthesis_forward(*args, **kwargs)

    G.synthesis.forward = types.MethodType(_new_synthesis_forward, G.synthesis)


def generate(num_images, interpolate):
    if interpolate:
        z1 = torch.randn([1, G.z_dim])# latent codes
        z2 = torch.randn([1, G.z_dim])# latent codes
        zs = torch.cat([z1 + (z2 - z1) * i / (num_images-1) for i in range(num_images)], 0)
    else:
        zs = torch.randn([num_images, G.z_dim])# latent codes
    with torch.no_grad():
        zs = zs.to(device)
        img = G(zs, None, force_fp32=True, noise_mode='const') 
        img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
    return img.cpu().numpy()

demo = gr.Blocks()

def infer(num_images, interpolate):
    img = generate(round(num_images), interpolate)
    imgs = list(img)
    return imgs

with demo:
    gr.Markdown(
    """
    # gen_ability_icon
    creates circular magic ability icons from stylegan2ada model trained on synthetic dataset .
    more information here :  [https://github.com/CorvaeOboro/gen_ability_icon](https://github.com/CorvaeOboro/gen_ability_icon).
    """)
    images_num = gr.inputs.Slider(default=1, label="Num Images", minimum=1, maximum=16, step=1)
    interpolate = gr.inputs.Checkbox(default=False, label="Interpolate")
    submit = gr.Button("Generate")
    
    
    out = gr.Gallery()

    submit.click(fn=infer, 
               inputs=[images_num, interpolate], 
               outputs=out)

demo.launch()