import gradio as gr import os import numpy as np import torch import pickle import types from huggingface_hub import hf_hub_url, cached_download #TOKEN = os.environ['TOKEN'] with open(cached_download(hf_hub_url('CorvaeOboro/gen_ability_icon', 'gen_ability_icon_stylegan2ada_20220819.pkl')), 'rb') as f: G = pickle.load(f)['G_ema']# torch.nn.Module device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device("cuda") G = G.to(device) else: _old_forward = G.forward def _new_forward(self, *args, **kwargs): kwargs["force_fp32"] = True return _old_forward(*args, **kwargs) G.forward = types.MethodType(_new_forward, G) _old_synthesis_forward = G.synthesis.forward def _new_synthesis_forward(self, *args, **kwargs): kwargs["force_fp32"] = True return _old_synthesis_forward(*args, **kwargs) G.synthesis.forward = types.MethodType(_new_synthesis_forward, G.synthesis) def generate(num_images, interpolate): if interpolate: z1 = torch.randn([1, G.z_dim])# latent codes z2 = torch.randn([1, G.z_dim])# latent codes zs = torch.cat([z1 + (z2 - z1) * i / (num_images-1) for i in range(num_images)], 0) else: zs = torch.randn([num_images, G.z_dim])# latent codes with torch.no_grad(): zs = zs.to(device) img = G(zs, None, force_fp32=True, noise_mode='const') img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) return img.cpu().numpy() demo = gr.Blocks() def infer(num_images, interpolate): img = generate(round(num_images), interpolate) imgs = list(img) return imgs with demo: gr.Markdown( """ # gen_ability_icon creates circular magic ability icons from stylegan2ada model trained on synthetic dataset . more information here : [https://github.com/CorvaeOboro/gen_ability_icon](https://github.com/CorvaeOboro/gen_ability_icon). """) images_num = gr.inputs.Slider(default=8, label="Num Images", minimum=1, maximum=16, step=1) interpolate = gr.inputs.Checkbox(default=False, label="Interpolate") submit = gr.Button("Generate") out = gr.Gallery() submit.click(fn=infer, inputs=[images_num, interpolate], outputs=out) demo.launch()