cifkao commited on
Commit
10ac8e4
·
1 Parent(s): ab89a9d

Fix sneaky problem caused by caching

Browse files
Files changed (1) hide show
  1. app.py +5 -1
app.py CHANGED
@@ -233,6 +233,11 @@ def generate(model, inputs, metric, window_len, max_new_tokens, **kwargs):
233
  model_kwargs = dict(use_cache=True)
234
  for i in range(max_steps):
235
  pbar.progress(i / max_steps, f"{i}/{max_steps}")
 
 
 
 
 
236
  model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
237
  model_outputs = model(**model_inputs)
238
  model_kwargs = model._update_model_kwargs_for_generation(model_outputs, model_kwargs, is_encoder_decoder=False)
@@ -250,7 +255,6 @@ def generate(model, inputs, metric, window_len, max_new_tokens, **kwargs):
250
  input_ids = torch.cat([input_ids, torch.tensor([[next_token]])], dim=1)
251
  if input_ids.shape[1] > window_len:
252
  input_ids = input_ids[:, 1:]
253
- model_kwargs.update(use_cache=False, past_key_values=None)
254
  if logprobs_window.shape[0] == window_len:
255
  logprobs.append(
256
  logprobs_window[torch.arange(window_len), input_ids.squeeze(0)]
 
233
  model_kwargs = dict(use_cache=True)
234
  for i in range(max_steps):
235
  pbar.progress(i / max_steps, f"{i}/{max_steps}")
236
+
237
+ if input_ids.shape[1] == window_len:
238
+ model_kwargs.update(use_cache=False)
239
+ if "past_key_values" in model_kwargs:
240
+ del model_kwargs["past_key_values"]
241
  model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
242
  model_outputs = model(**model_inputs)
243
  model_kwargs = model._update_model_kwargs_for_generation(model_outputs, model_kwargs, is_encoder_decoder=False)
 
255
  input_ids = torch.cat([input_ids, torch.tensor([[next_token]])], dim=1)
256
  if input_ids.shape[1] > window_len:
257
  input_ids = input_ids[:, 1:]
 
258
  if logprobs_window.shape[0] == window_len:
259
  logprobs.append(
260
  logprobs_window[torch.arange(window_len), input_ids.squeeze(0)]