JasonTPhillipsJr commited on
Commit
a74fa0d
1 Parent(s): 5914cea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -1
app.py CHANGED
@@ -37,7 +37,7 @@ spaBERT_model.load_state_dict(pre_trained_model, strict=False)
37
  spaBERT_model.to(device)
38
  spaBERT_model.eval()
39
 
40
- # Load data using SpatialDataset
41
  spatialDataset = PbfMapDataset(data_file_path = data_file_path,
42
  tokenizer = bert_tokenizer,
43
  max_token_len = 256, #Originally 300
@@ -51,6 +51,48 @@ spatialDataset = PbfMapDataset(data_file_path = data_file_path,
51
 
52
  data_loader = DataLoader(spatialDataset, batch_size=1, num_workers=0, shuffle=False, pin_memory=False, drop_last=False) #issue needs to be fixed with num_workers not stopping after finished
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  #Get BERT Embedding for review
55
  def get_bert_embedding(review_text):
56
  #tokenize review
 
37
  spaBERT_model.to(device)
38
  spaBERT_model.eval()
39
 
40
+ #Load data using SpatialDataset
41
  spatialDataset = PbfMapDataset(data_file_path = data_file_path,
42
  tokenizer = bert_tokenizer,
43
  max_token_len = 256, #Originally 300
 
51
 
52
  data_loader = DataLoader(spatialDataset, batch_size=1, num_workers=0, shuffle=False, pin_memory=False, drop_last=False) #issue needs to be fixed with num_workers not stopping after finished
53
 
54
+ #Pre-aquire the SpaBERT embeddings for all geo-entities within our dataset
55
+ def process_entity(batch, model, device):
56
+ input_ids = batch['masked_input'].to(device)
57
+ attention_mask = batch['attention_mask'].to(device)
58
+ position_list_x = batch['norm_lng_list'].to(device)
59
+ position_list_y = batch['norm_lat_list'].to(device)
60
+ sent_position_ids = batch['sent_position_ids'].to(device)
61
+ pseudo_sentence = batch['pseudo_sentence'].to(device)
62
+
63
+ # Convert tensor to list of token IDs, and decode them into a readable sentence
64
+ pseudo_sentence_decoded = tokenizer.decode(pseudo_sentence[0].tolist(), skip_special_tokens=False)
65
+
66
+ with torch.no_grad():
67
+ outputs = spaBERT_model(#input_ids=input_ids,
68
+ input_ids=pseudo_sentence,
69
+ attention_mask=attention_mask,
70
+ sent_position_ids=sent_position_ids,
71
+ position_list_x=position_list_x,
72
+ position_list_y=position_list_y)
73
+ #NOTE: we are ommitting the pseudo_sentence here. Verify that this is correct
74
+
75
+ embeddings = outputs.hidden_states[-1].to(device)
76
+
77
+ # Extract the [CLS] token embedding (first token)
78
+ embedding = embeddings[:, 0, :].detach() # [batch_size, hidden_size]
79
+
80
+ #pivot_token_len = batch['pivot_token_len'].item()
81
+ #pivot_embeddings = embeddings[:, :pivot_token_len, :]
82
+
83
+ #return pivot_embeddings.cpu().numpy(), input_ids.cpu().numpy()
84
+ return embedding.cpu().numpy(), input_ids.cpu().numpy()
85
+
86
+ all_embeddings = []
87
+ for batch in (data_loader):
88
+ embeddings, input_ids = process_entity(batch, model, device)
89
+ all_embeddings.append(embeddings)
90
+
91
+ st.write("SpaBERT Embedding shape:", all_embeddings[0].shape)
92
+ st.write("SpaBERT Embedding:", all_embeddings[0])
93
+
94
+
95
+
96
  #Get BERT Embedding for review
97
  def get_bert_embedding(review_text):
98
  #tokenize review