Spaces:
Sleeping
Sleeping
JasonTPhillipsJr
commited on
Commit
•
7837592
1
Parent(s):
52ad8f8
Update app.py
Browse files
app.py
CHANGED
@@ -116,14 +116,14 @@ def get_bert_embedding(review_text):
|
|
116 |
|
117 |
|
118 |
#Get SpaBERT Embedding for geo-entity
|
119 |
-
def get_spaBert_embedding(entity):
|
120 |
entity_index = entity_index_dict.get(entity.lower(), None)
|
121 |
if entity_index is None:
|
122 |
if(dev_mode == True):
|
123 |
st.write("Got Bert embedding for: ", entity)
|
124 |
return get_bert_embedding(entity) #Fallback in-case SpaBERT could not resolve entity to retrieve embedding. Rare-cases only.
|
125 |
else:
|
126 |
-
|
127 |
if(dev_mode == True):
|
128 |
st.write("Got SpaBert embedding for: ", entity)
|
129 |
return spaBERT_embeddings[entity_index]
|
@@ -134,11 +134,12 @@ def processSpatialEntities(review, nlp):
|
|
134 |
doc = nlp(review)
|
135 |
entity_spans = [(ent.start, ent.end, ent.text, ent.label_) for ent in doc.ents]
|
136 |
token_embeddings = []
|
|
|
137 |
|
138 |
# Iterate over each entity span and process only geo entities
|
139 |
for start, end, text, label in entity_spans:
|
140 |
if label in ['FAC', 'ORG', 'LOC', 'GPE']: # Filter to geo-entities
|
141 |
-
spaBert_emb = get_spaBert_embedding(text)
|
142 |
token_embeddings.append(spaBert_emb)
|
143 |
if(dev_mode == True):
|
144 |
st.write("Geo-Entity Found in review: ", text)
|
@@ -146,7 +147,7 @@ def processSpatialEntities(review, nlp):
|
|
146 |
token_embeddings = torch.stack(token_embeddings, dim=0)
|
147 |
processed_embedding = token_embeddings.mean(dim=0) # Shape: (768)
|
148 |
#processed_embedding = processed_embedding.unsqueeze(0) # Shape: (1, 768)
|
149 |
-
return processed_embedding
|
150 |
|
151 |
|
152 |
#Initialize discriminator module
|
@@ -262,7 +263,7 @@ selected_review = example_reviews[selected_key]
|
|
262 |
if st.button("Process Review"):
|
263 |
if selected_review.strip():
|
264 |
bert_embedding = get_bert_embedding(selected_review)
|
265 |
-
spaBert_embedding = processSpatialEntities(selected_review,nlp)
|
266 |
combined_embedding = torch.cat((bert_embedding,spaBert_embedding),dim=-1)
|
267 |
|
268 |
if(dev_mode == True):
|
@@ -290,6 +291,10 @@ if st.button("Process Review"):
|
|
290 |
# Display the highlighted text with HTML support
|
291 |
st.markdown(highlighted_text, unsafe_allow_html=True)
|
292 |
|
|
|
|
|
|
|
|
|
293 |
#Display the models prediction
|
294 |
if(prediction == 0):
|
295 |
st.write("Prediction: Not Spam")
|
|
|
116 |
|
117 |
|
118 |
#Get SpaBERT Embedding for geo-entity
|
119 |
+
def get_spaBert_embedding(entity,current_pseudo_sentences):
|
120 |
entity_index = entity_index_dict.get(entity.lower(), None)
|
121 |
if entity_index is None:
|
122 |
if(dev_mode == True):
|
123 |
st.write("Got Bert embedding for: ", entity)
|
124 |
return get_bert_embedding(entity) #Fallback in-case SpaBERT could not resolve entity to retrieve embedding. Rare-cases only.
|
125 |
else:
|
126 |
+
current_pseudo_sentences.append(pseudo_sentences[entity_index])
|
127 |
if(dev_mode == True):
|
128 |
st.write("Got SpaBert embedding for: ", entity)
|
129 |
return spaBERT_embeddings[entity_index]
|
|
|
134 |
doc = nlp(review)
|
135 |
entity_spans = [(ent.start, ent.end, ent.text, ent.label_) for ent in doc.ents]
|
136 |
token_embeddings = []
|
137 |
+
current_pseudo_sentences = []
|
138 |
|
139 |
# Iterate over each entity span and process only geo entities
|
140 |
for start, end, text, label in entity_spans:
|
141 |
if label in ['FAC', 'ORG', 'LOC', 'GPE']: # Filter to geo-entities
|
142 |
+
spaBert_emb = get_spaBert_embedding(text,current_pseudo_sentences)
|
143 |
token_embeddings.append(spaBert_emb)
|
144 |
if(dev_mode == True):
|
145 |
st.write("Geo-Entity Found in review: ", text)
|
|
|
147 |
token_embeddings = torch.stack(token_embeddings, dim=0)
|
148 |
processed_embedding = token_embeddings.mean(dim=0) # Shape: (768)
|
149 |
#processed_embedding = processed_embedding.unsqueeze(0) # Shape: (1, 768)
|
150 |
+
return processed_embedding,current_pseudo_sentences
|
151 |
|
152 |
|
153 |
#Initialize discriminator module
|
|
|
263 |
if st.button("Process Review"):
|
264 |
if selected_review.strip():
|
265 |
bert_embedding = get_bert_embedding(selected_review)
|
266 |
+
spaBert_embedding, current_pseudo_sentences = processSpatialEntities(selected_review,nlp)
|
267 |
combined_embedding = torch.cat((bert_embedding,spaBert_embedding),dim=-1)
|
268 |
|
269 |
if(dev_mode == True):
|
|
|
291 |
# Display the highlighted text with HTML support
|
292 |
st.markdown(highlighted_text, unsafe_allow_html=True)
|
293 |
|
294 |
+
#Display pseudo sentences found
|
295 |
+
for sentence in current_pseudo_sentences:
|
296 |
+
st.write("Pseudo-Sentence: ", sentence)
|
297 |
+
|
298 |
#Display the models prediction
|
299 |
if(prediction == 0):
|
300 |
st.write("Prediction: Not Spam")
|