cifkao commited on
Commit
dca8e6c
·
1 Parent(s): 3102d58

Adjust limit, turn off caching for logprobs

Browse files
Files changed (1) hide show
  1. app.py +7 -9
app.py CHANGED
@@ -96,11 +96,12 @@ tokenizer = st.cache_resource(AutoTokenizer.from_pretrained, show_spinner=False)
96
  # Make sure the logprobs do not use up more than ~4 GB of memory
97
  MAX_MEM = 4e9 / (torch.finfo(torch.float16).bits / 8)
98
  # Select window lengths such that we are allowed to fill the whole window without running out of memory
99
- # (otherwise the window length is irrelevant)
100
- logprobs_dim = tokenizer.vocab_size if metric_name == "KL divergence" else 1
 
101
  window_len_options = [
102
  w for w in [8, 16, 32, 64, 128, 256, 512, 1024]
103
- if w == 8 or w * (2 * w) * logprobs_dim <= MAX_MEM
104
  ]
105
  window_len = st.select_slider(
106
  r"Window size ($c_\text{max}$)",
@@ -109,8 +110,7 @@ window_len = st.select_slider(
109
  )
110
  # Now figure out how many tokens we are allowed to use:
111
  # window_len * (num_tokens + window_len) * vocab_size <= MAX_MEM
112
- max_tokens = int(MAX_MEM / (logprobs_dim * window_len) - window_len)
113
- max_tokens = min(max_tokens, 2048)
114
 
115
  DEFAULT_TEXT = """
116
  We present context length probing, a novel explanation technique for causal
@@ -151,10 +151,8 @@ with st.spinner("Loading model…"):
151
 
152
  window_len = min(window_len, len(input_ids))
153
 
154
- @st.cache_data(show_spinner=False)
155
  @torch.inference_mode()
156
- def get_logprobs(_model, _inputs, cache_key):
157
- del cache_key
158
  return _model(**_inputs).logits.log_softmax(dim=-1).to(torch.float16)
159
 
160
  @st.cache_data(show_spinner=False)
@@ -179,7 +177,7 @@ def run_context_length_probing(_model, _tokenizer, _inputs, window_len, metric,
179
  batch_logprobs = get_logprobs(
180
  _model,
181
  batch,
182
- cache_key=(model_name, batch["input_ids"].cpu().numpy().tobytes())
183
  )
184
  batch_labels = batch["labels"]
185
  if metric != "KL divergence":
 
96
  # Make sure the logprobs do not use up more than ~4 GB of memory
97
  MAX_MEM = 4e9 / (torch.finfo(torch.float16).bits / 8)
98
  # Select window lengths such that we are allowed to fill the whole window without running out of memory
99
+ # (otherwise the window length is irrelevant); if using NLL, memory is not a consideration, but we want
100
+ # to limit runtime
101
+ multiplier = tokenizer.vocab_size if metric_name == "KL divergence" else 16384 # arbitrary number
102
  window_len_options = [
103
  w for w in [8, 16, 32, 64, 128, 256, 512, 1024]
104
+ if w == 8 or w * (2 * w) * multiplier <= MAX_MEM
105
  ]
106
  window_len = st.select_slider(
107
  r"Window size ($c_\text{max}$)",
 
110
  )
111
  # Now figure out how many tokens we are allowed to use:
112
  # window_len * (num_tokens + window_len) * vocab_size <= MAX_MEM
113
+ max_tokens = int(MAX_MEM / (multiplier * window_len) - window_len)
 
114
 
115
  DEFAULT_TEXT = """
116
  We present context length probing, a novel explanation technique for causal
 
151
 
152
  window_len = min(window_len, len(input_ids))
153
 
 
154
  @torch.inference_mode()
155
+ def get_logprobs(_model, _inputs):
 
156
  return _model(**_inputs).logits.log_softmax(dim=-1).to(torch.float16)
157
 
158
  @st.cache_data(show_spinner=False)
 
177
  batch_logprobs = get_logprobs(
178
  _model,
179
  batch,
180
+ #cache_key=(model_name, batch["input_ids"].cpu().numpy().tobytes())
181
  )
182
  batch_labels = batch["labels"]
183
  if metric != "KL divergence":