taka-yamakoshi
commited on
Commit
•
6800334
1
Parent(s):
7c56f41
bring back run
Browse files
app.py
CHANGED
@@ -142,7 +142,7 @@ def mask_out(input_ids,pron_locs,option_locs,mask_id):
|
|
142 |
# note annotations are shifted by 1 because special tokens were omitted
|
143 |
return input_ids[:pron_locs[0]+1] + [mask_id for _ in range(len(option_locs))] + input_ids[pron_locs[-1]+2:]
|
144 |
|
145 |
-
|
146 |
def run_intervention(interventions,batch_size,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs):
|
147 |
probs = []
|
148 |
for masked_ids, option_tokens in zip([masked_ids_option_1, masked_ids_option_2],[option_1_tokens,option_2_tokens]):
|
@@ -159,7 +159,6 @@ def run_intervention(interventions,batch_size,model,masked_ids_option_1,masked_i
|
|
159 |
probs = np.array(probs)
|
160 |
assert probs.shape[0]==2 and probs.shape[1]==2 and probs.shape[2]==batch_size
|
161 |
return probs
|
162 |
-
'''
|
163 |
|
164 |
if __name__=='__main__':
|
165 |
wide_setup()
|
|
|
142 |
# note annotations are shifted by 1 because special tokens were omitted
|
143 |
return input_ids[:pron_locs[0]+1] + [mask_id for _ in range(len(option_locs))] + input_ids[pron_locs[-1]+2:]
|
144 |
|
145 |
+
|
146 |
def run_intervention(interventions,batch_size,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs):
|
147 |
probs = []
|
148 |
for masked_ids, option_tokens in zip([masked_ids_option_1, masked_ids_option_2],[option_1_tokens,option_2_tokens]):
|
|
|
159 |
probs = np.array(probs)
|
160 |
assert probs.shape[0]==2 and probs.shape[1]==2 and probs.shape[2]==batch_size
|
161 |
return probs
|
|
|
162 |
|
163 |
if __name__=='__main__':
|
164 |
wide_setup()
|