taka-yamakoshi
commited on
Commit
•
77d2a77
1
Parent(s):
3052d18
debug
Browse files
app.py
CHANGED
@@ -117,15 +117,13 @@ def show_instruction(sent,fontsize=20):
|
|
117 |
suffix = '</span></p>'
|
118 |
return st.markdown(prefix + sent + suffix, unsafe_allow_html = True)
|
119 |
|
120 |
-
def create_interventions(token_id,
|
121 |
interventions = {}
|
122 |
-
for
|
123 |
-
|
124 |
-
|
125 |
-
for rep in ['lay','qry','key','val']:
|
126 |
-
interventions[layer_id][rep] = [(head_id,token_id,[head_id,head_id+num_heads]) for head_id in range(num_heads)]
|
127 |
else:
|
128 |
-
interventions[
|
129 |
return interventions
|
130 |
|
131 |
def separate_options(option_locs):
|
@@ -195,7 +193,7 @@ if __name__=='__main__':
|
|
195 |
mask_locs=st.session_state['mask_locs_2'])
|
196 |
|
197 |
option_1_locs, option_2_locs = {}, {}
|
198 |
-
|
199 |
input_ids_dict = {}
|
200 |
masked_ids_option_1 = {}
|
201 |
masked_ids_option_2 = {}
|
@@ -215,14 +213,15 @@ if __name__=='__main__':
|
|
215 |
st.write(' '.join([tokenizer.decode([token]) for toke in token_ids]))
|
216 |
|
217 |
if st.session_state['page_status'] == 'finish_debug':
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
|
|
226 |
|
227 |
|
228 |
preds_0 = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1) for probs in logprobs[0][1:-1]]
|
|
|
117 |
suffix = '</span></p>'
|
118 |
return st.markdown(prefix + sent + suffix, unsafe_allow_html = True)
|
119 |
|
120 |
+
def create_interventions(token_id,interv_types,num_heads):
|
121 |
interventions = {}
|
122 |
+
for rep in ['lay','qry','key','val']:
|
123 |
+
if rep in interv_types:
|
124 |
+
interventions[rep] = [(head_id,token_id,[head_id,head_id+num_heads]) for head_id in range(num_heads)]
|
|
|
|
|
125 |
else:
|
126 |
+
interventions[rep] = []
|
127 |
return interventions
|
128 |
|
129 |
def separate_options(option_locs):
|
|
|
193 |
mask_locs=st.session_state['mask_locs_2'])
|
194 |
|
195 |
option_1_locs, option_2_locs = {}, {}
|
196 |
+
pron_locs = {}
|
197 |
input_ids_dict = {}
|
198 |
masked_ids_option_1 = {}
|
199 |
masked_ids_option_2 = {}
|
|
|
213 |
st.write(' '.join([tokenizer.decode([token]) for toke in token_ids]))
|
214 |
|
215 |
if st.session_state['page_status'] == 'finish_debug':
|
216 |
+
for layer_id in range(num_layers):
|
217 |
+
interventions = [create_interventions(16,['lay','qry','key','val'],num_heads) if i==layer_id else {'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
|
218 |
+
for masked_ids in [masked_ids_option_1, masked_ids_option_2]:
|
219 |
+
input_ids = torch.tensor([
|
220 |
+
*[masked_ids['sent_1'] for _ in range(num_heads)],
|
221 |
+
*[masked_ids['sent_2'] for _ in range(num_heads)]
|
222 |
+
])
|
223 |
+
outputs = SkeletonAlbertForMaskedLM(model,input_ids,interventions=interventions)
|
224 |
+
logprobs = F.log_softmax(outputs['logits'], dim = -1)
|
225 |
|
226 |
|
227 |
preds_0 = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1) for probs in logprobs[0][1:-1]]
|