Spaces:
Sleeping
Sleeping
Use caching when possible
Browse files
app.py
CHANGED
@@ -90,7 +90,7 @@ generation_mode = st.radio(
|
|
90 |
horizontal=True, label_visibility="collapsed"
|
91 |
) == "Generation mode"
|
92 |
st.caption(
|
93 |
-
"In standard mode, we analyze the model's predictions on the input text. "
|
94 |
"In generation mode, we generate a continuation of the input text (prompt) "
|
95 |
"and visualize the contributions of different contexts to each generated token."
|
96 |
)
|
@@ -128,7 +128,7 @@ with st.empty():
|
|
128 |
with st.expander("Generation options", expanded=False):
|
129 |
generate_kwargs["max_new_tokens"] = st.slider(
|
130 |
"Max. number of generated tokens",
|
131 |
-
min_value=8, max_value=min(1024, max_tokens), value=min(128, max_tokens)
|
132 |
)
|
133 |
col1, col2, col3, col4 = st.columns(4)
|
134 |
with col1:
|
@@ -222,8 +222,7 @@ def get_logits_processor(temperature, top_p, typical_p, repetition_penalty) -> L
|
|
222 |
def generate(model, inputs, metric, window_len, max_new_tokens, **kwargs):
|
223 |
assert metric == "NLL loss"
|
224 |
start = max(0, inputs["input_ids"].shape[1] - window_len + 1)
|
225 |
-
|
226 |
-
del inputs_window["labels"]
|
227 |
|
228 |
logits_warper = get_logits_processor(**kwargs)
|
229 |
|
@@ -231,13 +230,16 @@ def generate(model, inputs, metric, window_len, max_new_tokens, **kwargs):
|
|
231 |
eos_idx = None
|
232 |
pbar = st.progress(0)
|
233 |
max_steps = max_new_tokens + window_len - 1
|
|
|
234 |
for i in range(max_steps):
|
235 |
pbar.progress(i / max_steps, f"{i}/{max_steps}")
|
236 |
-
|
237 |
-
|
|
|
|
|
238 |
logprobs_window = logits_window.log_softmax(dim=-1)
|
239 |
if eos_idx is None:
|
240 |
-
probs_next = logits_warper(
|
241 |
next_token = torch.multinomial(probs_next, num_samples=1).item()
|
242 |
if next_token == tokenizer.eos_token_id or i >= max_new_tokens - 1:
|
243 |
eos_idx = i
|
@@ -245,12 +247,13 @@ def generate(model, inputs, metric, window_len, max_new_tokens, **kwargs):
|
|
245 |
next_token = tokenizer.eos_token_id
|
246 |
new_ids.append(next_token)
|
247 |
|
248 |
-
|
249 |
-
if
|
250 |
-
|
|
|
251 |
if logprobs_window.shape[0] == window_len:
|
252 |
logprobs.append(
|
253 |
-
logprobs_window[torch.arange(window_len),
|
254 |
)
|
255 |
|
256 |
if eos_idx is not None and i - eos_idx >= window_len - 1:
|
|
|
90 |
horizontal=True, label_visibility="collapsed"
|
91 |
) == "Generation mode"
|
92 |
st.caption(
|
93 |
+
"In standard mode, we analyze the model's one-step-ahead predictions on the input text. "
|
94 |
"In generation mode, we generate a continuation of the input text (prompt) "
|
95 |
"and visualize the contributions of different contexts to each generated token."
|
96 |
)
|
|
|
128 |
with st.expander("Generation options", expanded=False):
|
129 |
generate_kwargs["max_new_tokens"] = st.slider(
|
130 |
"Max. number of generated tokens",
|
131 |
+
min_value=8, max_value=min(1024, max_tokens), step=8, value=min(128, max_tokens)
|
132 |
)
|
133 |
col1, col2, col3, col4 = st.columns(4)
|
134 |
with col1:
|
|
|
222 |
def generate(model, inputs, metric, window_len, max_new_tokens, **kwargs):
|
223 |
assert metric == "NLL loss"
|
224 |
start = max(0, inputs["input_ids"].shape[1] - window_len + 1)
|
225 |
+
input_ids = inputs["input_ids"][:, start:]
|
|
|
226 |
|
227 |
logits_warper = get_logits_processor(**kwargs)
|
228 |
|
|
|
230 |
eos_idx = None
|
231 |
pbar = st.progress(0)
|
232 |
max_steps = max_new_tokens + window_len - 1
|
233 |
+
model_kwargs = dict(use_cache=True)
|
234 |
for i in range(max_steps):
|
235 |
pbar.progress(i / max_steps, f"{i}/{max_steps}")
|
236 |
+
model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
237 |
+
model_outputs = model(**model_inputs)
|
238 |
+
model_kwargs = model._update_model_kwargs_for_generation(model_outputs, model_kwargs, is_encoder_decoder=False)
|
239 |
+
logits_window = model_outputs.logits.squeeze(0)
|
240 |
logprobs_window = logits_window.log_softmax(dim=-1)
|
241 |
if eos_idx is None:
|
242 |
+
probs_next = logits_warper(input_ids, logits_window[[-1]]).softmax(dim=-1)
|
243 |
next_token = torch.multinomial(probs_next, num_samples=1).item()
|
244 |
if next_token == tokenizer.eos_token_id or i >= max_new_tokens - 1:
|
245 |
eos_idx = i
|
|
|
247 |
next_token = tokenizer.eos_token_id
|
248 |
new_ids.append(next_token)
|
249 |
|
250 |
+
input_ids = torch.cat([input_ids, torch.tensor([[next_token]])], dim=1)
|
251 |
+
if input_ids.shape[1] > window_len:
|
252 |
+
input_ids = input_ids[:, 1:]
|
253 |
+
model_kwargs.update(use_cache=False, past_key_values=None)
|
254 |
if logprobs_window.shape[0] == window_len:
|
255 |
logprobs.append(
|
256 |
+
logprobs_window[torch.arange(window_len), input_ids.squeeze(0)]
|
257 |
)
|
258 |
|
259 |
if eos_idx is not None and i - eos_idx >= window_len - 1:
|