Spaces:
Sleeping
Sleeping
Better progress indication
Browse files
app.py
CHANGED
@@ -56,8 +56,9 @@ if metric_name == "KL divergence":
|
|
56 |
st.error("KL divergence is not supported yet. Stay tuned!", icon="😭")
|
57 |
st.stop()
|
58 |
|
59 |
-
|
60 |
-
|
|
|
61 |
|
62 |
inputs = tokenizer([text])
|
63 |
[input_ids] = inputs["input_ids"]
|
@@ -80,29 +81,29 @@ def run_context_length_probing(model_name, text, window_len):
|
|
80 |
).convert_to_tensors("pt")
|
81 |
|
82 |
logits = []
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
|
107 |
return scores
|
108 |
|
|
|
56 |
st.error("KL divergence is not supported yet. Stay tuned!", icon="😭")
|
57 |
st.stop()
|
58 |
|
59 |
+
with st.spinner("Loading model…"):
|
60 |
+
tokenizer = st.cache_resource(AutoTokenizer.from_pretrained, show_spinner=False)(model_name)
|
61 |
+
model = st.cache_resource(AutoModelForCausalLM.from_pretrained, show_spinner=False)(model_name)
|
62 |
|
63 |
inputs = tokenizer([text])
|
64 |
[input_ids] = inputs["input_ids"]
|
|
|
81 |
).convert_to_tensors("pt")
|
82 |
|
83 |
logits = []
|
84 |
+
with st.spinner("Running model…"):
|
85 |
+
batch_size = 8
|
86 |
+
num_items = len(inputs_sliding["input_ids"])
|
87 |
+
pbar = st.progress(0)
|
88 |
+
for i in range(0, num_items, batch_size):
|
89 |
+
pbar.progress(i / num_items, f"{i}/{num_items}")
|
90 |
+
batch = {k: v[i:i + batch_size] for k, v in inputs_sliding.items()}
|
91 |
+
logits.append(model(**batch).logits.to(torch.float16))
|
92 |
+
logits = torch.cat(logits, dim=0)
|
93 |
+
pbar.empty()
|
94 |
+
|
95 |
+
with st.spinner("Computing scores…"):
|
96 |
+
logits = logits.permute(1, 0, 2)
|
97 |
+
logits = F.pad(logits, (0, 0, 0, window_len, 0, 0), value=torch.nan)
|
98 |
+
logits = logits.view(-1, logits.shape[-1])[:-window_len]
|
99 |
+
logits = logits.view(window_len, len(input_ids) + window_len - 2, logits.shape[-1])
|
100 |
+
|
101 |
+
scores = logits.to(torch.float32).log_softmax(dim=-1)
|
102 |
+
scores = scores[:, torch.arange(len(input_ids[1:])), input_ids[1:]]
|
103 |
+
scores = scores.diff(dim=0).transpose(0, 1)
|
104 |
+
scores = scores.nan_to_num()
|
105 |
+
scores /= scores.abs().max(dim=1, keepdim=True).values + 1e-9
|
106 |
+
scores = scores.to(torch.float16)
|
107 |
|
108 |
return scores
|
109 |
|