miittnnss commited on
Commit
c6b29b7
1 Parent(s): 2690b0a

Create pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +55 -0
pipeline.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import transforms
4
+
5
+ class Generator(nn.Module):
6
+ def __init__(self, nz=128, ngf=64, nc=3):
7
+ super(Generator, self).__init__()
8
+ self.main = nn.Sequential(
9
+ nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
10
+ nn.BatchNorm2d(ngf * 8),
11
+ nn.LeakyReLU(0.2, inplace=True),
12
+ nn.Dropout(0.2),
13
+
14
+ nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
15
+ nn.BatchNorm2d(ngf * 4),
16
+ nn.LeakyReLU(0.2, inplace=True),
17
+
18
+ nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
19
+ nn.BatchNorm2d(ngf * 2),
20
+ nn.LeakyReLU(0.2, inplace=True),
21
+
22
+ nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
23
+ nn.BatchNorm2d(ngf),
24
+ nn.LeakyReLU(0.2, inplace=True),
25
+
26
+ nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
27
+ nn.Tanh()
28
+ )
29
+
30
+ def forward(self, input):
31
+ output = self.main(input)
32
+ return output
33
+
34
+ class PreTrainedPipeline():
35
+ def __init__(self, path=""):
36
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
37
+ self.model = Generator().to(self.device)
38
+ self.model.load_state_dict(torch.load("pytorch_model.bin", map_location=device))
39
+
40
+ def __call__(self, inputs: str):
41
+ """
42
+ Args:
43
+ inputs (:obj:`str`):
44
+ a string containing some text
45
+ Return:
46
+ A :obj:`PIL.Image` with the raw image representation as PIL.
47
+ """
48
+ noise = torch.randn(1, 128, 1, 1, device=self.device)
49
+ with torch.no_grad():
50
+ output = self.model(noise).cpu()
51
+
52
+ img = output[0]
53
+ img = (img + 1) / 2
54
+ img = transforms.ToPILImage()(img)
55
+ return img