miittnnss commited on
Commit
d2d0116
1 Parent(s): db8e3e3

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +16 -12
pipeline.py CHANGED
@@ -22,23 +22,27 @@ class LSTMTextGenerator(nn.Module, PyTorchModelHubMixin):
22
 
23
  class PreTrainedPipeline():
24
  def __init__(self, path=""):
25
- self.model = model = LSTMTextGenerator.from_pretrained("miittnnss/lstm-textgen-pets")
26
- self.char_to_index = {char: index for index, char in enumerate(chars)}
27
-
 
 
 
 
28
  def __call__(self, inputs: str):
29
- seed_numerical_data = [char_to_index[char] for char in inputs]
30
  with torch.no_grad():
31
- input_sequence = torch.LongTensor([seed_numerical_data]).to(device)
32
- hidden = model.init_hidden(1)
33
 
34
  generated_text = inputs # Initialize generated text with seed text
35
- temperature = 0.7 # Temperature for temperature sampling
36
 
37
  for _ in range(500):
38
- output, hidden = model(input_sequence, hidden)
39
  probabilities = nn.functional.softmax(output[-1, 0] / temperature, dim=0).cpu().numpy()
40
- predicted_index = random.choices(range(output_size), weights=probabilities, k=1)[0]
41
- generated_text += index_to_char[predicted_index] # Append the generated character to the text
42
- input_sequence = torch.LongTensor([[predicted_index]]).to(device)
43
 
44
- return output
 
22
 
23
  class PreTrainedPipeline():
24
  def __init__(self, path=""):
25
+ self.model = LSTMTextGenerator.from_pretrained("miittnnss/lstm-textgen-pets")
26
+ self.chars = ""
27
+ self.char_to_index = {char: index for index, char in enumerate(self.chars)}
28
+ self.index_to_char = {index: char for char, index in self.char_to_index.items()}
29
+ self.output_size = len(chars)
30
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
+
32
  def __call__(self, inputs: str):
33
+ seed_numerical_data = [self.char_to_index[char] for char in inputs]
34
  with torch.no_grad():
35
+ input_sequence = torch.LongTensor([seed_numerical_data]).to(self.device)
36
+ hidden = self.model.init_hidden(1)
37
 
38
  generated_text = inputs # Initialize generated text with seed text
39
+ temperature = 0.7 # Temperature for temperature sampling
40
 
41
  for _ in range(500):
42
+ output, hidden = self.model(input_sequence, hidden)
43
  probabilities = nn.functional.softmax(output[-1, 0] / temperature, dim=0).cpu().numpy()
44
+ predicted_index = random.choices(range(self.output_size), weights=probabilities, k=1)[0]
45
+ generated_text += self.index_to_char[predicted_index] # Append the generated character to the text
46
+ input_sequence = torch.LongTensor([[predicted_index]]).to(self.device)
47
 
48
+ return generated_text