Crystalcareai commited on
Commit
55d98bf
·
verified ·
1 Parent(s): a5c6c16

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +11 -28
generate.py CHANGED
@@ -6,6 +6,7 @@ from transformers.generation.utils import (
6
  )
7
  from transformers import TextStreamer
8
 
 
9
  def custom_generate(
10
  self,
11
  input_ids,
@@ -42,17 +43,12 @@ def custom_generate(
42
  synced_gpus=None,
43
  **kwargs,
44
  ):
45
- if input_ids is None or input_ids.nelement() == 0:
46
- # If input_ids is None or an empty tensor, create a default input tensor
47
- input_ids = torch.LongTensor([[self.tokenizer.bos_token_id]]).to(self.device)
48
- attention_mask = torch.ones_like(input_ids).to(self.device)
49
-
50
  device = input_ids.device
51
  with torch.no_grad():
52
- batch_size = input_ids.shape[0]
53
- finished_generating = torch.zeros(batch_size, dtype=torch.bool, device=device)
54
- generated_token_ids = torch.full((batch_size, max_new_tokens), self.tokenizer.pad_token_id, dtype=torch.long, device=device)
55
 
 
 
56
  for cur_token_idx in range(max_new_tokens):
57
  # Sample the next token
58
  new_ids = self(
@@ -76,7 +72,7 @@ def custom_generate(
76
  # Assign the new id to the last token
77
  if last_token_idx + 1 >= len(base_answer_ids):
78
  # Add padding everywhere
79
- new_padding = torch.full((batch_size, 1), self.tokenizer.pad_token_id, dtype=torch.long,
80
  device=device)
81
  input_ids = torch.cat([input_ids, new_padding], dim=-1)
82
  if attention_mask is not None:
@@ -85,7 +81,6 @@ def custom_generate(
85
  if attention_mask is not None:
86
  attention_mask[answer_idx, last_token_idx + 1] = 1
87
  input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
88
- generated_token_ids[answer_idx, cur_token_idx] = new_ids_sampled
89
 
90
  if new_ids_sampled == self.tokenizer.eos_token_id or new_ids_sampled == self.tokenizer.bos_token_id or new_ids_sampled == self.tokenizer.pad_token_id:
91
  finished_generating[answer_idx] = 1
@@ -100,7 +95,8 @@ def custom_generate(
100
  if streamer is not None:
101
  streamer.put(new_ids_sampled)
102
 
103
- return generated_token_ids
 
104
 
105
 
106
  def generate(
@@ -137,7 +133,7 @@ def generate(
137
  forced_eos_token_id=None,
138
  remove_invalid_values=None,
139
  synced_gpus=None,
140
- n_ahead=8,
141
  n_ahead_talk=4,
142
  merged_talk_heads=True,
143
  merged_lm_and_talk_heads=False,
@@ -152,10 +148,6 @@ def generate(
152
  torch_dtype=torch.bfloat16,
153
  **model_kwargs,
154
  ):
155
-
156
- if max_new_tokens is None:
157
- max_new_tokens = 128
158
-
159
  # Set model attributes
160
  self.max_thoughts = n_ahead + n_ahead_talk + 1
161
  self.merged_talk_heads = merged_talk_heads
@@ -178,18 +170,9 @@ def generate(
178
  self.rm_initialized = True
179
  self.original_mode = False
180
 
181
- # Check if the input is a string (for compatibility with text-generation-webui)
182
- if isinstance(input_ids, str):
183
- input_ids = self.tokenizer.encode(input_ids, return_tensors='pt')
184
-
185
- # Move input_ids and attention_mask to the same device as the model
186
- input_ids = input_ids.to(self.device)
187
- if attention_mask is not None:
188
- attention_mask = attention_mask.to(self.device)
189
-
190
- generated_token_ids = custom_generate(
191
  self,
192
- input_ids=input_ids,
193
  attention_mask=attention_mask,
194
  max_new_tokens=max_new_tokens,
195
  min_length=min_length,
@@ -224,4 +207,4 @@ def generate(
224
  **model_kwargs,
225
  )
226
 
227
- return generated_token_ids
 
6
  )
7
  from transformers import TextStreamer
8
 
9
+
10
  def custom_generate(
11
  self,
12
  input_ids,
 
43
  synced_gpus=None,
44
  **kwargs,
45
  ):
 
 
 
 
 
46
  device = input_ids.device
47
  with torch.no_grad():
48
+ finished_generating = torch.zeros(len(input_ids), dtype=torch.bool, device=device)
 
 
49
 
50
+ if max_new_tokens is None:
51
+ max_new_tokens = 50 # Default value if not specified
52
  for cur_token_idx in range(max_new_tokens):
53
  # Sample the next token
54
  new_ids = self(
 
72
  # Assign the new id to the last token
73
  if last_token_idx + 1 >= len(base_answer_ids):
74
  # Add padding everywhere
75
+ new_padding = torch.full((len(input_ids), 1), self.tokenizer.pad_token_id, dtype=torch.long,
76
  device=device)
77
  input_ids = torch.cat([input_ids, new_padding], dim=-1)
78
  if attention_mask is not None:
 
81
  if attention_mask is not None:
82
  attention_mask[answer_idx, last_token_idx + 1] = 1
83
  input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
 
84
 
85
  if new_ids_sampled == self.tokenizer.eos_token_id or new_ids_sampled == self.tokenizer.bos_token_id or new_ids_sampled == self.tokenizer.pad_token_id:
86
  finished_generating[answer_idx] = 1
 
95
  if streamer is not None:
96
  streamer.put(new_ids_sampled)
97
 
98
+ generated_token_ids = input_ids.tolist()
99
+ return generated_token_ids, attention_mask
100
 
101
 
102
  def generate(
 
133
  forced_eos_token_id=None,
134
  remove_invalid_values=None,
135
  synced_gpus=None,
136
+ n_ahead=12,
137
  n_ahead_talk=4,
138
  merged_talk_heads=True,
139
  merged_lm_and_talk_heads=False,
 
148
  torch_dtype=torch.bfloat16,
149
  **model_kwargs,
150
  ):
 
 
 
 
151
  # Set model attributes
152
  self.max_thoughts = n_ahead + n_ahead_talk + 1
153
  self.merged_talk_heads = merged_talk_heads
 
170
  self.rm_initialized = True
171
  self.original_mode = False
172
 
173
+ generated_token_ids, attention_mask = custom_generate(
 
 
 
 
 
 
 
 
 
174
  self,
175
+ input_ids=input_ids,
176
  attention_mask=attention_mask,
177
  max_new_tokens=max_new_tokens,
178
  min_length=min_length,
 
207
  **model_kwargs,
208
  )
209
 
210
+ return generated_token_ids, attention_mask