taka-yamakoshi
commited on
Commit
•
ebfe870
1
Parent(s):
5958ae4
add model options
Browse files- app.py +26 -15
- skeleton_modeling_bert.py +73 -0
- skeleton_modeling_roberta.py +73 -0
app.py
CHANGED
@@ -10,10 +10,7 @@ import seaborn as sns
|
|
10 |
import torch
|
11 |
import torch.nn.functional as F
|
12 |
|
13 |
-
from transformers import AlbertTokenizer, AlbertForMaskedLM
|
14 |
-
|
15 |
#from custom_modeling_albert_flax import CustomFlaxAlbertForMaskedLM
|
16 |
-
from skeleton_modeling_albert import SkeletonAlbertForMaskedLM
|
17 |
|
18 |
def wide_setup():
|
19 |
max_width = 1500
|
@@ -48,10 +45,23 @@ def load_css(file_name):
|
|
48 |
|
49 |
@st.cache(show_spinner=True,allow_output_mutation=True)
|
50 |
def load_model(model_name):
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
def clear_data():
|
57 |
for key in st.session_state:
|
@@ -147,14 +157,14 @@ def mask_out(input_ids,pron_locs,option_locs,mask_id):
|
|
147 |
return input_ids[:pron_locs[0]+1] + [mask_id for _ in range(len(option_locs))] + input_ids[pron_locs[-1]+2:]
|
148 |
|
149 |
|
150 |
-
def run_intervention(interventions,batch_size,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs):
|
151 |
probs = []
|
152 |
for masked_ids, option_tokens in zip([masked_ids_option_1, masked_ids_option_2],[option_1_tokens,option_2_tokens]):
|
153 |
input_ids = torch.tensor([
|
154 |
*[masked_ids['sent_1'] for _ in range(batch_size)],
|
155 |
*[masked_ids['sent_2'] for _ in range(batch_size)]
|
156 |
])
|
157 |
-
outputs =
|
158 |
logprobs = F.log_softmax(outputs['logits'], dim = -1)
|
159 |
logprobs_1, logprobs_2 = logprobs[:batch_size], logprobs[batch_size:]
|
160 |
evals_1 = [logprobs_1[:,pron_locs['sent_1'][0]+1+i,token].numpy() for i,token in enumerate(option_tokens)]
|
@@ -181,9 +191,10 @@ if __name__=='__main__':
|
|
181 |
st.session_state['page_status'] = 'type_in'
|
182 |
st.experimental_rerun()
|
183 |
|
184 |
-
|
185 |
-
|
186 |
-
|
|
|
187 |
|
188 |
if st.session_state['page_status']=='type_in':
|
189 |
show_instruction('1. Type in the sentences and click "Tokenize"',fontsize=16)
|
@@ -263,7 +274,7 @@ if __name__=='__main__':
|
|
263 |
option_2_tokens = option_2_tokens_1
|
264 |
|
265 |
interventions = [{'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
|
266 |
-
probs_original = run_intervention(interventions,1,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
|
267 |
df = pd.DataFrame(data=[[probs_original[0,0][0],probs_original[1,0][0]],
|
268 |
[probs_original[0,1][0],probs_original[1,1][0]]],
|
269 |
columns=[tokenizer.decode(option_1_tokens),tokenizer.decode(option_2_tokens)],
|
@@ -292,9 +303,9 @@ if __name__=='__main__':
|
|
292 |
for layer_id in range(num_layers):
|
293 |
interventions = [create_interventions(token_id,['lay','qry','key','val'],num_heads,multihead) if i==layer_id else {'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
|
294 |
if multihead:
|
295 |
-
probs = run_intervention(interventions,1,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
|
296 |
else:
|
297 |
-
probs = run_intervention(interventions,num_heads,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
|
298 |
effect = ((probs_original-probs)[0,0] + (probs_original-probs)[1,1] + (probs-probs_original)[0,1] + (probs-probs_original)[1,0])/4
|
299 |
effect_list.append(effect)
|
300 |
effect_array.append(effect_list)
|
|
|
10 |
import torch
|
11 |
import torch.nn.functional as F
|
12 |
|
|
|
|
|
13 |
#from custom_modeling_albert_flax import CustomFlaxAlbertForMaskedLM
|
|
|
14 |
|
15 |
def wide_setup():
|
16 |
max_width = 1500
|
|
|
45 |
|
46 |
@st.cache(show_spinner=True,allow_output_mutation=True)
|
47 |
def load_model(model_name):
|
48 |
+
if model_name.startswith('albert'):
|
49 |
+
from transformers import AlbertTokenizer, AlbertForMaskedLM
|
50 |
+
from skeleton_modeling_albert import SkeletonAlbertForMaskedLM
|
51 |
+
tokenizer = AlbertTokenizer.from_pretrained(model_name)
|
52 |
+
model = AlbertForMaskedLM.from_pretrained(model_name)
|
53 |
+
skeleton_model = SkeletonAlbertForMaskedLM
|
54 |
+
elif model_name.startswith('bert'):
|
55 |
+
from transformers import BertTokenizer, BertForMaskedLM
|
56 |
+
from skeleton_modeling_bert import SkeletonBertForMaskedLM
|
57 |
+
tokenizer = BertTokenizer.from_pretrained(model_name)
|
58 |
+
model = BertForMaskedLM.from_pretrained(model_name)
|
59 |
+
elif model_name.startswith('roberta'):
|
60 |
+
from transformers import RobertaTokenizer, RobertaForMaskedLM
|
61 |
+
from skeleton_modeling_roberta import SkeletonRobertaForMaskedLM
|
62 |
+
tokenizer = RobertaTokenizer.from_pretrained(model_name)
|
63 |
+
model = RobertaForMaskedLM.from_pretrained(model_name)
|
64 |
+
return tokenizer,model,skeleton_model
|
65 |
|
66 |
def clear_data():
|
67 |
for key in st.session_state:
|
|
|
157 |
return input_ids[:pron_locs[0]+1] + [mask_id for _ in range(len(option_locs))] + input_ids[pron_locs[-1]+2:]
|
158 |
|
159 |
|
160 |
+
def run_intervention(interventions,batch_size,skeleton_model,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs):
|
161 |
probs = []
|
162 |
for masked_ids, option_tokens in zip([masked_ids_option_1, masked_ids_option_2],[option_1_tokens,option_2_tokens]):
|
163 |
input_ids = torch.tensor([
|
164 |
*[masked_ids['sent_1'] for _ in range(batch_size)],
|
165 |
*[masked_ids['sent_2'] for _ in range(batch_size)]
|
166 |
])
|
167 |
+
outputs = skeleton_model(model,input_ids,interventions=interventions)
|
168 |
logprobs = F.log_softmax(outputs['logits'], dim = -1)
|
169 |
logprobs_1, logprobs_2 = logprobs[:batch_size], logprobs[batch_size:]
|
170 |
evals_1 = [logprobs_1[:,pron_locs['sent_1'][0]+1+i,token].numpy() for i,token in enumerate(option_tokens)]
|
|
|
191 |
st.session_state['page_status'] = 'type_in'
|
192 |
st.experimental_rerun()
|
193 |
|
194 |
+
if st.session_state['page_status']!='model_selection':
|
195 |
+
tokenizer,model,skeleton_model = load_model(st.session_state['model_name'])
|
196 |
+
num_layers, num_heads = model.config.num_hidden_layers, model.config.num_attention_heads
|
197 |
+
mask_id = tokenizer('[MASK]').input_ids[1:-1][0]
|
198 |
|
199 |
if st.session_state['page_status']=='type_in':
|
200 |
show_instruction('1. Type in the sentences and click "Tokenize"',fontsize=16)
|
|
|
274 |
option_2_tokens = option_2_tokens_1
|
275 |
|
276 |
interventions = [{'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
|
277 |
+
probs_original = run_intervention(interventions,1,skeleton_model,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
|
278 |
df = pd.DataFrame(data=[[probs_original[0,0][0],probs_original[1,0][0]],
|
279 |
[probs_original[0,1][0],probs_original[1,1][0]]],
|
280 |
columns=[tokenizer.decode(option_1_tokens),tokenizer.decode(option_2_tokens)],
|
|
|
303 |
for layer_id in range(num_layers):
|
304 |
interventions = [create_interventions(token_id,['lay','qry','key','val'],num_heads,multihead) if i==layer_id else {'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
|
305 |
if multihead:
|
306 |
+
probs = run_intervention(interventions,1,skeleton_model,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
|
307 |
else:
|
308 |
+
probs = run_intervention(interventions,num_heads,skeleton_model,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
|
309 |
effect = ((probs_original-probs)[0,0] + (probs_original-probs)[1,1] + (probs-probs_original)[0,1] + (probs-probs_original)[1,0])/4
|
310 |
effect_list.append(effect)
|
311 |
effect_array.append(effect_list)
|
skeleton_modeling_bert.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import math
|
5 |
+
|
6 |
+
@torch.no_grad()
|
7 |
+
def SkeletonBertLayer(layer_id,layer,hidden,interventions):
|
8 |
+
attention_layer = layer.attention.self
|
9 |
+
num_heads = attention_layer.num_attention_heads
|
10 |
+
head_dim = attention_layer.attention_head_size
|
11 |
+
assert num_heads*head_dim == hidden.shape[2]
|
12 |
+
|
13 |
+
qry = attention_layer.query(hidden)
|
14 |
+
key = attention_layer.key(hidden)
|
15 |
+
val = attention_layer.value(hidden)
|
16 |
+
|
17 |
+
assert qry.shape == hidden.shape
|
18 |
+
assert key.shape == hidden.shape
|
19 |
+
assert val.shape == hidden.shape
|
20 |
+
|
21 |
+
# swap representations
|
22 |
+
reps = {
|
23 |
+
'lay': hidden,
|
24 |
+
'qry': qry,
|
25 |
+
'key': key,
|
26 |
+
'val': val,
|
27 |
+
}
|
28 |
+
for rep_type in ['lay','qry','key','val']:
|
29 |
+
interv_rep = interventions[layer_id][rep_type]
|
30 |
+
new_state = reps[rep_type].clone()
|
31 |
+
for head_id, pos, swap_ids in interv_rep:
|
32 |
+
new_state[swap_ids[0],:,head_dim*head_id:head_dim*(head_id+1)][pos,:] = reps[rep_type][swap_ids[1],:,head_dim*head_id:head_dim*(head_id+1)][pos,:]
|
33 |
+
new_state[swap_ids[1],:,head_dim*head_id:head_dim*(head_id+1)][pos,:] = reps[rep_type][swap_ids[0],:,head_dim*head_id:head_dim*(head_id+1)][pos,:]
|
34 |
+
reps[rep_type] = new_state.clone()
|
35 |
+
|
36 |
+
hidden = reps['lay'].clone()
|
37 |
+
qry = reps['qry'].clone()
|
38 |
+
key = reps['key'].clone()
|
39 |
+
val = reps['val'].clone()
|
40 |
+
|
41 |
+
|
42 |
+
#split into multiple heads
|
43 |
+
split_qry = qry.view(*(qry.size()[:-1]+(num_heads,head_dim))).permute(0,2,1,3)
|
44 |
+
split_key = key.view(*(key.size()[:-1]+(num_heads,head_dim))).permute(0,2,1,3)
|
45 |
+
split_val = val.view(*(val.size()[:-1]+(num_heads,head_dim))).permute(0,2,1,3)
|
46 |
+
|
47 |
+
#calculate the attention matrix
|
48 |
+
attn_mat = F.softmax(split_qry@split_key.permute(0,1,3,2)/math.sqrt(head_dim),dim=-1)
|
49 |
+
|
50 |
+
z_rep_indiv = attn_mat@split_val
|
51 |
+
z_rep = z_rep_indiv.permute(0,2,1,3).reshape(*hidden.size())
|
52 |
+
|
53 |
+
hidden_post_attn_res = layer.attention.output.dense(z_rep)+hidden # residual connection
|
54 |
+
hidden_post_attn = layer.attention.output.LayerNorm(hidden_post_attn_res) # layer_norm
|
55 |
+
|
56 |
+
hidden_post_interm = layer.intermediate(hidden_post_attn) # massive feed forward
|
57 |
+
hidden_post_interm_res = layer.output.dense(hidden_post_interm)+hidden_post_attn # residual connection
|
58 |
+
new_hidden = layer.output.LayerNorm(hidden_post_interm_res) # layer_norm
|
59 |
+
return new_hidden
|
60 |
+
|
61 |
+
def SkeletonBertForMaskedLM(model,input_ids,interventions):
|
62 |
+
core_model = model.bert
|
63 |
+
lm_head = model.cls
|
64 |
+
output_hidden = []
|
65 |
+
with torch.no_grad():
|
66 |
+
hidden = core_model.embeddings(input_ids)
|
67 |
+
output_hidden.append(hidden)
|
68 |
+
for layer_id in range(model.config.num_hidden_layers):
|
69 |
+
layer = core_model.encoder.layer[layer_id]
|
70 |
+
hidden = SkeletonBertLayer(layer_id,layer,hidden,interventions)
|
71 |
+
output_hidden.append(hidden)
|
72 |
+
logits = lm_head(hidden)
|
73 |
+
return {'logits':logits,'hidden_states':output_hidden}
|
skeleton_modeling_roberta.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import math
|
5 |
+
|
6 |
+
@torch.no_grad()
|
7 |
+
def SkeletonRobertaLayer(layer_id,layer,hidden,interventions):
|
8 |
+
attention_layer = layer.attention.self
|
9 |
+
num_heads = attention_layer.num_attention_heads
|
10 |
+
head_dim = attention_layer.attention_head_size
|
11 |
+
assert num_heads*head_dim == hidden.shape[2]
|
12 |
+
|
13 |
+
qry = attention_layer.query(hidden)
|
14 |
+
key = attention_layer.key(hidden)
|
15 |
+
val = attention_layer.value(hidden)
|
16 |
+
|
17 |
+
assert qry.shape == hidden.shape
|
18 |
+
assert key.shape == hidden.shape
|
19 |
+
assert val.shape == hidden.shape
|
20 |
+
|
21 |
+
# swap representations
|
22 |
+
reps = {
|
23 |
+
'lay': hidden,
|
24 |
+
'qry': qry,
|
25 |
+
'key': key,
|
26 |
+
'val': val,
|
27 |
+
}
|
28 |
+
for rep_type in ['lay','qry','key','val']:
|
29 |
+
interv_rep = interventions[layer_id][rep_type]
|
30 |
+
new_state = reps[rep_type].clone()
|
31 |
+
for head_id, pos, swap_ids in interv_rep:
|
32 |
+
new_state[swap_ids[0],:,head_dim*head_id:head_dim*(head_id+1)][pos,:] = reps[rep_type][swap_ids[1],:,head_dim*head_id:head_dim*(head_id+1)][pos,:]
|
33 |
+
new_state[swap_ids[1],:,head_dim*head_id:head_dim*(head_id+1)][pos,:] = reps[rep_type][swap_ids[0],:,head_dim*head_id:head_dim*(head_id+1)][pos,:]
|
34 |
+
reps[rep_type] = new_state.clone()
|
35 |
+
|
36 |
+
hidden = reps['lay'].clone()
|
37 |
+
qry = reps['qry'].clone()
|
38 |
+
key = reps['key'].clone()
|
39 |
+
val = reps['val'].clone()
|
40 |
+
|
41 |
+
|
42 |
+
#split into multiple heads
|
43 |
+
split_qry = qry.view(*(qry.size()[:-1]+(num_heads,head_dim))).permute(0,2,1,3)
|
44 |
+
split_key = key.view(*(key.size()[:-1]+(num_heads,head_dim))).permute(0,2,1,3)
|
45 |
+
split_val = val.view(*(val.size()[:-1]+(num_heads,head_dim))).permute(0,2,1,3)
|
46 |
+
|
47 |
+
#calculate the attention matrix
|
48 |
+
attn_mat = F.softmax(split_qry@split_key.permute(0,1,3,2)/math.sqrt(head_dim),dim=-1)
|
49 |
+
|
50 |
+
z_rep_indiv = attn_mat@split_val
|
51 |
+
z_rep = z_rep_indiv.permute(0,2,1,3).reshape(*hidden.size())
|
52 |
+
|
53 |
+
hidden_post_attn_res = layer.attention.output.dense(z_rep)+hidden # residual connection
|
54 |
+
hidden_post_attn = layer.attention.output.LayerNorm(hidden_post_attn_res) # layer_norm
|
55 |
+
|
56 |
+
hidden_post_interm = layer.intermediate(hidden_post_attn) # massive feed forward
|
57 |
+
hidden_post_interm_res = layer.output.dense(hidden_post_interm)+hidden_post_attn # residual connection
|
58 |
+
new_hidden = layer.output.LayerNorm(hidden_post_interm_res) # layer_norm
|
59 |
+
return new_hidden
|
60 |
+
|
61 |
+
def SkeletonBertForMaskedLM(model,input_ids,interventions):
|
62 |
+
core_model = model.roberta
|
63 |
+
lm_head = model.lm_head
|
64 |
+
output_hidden = []
|
65 |
+
with torch.no_grad():
|
66 |
+
hidden = core_model.embeddings(input_ids)
|
67 |
+
output_hidden.append(hidden)
|
68 |
+
for layer_id in range(model.config.num_hidden_layers):
|
69 |
+
layer = core_model.encoder.layer[layer_id]
|
70 |
+
hidden = SkeletonRobertaLayer(layer_id,layer,hidden,interventions)
|
71 |
+
output_hidden.append(hidden)
|
72 |
+
logits = lm_head(hidden)
|
73 |
+
return {'logits':logits,'hidden_states':output_hidden}
|