cifkao commited on
Commit
f962dd0
1 Parent(s): b6ab215

Better caching

Browse files
Files changed (1) hide show
  1. app.py +7 -6
app.py CHANGED
@@ -62,13 +62,14 @@ if metric_name == "KL divergence":
62
  tokenizer = st.cache_resource(AutoTokenizer.from_pretrained, show_spinner=False)(model_name)
63
  model = st.cache_resource(AutoModelForCausalLM.from_pretrained, show_spinner=False)(model_name)
64
 
 
 
 
 
65
  @st.cache_data(show_spinner=False)
66
  def run_context_length_probing(model_name, text, window_len):
67
  assert model.name_or_path == model_name
68
-
69
- inputs = tokenizer([text])
70
- [input_ids] = inputs["input_ids"]
71
- window_len = min(window_len, len(input_ids))
72
 
73
  inputs_sliding = get_windows_batched(
74
  inputs,
@@ -89,9 +90,9 @@ def run_context_length_probing(model_name, text, window_len):
89
  scores /= scores.abs().max(dim=1, keepdim=True).values + 1e-9
90
  scores = scores.to(torch.float16)
91
 
92
- return input_ids, scores
93
 
94
- input_ids, scores = run_context_length_probing(
95
  model_name=model_name,
96
  text=text,
97
  window_len=window_len
 
62
  tokenizer = st.cache_resource(AutoTokenizer.from_pretrained, show_spinner=False)(model_name)
63
  model = st.cache_resource(AutoModelForCausalLM.from_pretrained, show_spinner=False)(model_name)
64
 
65
+ inputs = tokenizer([text])
66
+ [input_ids] = inputs["input_ids"]
67
+ window_len = min(window_len, len(input_ids))
68
+
69
  @st.cache_data(show_spinner=False)
70
  def run_context_length_probing(model_name, text, window_len):
71
  assert model.name_or_path == model_name
72
+ del text # needed as a cache key but for the computation we access inputs directly
 
 
 
73
 
74
  inputs_sliding = get_windows_batched(
75
  inputs,
 
90
  scores /= scores.abs().max(dim=1, keepdim=True).values + 1e-9
91
  scores = scores.to(torch.float16)
92
 
93
+ return scores
94
 
95
+ scores = run_context_length_probing(
96
  model_name=model_name,
97
  text=text,
98
  window_len=window_len