Spaces:
Runtime error
Runtime error
import torch | |
import nltk | |
import io | |
import base64 | |
import shutil | |
from torchvision import transforms | |
from pytorch_pretrained_biggan import BigGAN, one_hot_from_names, truncated_noise_sample | |
class PreTrainedPipeline(): | |
def __init__(self, path=""): | |
""" | |
Initialize model | |
""" | |
nltk.download('wordnet') | |
self.model = BigGAN.from_pretrained(path) | |
self.truncation = 0.1 | |
def __call__(self, inputs: str): | |
""" | |
Args: | |
inputs (:obj:`str`): | |
a string containing some text | |
Return: | |
A :obj:`PIL.Image` with the raw image representation as PIL. | |
""" | |
class_vector = one_hot_from_names([inputs], batch_size=1) | |
if type(class_vector) == type(None): | |
raise ValueError("Input is not in ImageNet") | |
noise_vector = truncated_noise_sample(truncation=self.truncation, batch_size=1) | |
noise_vector = torch.from_numpy(noise_vector) | |
class_vector = torch.from_numpy(class_vector) | |
with torch.no_grad(): | |
output = self.model(noise_vector, class_vector, self.truncation) | |
# Scale image | |
img = output[0] | |
img = (img + 1) / 2.0 | |
img = transforms.ToPILImage()(img) | |
return img |