|
import numpy as np |
|
import pandas as pd |
|
import streamlit as st |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
def wide_setup(): |
|
max_width = 1500 |
|
padding_top = 0 |
|
padding_right = 2 |
|
padding_bottom = 0 |
|
padding_left = 2 |
|
|
|
define_margins = f""" |
|
<style> |
|
.appview-container .main .block-container{{ |
|
max-width: {max_width}px; |
|
padding-top: {padding_top}rem; |
|
padding-right: {padding_right}rem; |
|
padding-left: {padding_left}rem; |
|
padding-bottom: {padding_bottom}rem; |
|
}} |
|
</style> |
|
""" |
|
hide_table_row_index = """ |
|
<style> |
|
tbody th {display:none} |
|
.blank {display:none} |
|
</style> |
|
""" |
|
st.markdown(define_margins, unsafe_allow_html=True) |
|
st.markdown(hide_table_row_index, unsafe_allow_html=True) |
|
|
|
def load_css(file_name): |
|
with open(file_name) as f: |
|
st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True) |
|
|
|
@st.cache(show_spinner=True,allow_output_mutation=True) |
|
def load_model(model_name): |
|
if model_name.startswith('albert'): |
|
from transformers import AlbertTokenizer, AlbertForMaskedLM |
|
from skeleton_modeling_albert import SkeletonAlbertForMaskedLM |
|
tokenizer = AlbertTokenizer.from_pretrained(model_name) |
|
model = AlbertForMaskedLM.from_pretrained(model_name) |
|
skeleton_model = SkeletonAlbertForMaskedLM |
|
elif model_name.startswith('bert'): |
|
from transformers import BertTokenizer, BertForMaskedLM |
|
from skeleton_modeling_bert import SkeletonBertForMaskedLM |
|
tokenizer = BertTokenizer.from_pretrained(model_name) |
|
model = BertForMaskedLM.from_pretrained(model_name) |
|
skeleton_model = SkeletonBertForMaskedLM |
|
elif model_name.startswith('roberta'): |
|
from transformers import RobertaTokenizer, RobertaForMaskedLM |
|
from skeleton_modeling_roberta import SkeletonRobertaForMaskedLM |
|
tokenizer = RobertaTokenizer.from_pretrained(model_name) |
|
model = RobertaForMaskedLM.from_pretrained(model_name) |
|
skeleton_model = SkeletonRobertaForMaskedLM |
|
return tokenizer,model,skeleton_model |
|
|
|
def clear_data(): |
|
for key in st.session_state: |
|
del st.session_state[key] |
|
|
|
def annotate_mask(sent_id,sent): |
|
show_instruction(f'Sentence {sent_id}',fontsize=16) |
|
input_sent = tokenizer(sent).input_ids |
|
decoded_sent = [tokenizer.decode([token]) for token in input_sent[1:-1]] |
|
st.session_state[f'decoded_sent_{sent_id}'] = decoded_sent |
|
char_nums = [len(word)+2 for word in decoded_sent] |
|
cols = st.columns(char_nums) |
|
if f'mask_locs_{sent_id}' not in st.session_state: |
|
st.session_state[f'mask_locs_{sent_id}'] = [] |
|
for word_id,(col,word) in enumerate(zip(cols,decoded_sent)): |
|
with col: |
|
if st.button(word,key=f'word_mask_{sent_id}_{word_id}'): |
|
if word_id not in st.session_state[f'mask_locs_{sent_id}']: |
|
st.session_state[f'mask_locs_{sent_id}'].append(word_id) |
|
else: |
|
st.session_state[f'mask_locs_{sent_id}'].remove(word_id) |
|
show_annotated_sentence(decoded_sent, |
|
mask_locs=st.session_state[f'mask_locs_{sent_id}']) |
|
|
|
def annotate_options(sent_id,sent): |
|
show_instruction(f'Sentence {sent_id}',fontsize=16) |
|
input_sent = tokenizer(sent).input_ids |
|
decoded_sent = [tokenizer.decode([token]) for token in input_sent[1:-1]] |
|
char_nums = [len(word)+2 for word in decoded_sent] |
|
cols = st.columns(char_nums) |
|
if f'option_locs_{sent_id}' not in st.session_state: |
|
st.session_state[f'option_locs_{sent_id}'] = [] |
|
for word_id,(col,word) in enumerate(zip(cols,decoded_sent)): |
|
with col: |
|
if st.button(word,key=f'word_option_{sent_id}_{word_id}'): |
|
if word_id not in st.session_state[f'option_locs_{sent_id}']: |
|
st.session_state[f'option_locs_{sent_id}'].append(word_id) |
|
else: |
|
st.session_state[f'option_locs_{sent_id}'].remove(word_id) |
|
show_annotated_sentence(decoded_sent, |
|
option_locs=st.session_state[f'option_locs_{sent_id}'], |
|
mask_locs=st.session_state[f'mask_locs_{sent_id}']) |
|
st.session_state[f'option_locs_{sent_id}'] = list(np.sort(st.session_state[f'option_locs_{sent_id}'])) |
|
st.session_state[f'mask_locs_{sent_id}'] = list(np.sort(st.session_state[f'mask_locs_{sent_id}'])) |
|
|
|
def show_annotated_sentence(sent,option_locs=[],mask_locs=[]): |
|
disp_style = '"font-family:san serif; color:Black; font-size: 20px"' |
|
prefix = f'<p style={disp_style}><span style="font-weight:bold">' |
|
style_list = [] |
|
for i, word in enumerate(sent): |
|
if i in mask_locs: |
|
style_list.append(f'<span style="color:Red">{word}</span>') |
|
elif i in option_locs: |
|
style_list.append(f'<span style="color:Blue">{word}</span>') |
|
else: |
|
style_list.append(f'{word}') |
|
disp = ' '.join(style_list) |
|
suffix = '</span></p>' |
|
return st.markdown(prefix + disp + suffix, unsafe_allow_html = True) |
|
|
|
def show_instruction(sent,fontsize=20): |
|
disp_style = f'"font-family:san serif; color:Black; font-size: {fontsize}px"' |
|
prefix = f'<p style={disp_style}><span style="font-weight:bold">' |
|
suffix = '</span></p>' |
|
return st.markdown(prefix + sent + suffix, unsafe_allow_html = True) |
|
|
|
def create_interventions(token_id,interv_types,num_heads,multihead=False,heads=[]): |
|
interventions = {} |
|
for rep in ['lay','qry','key','val']: |
|
if rep in interv_types: |
|
if multihead: |
|
interventions[rep] = [(head_id,token_id,[0,1]) for head_id in range(num_heads)] |
|
else: |
|
interventions[rep] = [(head_id,token_id,[i,i+len(heads)]) for i,head_id in enumerate(heads)] |
|
else: |
|
interventions[rep] = [] |
|
return interventions |
|
|
|
def separate_options(option_locs): |
|
assert np.sum(np.diff(option_locs)>1)==1 |
|
sep = list(np.diff(option_locs)>1).index(1)+1 |
|
option_1_locs, option_2_locs = option_locs[:sep], option_locs[sep:] |
|
if len(option_1_locs)>1: |
|
assert np.all(np.diff(option_1_locs)==1) |
|
if len(option_2_locs)>1: |
|
assert np.all(np.diff(option_2_locs)==1) |
|
return option_1_locs, option_2_locs |
|
|
|
def mask_out(input_ids,pron_locs,option_locs,mask_id): |
|
if len(pron_locs)>1: |
|
assert np.all(np.diff(pron_locs)==1) |
|
|
|
return input_ids[:pron_locs[0]+1] + [mask_id for _ in range(len(option_locs))] + input_ids[pron_locs[-1]+2:] |
|
|
|
|
|
def run_intervention(interventions,batch_size,skeleton_model,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs): |
|
probs = [] |
|
for masked_ids, option_tokens in zip([masked_ids_option_1, masked_ids_option_2],[option_1_tokens,option_2_tokens]): |
|
input_ids = torch.tensor([ |
|
*[masked_ids['sent_1'] for _ in range(batch_size)], |
|
*[masked_ids['sent_2'] for _ in range(batch_size)] |
|
]) |
|
outputs = skeleton_model(model,input_ids,interventions=interventions) |
|
logprobs = F.log_softmax(outputs['logits'], dim = -1) |
|
logprobs_1, logprobs_2 = logprobs[:batch_size], logprobs[batch_size:] |
|
evals_1 = [logprobs_1[:,pron_locs['sent_1'][0]+1+i,token].numpy() for i,token in enumerate(option_tokens)] |
|
evals_2 = [logprobs_2[:,pron_locs['sent_2'][0]+1+i,token].numpy() for i,token in enumerate(option_tokens)] |
|
probs.append([np.exp(np.mean(evals_1,axis=0)),np.exp(np.mean(evals_2,axis=0))]) |
|
probs = np.array(probs) |
|
assert probs.shape[0]==2 and probs.shape[1]==2 and probs.shape[2]==batch_size |
|
return probs |
|
|
|
def show_results(effect_array,masked_sent,token_id_list,num_layers): |
|
cols = st.columns(len(masked_sent)-2) |
|
for col_id,col in enumerate(cols): |
|
with col: |
|
st.write(tokenizer.decode([masked_sent[col_id+1]])) |
|
if col_id in token_id_list: |
|
interv_id = token_id_list.index(col_id) |
|
fig,ax = plt.subplots() |
|
ax.set_box_aspect(num_layers) |
|
ax.imshow(effect_array[:,interv_id:interv_id+1],cmap=sns.color_palette("light:r", as_cmap=True), |
|
vmin=effect_array.min(),vmax=effect_array.max()) |
|
ax.set_xticks([]) |
|
ax.set_xticklabels([]) |
|
ax.set_yticks([]) |
|
ax.set_yticklabels([]) |
|
ax.spines['top'].set_visible(False) |
|
ax.spines['bottom'].set_visible(False) |
|
ax.spines['right'].set_visible(False) |
|
ax.spines['left'].set_visible(False) |
|
st.pyplot(fig) |
|
|
|
if __name__=='__main__': |
|
wide_setup() |
|
load_css('style.css') |
|
|
|
if 'page_status' not in st.session_state: |
|
st.session_state['page_status'] = 'model_selection' |
|
|
|
if st.session_state['page_status']=='model_selection': |
|
show_instruction('0. Select the model and click "Confirm"',fontsize=16) |
|
model_name = st.selectbox('Please select the model from below.', |
|
('bert-base-uncased','bert-large-cased', |
|
'roberta-base','roberta-large', |
|
'albert-base-v2','albert-large-v2','albert-xlarge-v2','albert-xxlarge-v2'), |
|
index=3,label_visibility='collapsed') |
|
st.session_state['model_name'] = model_name |
|
if st.button('Confirm',key='confirm_models'): |
|
st.session_state['page_status'] = 'type_in' |
|
st.experimental_rerun() |
|
|
|
if st.session_state['page_status']!='model_selection': |
|
tokenizer,model,skeleton_model = load_model(st.session_state['model_name']) |
|
num_layers, num_heads = model.config.num_hidden_layers, model.config.num_attention_heads |
|
mask_id = tokenizer(tokenizer.mask_token).input_ids[1:-1][0] |
|
|
|
if st.session_state['page_status']=='type_in': |
|
show_instruction('1. Type in the sentences and click "Tokenize"',fontsize=16) |
|
sent_1 = st.text_input('Sentence 1',value="Paul tried to call George on the phone, but he wasn't successful.") |
|
sent_2 = st.text_input('Sentence 2',value="Paul tried to call George on the phone, but he wasn't available.") |
|
if st.button('Tokenize'): |
|
st.session_state['page_status'] = 'annotate_mask' |
|
st.session_state['sent_1'] = sent_1 |
|
st.session_state['sent_2'] = sent_2 |
|
st.experimental_rerun() |
|
|
|
if st.session_state['page_status']=='annotate_mask': |
|
sent_1 = st.session_state['sent_1'] |
|
sent_2 = st.session_state['sent_2'] |
|
|
|
show_instruction('2. Select sites to mask out and click "Confirm"',fontsize=16) |
|
|
|
annotate_mask(1,sent_1) |
|
show_instruction('------------------------------',fontsize=24) |
|
annotate_mask(2,sent_2) |
|
if st.button('Confirm',key='confirm_mask'): |
|
st.session_state['page_status'] = 'annotate_options' |
|
st.experimental_rerun() |
|
|
|
if st.session_state['page_status'] == 'annotate_options': |
|
sent_1 = st.session_state['sent_1'] |
|
sent_2 = st.session_state['sent_2'] |
|
|
|
show_instruction('3. Select options and click "Confirm"',fontsize=16) |
|
|
|
annotate_options(1,sent_1) |
|
show_instruction('------------------------------',fontsize=24) |
|
annotate_options(2,sent_2) |
|
if st.button('Confirm',key='confirm_option'): |
|
st.session_state['page_status'] = 'analysis' |
|
st.experimental_rerun() |
|
|
|
if st.session_state['page_status']=='analysis': |
|
interv_reps = st.multiselect('Select the types of representations to intervene.',['layer','query','key','value']) |
|
rep_dict = {'layer':'lay','query':'qry','key':'key','value':'val'} |
|
multihead = not st.checkbox('Perform individual head analysis (takes time)') |
|
if not multihead: |
|
heads = st.multiselect('Select heads to intervene.',list(np.arange(1,num_heads+1))) |
|
else: |
|
heads = [] |
|
|
|
if st.button('Run',key='run'): |
|
st.session_state['reps'] = [rep_dict[rep] for rep in interv_reps] |
|
st.session_state['multihead'] = multihead |
|
st.session_state['heads'] = heads |
|
st.session_state['page_status'] = 'results' |
|
st.experimental_rerun() |
|
|
|
if st.session_state['page_status']=='results': |
|
sent_1 = st.session_state['sent_1'] |
|
sent_2 = st.session_state['sent_2'] |
|
multihead = st.session_state['multihead'] |
|
heads = st.session_state['heads'] |
|
reps = st.session_state['reps'] |
|
|
|
option_1_locs, option_2_locs = {}, {} |
|
pron_locs = {} |
|
input_ids_dict = {} |
|
masked_ids_option_1 = {} |
|
masked_ids_option_2 = {} |
|
for sent_id in [1,2]: |
|
option_1_locs[f'sent_{sent_id}'], option_2_locs[f'sent_{sent_id}'] = separate_options(st.session_state[f'option_locs_{sent_id}']) |
|
pron_locs[f'sent_{sent_id}'] = st.session_state[f'mask_locs_{sent_id}'] |
|
input_ids_dict[f'sent_{sent_id}'] = tokenizer(st.session_state[f'sent_{sent_id}']).input_ids |
|
|
|
masked_ids_option_1[f'sent_{sent_id}'] = mask_out(input_ids_dict[f'sent_{sent_id}'], |
|
pron_locs[f'sent_{sent_id}'], |
|
option_1_locs[f'sent_{sent_id}'],mask_id) |
|
masked_ids_option_2[f'sent_{sent_id}'] = mask_out(input_ids_dict[f'sent_{sent_id}'], |
|
pron_locs[f'sent_{sent_id}'], |
|
option_2_locs[f'sent_{sent_id}'],mask_id) |
|
|
|
option_1_tokens_1 = np.array(input_ids_dict['sent_1'])[np.array(option_1_locs['sent_1'])+1] |
|
option_1_tokens_2 = np.array(input_ids_dict['sent_2'])[np.array(option_1_locs['sent_2'])+1] |
|
option_2_tokens_1 = np.array(input_ids_dict['sent_1'])[np.array(option_2_locs['sent_1'])+1] |
|
option_2_tokens_2 = np.array(input_ids_dict['sent_2'])[np.array(option_2_locs['sent_2'])+1] |
|
assert np.all(option_1_tokens_1==option_1_tokens_2) and np.all(option_2_tokens_1==option_2_tokens_2) |
|
option_1_tokens = option_1_tokens_1 |
|
option_2_tokens = option_2_tokens_1 |
|
|
|
interventions = [{'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)] |
|
probs_original = run_intervention(interventions,1,skeleton_model,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs) |
|
df = pd.DataFrame(data=[[probs_original[0,0][0],probs_original[1,0][0]], |
|
[probs_original[0,1][0],probs_original[1,1][0]]], |
|
columns=[tokenizer.decode(option_1_tokens),tokenizer.decode(option_2_tokens)], |
|
index=['Sentence 1','Sentence 2']) |
|
cols = st.columns(3) |
|
with cols[1]: |
|
show_instruction('Probability of predicting each option in each sentence',fontsize=12) |
|
st.dataframe(df.style.highlight_max(axis=1),use_container_width=True) |
|
|
|
compare_1 = np.array(masked_ids_option_1['sent_1'])!=np.array(masked_ids_option_1['sent_2']) |
|
compare_2 = np.array(masked_ids_option_2['sent_1'])!=np.array(masked_ids_option_2['sent_2']) |
|
assert np.all(compare_1.astype(int)==compare_2.astype(int)) |
|
context_locs = list(np.arange(len(masked_ids_option_1['sent_1']))[compare_1]-1) |
|
|
|
assert np.all(np.array(pron_locs['sent_1'])==np.array(pron_locs['sent_2'])) |
|
assert np.all(np.array(option_1_locs['sent_1'])==np.array(option_1_locs['sent_2'])) |
|
assert np.all(np.array(option_2_locs['sent_1'])==np.array(option_2_locs['sent_2'])) |
|
token_id_list = pron_locs['sent_1'] + option_1_locs['sent_1'] + option_2_locs['sent_1'] + context_locs |
|
|
|
effect_array = [] |
|
for token_id in token_id_list: |
|
token_id += 1 |
|
effect_list = [] |
|
for layer_id in range(num_layers): |
|
interventions = [create_interventions(token_id,reps,num_heads,multihead,[head_id-1 for head_id in heads]) |
|
if i==layer_id else {'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)] |
|
if multihead: |
|
probs = run_intervention(interventions,1,skeleton_model,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs) |
|
else: |
|
probs = run_intervention(interventions,len(heads),skeleton_model,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs) |
|
effect = ((probs_original-probs)[0,0] + (probs_original-probs)[1,1] + (probs-probs_original)[0,1] + (probs-probs_original)[1,0])/4 |
|
effect_list.append(effect) |
|
effect_array.append(effect_list) |
|
effect_array = np.transpose(np.array(effect_array),(1,0,2)) |
|
|
|
if multihead: |
|
show_results(effect_array[:,:,0],masked_ids_option_1['sent_1'],token_id_list,num_layers) |
|
else: |
|
tabs = st.tabs([str(head_id) for head_id in heads]) |
|
for i,tab in enumerate(tabs): |
|
with tab: |
|
show_results(effect_array[:,:,i],masked_ids_option_1['sent_1'],token_id_list,num_layers) |
|
|