Spaces:
Sleeping
Sleeping
JasonTPhillipsJr
commited on
Commit
•
a74fa0d
1
Parent(s):
5914cea
Update app.py
Browse files
app.py
CHANGED
@@ -37,7 +37,7 @@ spaBERT_model.load_state_dict(pre_trained_model, strict=False)
|
|
37 |
spaBERT_model.to(device)
|
38 |
spaBERT_model.eval()
|
39 |
|
40 |
-
#
|
41 |
spatialDataset = PbfMapDataset(data_file_path = data_file_path,
|
42 |
tokenizer = bert_tokenizer,
|
43 |
max_token_len = 256, #Originally 300
|
@@ -51,6 +51,48 @@ spatialDataset = PbfMapDataset(data_file_path = data_file_path,
|
|
51 |
|
52 |
data_loader = DataLoader(spatialDataset, batch_size=1, num_workers=0, shuffle=False, pin_memory=False, drop_last=False) #issue needs to be fixed with num_workers not stopping after finished
|
53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
#Get BERT Embedding for review
|
55 |
def get_bert_embedding(review_text):
|
56 |
#tokenize review
|
|
|
37 |
spaBERT_model.to(device)
|
38 |
spaBERT_model.eval()
|
39 |
|
40 |
+
#Load data using SpatialDataset
|
41 |
spatialDataset = PbfMapDataset(data_file_path = data_file_path,
|
42 |
tokenizer = bert_tokenizer,
|
43 |
max_token_len = 256, #Originally 300
|
|
|
51 |
|
52 |
data_loader = DataLoader(spatialDataset, batch_size=1, num_workers=0, shuffle=False, pin_memory=False, drop_last=False) #issue needs to be fixed with num_workers not stopping after finished
|
53 |
|
54 |
+
#Pre-aquire the SpaBERT embeddings for all geo-entities within our dataset
|
55 |
+
def process_entity(batch, model, device):
|
56 |
+
input_ids = batch['masked_input'].to(device)
|
57 |
+
attention_mask = batch['attention_mask'].to(device)
|
58 |
+
position_list_x = batch['norm_lng_list'].to(device)
|
59 |
+
position_list_y = batch['norm_lat_list'].to(device)
|
60 |
+
sent_position_ids = batch['sent_position_ids'].to(device)
|
61 |
+
pseudo_sentence = batch['pseudo_sentence'].to(device)
|
62 |
+
|
63 |
+
# Convert tensor to list of token IDs, and decode them into a readable sentence
|
64 |
+
pseudo_sentence_decoded = tokenizer.decode(pseudo_sentence[0].tolist(), skip_special_tokens=False)
|
65 |
+
|
66 |
+
with torch.no_grad():
|
67 |
+
outputs = spaBERT_model(#input_ids=input_ids,
|
68 |
+
input_ids=pseudo_sentence,
|
69 |
+
attention_mask=attention_mask,
|
70 |
+
sent_position_ids=sent_position_ids,
|
71 |
+
position_list_x=position_list_x,
|
72 |
+
position_list_y=position_list_y)
|
73 |
+
#NOTE: we are ommitting the pseudo_sentence here. Verify that this is correct
|
74 |
+
|
75 |
+
embeddings = outputs.hidden_states[-1].to(device)
|
76 |
+
|
77 |
+
# Extract the [CLS] token embedding (first token)
|
78 |
+
embedding = embeddings[:, 0, :].detach() # [batch_size, hidden_size]
|
79 |
+
|
80 |
+
#pivot_token_len = batch['pivot_token_len'].item()
|
81 |
+
#pivot_embeddings = embeddings[:, :pivot_token_len, :]
|
82 |
+
|
83 |
+
#return pivot_embeddings.cpu().numpy(), input_ids.cpu().numpy()
|
84 |
+
return embedding.cpu().numpy(), input_ids.cpu().numpy()
|
85 |
+
|
86 |
+
all_embeddings = []
|
87 |
+
for batch in (data_loader):
|
88 |
+
embeddings, input_ids = process_entity(batch, model, device)
|
89 |
+
all_embeddings.append(embeddings)
|
90 |
+
|
91 |
+
st.write("SpaBERT Embedding shape:", all_embeddings[0].shape)
|
92 |
+
st.write("SpaBERT Embedding:", all_embeddings[0])
|
93 |
+
|
94 |
+
|
95 |
+
|
96 |
#Get BERT Embedding for review
|
97 |
def get_bert_embedding(review_text):
|
98 |
#tokenize review
|