Update pipeline.py
Browse files- pipeline.py +16 -12
pipeline.py
CHANGED
@@ -22,23 +22,27 @@ class LSTMTextGenerator(nn.Module, PyTorchModelHubMixin):
|
|
22 |
|
23 |
class PreTrainedPipeline():
|
24 |
def __init__(self, path=""):
|
25 |
-
self.model =
|
26 |
-
self.
|
27 |
-
|
|
|
|
|
|
|
|
|
28 |
def __call__(self, inputs: str):
|
29 |
-
seed_numerical_data = [char_to_index[char] for char in inputs]
|
30 |
with torch.no_grad():
|
31 |
-
input_sequence = torch.LongTensor([seed_numerical_data]).to(device)
|
32 |
-
hidden = model.init_hidden(1)
|
33 |
|
34 |
generated_text = inputs # Initialize generated text with seed text
|
35 |
-
temperature = 0.7
|
36 |
|
37 |
for _ in range(500):
|
38 |
-
output, hidden = model(input_sequence, hidden)
|
39 |
probabilities = nn.functional.softmax(output[-1, 0] / temperature, dim=0).cpu().numpy()
|
40 |
-
predicted_index = random.choices(range(output_size), weights=probabilities, k=1)[0]
|
41 |
-
generated_text += index_to_char[predicted_index] # Append the generated character to the text
|
42 |
-
input_sequence = torch.LongTensor([[predicted_index]]).to(device)
|
43 |
|
44 |
-
return
|
|
|
22 |
|
23 |
class PreTrainedPipeline():
|
24 |
def __init__(self, path=""):
|
25 |
+
self.model = LSTMTextGenerator.from_pretrained("miittnnss/lstm-textgen-pets")
|
26 |
+
self.chars = ""
|
27 |
+
self.char_to_index = {char: index for index, char in enumerate(self.chars)}
|
28 |
+
self.index_to_char = {index: char for char, index in self.char_to_index.items()}
|
29 |
+
self.output_size = len(chars)
|
30 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
31 |
+
|
32 |
def __call__(self, inputs: str):
|
33 |
+
seed_numerical_data = [self.char_to_index[char] for char in inputs]
|
34 |
with torch.no_grad():
|
35 |
+
input_sequence = torch.LongTensor([seed_numerical_data]).to(self.device)
|
36 |
+
hidden = self.model.init_hidden(1)
|
37 |
|
38 |
generated_text = inputs # Initialize generated text with seed text
|
39 |
+
temperature = 0.7 # Temperature for temperature sampling
|
40 |
|
41 |
for _ in range(500):
|
42 |
+
output, hidden = self.model(input_sequence, hidden)
|
43 |
probabilities = nn.functional.softmax(output[-1, 0] / temperature, dim=0).cpu().numpy()
|
44 |
+
predicted_index = random.choices(range(self.output_size), weights=probabilities, k=1)[0]
|
45 |
+
generated_text += self.index_to_char[predicted_index] # Append the generated character to the text
|
46 |
+
input_sequence = torch.LongTensor([[predicted_index]]).to(self.device)
|
47 |
|
48 |
+
return generated_text
|