taka-yamakoshi
commited on
Commit
•
2092dd1
1
Parent(s):
ddf537a
fix
Browse files
app.py
CHANGED
@@ -54,13 +54,13 @@ if __name__=='__main__':
|
|
54 |
tokenizer,model = load_model()
|
55 |
mask_id = tokenizer('[MASK]').input_ids[1:-1][0]
|
56 |
|
57 |
-
sent_1 = st.sidebar.text_input('Sentence 1',on_change=clear_data)
|
58 |
-
sent_2 = st.sidebar.text_input('Sentence 2',on_change=clear_data)
|
59 |
input_ids_1 = tokenizer(sent_1).input_ids
|
60 |
input_ids_2 = tokenizer(sent_2).input_ids
|
61 |
input_ids = np.array([input_ids_1,input_ids_2])
|
62 |
|
63 |
-
outputs = model(input_ids, interv_type='swap', interv_dict = {0
|
64 |
logprobs = jax.nn.log_softmax(outputs.logits, axis = -1)
|
65 |
preds = [np.random.choice(np.arange(len(probs)),p=np.exp(probs)/np.sum(np.exp(probs))) for probs in logprobs[0]]
|
66 |
st.write([tokenizer.decode([token]) for token in preds])
|
|
|
54 |
tokenizer,model = load_model()
|
55 |
mask_id = tokenizer('[MASK]').input_ids[1:-1][0]
|
56 |
|
57 |
+
sent_1 = st.sidebar.text_input('Sentence 1',value='It is better to play a prank on Samuel than Craig because he gets angry less often.',on_change=clear_data)
|
58 |
+
sent_2 = st.sidebar.text_input('Sentence 2',value='It is better to play a prank on Samuel than Craig because he gets angry more often.',on_change=clear_data)
|
59 |
input_ids_1 = tokenizer(sent_1).input_ids
|
60 |
input_ids_2 = tokenizer(sent_2).input_ids
|
61 |
input_ids = np.array([input_ids_1,input_ids_2])
|
62 |
|
63 |
+
outputs = model(input_ids, interv_type='swap', interv_dict = {0:{'lay':[(8,1,[0,1])]}})
|
64 |
logprobs = jax.nn.log_softmax(outputs.logits, axis = -1)
|
65 |
preds = [np.random.choice(np.arange(len(probs)),p=np.exp(probs)/np.sum(np.exp(probs))) for probs in logprobs[0]]
|
66 |
st.write([tokenizer.decode([token]) for token in preds])
|