cifkao commited on
Commit
ffa03c8
1 Parent(s): ebb68fb

Add sampling options

Browse files
Files changed (1) hide show
  1. app.py +47 -12
app.py CHANGED
@@ -1,5 +1,5 @@
1
  from pathlib import Path
2
- from typing import Dict, Hashable
3
 
4
  import streamlit as st
5
  import streamlit.components.v1 as components
@@ -7,6 +7,7 @@ import numpy as np
7
  import torch
8
  import torch.nn.functional as F
9
  from transformers import AutoModelForCausalLM, AutoTokenizer, BatchEncoding, GPT2LMHeadModel, PreTrainedTokenizer
 
10
 
11
  root_dir = Path(__file__).resolve().parent
12
  highlighted_text_component = components.declare_component(
@@ -118,12 +119,30 @@ window_len = st.select_slider(
118
  max_tokens = int(MAX_MEM / (multiplier * window_len) - window_len)
119
  max_tokens = min(max_tokens, 4096)
120
 
121
- max_new_tokens = None
122
  if generation_mode:
123
- max_new_tokens = st.slider(
124
- "Max. number of generated tokens",
125
- min_value=8, max_value=min(1024, max_tokens), value=min(128, max_tokens)
126
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
  DEFAULT_TEXT = """
129
  We present context length probing, a novel explanation technique for causal
@@ -180,13 +199,27 @@ def get_logprobs(model, inputs, metric):
180
  pbar.empty()
181
  return logprobs
182
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  @torch.inference_mode()
184
- def generate(model, inputs, metric, window_len, max_new_tokens):
185
  assert metric == "NLL loss"
186
  start = max(0, inputs["input_ids"].shape[1] - window_len + 1)
187
  inputs_window = {k: v[:, start:] for k, v in inputs.items()}
188
  del inputs_window["labels"]
189
 
 
 
190
  new_ids, logprobs = [], []
191
  eos_idx = None
192
  pbar = st.progress(0)
@@ -194,9 +227,11 @@ def generate(model, inputs, metric, window_len, max_new_tokens):
194
  for i in range(max_steps):
195
  pbar.progress(i / max_steps, f"{i}/{max_steps}")
196
  inputs_window["attention_mask"] = torch.ones_like(inputs_window["input_ids"], dtype=torch.long)
197
- logprobs_window = model(**inputs_window).logits.log_softmax(dim=-1).squeeze(0)
 
198
  if eos_idx is None:
199
- next_token = torch.multinomial(logprobs_window[-1].exp(), num_samples=1).item()
 
200
  if next_token == tokenizer.eos_token_id or i >= max_new_tokens - 1:
201
  eos_idx = i
202
  else:
@@ -225,7 +260,7 @@ def run_context_length_probing(
225
  window_len: int,
226
  metric: str,
227
  generation_mode: bool,
228
- max_new_tokens: int,
229
  cache_key: Hashable
230
  ):
231
  del cache_key
@@ -240,7 +275,7 @@ def run_context_length_probing(
240
  inputs=_inputs.convert_to_tensors("pt"),
241
  metric=metric,
242
  window_len=window_len,
243
- max_new_tokens=max_new_tokens
244
  )
245
  output_ids = [*input_ids, *new_ids]
246
  window_len = logprobs.shape[1]
@@ -288,7 +323,7 @@ output_ids, scores = run_context_length_probing(
288
  window_len=window_len,
289
  metric=metric_name,
290
  generation_mode=generation_mode,
291
- max_new_tokens=max_new_tokens,
292
  cache_key=(model_name, text),
293
  )
294
  tokens = ids_to_readable_tokens(tokenizer, output_ids)
 
1
  from pathlib import Path
2
+ from typing import Any, Dict, Hashable
3
 
4
  import streamlit as st
5
  import streamlit.components.v1 as components
 
7
  import torch
8
  import torch.nn.functional as F
9
  from transformers import AutoModelForCausalLM, AutoTokenizer, BatchEncoding, GPT2LMHeadModel, PreTrainedTokenizer
10
+ from transformers import LogitsProcessorList, RepetitionPenaltyLogitsProcessor, TemperatureLogitsWarper, TopPLogitsWarper, TypicalLogitsWarper
11
 
12
  root_dir = Path(__file__).resolve().parent
13
  highlighted_text_component = components.declare_component(
 
119
  max_tokens = int(MAX_MEM / (multiplier * window_len) - window_len)
120
  max_tokens = min(max_tokens, 4096)
121
 
122
+ generate_kwargs = {}
123
  if generation_mode:
124
+ with st.expander("Generation options", expanded=False):
125
+ generate_kwargs["max_new_tokens"] = st.slider(
126
+ "Max. number of generated tokens",
127
+ min_value=8, max_value=min(1024, max_tokens), value=min(128, max_tokens)
128
+ )
129
+ col1, col2, col3, col4 = st.columns(4)
130
+ with col1:
131
+ generate_kwargs["temperature"] = st.number_input(
132
+ min_value=0.01, value=0.9, step=0.05, label="`temperature`"
133
+ )
134
+ with col2:
135
+ generate_kwargs["top_p"] = st.number_input(
136
+ min_value=0., value=0.95, max_value=1., step=0.05, label="`top_p`"
137
+ )
138
+ with col3:
139
+ generate_kwargs["typical_p"] = st.number_input(
140
+ min_value=0., value=1., max_value=1., step=0.05, label="`typical_p`"
141
+ )
142
+ with col4:
143
+ generate_kwargs["repetition_penalty"] = st.number_input(
144
+ min_value=1., value=1., step=0.05, label="`repetition_penalty`"
145
+ )
146
 
147
  DEFAULT_TEXT = """
148
  We present context length probing, a novel explanation technique for causal
 
199
  pbar.empty()
200
  return logprobs
201
 
202
+ def get_logits_processor(temperature, top_p, typical_p, repetition_penalty) -> LogitsProcessorList:
203
+ processor = LogitsProcessorList()
204
+ if repetition_penalty != 1.0:
205
+ processor.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
206
+ if temperature != 1.0:
207
+ processor.append(TemperatureLogitsWarper(temperature))
208
+ if top_p < 1.0:
209
+ processor.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=1))
210
+ if typical_p < 1.0:
211
+ processor.append(TypicalLogitsWarper(mass=typical_p, min_tokens_to_keep=1))
212
+ return processor
213
+
214
  @torch.inference_mode()
215
+ def generate(model, inputs, metric, window_len, max_new_tokens, **kwargs):
216
  assert metric == "NLL loss"
217
  start = max(0, inputs["input_ids"].shape[1] - window_len + 1)
218
  inputs_window = {k: v[:, start:] for k, v in inputs.items()}
219
  del inputs_window["labels"]
220
 
221
+ logits_warper = get_logits_processor(**kwargs)
222
+
223
  new_ids, logprobs = [], []
224
  eos_idx = None
225
  pbar = st.progress(0)
 
227
  for i in range(max_steps):
228
  pbar.progress(i / max_steps, f"{i}/{max_steps}")
229
  inputs_window["attention_mask"] = torch.ones_like(inputs_window["input_ids"], dtype=torch.long)
230
+ logits_window = model(**inputs_window).logits.squeeze(0)
231
+ logprobs_window = logits_window.log_softmax(dim=-1)
232
  if eos_idx is None:
233
+ probs_next = logits_warper(inputs_window["input_ids"], logits_window[[-1]]).softmax(dim=-1)
234
+ next_token = torch.multinomial(probs_next, num_samples=1).item()
235
  if next_token == tokenizer.eos_token_id or i >= max_new_tokens - 1:
236
  eos_idx = i
237
  else:
 
260
  window_len: int,
261
  metric: str,
262
  generation_mode: bool,
263
+ generate_kwargs: Dict[str, Any],
264
  cache_key: Hashable
265
  ):
266
  del cache_key
 
275
  inputs=_inputs.convert_to_tensors("pt"),
276
  metric=metric,
277
  window_len=window_len,
278
+ **generate_kwargs
279
  )
280
  output_ids = [*input_ids, *new_ids]
281
  window_len = logprobs.shape[1]
 
323
  window_len=window_len,
324
  metric=metric_name,
325
  generation_mode=generation_mode,
326
+ generate_kwargs=generate_kwargs,
327
  cache_key=(model_name, text),
328
  )
329
  tokens = ids_to_readable_tokens(tokenizer, output_ids)