cifkao commited on
Commit
3102d58
1 Parent(s): 6f46ddf

More efficient NLL implementation

Browse files
Files changed (1) hide show
  1. app.py +31 -11
app.py CHANGED
@@ -17,7 +17,7 @@ def get_windows_batched(examples: BatchEncoding, window_len: int, stride: int =
17
  return BatchEncoding({
18
  k: [
19
  t[i][j : j + window_len] + [
20
- pad_id if k == "input_ids" else 0
21
  ] * (j + window_len - len(t[i]))
22
  for i in range(len(examples["input_ids"]))
23
  for j in range(0, len(examples["input_ids"][i]) - 1, stride)
@@ -43,7 +43,10 @@ def ids_to_readable_tokens(tokenizer, ids, strip_whitespace=False):
43
  return result
44
 
45
  def nll_score(logprobs, labels):
46
- return -logprobs[:, torch.arange(len(labels)), labels]
 
 
 
47
 
48
  def kl_div_score(logprobs):
49
  log_p = logprobs[
@@ -75,8 +78,18 @@ if not compact_layout:
75
  """
76
  )
77
 
 
 
 
 
 
 
 
 
78
  model_name = st.selectbox("Model", ["distilgpt2", "gpt2", "EleutherAI/gpt-neo-125m"])
79
- metric_name = st.selectbox("Metric", ["KL divergence", "NLL loss"], index=0)
 
 
80
 
81
  tokenizer = st.cache_resource(AutoTokenizer.from_pretrained, show_spinner=False)(model_name, use_fast=False)
82
 
@@ -84,9 +97,10 @@ tokenizer = st.cache_resource(AutoTokenizer.from_pretrained, show_spinner=False)
84
  MAX_MEM = 4e9 / (torch.finfo(torch.float16).bits / 8)
85
  # Select window lengths such that we are allowed to fill the whole window without running out of memory
86
  # (otherwise the window length is irrelevant)
 
87
  window_len_options = [
88
  w for w in [8, 16, 32, 64, 128, 256, 512, 1024]
89
- if w == 8 or w * (2 * w) * tokenizer.vocab_size <= MAX_MEM
90
  ]
91
  window_len = st.select_slider(
92
  r"Window size ($c_\text{max}$)",
@@ -95,7 +109,8 @@ window_len = st.select_slider(
95
  )
96
  # Now figure out how many tokens we are allowed to use:
97
  # window_len * (num_tokens + window_len) * vocab_size <= MAX_MEM
98
- max_tokens = int(MAX_MEM / (tokenizer.vocab_size * window_len) - window_len)
 
99
 
100
  DEFAULT_TEXT = """
101
  We present context length probing, a novel explanation technique for causal
@@ -117,6 +132,7 @@ if tokenizer.eos_token:
117
  text += tokenizer.eos_token
118
  inputs = tokenizer([text])
119
  [input_ids] = inputs["input_ids"]
 
120
  num_user_tokens = len(input_ids) - (1 if tokenizer.eos_token else 0)
121
 
122
  if num_user_tokens < 1:
@@ -160,13 +176,17 @@ def run_context_length_probing(_model, _tokenizer, _inputs, window_len, metric,
160
  for i in range(0, num_items, batch_size):
161
  pbar.progress(i / num_items, f"{i}/{num_items}")
162
  batch = {k: v[i:i + batch_size] for k, v in inputs_sliding.items()}
163
- logprobs.append(
164
- get_logprobs(
165
- _model,
166
- batch,
167
- cache_key=(model_name, batch["input_ids"].cpu().numpy().tobytes())
168
- )
169
  )
 
 
 
 
 
 
170
  logprobs = torch.cat(logprobs, dim=0)
171
  pbar.empty()
172
 
 
17
  return BatchEncoding({
18
  k: [
19
  t[i][j : j + window_len] + [
20
+ pad_id if k in ["input_ids", "labels"] else 0
21
  ] * (j + window_len - len(t[i]))
22
  for i in range(len(examples["input_ids"]))
23
  for j in range(0, len(examples["input_ids"][i]) - 1, stride)
 
43
  return result
44
 
45
  def nll_score(logprobs, labels):
46
+ if logprobs.shape[-1] == 1:
47
+ return -logprobs.squeeze(-1)
48
+ else:
49
+ return -logprobs[:, torch.arange(len(labels)), labels]
50
 
51
  def kl_div_score(logprobs):
52
  log_p = logprobs[
 
78
  """
79
  )
80
 
81
+ generation_mode = False
82
+ # st.radio("Mode", ["Standard", "Generation"], horizontal=True) == "Generation"
83
+ # st.caption(
84
+ # "In standard mode, we analyze the model's predictions on the input text. "
85
+ # "In generation mode, we generate a continuation of the input text "
86
+ # "and visualize the contributions of different contexts to each generated token."
87
+ # )
88
+
89
  model_name = st.selectbox("Model", ["distilgpt2", "gpt2", "EleutherAI/gpt-neo-125m"])
90
+ metric_name = st.radio(
91
+ "Metric", (["KL divergence"] if not generation_mode else []) + ["NLL loss"], index=0, horizontal=True
92
+ )
93
 
94
  tokenizer = st.cache_resource(AutoTokenizer.from_pretrained, show_spinner=False)(model_name, use_fast=False)
95
 
 
97
  MAX_MEM = 4e9 / (torch.finfo(torch.float16).bits / 8)
98
  # Select window lengths such that we are allowed to fill the whole window without running out of memory
99
  # (otherwise the window length is irrelevant)
100
+ logprobs_dim = tokenizer.vocab_size if metric_name == "KL divergence" else 1
101
  window_len_options = [
102
  w for w in [8, 16, 32, 64, 128, 256, 512, 1024]
103
+ if w == 8 or w * (2 * w) * logprobs_dim <= MAX_MEM
104
  ]
105
  window_len = st.select_slider(
106
  r"Window size ($c_\text{max}$)",
 
109
  )
110
  # Now figure out how many tokens we are allowed to use:
111
  # window_len * (num_tokens + window_len) * vocab_size <= MAX_MEM
112
+ max_tokens = int(MAX_MEM / (logprobs_dim * window_len) - window_len)
113
+ max_tokens = min(max_tokens, 2048)
114
 
115
  DEFAULT_TEXT = """
116
  We present context length probing, a novel explanation technique for causal
 
132
  text += tokenizer.eos_token
133
  inputs = tokenizer([text])
134
  [input_ids] = inputs["input_ids"]
135
+ inputs["labels"] = [[*input_ids[1:], tokenizer.eos_token_id]]
136
  num_user_tokens = len(input_ids) - (1 if tokenizer.eos_token else 0)
137
 
138
  if num_user_tokens < 1:
 
176
  for i in range(0, num_items, batch_size):
177
  pbar.progress(i / num_items, f"{i}/{num_items}")
178
  batch = {k: v[i:i + batch_size] for k, v in inputs_sliding.items()}
179
+ batch_logprobs = get_logprobs(
180
+ _model,
181
+ batch,
182
+ cache_key=(model_name, batch["input_ids"].cpu().numpy().tobytes())
 
 
183
  )
184
+ batch_labels = batch["labels"]
185
+ if metric != "KL divergence":
186
+ batch_logprobs = torch.gather(
187
+ batch_logprobs, dim=-1, index=batch_labels[..., None]
188
+ )
189
+ logprobs.append(batch_logprobs)
190
  logprobs = torch.cat(logprobs, dim=0)
191
  pbar.empty()
192