import numpy as np import pandas as pd import time import streamlit as st import matplotlib.pyplot as plt import seaborn as sns import jax import jax.numpy as jnp import torch import torch.nn.functional as F from transformers import AlbertTokenizer, AlbertForMaskedLM #from custom_modeling_albert_flax import CustomFlaxAlbertForMaskedLM from skeleton_modeling_albert import SkeletonAlbertForMaskedLM def wide_setup(): max_width = 1500 padding_top = 0 padding_right = 2 padding_bottom = 0 padding_left = 2 define_margins = f""" """ hide_table_row_index = """ """ st.markdown(define_margins, unsafe_allow_html=True) st.markdown(hide_table_row_index, unsafe_allow_html=True) def load_css(file_name): with open(file_name) as f: st.markdown(f'', unsafe_allow_html=True) @st.cache(show_spinner=True,allow_output_mutation=True) def load_model(): tokenizer = AlbertTokenizer.from_pretrained('albert-xxlarge-v2') #model = CustomFlaxAlbertForMaskedLM.from_pretrained('albert-xxlarge-v2',from_pt=True) model = AlbertForMaskedLM.from_pretrained('albert-xxlarge-v2') return tokenizer,model def clear_data(): for key in st.session_state: del st.session_state[key] def annotate_mask(sent_id,sent): st.write(f'Sentence {sent_id}') input_sent = tokenizer(sent).input_ids decoded_sent = [tokenizer.decode([token]) for token in input_sent[1:-1]] st.session_state[f'decoded_sent_{sent_id}'] = decoded_sent char_nums = [len(word)+2 for word in decoded_sent] cols = st.columns(char_nums) if f'mask_locs_{sent_id}' not in st.session_state: st.session_state[f'mask_locs_{sent_id}'] = [] for word_id,(col,word) in enumerate(zip(cols,decoded_sent)): with col: if st.button(word,key=f'word_mask_{sent_id}_{word_id}'): if word_id not in st.session_state[f'mask_locs_{sent_id}']: st.session_state[f'mask_locs_{sent_id}'].append(word_id) else: st.session_state[f'mask_locs_{sent_id}'].remove(word_id) show_annotated_sentence(decoded_sent, mask_locs=st.session_state[f'mask_locs_{sent_id}']) def annotate_options(sent_id,sent): st.write(f'Sentence {sent_id}') input_sent = tokenizer(sent).input_ids decoded_sent = [tokenizer.decode([token]) for token in input_sent[1:-1]] char_nums = [len(word)+2 for word in decoded_sent] cols = st.columns(char_nums) if f'option_locs_{sent_id}' not in st.session_state: st.session_state[f'option_locs_{sent_id}'] = [] for word_id,(col,word) in enumerate(zip(cols,decoded_sent)): with col: if st.button(word,key=f'word_option_{sent_id}_{word_id}'): if word_id not in st.session_state[f'option_locs_{sent_id}']: st.session_state[f'option_locs_{sent_id}'].append(word_id) else: st.session_state[f'option_locs_{sent_id}'].remove(word_id) show_annotated_sentence(decoded_sent, option_locs=st.session_state[f'option_locs_{sent_id}'], mask_locs=st.session_state[f'mask_locs_{sent_id}']) def show_annotated_sentence(sent,option_locs=[],mask_locs=[]): disp_style = '"font-family:san serif; color:Black; font-size: 20px"' prefix = f'
' style_list = [] for i, word in enumerate(sent): if i in mask_locs: style_list.append(f'{word}') elif i in option_locs: style_list.append(f'{word}') else: style_list.append(f'{word}') disp = ' '.join(style_list) suffix = '
' return st.markdown(prefix + disp + suffix, unsafe_allow_html = True) def show_instruction(sent,fontsize=20): disp_style = f'"font-family:san serif; color:Black; font-size: {fontsize}px"' prefix = f'' suffix = '
' return st.markdown(prefix + sent + suffix, unsafe_allow_html = True) def create_interventions(token_id,interv_types,num_heads): interventions = {} for rep in ['lay','qry','key','val']: if rep in interv_types: interventions[rep] = [(head_id,token_id,[head_id,head_id+num_heads]) for head_id in range(num_heads)] else: interventions[rep] = [] return interventions def separate_options(option_locs): assert np.sum(np.diff(option_locs)>1)==1 sep = list(np.diff(option_locs)>1).index(1)+1 option_1_locs, option_2_locs = option_locs[:sep], option_locs[sep:] if len(option_1_locs)>1: assert np.all(np.diff(option_1_locs)==1) if len(option_2_locs)>1: assert np.all(np.diff(option_2_locs)==1) return option_1_locs, option_2_locs def mask_out(input_ids,pron_locs,option_locs,mask_id): if len(pron_locs)>1: assert np.all(np.diff(pron_locs)==1) # note annotations are shifted by 1 because special tokens were omitted return input_ids[:pron_locs[0]+1] + [mask_id for _ in range(len(option_locs))] + input_ids[pron_locs[-1]+2:] def run(interventions,batch_size,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs): probs = [] for masked_ids, option_tokens in zip([masked_ids_option_1, masked_ids_option_2],[option_1_tokens,option_2_tokens]]): input_ids = torch.tensor([ *[masked_ids['sent_1'] for _ in range(batch_size)], *[masked_ids['sent_2'] for _ in range(batch_size)] ]) outputs = SkeletonAlbertForMaskedLM(model,input_ids,interventions=interventions) logprobs = F.log_softmax(outputs['logits'], dim = -1) logprobs_1, logprobs_2 = logprobs[:batch_size], logprobs[batch_size:] evals_1 = [logprobs_1[:,pron_locs['sent_1'][0]+1+i,token].numpy() for i,token in enumerate(option_tokens)] evals_2 = [logprobs_2[:,pron_locs['sent_2'][0]+1+i,token].numpy() for i,token in enumerate(option_tokens)] probs.append([np.exp(np.mean(evals_1,axis=0)),np.exp(np.mean(evals_2,axis=0))]) probs = np.array(probs) assert probs.shape[0]==2 and probs.shape[1]==2 and probs.shape[2]==batch_size return probs if __name__=='__main__': wide_setup() load_css('style.css') tokenizer,model = load_model() num_layers, num_heads = 12, 64 mask_id = tokenizer('[MASK]').input_ids[1:-1][0] main_area = st.empty() if 'page_status' not in st.session_state: st.session_state['page_status'] = 'type_in' if st.session_state['page_status']=='type_in': show_instruction('1. Type in the sentences and click "Tokenize"') sent_1 = st.text_input('Sentence 1',value='It is better to play a prank on Samuel than Craig because he gets angry less often.') sent_2 = st.text_input('Sentence 2',value='It is better to play a prank on Samuel than Craig because he gets angry more often.') if st.button('Tokenize'): st.session_state['page_status'] = 'annotate_mask' st.session_state['sent_1'] = sent_1 st.session_state['sent_2'] = sent_2 st.experimental_rerun() if st.session_state['page_status']=='annotate_mask': sent_1 = st.session_state['sent_1'] sent_2 = st.session_state['sent_2'] show_instruction('2. Select sites to mask out and click "Confirm"') annotate_mask(1,sent_1) annotate_mask(2,sent_2) if st.button('Confirm',key='mask'): st.session_state['page_status'] = 'annotate_options' st.experimental_rerun() if st.session_state['page_status'] == 'annotate_options': sent_1 = st.session_state['sent_1'] sent_2 = st.session_state['sent_2'] show_instruction('3. Select options and click "Confirm"') annotate_options(1,sent_1) annotate_options(2,sent_2) if st.button('Confirm',key='option'): st.session_state['page_status'] = 'analysis' st.experimental_rerun() if st.session_state['page_status']=='analysis': with main_area.container(): sent_1 = st.session_state['sent_1'] sent_2 = st.session_state['sent_2'] show_annotated_sentence(st.session_state['decoded_sent_1'], option_locs=st.session_state['option_locs_1'], mask_locs=st.session_state['mask_locs_1']) show_annotated_sentence(st.session_state['decoded_sent_2'], option_locs=st.session_state['option_locs_2'], mask_locs=st.session_state['mask_locs_2']) option_1_locs, option_2_locs = {}, {} pron_locs = {} input_ids_dict = {} masked_ids_option_1 = {} masked_ids_option_2 = {} for sent_id in [1,2]: option_1_locs[f'sent_{sent_id}'], option_2_locs[f'sent_{sent_id}'] = separate_options(st.session_state[f'option_locs_{sent_id}']) pron_locs[f'sent_{sent_id}'] = st.session_state[f'mask_locs_{sent_id}'] input_ids_dict[f'sent_{sent_id}'] = tokenizer(st.session_state[f'sent_{sent_id}']).input_ids masked_ids_option_1[f'sent_{sent_id}'] = mask_out(input_ids_dict[f'sent_{sent_id}'], pron_locs[f'sent_{sent_id}'], option_1_locs[f'sent_{sent_id}'],mask_id) masked_ids_option_2[f'sent_{sent_id}'] = mask_out(input_ids_dict[f'sent_{sent_id}'], pron_locs[f'sent_{sent_id}'], option_2_locs[f'sent_{sent_id}'],mask_id) st.write(option_1_locs) st.write(option_2_locs) st.write(pron_locs) for token_ids in [masked_ids_option_1['sent_1'],masked_ids_option_1['sent_2'],masked_ids_option_2['sent_1'],masked_ids_option_2['sent_2']]: st.write(' '.join([tokenizer.decode([token]) for token in token_ids])) option_1_tokens_1 = np.array(input_ids_dict['sent_1'])[np.array(option_1_locs['sent_1'])+1] option_1_tokens_2 = np.array(input_ids_dict['sent_2'])[np.array(option_1_locs['sent_2'])+1] option_2_tokens_1 = np.array(input_ids_dict['sent_1'])[np.array(option_2_locs['sent_1'])+1] option_2_tokens_2 = np.array(input_ids_dict['sent_2'])[np.array(option_2_locs['sent_2'])+1] assert np.all(option_1_tokens_1==option_1_tokens_2) and np.all(option_2_tokens_1==option_2_tokens_2) option_1_tokens = option_1_tokens_1 option_2_tokens = option_2_tokens_1 interventions = [{'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)] probs_original = run(interventions,1,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs) st.write(probs_original) print(probs_original) if st.session_state['page_status'] == 'finish_debug': for layer_id in range(num_layers): interventions = [create_interventions(16,['lay','qry','key','val'],num_heads) if i==layer_id else {'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)] probs = run(interventions,num_heads,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)