cifkao commited on
Commit
fe36eff
1 Parent(s): b253e66

Cache models, add gpt-neo-125m

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -44,12 +44,9 @@ def ids_to_readable_tokens(tokenizer, ids, strip_whitespace=False):
44
  st.header("Context length probing")
45
 
46
  with st.form("form"):
47
- model_name = st.selectbox("Model", ["distilgpt2", "gpt2"])
48
  metric_name = st.selectbox("Metric", ["Cross entropy"])
49
 
50
- tokenizer = AutoTokenizer.from_pretrained(model_name)
51
- model = AutoModelForCausalLM.from_pretrained(model_name)
52
-
53
  window_len = st.select_slider("Window size", options=[8, 16, 32, 64, 128, 256, 512, 1024], value=512)
54
  text = st.text_area(
55
  "Input text",
@@ -58,6 +55,9 @@ with st.form("form"):
58
 
59
  st.form_submit_button("Submit")
60
 
 
 
 
61
  inputs = tokenizer([text])
62
  [input_ids] = inputs["input_ids"]
63
  window_len = min(window_len, len(input_ids))
 
44
  st.header("Context length probing")
45
 
46
  with st.form("form"):
47
+ model_name = st.selectbox("Model", ["distilgpt2", "gpt2", "EleutherAI/gpt-neo-125m"])
48
  metric_name = st.selectbox("Metric", ["Cross entropy"])
49
 
 
 
 
50
  window_len = st.select_slider("Window size", options=[8, 16, 32, 64, 128, 256, 512, 1024], value=512)
51
  text = st.text_area(
52
  "Input text",
 
55
 
56
  st.form_submit_button("Submit")
57
 
58
+ tokenizer = st.cache_resource(AutoTokenizer.from_pretrained, show_spinner=False)(model_name)
59
+ model = st.cache_resource(AutoModelForCausalLM.from_pretrained, show_spinner=False)(model_name)
60
+
61
  inputs = tokenizer([text])
62
  [input_ids] = inputs["input_ids"]
63
  window_len = min(window_len, len(input_ids))