import torch import torch.nn as nn import random from huggingface_hub import PyTorchModelHubMixin class LSTMTextGenerator(nn.Module, PyTorchModelHubMixin): def __init__(self, input_size=45, hidden_size=512, output_size=45, num_layers=2, dropout=0.5): super(LSTMTextGenerator, self).__init__() self.embedding = nn.Embedding(input_size, hidden_size) self.lstm = nn.LSTM(hidden_size, hidden_size, num_layers, batch_first=True, dropout=dropout, bidirectional=False) self.fc = nn.Linear(hidden_size, output_size) self.num_layers = num_layers self.hidden_size = hidden_size def forward(self, x, hidden): x = x.to(torch.long) x = self.embedding(x) x, hidden = self.lstm(x, hidden) x = self.fc(x) return x, hidden def init_hidden(self, batch_size): return (torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device), torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device)) class PreTrainedPipeline(): def __init__(self, path=""): self.model = LSTMTextGenerator.from_pretrained("miittnnss/lstm-textgen-pets") self.chars = "!',.;ACDFGHIMORSTWabcdefghijklmnopqrstuvwxy" self.char_to_index = {char: index for index, char in enumerate(self.chars)} self.index_to_char = {index: char for char, index in self.char_to_index.items()} self.output_size = len(chars) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def __call__(self, inputs: str): seed_numerical_data = [self.char_to_index[char] for char in inputs] with torch.no_grad(): input_sequence = torch.LongTensor([seed_numerical_data]).to(self.device) hidden = self.model.init_hidden(1) generated_text = inputs # Initialize generated text with seed text temperature = 0.7 # Temperature for temperature sampling for _ in range(500): output, hidden = self.model(input_sequence, hidden) probabilities = nn.functional.softmax(output[-1, 0] / temperature, dim=0).cpu().numpy() predicted_index = random.choices(range(self.output_size), weights=probabilities, k=1)[0] generated_text += self.index_to_char[predicted_index] # Append the generated character to the text input_sequence = torch.LongTensor([[predicted_index]]).to(self.device) return generated_text