AhmedSSabir
commited on
Commit
•
5ea9bf1
1
Parent(s):
0b41e50
Update app.py
Browse files
app.py
CHANGED
@@ -20,8 +20,8 @@ from sentence_transformers import SentenceTransformer, util
|
|
20 |
|
21 |
#model_sts = gr.Interface.load('huggingface/sentence-transformers/stsb-distilbert-base')
|
22 |
|
23 |
-
|
24 |
-
model_sts = SentenceTransformer('roberta-large-nli-stsb-mean-tokens')
|
25 |
#batch_size = 1
|
26 |
#scorer = LMScorer.from_pretrained('gpt2' , device=device, batch_size=batch_size)
|
27 |
|
@@ -72,11 +72,7 @@ def cloze_prob(text):
|
|
72 |
text_list = text.split()
|
73 |
stem = ' '.join(text_list[:-1])
|
74 |
stem_encoding = tokenizer.encode(stem)
|
75 |
-
# cw_encoding is just the difference between whole_text_encoding and stem_encoding
|
76 |
-
# note: this might not correspond exactly to the word itself
|
77 |
cw_encoding = whole_text_encoding[len(stem_encoding):]
|
78 |
-
# Run the entire sentence through the model. Then go "back in time" to look at what the model predicted for each token, starting at the stem.
|
79 |
-
# Put the whole text encoding into a tensor, and get the model's comprehensive output
|
80 |
tokens_tensor = torch.tensor([whole_text_encoding])
|
81 |
|
82 |
with torch.no_grad():
|
@@ -93,10 +89,7 @@ def cloze_prob(text):
|
|
93 |
|
94 |
logprobs.append(np.log(softmax(raw_output)))
|
95 |
|
96 |
-
|
97 |
-
# [ [0.412, 0.001, ... ] ,[0.213, 0.004, ...], [0.002,0.001, 0.93 ...]]
|
98 |
-
# Then for the i'th token we want to find its associated probability
|
99 |
-
# this is just: raw_probabilities[i][token_index]
|
100 |
conditional_probs = []
|
101 |
for cw,prob in zip(cw_encoding,logprobs):
|
102 |
conditional_probs.append(prob[cw])
|
|
|
20 |
|
21 |
#model_sts = gr.Interface.load('huggingface/sentence-transformers/stsb-distilbert-base')
|
22 |
|
23 |
+
model_sts = SentenceTransformer('stsb-distilbert-base')
|
24 |
+
#model_sts = SentenceTransformer('roberta-large-nli-stsb-mean-tokens')
|
25 |
#batch_size = 1
|
26 |
#scorer = LMScorer.from_pretrained('gpt2' , device=device, batch_size=batch_size)
|
27 |
|
|
|
72 |
text_list = text.split()
|
73 |
stem = ' '.join(text_list[:-1])
|
74 |
stem_encoding = tokenizer.encode(stem)
|
|
|
|
|
75 |
cw_encoding = whole_text_encoding[len(stem_encoding):]
|
|
|
|
|
76 |
tokens_tensor = torch.tensor([whole_text_encoding])
|
77 |
|
78 |
with torch.no_grad():
|
|
|
89 |
|
90 |
logprobs.append(np.log(softmax(raw_output)))
|
91 |
|
92 |
+
|
|
|
|
|
|
|
93 |
conditional_probs = []
|
94 |
for cw,prob in zip(cw_encoding,logprobs):
|
95 |
conditional_probs.append(prob[cw])
|