JasonTPhillipsJr commited on
Commit
7837592
1 Parent(s): 52ad8f8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -5
app.py CHANGED
@@ -116,14 +116,14 @@ def get_bert_embedding(review_text):
116
 
117
 
118
  #Get SpaBERT Embedding for geo-entity
119
- def get_spaBert_embedding(entity):
120
  entity_index = entity_index_dict.get(entity.lower(), None)
121
  if entity_index is None:
122
  if(dev_mode == True):
123
  st.write("Got Bert embedding for: ", entity)
124
  return get_bert_embedding(entity) #Fallback in-case SpaBERT could not resolve entity to retrieve embedding. Rare-cases only.
125
  else:
126
- st.write("Pseudo Sentnece:",pseudo_sentences[entity_index])
127
  if(dev_mode == True):
128
  st.write("Got SpaBert embedding for: ", entity)
129
  return spaBERT_embeddings[entity_index]
@@ -134,11 +134,12 @@ def processSpatialEntities(review, nlp):
134
  doc = nlp(review)
135
  entity_spans = [(ent.start, ent.end, ent.text, ent.label_) for ent in doc.ents]
136
  token_embeddings = []
 
137
 
138
  # Iterate over each entity span and process only geo entities
139
  for start, end, text, label in entity_spans:
140
  if label in ['FAC', 'ORG', 'LOC', 'GPE']: # Filter to geo-entities
141
- spaBert_emb = get_spaBert_embedding(text)
142
  token_embeddings.append(spaBert_emb)
143
  if(dev_mode == True):
144
  st.write("Geo-Entity Found in review: ", text)
@@ -146,7 +147,7 @@ def processSpatialEntities(review, nlp):
146
  token_embeddings = torch.stack(token_embeddings, dim=0)
147
  processed_embedding = token_embeddings.mean(dim=0) # Shape: (768)
148
  #processed_embedding = processed_embedding.unsqueeze(0) # Shape: (1, 768)
149
- return processed_embedding
150
 
151
 
152
  #Initialize discriminator module
@@ -262,7 +263,7 @@ selected_review = example_reviews[selected_key]
262
  if st.button("Process Review"):
263
  if selected_review.strip():
264
  bert_embedding = get_bert_embedding(selected_review)
265
- spaBert_embedding = processSpatialEntities(selected_review,nlp)
266
  combined_embedding = torch.cat((bert_embedding,spaBert_embedding),dim=-1)
267
 
268
  if(dev_mode == True):
@@ -290,6 +291,10 @@ if st.button("Process Review"):
290
  # Display the highlighted text with HTML support
291
  st.markdown(highlighted_text, unsafe_allow_html=True)
292
 
 
 
 
 
293
  #Display the models prediction
294
  if(prediction == 0):
295
  st.write("Prediction: Not Spam")
 
116
 
117
 
118
  #Get SpaBERT Embedding for geo-entity
119
+ def get_spaBert_embedding(entity,current_pseudo_sentences):
120
  entity_index = entity_index_dict.get(entity.lower(), None)
121
  if entity_index is None:
122
  if(dev_mode == True):
123
  st.write("Got Bert embedding for: ", entity)
124
  return get_bert_embedding(entity) #Fallback in-case SpaBERT could not resolve entity to retrieve embedding. Rare-cases only.
125
  else:
126
+ current_pseudo_sentences.append(pseudo_sentences[entity_index])
127
  if(dev_mode == True):
128
  st.write("Got SpaBert embedding for: ", entity)
129
  return spaBERT_embeddings[entity_index]
 
134
  doc = nlp(review)
135
  entity_spans = [(ent.start, ent.end, ent.text, ent.label_) for ent in doc.ents]
136
  token_embeddings = []
137
+ current_pseudo_sentences = []
138
 
139
  # Iterate over each entity span and process only geo entities
140
  for start, end, text, label in entity_spans:
141
  if label in ['FAC', 'ORG', 'LOC', 'GPE']: # Filter to geo-entities
142
+ spaBert_emb = get_spaBert_embedding(text,current_pseudo_sentences)
143
  token_embeddings.append(spaBert_emb)
144
  if(dev_mode == True):
145
  st.write("Geo-Entity Found in review: ", text)
 
147
  token_embeddings = torch.stack(token_embeddings, dim=0)
148
  processed_embedding = token_embeddings.mean(dim=0) # Shape: (768)
149
  #processed_embedding = processed_embedding.unsqueeze(0) # Shape: (1, 768)
150
+ return processed_embedding,current_pseudo_sentences
151
 
152
 
153
  #Initialize discriminator module
 
263
  if st.button("Process Review"):
264
  if selected_review.strip():
265
  bert_embedding = get_bert_embedding(selected_review)
266
+ spaBert_embedding, current_pseudo_sentences = processSpatialEntities(selected_review,nlp)
267
  combined_embedding = torch.cat((bert_embedding,spaBert_embedding),dim=-1)
268
 
269
  if(dev_mode == True):
 
291
  # Display the highlighted text with HTML support
292
  st.markdown(highlighted_text, unsafe_allow_html=True)
293
 
294
+ #Display pseudo sentences found
295
+ for sentence in current_pseudo_sentences:
296
+ st.write("Pseudo-Sentence: ", sentence)
297
+
298
  #Display the models prediction
299
  if(prediction == 0):
300
  st.write("Prediction: Not Spam")