Балаганский Никита Николаевич
commited on
Commit
•
e2f0b3b
1
Parent(s):
116ed83
add logging
Browse files- generator.py +2 -2
generator.py
CHANGED
@@ -199,12 +199,12 @@ class Generator:
|
|
199 |
return input_ids, past, ended_sequences
|
200 |
|
201 |
def get_input_ids(self, input_prompt, num_samples):
|
202 |
-
input_ids = torch.tensor([[self.lm.config.bos_token_id]])
|
203 |
if input_prompt is not None:
|
204 |
input_prompt = self.tokenizer(
|
205 |
input_prompt, return_tensors="pt"
|
206 |
).input_ids
|
207 |
-
input_ids =
|
208 |
input_ids = input_ids.repeat(num_samples, 1).to(self.device)
|
209 |
past = None
|
210 |
ended_sequences = torch.zeros(
|
|
|
199 |
return input_ids, past, ended_sequences
|
200 |
|
201 |
def get_input_ids(self, input_prompt, num_samples):
|
202 |
+
#input_ids = torch.tensor([[self.lm.config.bos_token_id]])
|
203 |
if input_prompt is not None:
|
204 |
input_prompt = self.tokenizer(
|
205 |
input_prompt, return_tensors="pt"
|
206 |
).input_ids
|
207 |
+
input_ids = input_prompt
|
208 |
input_ids = input_ids.repeat(num_samples, 1).to(self.device)
|
209 |
past = None
|
210 |
ended_sequences = torch.zeros(
|