|
import torch |
|
import torch.nn as nn |
|
import random |
|
from huggingface_hub import PyTorchModelHubMixin |
|
|
|
class LSTMTextGenerator(nn.Module, PyTorchModelHubMixin): |
|
def __init__(self, input_size=45, hidden_size=512, output_size=45, num_layers=2, dropout=0.5): |
|
super(LSTMTextGenerator, self).__init__() |
|
self.embedding = nn.Embedding(input_size, hidden_size) |
|
self.lstm = nn.LSTM(hidden_size, hidden_size, num_layers, batch_first=True, dropout=dropout, bidirectional=False) |
|
self.fc = nn.Linear(hidden_size, output_size) |
|
self.num_layers = num_layers |
|
self.hidden_size = hidden_size |
|
|
|
def forward(self, x, hidden): |
|
x = x.to(torch.long) |
|
x = self.embedding(x) |
|
x, hidden = self.lstm(x, hidden) |
|
x = self.fc(x) |
|
return x, hidden |
|
|
|
def init_hidden(self, batch_size): |
|
return (torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device), |
|
torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device)) |
|
|
|
|
|
class PreTrainedPipeline(): |
|
def __init__(self, path=""): |
|
self.model = LSTMTextGenerator.from_pretrained("miittnnss/lstm-textgen-pets") |
|
self.chars = "!',.;ACDFGHIMORSTWabcdefghijklmnopqrstuvwxy" |
|
self.char_to_index = {char: index for index, char in enumerate(self.chars)} |
|
self.index_to_char = {index: char for char, index in self.char_to_index.items()} |
|
self.output_size = len(chars) |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
def __call__(self, inputs: str): |
|
seed_numerical_data = [self.char_to_index[char] for char in inputs] |
|
with torch.no_grad(): |
|
input_sequence = torch.LongTensor([seed_numerical_data]).to(self.device) |
|
hidden = self.model.init_hidden(1) |
|
|
|
generated_text = inputs |
|
temperature = 0.7 |
|
|
|
for _ in range(500): |
|
output, hidden = self.model(input_sequence, hidden) |
|
probabilities = nn.functional.softmax(output[-1, 0] / temperature, dim=0).cpu().numpy() |
|
predicted_index = random.choices(range(self.output_size), weights=probabilities, k=1)[0] |
|
generated_text += self.index_to_char[predicted_index] |
|
input_sequence = torch.LongTensor([[predicted_index]]).to(self.device) |
|
|
|
return generated_text |