Crystalcareai commited on
Commit
f118366
·
verified ·
1 Parent(s): e929320

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +14 -18
generate.py CHANGED
@@ -47,7 +47,7 @@ def custom_generate(
47
  with torch.no_grad():
48
  batch_size = input_ids.shape[0]
49
  finished_generating = torch.zeros(batch_size, dtype=torch.bool, device=device)
50
- generated_text = [''] * batch_size
51
 
52
  for cur_token_idx in range(max_new_tokens):
53
  # Sample the next token
@@ -67,14 +67,13 @@ def custom_generate(
67
  last_token_idx = (base_answer_ids != self.tokenizer.pad_token_id).nonzero(as_tuple=True)[0].max()
68
 
69
  new_ids_sampled = torch.multinomial(
70
- torch.nn.functional.softmax(new_answer_ids[last_token_idx] / temperature, dim=-1), 1
71
- )
72
 
73
  # Assign the new id to the last token
74
  if last_token_idx + 1 >= len(base_answer_ids):
75
  # Add padding everywhere
76
  new_padding = torch.full((batch_size, 1), self.tokenizer.pad_token_id, dtype=torch.long,
77
- device=device)
78
  input_ids = torch.cat([input_ids, new_padding], dim=-1)
79
  if attention_mask is not None:
80
  attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1)
@@ -82,27 +81,23 @@ def custom_generate(
82
  if attention_mask is not None:
83
  attention_mask[answer_idx, last_token_idx + 1] = 1
84
  input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
85
- generated_token_id = new_ids_sampled.item()
86
- generated_text[answer_idx] += self.tokenizer.decode([generated_token_id])
87
 
88
- if generated_token_id == self.tokenizer.eos_token_id or generated_token_id == self.tokenizer.bos_token_id or generated_token_id == self.tokenizer.pad_token_id:
89
  finished_generating[answer_idx] = 1
90
 
91
  # Check if the end token is generated
92
- if generated_token_id == self.tokenizer.convert_tokens_to_ids("</s>"):
93
  finished_generating[answer_idx] = 1
94
-
95
  if finished_generating.all():
96
  break
97
 
98
  if streamer is not None:
99
  streamer.put(new_ids_sampled)
100
 
101
- # Check if dynamic_temperature argument is present
102
- if 'dynamic_temperature' in kwargs and kwargs['dynamic_temperature'] is not None:
103
- return generated_text
104
 
105
- return generated_text
106
 
107
  def generate(
108
  self,
@@ -153,9 +148,10 @@ def generate(
153
  torch_dtype=torch.bfloat16,
154
  **model_kwargs,
155
  ):
156
- if max_new_tokens is None:
157
- max_new_tokens = 128
158
 
 
 
 
159
  # Set model attributes
160
  self.max_thoughts = n_ahead + n_ahead_talk + 1
161
  self.merged_talk_heads = merged_talk_heads
@@ -187,9 +183,9 @@ def generate(
187
  if attention_mask is not None:
188
  attention_mask = attention_mask.to(self.device)
189
 
190
- generated_text = custom_generate(
191
  self,
192
- input_ids=input_ids,
193
  attention_mask=attention_mask,
194
  max_new_tokens=max_new_tokens,
195
  min_length=min_length,
@@ -224,4 +220,4 @@ def generate(
224
  **model_kwargs,
225
  )
226
 
227
- return generated_text
 
47
  with torch.no_grad():
48
  batch_size = input_ids.shape[0]
49
  finished_generating = torch.zeros(batch_size, dtype=torch.bool, device=device)
50
+ generated_token_ids = torch.full((batch_size, max_new_tokens), self.tokenizer.pad_token_id, dtype=torch.long, device=device)
51
 
52
  for cur_token_idx in range(max_new_tokens):
53
  # Sample the next token
 
67
  last_token_idx = (base_answer_ids != self.tokenizer.pad_token_id).nonzero(as_tuple=True)[0].max()
68
 
69
  new_ids_sampled = torch.multinomial(
70
+ torch.nn.functional.softmax(new_answer_ids[last_token_idx] / temperature, dim=-1), 1)
 
71
 
72
  # Assign the new id to the last token
73
  if last_token_idx + 1 >= len(base_answer_ids):
74
  # Add padding everywhere
75
  new_padding = torch.full((batch_size, 1), self.tokenizer.pad_token_id, dtype=torch.long,
76
+ device=device)
77
  input_ids = torch.cat([input_ids, new_padding], dim=-1)
78
  if attention_mask is not None:
79
  attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1)
 
81
  if attention_mask is not None:
82
  attention_mask[answer_idx, last_token_idx + 1] = 1
83
  input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
84
+ generated_token_ids[answer_idx, cur_token_idx] = new_ids_sampled
 
85
 
86
+ if new_ids_sampled == self.tokenizer.eos_token_id or new_ids_sampled == self.tokenizer.bos_token_id or new_ids_sampled == self.tokenizer.pad_token_id:
87
  finished_generating[answer_idx] = 1
88
 
89
  # Check if the end token is generated
90
+ if new_ids_sampled == self.tokenizer.convert_tokens_to_ids("</s>"):
91
  finished_generating[answer_idx] = 1
92
+
93
  if finished_generating.all():
94
  break
95
 
96
  if streamer is not None:
97
  streamer.put(new_ids_sampled)
98
 
99
+ return generated_token_ids
 
 
100
 
 
101
 
102
  def generate(
103
  self,
 
148
  torch_dtype=torch.bfloat16,
149
  **model_kwargs,
150
  ):
 
 
151
 
152
+ if max_new_tokens is None:
153
+ max_new_tokens = 128
154
+
155
  # Set model attributes
156
  self.max_thoughts = n_ahead + n_ahead_talk + 1
157
  self.merged_talk_heads = merged_talk_heads
 
183
  if attention_mask is not None:
184
  attention_mask = attention_mask.to(self.device)
185
 
186
+ generated_token_ids = custom_generate(
187
  self,
188
+ input_ids=input_ids,
189
  attention_mask=attention_mask,
190
  max_new_tokens=max_new_tokens,
191
  min_length=min_length,
 
220
  **model_kwargs,
221
  )
222
 
223
+ return generated_token_ids