BigGAN / app.py
Jezia's picture
Update app.py
312c3a7
raw
history blame
2.79 kB
import torch
import gradio as gr
import numpy as np
import nltk
nltk.download('wordnet')
nltk.download('omw-1.4')
from PIL import Image
from pytorch_pretrained_biggan import (BigGAN, one_hot_from_names, truncated_noise_sample,
save_as_images, display_in_terminal)
initial_archi = 'biggan-deep-128' #@param ['biggan-deep-128', 'biggan-deep-256', 'biggan-deep-512'] {allow-input: true}
initial_class = 'dog'
gan_model = BigGAN.from_pretrained(initial_archi)
def generate_images (initial_archi, initial_class, batch_size):
truncation = 0.4
class_vector = one_hot_from_names(initial_class, batch_size=batch_size)
noise_vector = truncated_noise_sample(truncation=truncation, batch_size=batch_size)
# All in tensors
noise_vector = torch.from_numpy(noise_vector)
class_vector = torch.from_numpy(class_vector)
# If you have a GPU, put everything on cuda
#noise_vector = noise_vector.to('cuda')
#class_vector = class_vector.to('cuda')
#gan_model.to('cuda')
# Generate an image
with torch.no_grad():
output = gan_model(noise_vector, class_vector, truncation)
# If you have a GPU put back on CPU
output = output.to('cpu')
save_as_images(output)
return output
def convert_to_images(obj):
""" Convert an output tensor from BigGAN in a list of images.
Params:
obj: tensor or numpy array of shape (batch_size, channels, height, width)
Output:
list of Pillow Images of size (height, width)
"""
try:
import PIL
except ImportError:
raise ImportError("Please install Pillow to use images: pip install Pillow")
if not isinstance(obj, np.ndarray):
obj = obj.detach().numpy()
obj = obj.transpose((0, 2, 3, 1))
obj = np.clip(((obj + 1) / 2.0) * 256, 0, 255)
img = []
for i, out in enumerate(obj):
out_array = np.asarray(np.uint8(out), dtype=np.uint8)
img.append(PIL.Image.fromarray(out_array))
return img
def inference(initial_archi, initial_class):
output = generate_images (initial_archi, initial_class, 1)
PIL_output = convert_to_images(output)
return PIL_output[0]
title = "BigGAN"
description = "BigGAN using various architecture models to generate images."
article="Coming soon"
examples = [
["biggan-deep-128", "dog"],
["biggan-deep-256", 'dog'],
["biggan-deep-512", 'dog']
]
gr.Interface(inference,
inputs=[gr.inputs.Dropdown(["biggan-deep-128", "biggan-deep-256", "biggan-deep-512"]), "text"],
outputs= [gr.outputs.Image(type="pil",label="output")],
examples=examples,
title=title,
description=description,
article=article).launch( debug=True)