Spaces:
Running
Running
Fix sneaky problem caused by caching
Browse files
app.py
CHANGED
@@ -233,6 +233,11 @@ def generate(model, inputs, metric, window_len, max_new_tokens, **kwargs):
|
|
233 |
model_kwargs = dict(use_cache=True)
|
234 |
for i in range(max_steps):
|
235 |
pbar.progress(i / max_steps, f"{i}/{max_steps}")
|
|
|
|
|
|
|
|
|
|
|
236 |
model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
237 |
model_outputs = model(**model_inputs)
|
238 |
model_kwargs = model._update_model_kwargs_for_generation(model_outputs, model_kwargs, is_encoder_decoder=False)
|
@@ -250,7 +255,6 @@ def generate(model, inputs, metric, window_len, max_new_tokens, **kwargs):
|
|
250 |
input_ids = torch.cat([input_ids, torch.tensor([[next_token]])], dim=1)
|
251 |
if input_ids.shape[1] > window_len:
|
252 |
input_ids = input_ids[:, 1:]
|
253 |
-
model_kwargs.update(use_cache=False, past_key_values=None)
|
254 |
if logprobs_window.shape[0] == window_len:
|
255 |
logprobs.append(
|
256 |
logprobs_window[torch.arange(window_len), input_ids.squeeze(0)]
|
|
|
233 |
model_kwargs = dict(use_cache=True)
|
234 |
for i in range(max_steps):
|
235 |
pbar.progress(i / max_steps, f"{i}/{max_steps}")
|
236 |
+
|
237 |
+
if input_ids.shape[1] == window_len:
|
238 |
+
model_kwargs.update(use_cache=False)
|
239 |
+
if "past_key_values" in model_kwargs:
|
240 |
+
del model_kwargs["past_key_values"]
|
241 |
model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
242 |
model_outputs = model(**model_inputs)
|
243 |
model_kwargs = model._update_model_kwargs_for_generation(model_outputs, model_kwargs, is_encoder_decoder=False)
|
|
|
255 |
input_ids = torch.cat([input_ids, torch.tensor([[next_token]])], dim=1)
|
256 |
if input_ids.shape[1] > window_len:
|
257 |
input_ids = input_ids[:, 1:]
|
|
|
258 |
if logprobs_window.shape[0] == window_len:
|
259 |
logprobs.append(
|
260 |
logprobs_window[torch.arange(window_len), input_ids.squeeze(0)]
|