Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -23,7 +23,7 @@ bert_model.to(device)
|
|
23 |
bert_model.eval()
|
24 |
|
25 |
#SpaBERT Initialization Section
|
26 |
-
data_file_path = 'models/spabert/datasets/
|
27 |
pretrained_model_path = 'models/spabert/datasets/fine-spabert-base-uncased-finetuned-osm-mn.pth'
|
28 |
|
29 |
config = SpatialBertConfig()
|
@@ -49,7 +49,7 @@ spatialDataset = PbfMapDataset(data_file_path = data_file_path,
|
|
49 |
label_encoder = None,
|
50 |
mode = None) #If set to None it will use the full dataset for mlm
|
51 |
|
52 |
-
data_loader = DataLoader(spatialDataset, batch_size=1, num_workers=0, shuffle=False, pin_memory=False, drop_last=False)
|
53 |
|
54 |
# Create a dictionary to map entity names to indices
|
55 |
entity_index_dict = {entity['pivot_name']: i for i, entity in enumerate(spatialDataset)}
|
@@ -87,9 +87,7 @@ def process_entity(batch, model, device):
|
|
87 |
return spaBERT_embedding, input_ids
|
88 |
|
89 |
spaBERT_embeddings = []
|
90 |
-
for
|
91 |
-
if i >= 2: # Stop after processing 3 batches
|
92 |
-
break
|
93 |
spaBERT_embedding, input_ids = process_entity(batch, spaBERT_model, device)
|
94 |
spaBERT_embeddings.append(spaBERT_embedding)
|
95 |
|
|
|
23 |
bert_model.eval()
|
24 |
|
25 |
#SpaBERT Initialization Section
|
26 |
+
data_file_path = 'models/spabert/datasets/SpaBERTPivots.json' #Sample file otherwise this model will take too long on CPU.
|
27 |
pretrained_model_path = 'models/spabert/datasets/fine-spabert-base-uncased-finetuned-osm-mn.pth'
|
28 |
|
29 |
config = SpatialBertConfig()
|
|
|
49 |
label_encoder = None,
|
50 |
mode = None) #If set to None it will use the full dataset for mlm
|
51 |
|
52 |
+
data_loader = DataLoader(spatialDataset, batch_size=1, num_workers=0, shuffle=False, pin_memory=False, drop_last=False)
|
53 |
|
54 |
# Create a dictionary to map entity names to indices
|
55 |
entity_index_dict = {entity['pivot_name']: i for i, entity in enumerate(spatialDataset)}
|
|
|
87 |
return spaBERT_embedding, input_ids
|
88 |
|
89 |
spaBERT_embeddings = []
|
90 |
+
for batch in (data_loader):
|
|
|
|
|
91 |
spaBERT_embedding, input_ids = process_entity(batch, spaBERT_model, device)
|
92 |
spaBERT_embeddings.append(spaBERT_embedding)
|
93 |
|