miittnnss commited on
Commit
f9ef14f
1 Parent(s): 04873a5

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +22 -0
pipeline.py CHANGED
@@ -1,3 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  class PreTrainedPipeline():
2
  def __init__(self, path=""):
3
  # IMPLEMENT_THIS
 
1
+
2
+ class LSTMTextGenerator(nn.Module):
3
+ def __init__(self, input_size, hidden_size, output_size, num_layers=2, dropout=0.5):
4
+ super(LSTMTextGenerator, self).__init__()
5
+ self.embedding = nn.Embedding(input_size, hidden_size)
6
+ self.lstm = nn.LSTM(hidden_size, hidden_size, num_layers, batch_first=True, dropout=dropout, bidirectional=False)
7
+ self.fc = nn.Linear(hidden_size, output_size)
8
+ self.num_layers = num_layers
9
+ self.hidden_size = hidden_size
10
+
11
+ def forward(self, x, hidden):
12
+ x = x.to(torch.long)
13
+ x = self.embedding(x)
14
+ x, hidden = self.lstm(x, hidden)
15
+ x = self.fc(x)
16
+ return x, hidden
17
+
18
+ def init_hidden(self, batch_size):
19
+ return (torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device),
20
+ torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device))
21
+
22
+
23
  class PreTrainedPipeline():
24
  def __init__(self, path=""):
25
  # IMPLEMENT_THIS