cifkao commited on
Commit
535e574
1 Parent(s): e5222c4

Better progress indication

Browse files
Files changed (1) hide show
  1. app.py +26 -25
app.py CHANGED
@@ -56,8 +56,9 @@ if metric_name == "KL divergence":
56
  st.error("KL divergence is not supported yet. Stay tuned!", icon="😭")
57
  st.stop()
58
 
59
- tokenizer = st.cache_resource(AutoTokenizer.from_pretrained, show_spinner=False)(model_name)
60
- model = st.cache_resource(AutoModelForCausalLM.from_pretrained, show_spinner=False)(model_name)
 
61
 
62
  inputs = tokenizer([text])
63
  [input_ids] = inputs["input_ids"]
@@ -80,29 +81,29 @@ def run_context_length_probing(model_name, text, window_len):
80
  ).convert_to_tensors("pt")
81
 
82
  logits = []
83
- pbar = st.progress(0.)
84
- batch_size = 8
85
- num_items = len(inputs_sliding["input_ids"])
86
- for i in range(0, num_items, batch_size):
87
- pbar.progress(i / num_items * 0.9, f"Running model… ({i}/{num_items})")
88
- batch = {k: v[i:i + batch_size] for k, v in inputs_sliding.items()}
89
- logits.append(model(**batch).logits.to(torch.float16))
90
- pbar.progress(0.9, "Computing scores…")
91
- logits = torch.cat(logits, dim=0)
92
-
93
- logits = logits.permute(1, 0, 2)
94
- logits = F.pad(logits, (0, 0, 0, window_len, 0, 0), value=torch.nan)
95
- logits = logits.view(-1, logits.shape[-1])[:-window_len]
96
- logits = logits.view(window_len, len(input_ids) + window_len - 2, logits.shape[-1])
97
-
98
- scores = logits.to(torch.float32).log_softmax(dim=-1)
99
- scores = scores[:, torch.arange(len(input_ids[1:])), input_ids[1:]]
100
- scores = scores.diff(dim=0).transpose(0, 1)
101
- scores = scores.nan_to_num()
102
- scores /= scores.abs().max(dim=1, keepdim=True).values + 1e-9
103
- scores = scores.to(torch.float16)
104
-
105
- pbar.progress(1., "Done!")
106
 
107
  return scores
108
 
 
56
  st.error("KL divergence is not supported yet. Stay tuned!", icon="😭")
57
  st.stop()
58
 
59
+ with st.spinner("Loading model…"):
60
+ tokenizer = st.cache_resource(AutoTokenizer.from_pretrained, show_spinner=False)(model_name)
61
+ model = st.cache_resource(AutoModelForCausalLM.from_pretrained, show_spinner=False)(model_name)
62
 
63
  inputs = tokenizer([text])
64
  [input_ids] = inputs["input_ids"]
 
81
  ).convert_to_tensors("pt")
82
 
83
  logits = []
84
+ with st.spinner("Running model…"):
85
+ batch_size = 8
86
+ num_items = len(inputs_sliding["input_ids"])
87
+ pbar = st.progress(0)
88
+ for i in range(0, num_items, batch_size):
89
+ pbar.progress(i / num_items, f"{i}/{num_items}")
90
+ batch = {k: v[i:i + batch_size] for k, v in inputs_sliding.items()}
91
+ logits.append(model(**batch).logits.to(torch.float16))
92
+ logits = torch.cat(logits, dim=0)
93
+ pbar.empty()
94
+
95
+ with st.spinner("Computing scores…"):
96
+ logits = logits.permute(1, 0, 2)
97
+ logits = F.pad(logits, (0, 0, 0, window_len, 0, 0), value=torch.nan)
98
+ logits = logits.view(-1, logits.shape[-1])[:-window_len]
99
+ logits = logits.view(window_len, len(input_ids) + window_len - 2, logits.shape[-1])
100
+
101
+ scores = logits.to(torch.float32).log_softmax(dim=-1)
102
+ scores = scores[:, torch.arange(len(input_ids[1:])), input_ids[1:]]
103
+ scores = scores.diff(dim=0).transpose(0, 1)
104
+ scores = scores.nan_to_num()
105
+ scores /= scores.abs().max(dim=1, keepdim=True).values + 1e-9
106
+ scores = scores.to(torch.float16)
107
 
108
  return scores
109