Spaces:
Runtime error
Runtime error
taka-yamakoshi
commited on
Commit
·
ce466e4
1
Parent(s):
d1e605d
check masking
Browse files
app.py
CHANGED
@@ -111,16 +111,39 @@ def show_annotated_sentence(sent,option_locs=[],mask_locs=[]):
|
|
111 |
suffix = '</span></p>'
|
112 |
return st.markdown(prefix + disp + suffix, unsafe_allow_html = True)
|
113 |
|
114 |
-
def show_instruction(sent):
|
115 |
-
disp_style = '"font-family:san serif; color:Black; font-size:
|
116 |
prefix = f'<p style={disp_style}><span style="font-weight:bold">'
|
117 |
suffix = '</span></p>'
|
118 |
return st.markdown(prefix + sent + suffix, unsafe_allow_html = True)
|
119 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
if __name__=='__main__':
|
121 |
wide_setup()
|
122 |
load_css('style.css')
|
123 |
tokenizer,model = load_model()
|
|
|
124 |
mask_id = tokenizer('[MASK]').input_ids[1:-1][0]
|
125 |
|
126 |
main_area = st.empty()
|
@@ -171,16 +194,37 @@ if __name__=='__main__':
|
|
171 |
option_locs=st.session_state['option_locs_2'],
|
172 |
mask_locs=st.session_state['mask_locs_2'])
|
173 |
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
logprobs = F.log_softmax(outputs['logits'], dim = -1)
|
|
|
|
|
184 |
preds_0 = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1) for probs in logprobs[0][1:-1]]
|
185 |
preds_1 = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1) for probs in logprobs[1][1:-1]]
|
186 |
st.write([tokenizer.decode([token]) for token in preds_0])
|
|
|
111 |
suffix = '</span></p>'
|
112 |
return st.markdown(prefix + disp + suffix, unsafe_allow_html = True)
|
113 |
|
114 |
+
def show_instruction(sent,fontsize=20):
|
115 |
+
disp_style = f'"font-family:san serif; color:Black; font-size: {fontsize}px"'
|
116 |
prefix = f'<p style={disp_style}><span style="font-weight:bold">'
|
117 |
suffix = '</span></p>'
|
118 |
return st.markdown(prefix + sent + suffix, unsafe_allow_html = True)
|
119 |
|
120 |
+
def create_interventions(token_id,interv_type,num_layers,num_heads):
|
121 |
+
interventions = {}
|
122 |
+
for layer_id in range(num_layers):
|
123 |
+
interventions[layer_id] = {}
|
124 |
+
if interv_type == 'all':
|
125 |
+
for rep in ['lay','qry','key','val']:
|
126 |
+
interventions[layer_id][rep] = [(head_id,token_id,[head_id,head_id+num_heads]) for head_id in range(num_heads)]
|
127 |
+
else:
|
128 |
+
interventions[layer_id][interv_type] = [(head_id,token_id,[head_id,head_id+num_heads]) for head_id in range(num_heads)]
|
129 |
+
return interventions
|
130 |
+
|
131 |
+
def separate_options(option_locs):
|
132 |
+
assert np.sum(np.diff(option_locs)>1)==1
|
133 |
+
sep = list(np.diff(option_locs)>1).index(1)+1
|
134 |
+
option_1_locs, option_2_locs = option_locs[:sep], option_locs[sep:]
|
135 |
+
assert np.all(np.diff(option_1_locs)==1) and np.all(np.diff(option_2_loc)==1)
|
136 |
+
return option_1_locs, option_2_locs
|
137 |
+
|
138 |
+
def mask_out(input_ids,pron_locs,option_locs,mask_id):
|
139 |
+
assert np.all(np.diff(pron_locs)==1)
|
140 |
+
return input_ids[:pron_locs[0]] + [mask_id for _ in range(len(option_locs))] + input_ids[pron_locs[-1]+1:]
|
141 |
+
|
142 |
if __name__=='__main__':
|
143 |
wide_setup()
|
144 |
load_css('style.css')
|
145 |
tokenizer,model = load_model()
|
146 |
+
num_layers, num_heads = 12, 64
|
147 |
mask_id = tokenizer('[MASK]').input_ids[1:-1][0]
|
148 |
|
149 |
main_area = st.empty()
|
|
|
194 |
option_locs=st.session_state['option_locs_2'],
|
195 |
mask_locs=st.session_state['mask_locs_2'])
|
196 |
|
197 |
+
option_1_locs, option_2_locs = {}, {}
|
198 |
+
pron_id = {}
|
199 |
+
input_ids_dict = {}
|
200 |
+
masked_ids_option_1 = {}
|
201 |
+
masked_ids_option_2 = {}
|
202 |
+
for sent_id in range(2):
|
203 |
+
option_1_locs[f'sent_{sent_id+1}'], option_2_locs[f'sent_{sent_id+1}'] = separate_options(st.session_state[f'option_locs_{sent_id}'])
|
204 |
+
pron_locs[f'sent_{sent_id+1}'] = st.session_state[f'mask_locs_{sent_id+1}']
|
205 |
+
input_ids_dict[f'sent_{sent_id+1}'] = tokenizer(st.session_state[f'sent_{sent_id+1}']).input_ids
|
206 |
+
|
207 |
+
masked_ids_option_1[f'sent_{sent_id+1}'] = mask_out(input_ids_dict[f'sent_{sent_id+1}'],
|
208 |
+
pron_locs[f'sent_{sent_id+1}'],
|
209 |
+
option_1_locs[f'sent_{sent_id+1}'],mask_id)
|
210 |
+
masked_ids_option_2[f'sent_{sent_id+1}'] = mask_out(input_ids_dict[f'sent_{sent_id+1}'],
|
211 |
+
pron_locs[f'sent_{sent_id+1}'],
|
212 |
+
option_2_locs[f'sent_{sent_id+1}'],mask_id)
|
213 |
+
|
214 |
+
for token_ids in [masked_ids_option_1['sent_1'],masked_ids_option_1['sent_2'],masked_ids_option_2['sent_1'],masked_ids_option_2['sent_2']]:
|
215 |
+
st.write(' '.join([tokenizer.decode([token]) for toke in token_ids]))
|
216 |
+
|
217 |
+
if st.session_state['page_status'] == 'finish_debug':
|
218 |
+
try:
|
219 |
+
assert len(input_ids_1) == len(input_ids_2)
|
220 |
+
except AssertionError:
|
221 |
+
show_instruction('Please make sure the number of tokens match between Sentence 1 and Sentence 2', fontsize=12)
|
222 |
+
input_ids = torch.tensor([*[input_ids_1 for _ in range(num_heads)],*[input_ids_2 for _ in range(num_heads)]])
|
223 |
+
interventions = create_interventions(16,'all',num_layers=num_layers,num_heads=num_heads)
|
224 |
+
outputs = SkeletonAlbertForMaskedLM(model,input_ids,interventions=interventions)
|
225 |
logprobs = F.log_softmax(outputs['logits'], dim = -1)
|
226 |
+
|
227 |
+
|
228 |
preds_0 = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1) for probs in logprobs[0][1:-1]]
|
229 |
preds_1 = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1) for probs in logprobs[1][1:-1]]
|
230 |
st.write([tokenizer.decode([token]) for token in preds_0])
|