Spaces:
Runtime error
Runtime error
File size: 2,264 Bytes
7bd3343 589ceac 7bd3343 589ceac 815e124 589ceac 815e124 589ceac 7bd3343 589ceac 7bd3343 589ceac 7bd3343 589ceac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
import gradio as gr
import os
import numpy as np
import torch
import pickle
import types
from huggingface_hub import hf_hub_url, cached_download
#TOKEN = os.environ['TOKEN']
with open(cached_download(hf_hub_url('CorvaeOboro/gen_ability_icon', 'gen_ability_icon_stylegan2ada_20220801.pkl')), 'rb') as f:
G = pickle.load(f)['G_ema']# torch.nn.Module
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda")
G = G.to(device)
else:
_old_forward = G.forward
def _new_forward(self, *args, **kwargs):
kwargs["force_fp32"] = True
return _old_forward(*args, **kwargs)
G.forward = types.MethodType(_new_forward, G)
_old_synthesis_forward = G.synthesis.forward
def _new_synthesis_forward(self, *args, **kwargs):
kwargs["force_fp32"] = True
return _old_synthesis_forward(*args, **kwargs)
G.synthesis.forward = types.MethodType(_new_synthesis_forward, G.synthesis)
def generate(num_images, interpolate):
if interpolate:
z1 = torch.randn([1, G.z_dim])# latent codes
z2 = torch.randn([1, G.z_dim])# latent codes
zs = torch.cat([z1 + (z2 - z1) * i / (num_images-1) for i in range(num_images)], 0)
else:
zs = torch.randn([num_images, G.z_dim])# latent codes
with torch.no_grad():
zs = zs.to(device)
img = G(zs, None, force_fp32=True, noise_mode='const')
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
return img.cpu().numpy()
demo = gr.Blocks()
def infer(num_images, interpolate):
img = generate(round(num_images), interpolate)
imgs = list(img)
return imgs
with demo:
gr.Markdown(
"""
# gen_ability_icon
creates circular magic ability icons from stylegan2ada model trained on synthetic dataset .
more information here : https://github.com/CorvaeOboro/gen_ability_icon.
""")
images_num = gr.inputs.Slider(default=1, label="Num Images", minimum=1, maximum=16, step=1)
interpolate = gr.inputs.Checkbox(default=False, label="Interpolate")
submit = gr.Button("Generate")
out = gr.Gallery()
submit.click(fn=infer,
inputs=[images_num, interpolate],
outputs=out)
demo.launch() |