taka-yamakoshi
commited on
Commit
•
10ced5b
1
Parent(s):
779710f
debug
Browse files
app.py
CHANGED
@@ -76,6 +76,7 @@ if __name__=='__main__':
|
|
76 |
if st.session_state['page_status']=='tokenized':
|
77 |
tokenizer,model = load_model()
|
78 |
mask_id = tokenizer('[MASK]').input_ids[1:-1][0]
|
|
|
79 |
sent_1 = st.session_state['sent_1']
|
80 |
sent_2 = st.session_state['sent_2']
|
81 |
if 'masked_pos_1' not in st.session_state:
|
@@ -85,29 +86,26 @@ if __name__=='__main__':
|
|
85 |
|
86 |
st.write('2. Select sites to mask out and click "Confirm"')
|
87 |
input_sent = tokenizer(sent_1).input_ids
|
88 |
-
decoded_sent = [tokenizer.decode([token]) for token in input_sent]
|
89 |
char_nums = [len(word)+2 for word in decoded_sent]
|
|
|
90 |
cols = st.columns(char_nums)
|
91 |
-
|
92 |
-
st.write(decoded_sent[0])
|
93 |
-
with cols[-1]:
|
94 |
-
st.write(decoded_sent[-1])
|
95 |
-
for word_id,(col,word) in enumerate(zip(cols[1:-1],decoded_sent[1:-1])):
|
96 |
with col:
|
97 |
if st.button(word,key=f'word_{word_id}'):
|
98 |
-
if
|
99 |
st.session_state['masked_pos_1'].append(word_id)
|
100 |
-
st.write(f'Masked words: {", ".join([decoded_sent[word_id
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
|
|
76 |
if st.session_state['page_status']=='tokenized':
|
77 |
tokenizer,model = load_model()
|
78 |
mask_id = tokenizer('[MASK]').input_ids[1:-1][0]
|
79 |
+
|
80 |
sent_1 = st.session_state['sent_1']
|
81 |
sent_2 = st.session_state['sent_2']
|
82 |
if 'masked_pos_1' not in st.session_state:
|
|
|
86 |
|
87 |
st.write('2. Select sites to mask out and click "Confirm"')
|
88 |
input_sent = tokenizer(sent_1).input_ids
|
89 |
+
decoded_sent = [tokenizer.decode([token]) for token in input_sent[1:-1]]
|
90 |
char_nums = [len(word)+2 for word in decoded_sent]
|
91 |
+
st.write(char_nums)
|
92 |
cols = st.columns(char_nums)
|
93 |
+
for word_id,(col,word) in enumerate(zip(cols,decoded_sent)):
|
|
|
|
|
|
|
|
|
94 |
with col:
|
95 |
if st.button(word,key=f'word_{word_id}'):
|
96 |
+
if word_id not in st.session_state['masked_pos_1']:
|
97 |
st.session_state['masked_pos_1'].append(word_id)
|
98 |
+
st.write(f'Masked words: {", ".join([decoded_sent[word_id] for word_id in np.sort(st.session_state["masked_pos_1"])])}')
|
99 |
+
|
100 |
+
|
101 |
+
if st.session_state['page_status']=='analysis':
|
102 |
+
sent_1 = st.sidebar.text_input('Sentence 1',value='It is better to play a prank on Samuel than Craig because he gets angry less often.',on_change=clear_data)
|
103 |
+
sent_2 = st.sidebar.text_input('Sentence 2',value='It is better to play a prank on Samuel than Craig because he gets angry more often.',on_change=clear_data)
|
104 |
+
input_ids_1 = tokenizer(sent_1).input_ids
|
105 |
+
input_ids_2 = tokenizer(sent_2).input_ids
|
106 |
+
input_ids = torch.tensor([input_ids_1,input_ids_2])
|
107 |
+
|
108 |
+
outputs = SkeletonAlbertForMaskedLM(model,input_ids,interventions = {0:{'lay':[(8,1,[0,1])]}})
|
109 |
+
logprobs = F.log_softmax(outputs['logits'], dim = -1)
|
110 |
+
preds = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1) for probs in logprobs[0]]
|
111 |
+
st.write([tokenizer.decode([token]) for token in preds])
|