interfacegan_pp / app.py
ybelkada's picture
update app
09ef29c
raw
history blame
2.92 kB
import os
import torch
import PIL.Image
import numpy as np
import gradio as gr
from yarg import get
from models.stylegan_generator import StyleGANGenerator
from models.stylegan2_generator import StyleGAN2Generator
VALID_CHOICES = [
"Bald",
"Young",
"Mustache",
"Eyeglasses",
"Hat",
"Smiling"
]
ENABLE_GPU = False
MODEL_NAMES = [
'stylegan_ffhq',
'stylegan2_ffhq'
]
NB_IMG = 4
OUTPUT_LIST = [gr.outputs.Image(type="pil", label="Generated Image") for _ in range(NB_IMG)] + [gr.outputs.Image(type="pil", label="Modified Image") for _ in range(NB_IMG)]
def tensor_to_pil(input_object):
"""Shows images in one figure."""
if isinstance(input_object, dict):
im_array = []
images = input_object['image']
else:
images = input_object
for _, image in enumerate(images):
im_array.append(PIL.Image.fromarray(image))
return im_array
def get_generator(model_name):
if model_name == 'stylegan_ffhq':
generator = StyleGANGenerator(model_name)
elif model_name == 'stylegan2_ffhq':
generator = StyleGAN2Generator(model_name)
else:
raise ValueError('Model name not recognized')
if ENABLE_GPU:
generator = generator.cuda()
return generator
@torch.no_grad()
def inference(seed, choice, model_name, coef, nb_images=NB_IMG):
np.random.seed(seed)
boundary = np.squeeze(np.load(open(os.path.join('boundaries', model_name, 'boundary_%s.npy' % choice), 'rb')))
generator = get_generator(model_name)
latent_codes = generator.easy_sample(nb_images)
if ENABLE_GPU:
latent_codes = latent_codes.cuda()
generator = generator.cuda()
generated_images = generator.easy_synthesize(latent_codes)
generated_images = tensor_to_pil(generated_images)
new_latent_codes = latent_codes.copy()
for i, _ in enumerate(generated_images):
new_latent_codes[i, :] += boundary*coef
modified_generated_images = generator.easy_synthesize(new_latent_codes)
modified_generated_images = tensor_to_pil(modified_generated_images)
return generated_images + modified_generated_images
iface = gr.Interface(
fn=inference,
inputs=[
gr.inputs.Slider(
minimum=0,
maximum=1000,
step=1,
default=264,
label="Random seed to use for the generation"
),
gr.inputs.Dropdown(
choices=VALID_CHOICES,
type="value",
label="Attribute to modify",
),
gr.inputs.Dropdown(
choices=MODEL_NAMES,
type="value",
label="Model to use",
),
gr.inputs.Slider(
minimum=-3,
maximum=3,
step=0.1,
default=0,
label="Modification coefficient",
),
],
outputs=OUTPUT_LIST,
layout="horizontal",
theme="peach"
)
iface.launch()