|
import torch |
|
import gradio as gr |
|
import numpy as np |
|
import nltk |
|
nltk.download('wordnet') |
|
nltk.download('omw-1.4') |
|
from PIL import Image |
|
from pytorch_pretrained_biggan import (BigGAN, one_hot_from_names, truncated_noise_sample, |
|
save_as_images, display_in_terminal) |
|
initial_archi = 'biggan-deep-128' |
|
initial_class = 'dog' |
|
|
|
gan_model = BigGAN.from_pretrained(initial_archi) |
|
|
|
def generate_images (initial_archi, initial_class, batch_size): |
|
truncation = 0.4 |
|
class_vector = one_hot_from_names(initial_class, batch_size=batch_size) |
|
noise_vector = truncated_noise_sample(truncation=truncation, batch_size=batch_size) |
|
|
|
|
|
noise_vector = torch.from_numpy(noise_vector) |
|
class_vector = torch.from_numpy(class_vector) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
output = gan_model(noise_vector, class_vector, truncation) |
|
|
|
|
|
output = output.to('cpu') |
|
save_as_images(output) |
|
return output |
|
|
|
def convert_to_images(obj): |
|
""" Convert an output tensor from BigGAN in a list of images. |
|
Params: |
|
obj: tensor or numpy array of shape (batch_size, channels, height, width) |
|
Output: |
|
list of Pillow Images of size (height, width) |
|
""" |
|
try: |
|
import PIL |
|
except ImportError: |
|
raise ImportError("Please install Pillow to use images: pip install Pillow") |
|
|
|
if not isinstance(obj, np.ndarray): |
|
obj = obj.detach().numpy() |
|
|
|
obj = obj.transpose((0, 2, 3, 1)) |
|
obj = np.clip(((obj + 1) / 2.0) * 256, 0, 255) |
|
|
|
img = [] |
|
for i, out in enumerate(obj): |
|
out_array = np.asarray(np.uint8(out), dtype=np.uint8) |
|
img.append(PIL.Image.fromarray(out_array)) |
|
return img |
|
|
|
def inference(initial_archi, initial_class): |
|
output = generate_images (initial_archi, initial_class, 1) |
|
PIL_output = convert_to_images(output) |
|
return PIL_output[0] |
|
|
|
|
|
|
|
title = "BigGAN" |
|
description = "BigGAN using various architecture models to generate images." |
|
article="Coming soon" |
|
|
|
examples = [ |
|
["biggan-deep-128", "dog"], |
|
["biggan-deep-256", 'dog'], |
|
["biggan-deep-512", 'dog'] |
|
] |
|
|
|
gr.Interface(inference, |
|
inputs=[gr.inputs.Dropdown(["biggan-deep-128", "biggan-deep-256", "biggan-deep-512"]), "text"], |
|
outputs= [gr.outputs.Image(type="pil",label="output")], |
|
examples=examples, |
|
title=title, |
|
description=description, |
|
article=article).launch( debug=True) |