JasonTPhillipsJr commited on
Commit
cbcad17
·
verified ·
1 Parent(s): 4c18d69

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -3
app.py CHANGED
@@ -83,9 +83,6 @@ def process_entity(batch, model, device):
83
  # Extract the [CLS] token embedding (first token)
84
  spaBERT_embedding = spaBERT_embedding[:, 0, :].detach() # [batch_size, hidden_size]
85
 
86
- #pivot_token_len = batch['pivot_token_len'].item()
87
- #pivot_embeddings = embeddings[:, :pivot_token_len, :]
88
-
89
  #return pivot_embeddings.cpu().numpy(), input_ids.cpu().numpy()
90
  return spaBERT_embedding, input_ids
91
 
@@ -114,10 +111,12 @@ def get_bert_embedding(review_text):
114
  bert_embedding = outputs.last_hidden_state[:, 0, :].detach() #CLS Token
115
  return bert_embedding
116
 
 
117
  def get_spaBert_embedding(entity):
118
  entity_index = entity_index_dict.get(entity.lower(), None)
119
  return spaBERT_embeddings[entity_index]
120
 
 
121
  def processSpatialEntities(review, nlp):
122
  doc = nlp(review)
123
  entity_spans = [(ent.start, ent.end, ent.text, ent.label_) for ent in doc.ents]
 
83
  # Extract the [CLS] token embedding (first token)
84
  spaBERT_embedding = spaBERT_embedding[:, 0, :].detach() # [batch_size, hidden_size]
85
 
 
 
 
86
  #return pivot_embeddings.cpu().numpy(), input_ids.cpu().numpy()
87
  return spaBERT_embedding, input_ids
88
 
 
111
  bert_embedding = outputs.last_hidden_state[:, 0, :].detach() #CLS Token
112
  return bert_embedding
113
 
114
+ #Get SpaBERT Embedding for geo-entity
115
  def get_spaBert_embedding(entity):
116
  entity_index = entity_index_dict.get(entity.lower(), None)
117
  return spaBERT_embeddings[entity_index]
118
 
119
+ #Go through each review, identify all geo-entities, then extract their SpaBERT embedings
120
  def processSpatialEntities(review, nlp):
121
  doc = nlp(review)
122
  entity_spans = [(ent.start, ent.end, ent.text, ent.label_) for ent in doc.ents]