Update generate.py
Browse files- generate.py +14 -18
generate.py
CHANGED
@@ -47,7 +47,7 @@ def custom_generate(
|
|
47 |
with torch.no_grad():
|
48 |
batch_size = input_ids.shape[0]
|
49 |
finished_generating = torch.zeros(batch_size, dtype=torch.bool, device=device)
|
50 |
-
|
51 |
|
52 |
for cur_token_idx in range(max_new_tokens):
|
53 |
# Sample the next token
|
@@ -67,14 +67,13 @@ def custom_generate(
|
|
67 |
last_token_idx = (base_answer_ids != self.tokenizer.pad_token_id).nonzero(as_tuple=True)[0].max()
|
68 |
|
69 |
new_ids_sampled = torch.multinomial(
|
70 |
-
torch.nn.functional.softmax(new_answer_ids[last_token_idx] / temperature, dim=-1), 1
|
71 |
-
)
|
72 |
|
73 |
# Assign the new id to the last token
|
74 |
if last_token_idx + 1 >= len(base_answer_ids):
|
75 |
# Add padding everywhere
|
76 |
new_padding = torch.full((batch_size, 1), self.tokenizer.pad_token_id, dtype=torch.long,
|
77 |
-
|
78 |
input_ids = torch.cat([input_ids, new_padding], dim=-1)
|
79 |
if attention_mask is not None:
|
80 |
attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1)
|
@@ -82,27 +81,23 @@ def custom_generate(
|
|
82 |
if attention_mask is not None:
|
83 |
attention_mask[answer_idx, last_token_idx + 1] = 1
|
84 |
input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
|
85 |
-
|
86 |
-
generated_text[answer_idx] += self.tokenizer.decode([generated_token_id])
|
87 |
|
88 |
-
if
|
89 |
finished_generating[answer_idx] = 1
|
90 |
|
91 |
# Check if the end token is generated
|
92 |
-
if
|
93 |
finished_generating[answer_idx] = 1
|
94 |
-
|
95 |
if finished_generating.all():
|
96 |
break
|
97 |
|
98 |
if streamer is not None:
|
99 |
streamer.put(new_ids_sampled)
|
100 |
|
101 |
-
|
102 |
-
if 'dynamic_temperature' in kwargs and kwargs['dynamic_temperature'] is not None:
|
103 |
-
return generated_text
|
104 |
|
105 |
-
return generated_text
|
106 |
|
107 |
def generate(
|
108 |
self,
|
@@ -153,9 +148,10 @@ def generate(
|
|
153 |
torch_dtype=torch.bfloat16,
|
154 |
**model_kwargs,
|
155 |
):
|
156 |
-
if max_new_tokens is None:
|
157 |
-
max_new_tokens = 128
|
158 |
|
|
|
|
|
|
|
159 |
# Set model attributes
|
160 |
self.max_thoughts = n_ahead + n_ahead_talk + 1
|
161 |
self.merged_talk_heads = merged_talk_heads
|
@@ -187,9 +183,9 @@ def generate(
|
|
187 |
if attention_mask is not None:
|
188 |
attention_mask = attention_mask.to(self.device)
|
189 |
|
190 |
-
|
191 |
self,
|
192 |
-
input_ids=input_ids,
|
193 |
attention_mask=attention_mask,
|
194 |
max_new_tokens=max_new_tokens,
|
195 |
min_length=min_length,
|
@@ -224,4 +220,4 @@ def generate(
|
|
224 |
**model_kwargs,
|
225 |
)
|
226 |
|
227 |
-
return
|
|
|
47 |
with torch.no_grad():
|
48 |
batch_size = input_ids.shape[0]
|
49 |
finished_generating = torch.zeros(batch_size, dtype=torch.bool, device=device)
|
50 |
+
generated_token_ids = torch.full((batch_size, max_new_tokens), self.tokenizer.pad_token_id, dtype=torch.long, device=device)
|
51 |
|
52 |
for cur_token_idx in range(max_new_tokens):
|
53 |
# Sample the next token
|
|
|
67 |
last_token_idx = (base_answer_ids != self.tokenizer.pad_token_id).nonzero(as_tuple=True)[0].max()
|
68 |
|
69 |
new_ids_sampled = torch.multinomial(
|
70 |
+
torch.nn.functional.softmax(new_answer_ids[last_token_idx] / temperature, dim=-1), 1)
|
|
|
71 |
|
72 |
# Assign the new id to the last token
|
73 |
if last_token_idx + 1 >= len(base_answer_ids):
|
74 |
# Add padding everywhere
|
75 |
new_padding = torch.full((batch_size, 1), self.tokenizer.pad_token_id, dtype=torch.long,
|
76 |
+
device=device)
|
77 |
input_ids = torch.cat([input_ids, new_padding], dim=-1)
|
78 |
if attention_mask is not None:
|
79 |
attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1)
|
|
|
81 |
if attention_mask is not None:
|
82 |
attention_mask[answer_idx, last_token_idx + 1] = 1
|
83 |
input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
|
84 |
+
generated_token_ids[answer_idx, cur_token_idx] = new_ids_sampled
|
|
|
85 |
|
86 |
+
if new_ids_sampled == self.tokenizer.eos_token_id or new_ids_sampled == self.tokenizer.bos_token_id or new_ids_sampled == self.tokenizer.pad_token_id:
|
87 |
finished_generating[answer_idx] = 1
|
88 |
|
89 |
# Check if the end token is generated
|
90 |
+
if new_ids_sampled == self.tokenizer.convert_tokens_to_ids("</s>"):
|
91 |
finished_generating[answer_idx] = 1
|
92 |
+
|
93 |
if finished_generating.all():
|
94 |
break
|
95 |
|
96 |
if streamer is not None:
|
97 |
streamer.put(new_ids_sampled)
|
98 |
|
99 |
+
return generated_token_ids
|
|
|
|
|
100 |
|
|
|
101 |
|
102 |
def generate(
|
103 |
self,
|
|
|
148 |
torch_dtype=torch.bfloat16,
|
149 |
**model_kwargs,
|
150 |
):
|
|
|
|
|
151 |
|
152 |
+
if max_new_tokens is None:
|
153 |
+
max_new_tokens = 128
|
154 |
+
|
155 |
# Set model attributes
|
156 |
self.max_thoughts = n_ahead + n_ahead_talk + 1
|
157 |
self.merged_talk_heads = merged_talk_heads
|
|
|
183 |
if attention_mask is not None:
|
184 |
attention_mask = attention_mask.to(self.device)
|
185 |
|
186 |
+
generated_token_ids = custom_generate(
|
187 |
self,
|
188 |
+
input_ids=input_ids,
|
189 |
attention_mask=attention_mask,
|
190 |
max_new_tokens=max_new_tokens,
|
191 |
min_length=min_length,
|
|
|
220 |
**model_kwargs,
|
221 |
)
|
222 |
|
223 |
+
return generated_token_ids
|