Spaces:
Running
Running
Better use of memory; limit window size and number of tokens
Browse files
app.py
CHANGED
@@ -54,11 +54,25 @@ if not compact_layout:
|
|
54 |
|
55 |
model_name = st.selectbox("Model", ["distilgpt2", "gpt2", "EleutherAI/gpt-neo-125m"])
|
56 |
metric_name = st.selectbox("Metric", ["KL divergence", "Cross entropy"], index=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
window_len = st.select_slider(
|
58 |
r"Window size ($c_\text{max}$)",
|
59 |
-
options=
|
60 |
-
value=
|
61 |
)
|
|
|
|
|
|
|
62 |
|
63 |
DEFAULT_TEXT = """
|
64 |
We present context length probing, a novel explanation technique for causal
|
@@ -71,31 +85,38 @@ dependencies.
|
|
71 |
""".replace("\n", " ").strip()
|
72 |
|
73 |
text = st.text_area(
|
74 |
-
"Input text",
|
75 |
DEFAULT_TEXT,
|
76 |
)
|
77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
if metric_name == "KL divergence":
|
79 |
st.error("KL divergence is not supported yet. Stay tuned!", icon="😭")
|
80 |
st.stop()
|
81 |
|
82 |
with st.spinner("Loading model…"):
|
83 |
-
tokenizer = st.cache_resource(AutoTokenizer.from_pretrained, show_spinner=False)(model_name)
|
84 |
model = st.cache_resource(AutoModelForCausalLM.from_pretrained, show_spinner=False)(model_name)
|
85 |
|
86 |
-
inputs = tokenizer([text])
|
87 |
-
[input_ids] = inputs["input_ids"]
|
88 |
window_len = min(window_len, len(input_ids))
|
89 |
|
90 |
-
if len(input_ids) < 2:
|
91 |
-
st.error("Please enter at least 2 tokens.", icon="🚨")
|
92 |
-
st.stop()
|
93 |
-
|
94 |
@st.cache_data(show_spinner=False)
|
95 |
@torch.inference_mode()
|
96 |
-
def
|
97 |
del cache_key
|
98 |
-
return _model(**_inputs).logits.to(torch.float16)
|
99 |
|
100 |
@st.cache_data(show_spinner=False)
|
101 |
@torch.inference_mode()
|
@@ -108,7 +129,7 @@ def run_context_length_probing(_model, _tokenizer, _inputs, window_len, cache_ke
|
|
108 |
pad_id=_tokenizer.eos_token_id
|
109 |
).convert_to_tensors("pt")
|
110 |
|
111 |
-
|
112 |
with st.spinner("Running model…"):
|
113 |
batch_size = 8
|
114 |
num_items = len(inputs_sliding["input_ids"])
|
@@ -116,27 +137,26 @@ def run_context_length_probing(_model, _tokenizer, _inputs, window_len, cache_ke
|
|
116 |
for i in range(0, num_items, batch_size):
|
117 |
pbar.progress(i / num_items, f"{i}/{num_items}")
|
118 |
batch = {k: v[i:i + batch_size] for k, v in inputs_sliding.items()}
|
119 |
-
|
120 |
-
|
121 |
_model,
|
122 |
batch,
|
123 |
cache_key=(model_name, batch["input_ids"].cpu().numpy().tobytes())
|
124 |
)
|
125 |
)
|
126 |
-
|
127 |
pbar.empty()
|
128 |
|
129 |
with st.spinner("Computing scores…"):
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
|
135 |
-
scores =
|
136 |
-
scores = scores[:, torch.arange(len(input_ids[1:])), input_ids[1:]]
|
137 |
scores = scores.diff(dim=0).transpose(0, 1)
|
138 |
scores = scores.nan_to_num()
|
139 |
-
scores /= scores.abs().max(dim=1, keepdim=True).values + 1e-
|
140 |
scores = scores.to(torch.float16)
|
141 |
|
142 |
return scores
|
|
|
54 |
|
55 |
model_name = st.selectbox("Model", ["distilgpt2", "gpt2", "EleutherAI/gpt-neo-125m"])
|
56 |
metric_name = st.selectbox("Metric", ["KL divergence", "Cross entropy"], index=1)
|
57 |
+
|
58 |
+
tokenizer = st.cache_resource(AutoTokenizer.from_pretrained, show_spinner=False)(model_name, use_fast=False)
|
59 |
+
|
60 |
+
# Make sure the logprobs do not use up more than ~6 GB of memory
|
61 |
+
MAX_MEM = 6e9 / (torch.finfo(torch.float16).bits / 8)
|
62 |
+
# Select window lengths such that we are allowed to fill the whole window without running out of memory
|
63 |
+
# (otherwise the window length is irrelevant)
|
64 |
+
window_len_options = [
|
65 |
+
w for w in [8, 16, 32, 64, 128, 256, 512, 1024]
|
66 |
+
if w == 8 or w * (2 * w) * tokenizer.vocab_size <= MAX_MEM
|
67 |
+
]
|
68 |
window_len = st.select_slider(
|
69 |
r"Window size ($c_\text{max}$)",
|
70 |
+
options=window_len_options,
|
71 |
+
value=min(128, window_len_options[-1])
|
72 |
)
|
73 |
+
# Now figure out how many tokens we are allowed to use:
|
74 |
+
# window_len * (num_tokens + window_len) * vocab_size <= MAX_MEM
|
75 |
+
max_tokens = int(MAX_MEM / (tokenizer.vocab_size * window_len) - window_len)
|
76 |
|
77 |
DEFAULT_TEXT = """
|
78 |
We present context length probing, a novel explanation technique for causal
|
|
|
85 |
""".replace("\n", " ").strip()
|
86 |
|
87 |
text = st.text_area(
|
88 |
+
f"Input text (≤{max_tokens} tokens)",
|
89 |
DEFAULT_TEXT,
|
90 |
)
|
91 |
|
92 |
+
inputs = tokenizer([text])
|
93 |
+
[input_ids] = inputs["input_ids"]
|
94 |
+
|
95 |
+
if len(input_ids) < 2:
|
96 |
+
st.error("Please enter at least 2 tokens.", icon="🚨")
|
97 |
+
st.stop()
|
98 |
+
if len(input_ids) > max_tokens:
|
99 |
+
st.error(
|
100 |
+
f"Your input has {len(input_ids)} tokens. Please enter at most {max_tokens} tokens "
|
101 |
+
f"or try reducing the window size.",
|
102 |
+
icon="🚨"
|
103 |
+
)
|
104 |
+
st.stop()
|
105 |
+
|
106 |
if metric_name == "KL divergence":
|
107 |
st.error("KL divergence is not supported yet. Stay tuned!", icon="😭")
|
108 |
st.stop()
|
109 |
|
110 |
with st.spinner("Loading model…"):
|
|
|
111 |
model = st.cache_resource(AutoModelForCausalLM.from_pretrained, show_spinner=False)(model_name)
|
112 |
|
|
|
|
|
113 |
window_len = min(window_len, len(input_ids))
|
114 |
|
|
|
|
|
|
|
|
|
115 |
@st.cache_data(show_spinner=False)
|
116 |
@torch.inference_mode()
|
117 |
+
def get_logprobs(_model, _inputs, cache_key):
|
118 |
del cache_key
|
119 |
+
return _model(**_inputs).logits.log_softmax(dim=-1).to(torch.float16)
|
120 |
|
121 |
@st.cache_data(show_spinner=False)
|
122 |
@torch.inference_mode()
|
|
|
129 |
pad_id=_tokenizer.eos_token_id
|
130 |
).convert_to_tensors("pt")
|
131 |
|
132 |
+
logprobs = []
|
133 |
with st.spinner("Running model…"):
|
134 |
batch_size = 8
|
135 |
num_items = len(inputs_sliding["input_ids"])
|
|
|
137 |
for i in range(0, num_items, batch_size):
|
138 |
pbar.progress(i / num_items, f"{i}/{num_items}")
|
139 |
batch = {k: v[i:i + batch_size] for k, v in inputs_sliding.items()}
|
140 |
+
logprobs.append(
|
141 |
+
get_logprobs(
|
142 |
_model,
|
143 |
batch,
|
144 |
cache_key=(model_name, batch["input_ids"].cpu().numpy().tobytes())
|
145 |
)
|
146 |
)
|
147 |
+
logprobs = torch.cat(logprobs, dim=0)
|
148 |
pbar.empty()
|
149 |
|
150 |
with st.spinner("Computing scores…"):
|
151 |
+
logprobs = logprobs.permute(1, 0, 2)
|
152 |
+
logprobs = F.pad(logprobs, (0, 0, 0, window_len, 0, 0), value=torch.nan)
|
153 |
+
logprobs = logprobs.view(-1, logprobs.shape[-1])[:-window_len]
|
154 |
+
logprobs = logprobs.view(window_len, len(input_ids) + window_len - 2, logprobs.shape[-1])
|
155 |
|
156 |
+
scores = logprobs[:, torch.arange(len(input_ids[1:])), input_ids[1:]]
|
|
|
157 |
scores = scores.diff(dim=0).transpose(0, 1)
|
158 |
scores = scores.nan_to_num()
|
159 |
+
scores /= scores.abs().max(dim=1, keepdim=True).values + 1e-6
|
160 |
scores = scores.to(torch.float16)
|
161 |
|
162 |
return scores
|