taka-yamakoshi commited on
Commit
10ced5b
1 Parent(s): 779710f
Files changed (1) hide show
  1. app.py +19 -21
app.py CHANGED
@@ -76,6 +76,7 @@ if __name__=='__main__':
76
  if st.session_state['page_status']=='tokenized':
77
  tokenizer,model = load_model()
78
  mask_id = tokenizer('[MASK]').input_ids[1:-1][0]
 
79
  sent_1 = st.session_state['sent_1']
80
  sent_2 = st.session_state['sent_2']
81
  if 'masked_pos_1' not in st.session_state:
@@ -85,29 +86,26 @@ if __name__=='__main__':
85
 
86
  st.write('2. Select sites to mask out and click "Confirm"')
87
  input_sent = tokenizer(sent_1).input_ids
88
- decoded_sent = [tokenizer.decode([token]) for token in input_sent]
89
  char_nums = [len(word)+2 for word in decoded_sent]
 
90
  cols = st.columns(char_nums)
91
- with cols[0]:
92
- st.write(decoded_sent[0])
93
- with cols[-1]:
94
- st.write(decoded_sent[-1])
95
- for word_id,(col,word) in enumerate(zip(cols[1:-1],decoded_sent[1:-1])):
96
  with col:
97
  if st.button(word,key=f'word_{word_id}'):
98
- if word_in not in st.session_state['masked_pos_1']:
99
  st.session_state['masked_pos_1'].append(word_id)
100
- st.write(f'Masked words: {", ".join([decoded_sent[word_id+1] for word_id in np.sort(st.session_state["masked_pos_1"])])}')
101
-
102
- '''
103
- sent_1 = st.sidebar.text_input('Sentence 1',value='It is better to play a prank on Samuel than Craig because he gets angry less often.',on_change=clear_data)
104
- sent_2 = st.sidebar.text_input('Sentence 2',value='It is better to play a prank on Samuel than Craig because he gets angry more often.',on_change=clear_data)
105
- input_ids_1 = tokenizer(sent_1).input_ids
106
- input_ids_2 = tokenizer(sent_2).input_ids
107
- input_ids = torch.tensor([input_ids_1,input_ids_2])
108
-
109
- outputs = SkeletonAlbertForMaskedLM(model,input_ids,interventions = {0:{'lay':[(8,1,[0,1])]}})
110
- logprobs = F.log_softmax(outputs['logits'], dim = -1)
111
- preds = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1) for probs in logprobs[0]]
112
- st.write([tokenizer.decode([token]) for token in preds])
113
- '''
 
76
  if st.session_state['page_status']=='tokenized':
77
  tokenizer,model = load_model()
78
  mask_id = tokenizer('[MASK]').input_ids[1:-1][0]
79
+
80
  sent_1 = st.session_state['sent_1']
81
  sent_2 = st.session_state['sent_2']
82
  if 'masked_pos_1' not in st.session_state:
 
86
 
87
  st.write('2. Select sites to mask out and click "Confirm"')
88
  input_sent = tokenizer(sent_1).input_ids
89
+ decoded_sent = [tokenizer.decode([token]) for token in input_sent[1:-1]]
90
  char_nums = [len(word)+2 for word in decoded_sent]
91
+ st.write(char_nums)
92
  cols = st.columns(char_nums)
93
+ for word_id,(col,word) in enumerate(zip(cols,decoded_sent)):
 
 
 
 
94
  with col:
95
  if st.button(word,key=f'word_{word_id}'):
96
+ if word_id not in st.session_state['masked_pos_1']:
97
  st.session_state['masked_pos_1'].append(word_id)
98
+ st.write(f'Masked words: {", ".join([decoded_sent[word_id] for word_id in np.sort(st.session_state["masked_pos_1"])])}')
99
+
100
+
101
+ if st.session_state['page_status']=='analysis':
102
+ sent_1 = st.sidebar.text_input('Sentence 1',value='It is better to play a prank on Samuel than Craig because he gets angry less often.',on_change=clear_data)
103
+ sent_2 = st.sidebar.text_input('Sentence 2',value='It is better to play a prank on Samuel than Craig because he gets angry more often.',on_change=clear_data)
104
+ input_ids_1 = tokenizer(sent_1).input_ids
105
+ input_ids_2 = tokenizer(sent_2).input_ids
106
+ input_ids = torch.tensor([input_ids_1,input_ids_2])
107
+
108
+ outputs = SkeletonAlbertForMaskedLM(model,input_ids,interventions = {0:{'lay':[(8,1,[0,1])]}})
109
+ logprobs = F.log_softmax(outputs['logits'], dim = -1)
110
+ preds = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1) for probs in logprobs[0]]
111
+ st.write([tokenizer.decode([token]) for token in preds])