Spaces:
Running
Running
More efficient NLL implementation
Browse files
app.py
CHANGED
@@ -17,7 +17,7 @@ def get_windows_batched(examples: BatchEncoding, window_len: int, stride: int =
|
|
17 |
return BatchEncoding({
|
18 |
k: [
|
19 |
t[i][j : j + window_len] + [
|
20 |
-
pad_id if k
|
21 |
] * (j + window_len - len(t[i]))
|
22 |
for i in range(len(examples["input_ids"]))
|
23 |
for j in range(0, len(examples["input_ids"][i]) - 1, stride)
|
@@ -43,7 +43,10 @@ def ids_to_readable_tokens(tokenizer, ids, strip_whitespace=False):
|
|
43 |
return result
|
44 |
|
45 |
def nll_score(logprobs, labels):
|
46 |
-
|
|
|
|
|
|
|
47 |
|
48 |
def kl_div_score(logprobs):
|
49 |
log_p = logprobs[
|
@@ -75,8 +78,18 @@ if not compact_layout:
|
|
75 |
"""
|
76 |
)
|
77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
model_name = st.selectbox("Model", ["distilgpt2", "gpt2", "EleutherAI/gpt-neo-125m"])
|
79 |
-
metric_name = st.
|
|
|
|
|
80 |
|
81 |
tokenizer = st.cache_resource(AutoTokenizer.from_pretrained, show_spinner=False)(model_name, use_fast=False)
|
82 |
|
@@ -84,9 +97,10 @@ tokenizer = st.cache_resource(AutoTokenizer.from_pretrained, show_spinner=False)
|
|
84 |
MAX_MEM = 4e9 / (torch.finfo(torch.float16).bits / 8)
|
85 |
# Select window lengths such that we are allowed to fill the whole window without running out of memory
|
86 |
# (otherwise the window length is irrelevant)
|
|
|
87 |
window_len_options = [
|
88 |
w for w in [8, 16, 32, 64, 128, 256, 512, 1024]
|
89 |
-
if w == 8 or w * (2 * w) *
|
90 |
]
|
91 |
window_len = st.select_slider(
|
92 |
r"Window size ($c_\text{max}$)",
|
@@ -95,7 +109,8 @@ window_len = st.select_slider(
|
|
95 |
)
|
96 |
# Now figure out how many tokens we are allowed to use:
|
97 |
# window_len * (num_tokens + window_len) * vocab_size <= MAX_MEM
|
98 |
-
max_tokens = int(MAX_MEM / (
|
|
|
99 |
|
100 |
DEFAULT_TEXT = """
|
101 |
We present context length probing, a novel explanation technique for causal
|
@@ -117,6 +132,7 @@ if tokenizer.eos_token:
|
|
117 |
text += tokenizer.eos_token
|
118 |
inputs = tokenizer([text])
|
119 |
[input_ids] = inputs["input_ids"]
|
|
|
120 |
num_user_tokens = len(input_ids) - (1 if tokenizer.eos_token else 0)
|
121 |
|
122 |
if num_user_tokens < 1:
|
@@ -160,13 +176,17 @@ def run_context_length_probing(_model, _tokenizer, _inputs, window_len, metric,
|
|
160 |
for i in range(0, num_items, batch_size):
|
161 |
pbar.progress(i / num_items, f"{i}/{num_items}")
|
162 |
batch = {k: v[i:i + batch_size] for k, v in inputs_sliding.items()}
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
cache_key=(model_name, batch["input_ids"].cpu().numpy().tobytes())
|
168 |
-
)
|
169 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
logprobs = torch.cat(logprobs, dim=0)
|
171 |
pbar.empty()
|
172 |
|
|
|
17 |
return BatchEncoding({
|
18 |
k: [
|
19 |
t[i][j : j + window_len] + [
|
20 |
+
pad_id if k in ["input_ids", "labels"] else 0
|
21 |
] * (j + window_len - len(t[i]))
|
22 |
for i in range(len(examples["input_ids"]))
|
23 |
for j in range(0, len(examples["input_ids"][i]) - 1, stride)
|
|
|
43 |
return result
|
44 |
|
45 |
def nll_score(logprobs, labels):
|
46 |
+
if logprobs.shape[-1] == 1:
|
47 |
+
return -logprobs.squeeze(-1)
|
48 |
+
else:
|
49 |
+
return -logprobs[:, torch.arange(len(labels)), labels]
|
50 |
|
51 |
def kl_div_score(logprobs):
|
52 |
log_p = logprobs[
|
|
|
78 |
"""
|
79 |
)
|
80 |
|
81 |
+
generation_mode = False
|
82 |
+
# st.radio("Mode", ["Standard", "Generation"], horizontal=True) == "Generation"
|
83 |
+
# st.caption(
|
84 |
+
# "In standard mode, we analyze the model's predictions on the input text. "
|
85 |
+
# "In generation mode, we generate a continuation of the input text "
|
86 |
+
# "and visualize the contributions of different contexts to each generated token."
|
87 |
+
# )
|
88 |
+
|
89 |
model_name = st.selectbox("Model", ["distilgpt2", "gpt2", "EleutherAI/gpt-neo-125m"])
|
90 |
+
metric_name = st.radio(
|
91 |
+
"Metric", (["KL divergence"] if not generation_mode else []) + ["NLL loss"], index=0, horizontal=True
|
92 |
+
)
|
93 |
|
94 |
tokenizer = st.cache_resource(AutoTokenizer.from_pretrained, show_spinner=False)(model_name, use_fast=False)
|
95 |
|
|
|
97 |
MAX_MEM = 4e9 / (torch.finfo(torch.float16).bits / 8)
|
98 |
# Select window lengths such that we are allowed to fill the whole window without running out of memory
|
99 |
# (otherwise the window length is irrelevant)
|
100 |
+
logprobs_dim = tokenizer.vocab_size if metric_name == "KL divergence" else 1
|
101 |
window_len_options = [
|
102 |
w for w in [8, 16, 32, 64, 128, 256, 512, 1024]
|
103 |
+
if w == 8 or w * (2 * w) * logprobs_dim <= MAX_MEM
|
104 |
]
|
105 |
window_len = st.select_slider(
|
106 |
r"Window size ($c_\text{max}$)",
|
|
|
109 |
)
|
110 |
# Now figure out how many tokens we are allowed to use:
|
111 |
# window_len * (num_tokens + window_len) * vocab_size <= MAX_MEM
|
112 |
+
max_tokens = int(MAX_MEM / (logprobs_dim * window_len) - window_len)
|
113 |
+
max_tokens = min(max_tokens, 2048)
|
114 |
|
115 |
DEFAULT_TEXT = """
|
116 |
We present context length probing, a novel explanation technique for causal
|
|
|
132 |
text += tokenizer.eos_token
|
133 |
inputs = tokenizer([text])
|
134 |
[input_ids] = inputs["input_ids"]
|
135 |
+
inputs["labels"] = [[*input_ids[1:], tokenizer.eos_token_id]]
|
136 |
num_user_tokens = len(input_ids) - (1 if tokenizer.eos_token else 0)
|
137 |
|
138 |
if num_user_tokens < 1:
|
|
|
176 |
for i in range(0, num_items, batch_size):
|
177 |
pbar.progress(i / num_items, f"{i}/{num_items}")
|
178 |
batch = {k: v[i:i + batch_size] for k, v in inputs_sliding.items()}
|
179 |
+
batch_logprobs = get_logprobs(
|
180 |
+
_model,
|
181 |
+
batch,
|
182 |
+
cache_key=(model_name, batch["input_ids"].cpu().numpy().tobytes())
|
|
|
|
|
183 |
)
|
184 |
+
batch_labels = batch["labels"]
|
185 |
+
if metric != "KL divergence":
|
186 |
+
batch_logprobs = torch.gather(
|
187 |
+
batch_logprobs, dim=-1, index=batch_labels[..., None]
|
188 |
+
)
|
189 |
+
logprobs.append(batch_logprobs)
|
190 |
logprobs = torch.cat(logprobs, dim=0)
|
191 |
pbar.empty()
|
192 |
|