codebook-features / Code_Browser.py
taufeeque's picture
Update code
63b5bc1
raw
history blame
14.4 kB
"""Web App for the Codebook Features project."""
import argparse
import glob
import os
import streamlit as st
import code_search_utils
import utils
import webapp_utils
# --- Parse command line arguments ---
parser = argparse.ArgumentParser()
parser.add_argument(
"--deploy",
default=True,
help="Deploy mode.",
)
parser.add_argument(
"--cache_dir",
type=str,
default="cache/",
help="Path to directory containing cache for codebook models.",
)
try:
args = parser.parse_args()
except SystemExit as e:
# This exception will be raised if --help or invalid command line arguments
# are used. Currently streamlit prevents the program from exiting normally
# so we have to do a hard exit.
os._exit(e.code if isinstance(e.code, int) else 1)
deploy = args.deploy
webapp_utils.load_widget_state()
st.set_page_config(
page_title="Codebook Features",
page_icon="πŸ“š",
)
st.title("Codebook Features")
# --- Load model info and cache ---
pretty_model_names = {
"TinyStories-1Layer-21M#100ksteps_vcb_mlp": "TinyStories-1L-21M-MLP",
"TinyStories-1Layer-21M_ccb_attn_preproj": "TinyStories 1 Layer Attention Codebook",
"TinyStories-33M_ccb_attn_preproj": "TinyStories 4 Layer Attention Codebook",
"TinyStories-1Layer-21M_vcb_mlp": "TinyStories 1 Layer MLP Codebook",
}
orig_model_name = {v: k for k, v in pretty_model_names.items()}
base_cache_dir = args.cache_dir
dirs = glob.glob(base_cache_dir + "models/*/")
model_name_options = [d.split("/")[-2].split("_")[:-2] for d in dirs]
model_name_options = ["_".join(m) for m in model_name_options]
model_name_options = sorted(set(model_name_options))
def_model_idx = ["attn" in m.lower() for m in model_name_options].index(True)
p_model_name = st.selectbox(
"Model",
[pretty_model_names.get(m, m) for m in model_name_options],
index=def_model_idx,
key=webapp_utils.persist("model_name"),
)
model_name = orig_model_name.get(p_model_name, p_model_name)
is_fsm = "FSM" in p_model_name
codes_cache_path = base_cache_dir + f"models/{model_name}_*"
dirs = glob.glob(codes_cache_path)
dirs.sort(key=os.path.getmtime)
# session states
codes_cache_path = dirs[-1] + "/"
model_info = utils.ModelInfoForWebapp.load(codes_cache_path)
num_codes = model_info.num_codes
num_layers = model_info.n_layers
num_heads = model_info.n_heads
cb_at = model_info.cb_at
gcb = model_info.gcb
gcb = "_gcb" if gcb else ""
is_attn = "attn" in cb_at
dataset_cache_path = base_cache_dir + f"datasets/{model_info.dataset_name}/"
(
tokens_str,
tokens_text,
token_byte_pos,
cb_acts,
act_count_ft_tkns,
metrics,
) = webapp_utils.load_code_search_cache(codes_cache_path, dataset_cache_path)
seq_len = len(tokens_str[0])
metric_keys = ["eval_loss", "eval_accuracy", "eval_dead_code_fraction"]
metrics = {k: v for k, v in metrics.items() if k.split("/")[0] in metric_keys}
# --- Set the session states ---
st.session_state["model_name_id"] = model_name
st.session_state["cb_acts"] = cb_acts
st.session_state["tokens_text"] = tokens_text
st.session_state["tokens_str"] = tokens_str
st.session_state["act_count_ft_tkns"] = act_count_ft_tkns
st.session_state["num_codes"] = num_codes
st.session_state["gcb"] = gcb
st.session_state["cb_at"] = cb_at
st.session_state["is_attn"] = is_attn
st.session_state["seq_len"] = seq_len
if not deploy:
st.markdown("## Metrics")
# hide metrics by default
if st.checkbox("Show Model Metrics"):
st.write(metrics)
st.markdown("## Demo Codes")
demo_codes_desc = (
"This section contains codes that we've found to be interpretable along "
"with a description of the feature we think they are capturing. "
"Click on the πŸ” search button for a code to see the tokens that code activates on."
)
st.write(demo_codes_desc)
demo_file_path = codes_cache_path + "demo_codes.txt"
if st.checkbox("Show Demo Codes"):
try:
with open(demo_file_path, "r") as f:
demo_codes = f.readlines()
except FileNotFoundError:
demo_codes = []
code_desc, code_regex = "", ""
demo_codes = [code.strip() for code in demo_codes if code.strip()]
num_cols = 6 if is_attn else 5
cols = st.columns([1] * (num_cols - 1) + [2])
# st.markdown(button_height_style, unsafe_allow_html=True)
cols[0].markdown("Search", help="Button to see token activations for the code.")
cols[1].write("Code")
cols[2].write("Layer")
if is_attn:
cols[3].write("Head")
cols[-2].markdown(
"Num Acts",
help="Number of tokens that the code activates on in the acts dataset.",
)
cols[-1].markdown("Description", help="Interpreted description of the code.")
if len(demo_codes) == 0:
st.markdown(
f"""
<div style="font-size: 1.0rem; color: red;">
No demo codes found in file {demo_file_path}
</div>
""",
unsafe_allow_html=True,
)
skip = True
for code_txt in demo_codes:
if code_txt.startswith("##"):
skip = True
continue
if code_txt.startswith("#"):
code_desc, code_regex = code_txt[1:].split(":")
code_desc, code_regex = code_desc.strip(), code_regex.strip()
skip = False
continue
if skip:
continue
code_info = utils.CodeInfo.from_str(code_txt, regex=code_regex)
comp_info = f"layer{code_info.layer}_{f'head{code_info.head}' if code_info.head is not None else ''}"
button_key = (
f"demo_search_code{code_info.code}_layer{code_info.layer}_desc-{code_info.description}"
+ (f"head{code_info.head}" if code_info.head is not None else "")
)
cols = st.columns([1] * (num_cols - 1) + [2])
button_clicked = cols[0].button(
"πŸ”",
key=button_key,
)
if button_clicked:
webapp_utils.set_ct_acts(
code_info.code, code_info.layer, code_info.head, None, is_attn
)
cols[1].write(code_info.code)
cols[2].write(str(code_info.layer))
if is_attn:
cols[3].write(str(code_info.head))
cols[-2].write(str(act_count_ft_tkns[comp_info][code_info.code]))
cols[-1].write(code_desc)
skip = True
# --- Code Search ---
st.markdown("## Code Search")
code_search_desc = (
"If you want to find whether the codebooks model has captured a relevant features from the data,"
" you can specify a regex pattern for your feature and find whether any code activating on the regex pattern"
" exists. The first group in the regex pattern is the token that the code activates on. If the group contains"
" multiple tokens, we search for codes that will activate on the first token in the group followed by the"
" subsequent tokens in the group. For example, the search term 'New (York)' will try to find codes that"
" activate on the bigram feature 'New York' at the York token."
)
if st.checkbox("Search with Regex"):
st.write(code_search_desc)
regex_pattern = st.text_input(
"Enter a regex pattern",
help="Wrap code token in the first group. E.g. New (York)",
key="regex_pattern",
)
# topk = st.slider("Top K", 1, 20, 10)
prec_col, sort_col = st.columns(2)
prec_threshold = prec_col.slider(
"Precision Threshold",
0.0,
1.0,
0.9,
help="Shows codes with precision on the regex pattern above the threshold.",
)
sort_by_options = ["Precision", "Recall", "Num Acts"]
sort_by_name = sort_col.radio(
"Sort By",
sort_by_options,
index=0,
horizontal=True,
help="Sorts the codes by the selected metric.",
)
sort_by = sort_by_options.index(sort_by_name)
@st.cache_data(ttl=3600)
def get_codebook_wise_codes_for_regex(
regex_pattern, prec_threshold, gcb, model_name
):
"""Get codebook wise codes for a given regex pattern."""
assert model_name is not None # required for loading from correct cache data
return code_search_utils.get_codes_from_pattern(
regex_pattern,
tokens_text,
token_byte_pos,
cb_acts,
act_count_ft_tkns,
gcb=gcb,
topk=8,
prec_threshold=prec_threshold,
)
if regex_pattern:
codebook_wise_codes, re_token_matches = get_codebook_wise_codes_for_regex(
regex_pattern,
prec_threshold,
gcb,
model_name,
)
st.markdown(
f"Found <span style='color:green;'>{re_token_matches}</span> matches",
unsafe_allow_html=True,
)
num_search_cols = 7 if is_attn else 6
non_deploy_offset = 0
if not deploy:
non_deploy_offset = 1
num_search_cols += non_deploy_offset
cols = st.columns(num_search_cols)
cols[0].markdown("Search", help="Button to see token activations for the code.")
cols[1].write("Layer")
if is_attn:
cols[2].write("Head")
cols[-4 - non_deploy_offset].write("Code")
cols[-3 - non_deploy_offset].write("Precision")
cols[-2 - non_deploy_offset].write("Recall")
cols[-1 - non_deploy_offset].markdown(
"Num Acts",
help="Number of tokens that the code activates on in the acts dataset.",
)
if not deploy:
cols[-1].markdown(
"Save to Demos",
help="Button to save the code to demos along with the regex pattern.",
)
all_codes = codebook_wise_codes.items()
all_codes = [
(cb_name, code_pr_info)
for cb_name, code_pr_infos in all_codes
for code_pr_info in code_pr_infos
]
all_codes = sorted(all_codes, key=lambda x: x[1][1 + sort_by], reverse=True)
for cb_name, (code, prec, rec, code_acts) in all_codes:
layer_head = cb_name.split("_")
layer = layer_head[0][5:]
head = layer_head[1][4:] if len(layer_head) > 1 else None
button_key = f"search_code{code}_layer{layer}" + (
f"head{head}" if head is not None else ""
)
cols = st.columns(num_search_cols)
extra_args = {
"prec": prec,
"recall": rec,
"num_acts": code_acts,
"regex": regex_pattern,
}
button_clicked = cols[0].button("πŸ”", key=button_key)
if button_clicked:
webapp_utils.set_ct_acts(code, layer, head, extra_args, is_attn)
cols[1].write(layer)
if is_attn:
cols[2].write(head)
cols[-4 - non_deploy_offset].write(code)
cols[-3 - non_deploy_offset].write(f"{prec*100:.2f}%")
cols[-2 - non_deploy_offset].write(f"{rec*100:.2f}%")
cols[-1 - non_deploy_offset].write(str(code_acts))
if not deploy:
webapp_utils.add_save_code_button(
demo_file_path,
num_acts=code_acts,
save_regex=True,
prec=prec,
recall=rec,
button_st_container=cols[-1],
button_key_suffix=f"_code{code}_layer{layer}_head{head}",
)
if len(all_codes) == 0:
st.markdown(
f"""
<div style="font-size: 1.0rem; color: red;">
No codes found for pattern {regex_pattern} at precision threshold: {prec_threshold}
</div>
""",
unsafe_allow_html=True,
)
# --- Display Code Token Activations ---
st.markdown("## Code Token Activations")
filter_codes = st.checkbox("Show filters", key="filter_codes", value=True)
act_range, layer_code_acts = None, None
if filter_codes:
act_range = st.slider(
"Minimum number of activations",
0,
10_000,
100,
key="ct_act_range",
help="Filter codes by the number of tokens they activate on.",
)
cols = st.columns(5 if is_attn else 4)
layer = cols[0].number_input("Layer", 0, num_layers - 1, 0, key="ct_act_layer")
if is_attn:
head = cols[1].number_input("Head", 0, num_heads - 1, 0, key="ct_act_head")
else:
head = None
def_code = st.session_state.get("ct_act_code", 0)
if filter_codes:
layer_code_acts = act_count_ft_tkns[
f"layer{layer}{'_head'+str(head) if head is not None else ''}"
]
def_code = webapp_utils.find_next_code(def_code, layer_code_acts, act_range)
if "ct_act_code" in st.session_state:
st.session_state["ct_act_code"] = def_code
code = cols[-3].number_input(
"Code",
0,
num_codes - 1,
def_code,
key="ct_act_code",
)
num_examples = cols[-2].number_input(
"Max Results",
-1,
1000, # setting to 1000 for efficiency purposes even though it can be more than 1000.
100,
help="Number of examples to show in the results. Set to -1 to show all examples.",
)
ctx_size = cols[-1].number_input(
"Context Size",
1,
10,
5,
help="Number of tokens to show before and after the code token.",
)
acts, acts_count = webapp_utils.get_code_acts(
model_name,
tokens_str,
code,
layer,
head,
ctx_size,
num_examples,
is_fsm=is_fsm,
)
st.write(
f"Token Activations for Layer {layer}{f' Head {head}' if head is not None else ''} Code {code} | "
f"Activates on {acts_count[0]} tokens on the acts dataset",
)
if not deploy:
webapp_utils.add_save_code_button(
demo_file_path,
acts_count[0],
save_regex=False,
button_text=True,
button_key_suffix="_token_acts",
)
st.markdown(webapp_utils.escape_markdown(acts), unsafe_allow_html=True)