Crystalcareai
commited on
Update generate.py
Browse files- generate.py +11 -28
generate.py
CHANGED
@@ -6,6 +6,7 @@ from transformers.generation.utils import (
|
|
6 |
)
|
7 |
from transformers import TextStreamer
|
8 |
|
|
|
9 |
def custom_generate(
|
10 |
self,
|
11 |
input_ids,
|
@@ -42,17 +43,12 @@ def custom_generate(
|
|
42 |
synced_gpus=None,
|
43 |
**kwargs,
|
44 |
):
|
45 |
-
if input_ids is None or input_ids.nelement() == 0:
|
46 |
-
# If input_ids is None or an empty tensor, create a default input tensor
|
47 |
-
input_ids = torch.LongTensor([[self.tokenizer.bos_token_id]]).to(self.device)
|
48 |
-
attention_mask = torch.ones_like(input_ids).to(self.device)
|
49 |
-
|
50 |
device = input_ids.device
|
51 |
with torch.no_grad():
|
52 |
-
|
53 |
-
finished_generating = torch.zeros(batch_size, dtype=torch.bool, device=device)
|
54 |
-
generated_token_ids = torch.full((batch_size, max_new_tokens), self.tokenizer.pad_token_id, dtype=torch.long, device=device)
|
55 |
|
|
|
|
|
56 |
for cur_token_idx in range(max_new_tokens):
|
57 |
# Sample the next token
|
58 |
new_ids = self(
|
@@ -76,7 +72,7 @@ def custom_generate(
|
|
76 |
# Assign the new id to the last token
|
77 |
if last_token_idx + 1 >= len(base_answer_ids):
|
78 |
# Add padding everywhere
|
79 |
-
new_padding = torch.full((
|
80 |
device=device)
|
81 |
input_ids = torch.cat([input_ids, new_padding], dim=-1)
|
82 |
if attention_mask is not None:
|
@@ -85,7 +81,6 @@ def custom_generate(
|
|
85 |
if attention_mask is not None:
|
86 |
attention_mask[answer_idx, last_token_idx + 1] = 1
|
87 |
input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
|
88 |
-
generated_token_ids[answer_idx, cur_token_idx] = new_ids_sampled
|
89 |
|
90 |
if new_ids_sampled == self.tokenizer.eos_token_id or new_ids_sampled == self.tokenizer.bos_token_id or new_ids_sampled == self.tokenizer.pad_token_id:
|
91 |
finished_generating[answer_idx] = 1
|
@@ -100,7 +95,8 @@ def custom_generate(
|
|
100 |
if streamer is not None:
|
101 |
streamer.put(new_ids_sampled)
|
102 |
|
103 |
-
|
|
|
104 |
|
105 |
|
106 |
def generate(
|
@@ -137,7 +133,7 @@ def generate(
|
|
137 |
forced_eos_token_id=None,
|
138 |
remove_invalid_values=None,
|
139 |
synced_gpus=None,
|
140 |
-
n_ahead=
|
141 |
n_ahead_talk=4,
|
142 |
merged_talk_heads=True,
|
143 |
merged_lm_and_talk_heads=False,
|
@@ -152,10 +148,6 @@ def generate(
|
|
152 |
torch_dtype=torch.bfloat16,
|
153 |
**model_kwargs,
|
154 |
):
|
155 |
-
|
156 |
-
if max_new_tokens is None:
|
157 |
-
max_new_tokens = 128
|
158 |
-
|
159 |
# Set model attributes
|
160 |
self.max_thoughts = n_ahead + n_ahead_talk + 1
|
161 |
self.merged_talk_heads = merged_talk_heads
|
@@ -178,18 +170,9 @@ def generate(
|
|
178 |
self.rm_initialized = True
|
179 |
self.original_mode = False
|
180 |
|
181 |
-
|
182 |
-
if isinstance(input_ids, str):
|
183 |
-
input_ids = self.tokenizer.encode(input_ids, return_tensors='pt')
|
184 |
-
|
185 |
-
# Move input_ids and attention_mask to the same device as the model
|
186 |
-
input_ids = input_ids.to(self.device)
|
187 |
-
if attention_mask is not None:
|
188 |
-
attention_mask = attention_mask.to(self.device)
|
189 |
-
|
190 |
-
generated_token_ids = custom_generate(
|
191 |
self,
|
192 |
-
input_ids=input_ids,
|
193 |
attention_mask=attention_mask,
|
194 |
max_new_tokens=max_new_tokens,
|
195 |
min_length=min_length,
|
@@ -224,4 +207,4 @@ def generate(
|
|
224 |
**model_kwargs,
|
225 |
)
|
226 |
|
227 |
-
return generated_token_ids
|
|
|
6 |
)
|
7 |
from transformers import TextStreamer
|
8 |
|
9 |
+
|
10 |
def custom_generate(
|
11 |
self,
|
12 |
input_ids,
|
|
|
43 |
synced_gpus=None,
|
44 |
**kwargs,
|
45 |
):
|
|
|
|
|
|
|
|
|
|
|
46 |
device = input_ids.device
|
47 |
with torch.no_grad():
|
48 |
+
finished_generating = torch.zeros(len(input_ids), dtype=torch.bool, device=device)
|
|
|
|
|
49 |
|
50 |
+
if max_new_tokens is None:
|
51 |
+
max_new_tokens = 50 # Default value if not specified
|
52 |
for cur_token_idx in range(max_new_tokens):
|
53 |
# Sample the next token
|
54 |
new_ids = self(
|
|
|
72 |
# Assign the new id to the last token
|
73 |
if last_token_idx + 1 >= len(base_answer_ids):
|
74 |
# Add padding everywhere
|
75 |
+
new_padding = torch.full((len(input_ids), 1), self.tokenizer.pad_token_id, dtype=torch.long,
|
76 |
device=device)
|
77 |
input_ids = torch.cat([input_ids, new_padding], dim=-1)
|
78 |
if attention_mask is not None:
|
|
|
81 |
if attention_mask is not None:
|
82 |
attention_mask[answer_idx, last_token_idx + 1] = 1
|
83 |
input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
|
|
|
84 |
|
85 |
if new_ids_sampled == self.tokenizer.eos_token_id or new_ids_sampled == self.tokenizer.bos_token_id or new_ids_sampled == self.tokenizer.pad_token_id:
|
86 |
finished_generating[answer_idx] = 1
|
|
|
95 |
if streamer is not None:
|
96 |
streamer.put(new_ids_sampled)
|
97 |
|
98 |
+
generated_token_ids = input_ids.tolist()
|
99 |
+
return generated_token_ids, attention_mask
|
100 |
|
101 |
|
102 |
def generate(
|
|
|
133 |
forced_eos_token_id=None,
|
134 |
remove_invalid_values=None,
|
135 |
synced_gpus=None,
|
136 |
+
n_ahead=12,
|
137 |
n_ahead_talk=4,
|
138 |
merged_talk_heads=True,
|
139 |
merged_lm_and_talk_heads=False,
|
|
|
148 |
torch_dtype=torch.bfloat16,
|
149 |
**model_kwargs,
|
150 |
):
|
|
|
|
|
|
|
|
|
151 |
# Set model attributes
|
152 |
self.max_thoughts = n_ahead + n_ahead_talk + 1
|
153 |
self.merged_talk_heads = merged_talk_heads
|
|
|
170 |
self.rm_initialized = True
|
171 |
self.original_mode = False
|
172 |
|
173 |
+
generated_token_ids, attention_mask = custom_generate(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
self,
|
175 |
+
input_ids=input_ids,
|
176 |
attention_mask=attention_mask,
|
177 |
max_new_tokens=max_new_tokens,
|
178 |
min_length=min_length,
|
|
|
207 |
**model_kwargs,
|
208 |
)
|
209 |
|
210 |
+
return generated_token_ids, attention_mask
|