Spaces:
Running
Running
Adjust limit, turn off caching for logprobs
Browse files
app.py
CHANGED
@@ -96,11 +96,12 @@ tokenizer = st.cache_resource(AutoTokenizer.from_pretrained, show_spinner=False)
|
|
96 |
# Make sure the logprobs do not use up more than ~4 GB of memory
|
97 |
MAX_MEM = 4e9 / (torch.finfo(torch.float16).bits / 8)
|
98 |
# Select window lengths such that we are allowed to fill the whole window without running out of memory
|
99 |
-
# (otherwise the window length is irrelevant)
|
100 |
-
|
|
|
101 |
window_len_options = [
|
102 |
w for w in [8, 16, 32, 64, 128, 256, 512, 1024]
|
103 |
-
if w == 8 or w * (2 * w) *
|
104 |
]
|
105 |
window_len = st.select_slider(
|
106 |
r"Window size ($c_\text{max}$)",
|
@@ -109,8 +110,7 @@ window_len = st.select_slider(
|
|
109 |
)
|
110 |
# Now figure out how many tokens we are allowed to use:
|
111 |
# window_len * (num_tokens + window_len) * vocab_size <= MAX_MEM
|
112 |
-
max_tokens = int(MAX_MEM / (
|
113 |
-
max_tokens = min(max_tokens, 2048)
|
114 |
|
115 |
DEFAULT_TEXT = """
|
116 |
We present context length probing, a novel explanation technique for causal
|
@@ -151,10 +151,8 @@ with st.spinner("Loading model…"):
|
|
151 |
|
152 |
window_len = min(window_len, len(input_ids))
|
153 |
|
154 |
-
@st.cache_data(show_spinner=False)
|
155 |
@torch.inference_mode()
|
156 |
-
def get_logprobs(_model, _inputs
|
157 |
-
del cache_key
|
158 |
return _model(**_inputs).logits.log_softmax(dim=-1).to(torch.float16)
|
159 |
|
160 |
@st.cache_data(show_spinner=False)
|
@@ -179,7 +177,7 @@ def run_context_length_probing(_model, _tokenizer, _inputs, window_len, metric,
|
|
179 |
batch_logprobs = get_logprobs(
|
180 |
_model,
|
181 |
batch,
|
182 |
-
cache_key=(model_name, batch["input_ids"].cpu().numpy().tobytes())
|
183 |
)
|
184 |
batch_labels = batch["labels"]
|
185 |
if metric != "KL divergence":
|
|
|
96 |
# Make sure the logprobs do not use up more than ~4 GB of memory
|
97 |
MAX_MEM = 4e9 / (torch.finfo(torch.float16).bits / 8)
|
98 |
# Select window lengths such that we are allowed to fill the whole window without running out of memory
|
99 |
+
# (otherwise the window length is irrelevant); if using NLL, memory is not a consideration, but we want
|
100 |
+
# to limit runtime
|
101 |
+
multiplier = tokenizer.vocab_size if metric_name == "KL divergence" else 16384 # arbitrary number
|
102 |
window_len_options = [
|
103 |
w for w in [8, 16, 32, 64, 128, 256, 512, 1024]
|
104 |
+
if w == 8 or w * (2 * w) * multiplier <= MAX_MEM
|
105 |
]
|
106 |
window_len = st.select_slider(
|
107 |
r"Window size ($c_\text{max}$)",
|
|
|
110 |
)
|
111 |
# Now figure out how many tokens we are allowed to use:
|
112 |
# window_len * (num_tokens + window_len) * vocab_size <= MAX_MEM
|
113 |
+
max_tokens = int(MAX_MEM / (multiplier * window_len) - window_len)
|
|
|
114 |
|
115 |
DEFAULT_TEXT = """
|
116 |
We present context length probing, a novel explanation technique for causal
|
|
|
151 |
|
152 |
window_len = min(window_len, len(input_ids))
|
153 |
|
|
|
154 |
@torch.inference_mode()
|
155 |
+
def get_logprobs(_model, _inputs):
|
|
|
156 |
return _model(**_inputs).logits.log_softmax(dim=-1).to(torch.float16)
|
157 |
|
158 |
@st.cache_data(show_spinner=False)
|
|
|
177 |
batch_logprobs = get_logprobs(
|
178 |
_model,
|
179 |
batch,
|
180 |
+
#cache_key=(model_name, batch["input_ids"].cpu().numpy().tobytes())
|
181 |
)
|
182 |
batch_labels = batch["labels"]
|
183 |
if metric != "KL divergence":
|