nzkashan commited on
Commit
567b273
1 Parent(s): f093a61

Create pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +40 -0
pipeline.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import nltk
3
+ import io
4
+ import base64
5
+ import shutil
6
+ from torchvision import transforms
7
+
8
+ from pytorch_pretrained_biggan import BigGAN, one_hot_from_names, truncated_noise_sample
9
+
10
+ class PreTrainedPipeline():
11
+ def __init__(self, path=""):
12
+ """
13
+ Initialize model
14
+ """
15
+ nltk.download('wordnet')
16
+ self.model = BigGAN.from_pretrained(path)
17
+ self.truncation = 0.1
18
+
19
+ def __call__(self, inputs: str):
20
+ """
21
+ Args:
22
+ inputs (:obj:`str`):
23
+ a string containing some text
24
+ Return:
25
+ A :obj:`PIL.Image` with the raw image representation as PIL.
26
+ """
27
+ class_vector = one_hot_from_names([inputs], batch_size=1)
28
+ if type(class_vector) == type(None):
29
+ raise ValueError("Input is not in ImageNet")
30
+ noise_vector = truncated_noise_sample(truncation=self.truncation, batch_size=1)
31
+ noise_vector = torch.from_numpy(noise_vector)
32
+ class_vector = torch.from_numpy(class_vector)
33
+ with torch.no_grad():
34
+ output = self.model(noise_vector, class_vector, self.truncation)
35
+
36
+ # Scale image
37
+ img = output[0]
38
+ img = (img + 1) / 2.0
39
+ img = transforms.ToPILImage()(img)
40
+ return img