hvlgo commited on
Commit
51e990c
·
verified ·
1 Parent(s): 370aa4b

Update ts_generation_mixin.py

Browse files
Files changed (1) hide show
  1. ts_generation_mixin.py +34 -17
ts_generation_mixin.py CHANGED
@@ -6,8 +6,38 @@ from transformers.generation import validate_stopping_criteria, EosTokenCriteria
6
  from transformers.generation.utils import GenerateNonBeamOutput, GenerateEncoderDecoderOutput, GenerateDecoderOnlyOutput, GenerationConfig, GenerateOutput
7
  from transformers.utils import ModelOutput
8
 
 
9
  class TSGenerationMixin(GenerationMixin):
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def _greedy_search(
12
  self,
13
  input_ids: torch.Tensor,
@@ -26,19 +56,7 @@ class TSGenerationMixin(GenerationMixin):
26
  **model_kwargs,
27
  ) -> Union[GenerateNonBeamOutput, torch.Tensor]:
28
  input_ids = input_ids.to(self.device)
29
- initial_input_length = input_ids.shape[1]
30
- if len(input_ids.shape) == 2:
31
- batch_size, cur_len = input_ids.shape
32
- if cur_len < self.config.input_token_len:
33
- raise ValueError(
34
- f"Input length must be at least {self.config.input_token_len}")
35
- elif cur_len % self.config.input_token_len != 0:
36
- new_len = (cur_len // self.config.input_token_len) * \
37
- self.config.input_token_len
38
- input_ids = input_ids[:, -new_len:]
39
- else:
40
- raise ValueError('Input shape must be: [batch_size, seq_len]')
41
-
42
  # init values
43
  logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
44
  stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
@@ -106,9 +124,8 @@ class TSGenerationMixin(GenerationMixin):
106
  batch_size, dtype=torch.long, device=input_ids.device)
107
  model_kwargs["cache_position"] = torch.arange(
108
  cur_len, device=input_ids.device)
109
- true_seq_len = input_ids.shape[1] // self.config.input_token_len
110
  model_kwargs["attention_mask"] = model_kwargs["attention_mask"][:, -true_seq_len:]
111
-
112
  max_length = stopping_criteria.max_length
113
  while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
114
  # prepare model inputs
@@ -129,7 +146,7 @@ class TSGenerationMixin(GenerationMixin):
129
  if synced_gpus and this_peer_finished:
130
  continue # don't waste resources running the code we don't need
131
 
132
- next_token_logits = outputs.logits[:, -1, :]
133
 
134
  # pre-process distribution
135
  next_tokens_scores = logits_processor(input_ids, next_token_logits)
@@ -212,7 +229,7 @@ class TSGenerationMixin(GenerationMixin):
212
  past_key_values=model_kwargs.get("past_key_values"),
213
  )
214
  else:
215
- return input_ids[:, -(max_length - initial_input_length):]
216
 
217
  def _update_model_kwargs_for_generation(
218
  self,
 
6
  from transformers.generation.utils import GenerateNonBeamOutput, GenerateEncoderDecoderOutput, GenerateDecoderOnlyOutput, GenerationConfig, GenerateOutput
7
  from transformers.utils import ModelOutput
8
 
9
+
10
  class TSGenerationMixin(GenerationMixin):
11
 
12
+ @torch.no_grad()
13
+ def generate(
14
+ self,
15
+ inputs: Optional[torch.Tensor] = None,
16
+ generation_config: Optional[GenerationConfig] = None,
17
+ logits_processor: Optional[LogitsProcessorList] = None,
18
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
19
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
20
+ synced_gpus: Optional[bool] = None,
21
+ assistant_model: Optional["PreTrainedModel"] = None,
22
+ streamer: Optional["BaseStreamer"] = None,
23
+ negative_prompt_ids: Optional[torch.Tensor] = None,
24
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
25
+ **kwargs,
26
+ ) -> Union[GenerateOutput, torch.LongTensor]:
27
+ if len(inputs.shape) == 2:
28
+ batch_size, cur_len = inputs.shape
29
+ if cur_len < self.config.input_token_len:
30
+ raise ValueError(
31
+ f"Input length must be at least {self.config.input_token_len}")
32
+ elif cur_len % self.config.input_token_len != 0:
33
+ new_len = (cur_len // self.config.input_token_len) * \
34
+ self.config.input_token_len
35
+ inputs = inputs[:, -new_len:]
36
+ else:
37
+ raise ValueError('Input shape must be: [batch_size, seq_len]')
38
+ return super().generate(inputs=inputs, generation_config=generation_config, logits_processor=logits_processor, stopping_criteria=stopping_criteria, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, synced_gpus=synced_gpus, assistant_model=assistant_model, streamer=streamer, negative_prompt_ids=negative_prompt_ids, negative_prompt_attention_mask=negative_prompt_attention_mask, **kwargs)
39
+
40
+
41
  def _greedy_search(
42
  self,
43
  input_ids: torch.Tensor,
 
56
  **model_kwargs,
57
  ) -> Union[GenerateNonBeamOutput, torch.Tensor]:
58
  input_ids = input_ids.to(self.device)
59
+ batch_size, cur_len = input_ids.shape
 
 
 
 
 
 
 
 
 
 
 
 
60
  # init values
61
  logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
62
  stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
 
124
  batch_size, dtype=torch.long, device=input_ids.device)
125
  model_kwargs["cache_position"] = torch.arange(
126
  cur_len, device=input_ids.device)
127
+ true_seq_len = cur_len // self.config.input_token_len
128
  model_kwargs["attention_mask"] = model_kwargs["attention_mask"][:, -true_seq_len:]
 
129
  max_length = stopping_criteria.max_length
130
  while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
131
  # prepare model inputs
 
146
  if synced_gpus and this_peer_finished:
147
  continue # don't waste resources running the code we don't need
148
 
149
+ next_token_logits = outputs.logits
150
 
151
  # pre-process distribution
152
  next_tokens_scores = logits_processor(input_ids, next_token_logits)
 
229
  past_key_values=model_kwargs.get("past_key_values"),
230
  )
231
  else:
232
+ return input_ids[:, -(max_length - cur_len):]
233
 
234
  def _update_model_kwargs_for_generation(
235
  self,