JasonTPhillipsJr commited on
Commit
fa29176
·
verified ·
1 Parent(s): dba4dc1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -5
app.py CHANGED
@@ -23,7 +23,7 @@ bert_model.to(device)
23
  bert_model.eval()
24
 
25
  #SpaBERT Initialization Section
26
- data_file_path = 'models/spabert/datasets/SPABERT_finetuning_data_combined.json' #Make a new json file with only the geo entities needed, or it takes too long to run.
27
  pretrained_model_path = 'models/spabert/datasets/fine-spabert-base-uncased-finetuned-osm-mn.pth'
28
 
29
  config = SpatialBertConfig()
@@ -49,7 +49,7 @@ spatialDataset = PbfMapDataset(data_file_path = data_file_path,
49
  label_encoder = None,
50
  mode = None) #If set to None it will use the full dataset for mlm
51
 
52
- data_loader = DataLoader(spatialDataset, batch_size=1, num_workers=0, shuffle=False, pin_memory=False, drop_last=False) #issue needs to be fixed with num_workers not stopping after finished
53
 
54
  # Create a dictionary to map entity names to indices
55
  entity_index_dict = {entity['pivot_name']: i for i, entity in enumerate(spatialDataset)}
@@ -87,9 +87,7 @@ def process_entity(batch, model, device):
87
  return spaBERT_embedding, input_ids
88
 
89
  spaBERT_embeddings = []
90
- for i, batch in enumerate(data_loader):
91
- if i >= 2: # Stop after processing 3 batches
92
- break
93
  spaBERT_embedding, input_ids = process_entity(batch, spaBERT_model, device)
94
  spaBERT_embeddings.append(spaBERT_embedding)
95
 
 
23
  bert_model.eval()
24
 
25
  #SpaBERT Initialization Section
26
+ data_file_path = 'models/spabert/datasets/SpaBERTPivots.json' #Sample file otherwise this model will take too long on CPU.
27
  pretrained_model_path = 'models/spabert/datasets/fine-spabert-base-uncased-finetuned-osm-mn.pth'
28
 
29
  config = SpatialBertConfig()
 
49
  label_encoder = None,
50
  mode = None) #If set to None it will use the full dataset for mlm
51
 
52
+ data_loader = DataLoader(spatialDataset, batch_size=1, num_workers=0, shuffle=False, pin_memory=False, drop_last=False)
53
 
54
  # Create a dictionary to map entity names to indices
55
  entity_index_dict = {entity['pivot_name']: i for i, entity in enumerate(spatialDataset)}
 
87
  return spaBERT_embedding, input_ids
88
 
89
  spaBERT_embeddings = []
90
+ for batch in (data_loader):
 
 
91
  spaBERT_embedding, input_ids = process_entity(batch, spaBERT_model, device)
92
  spaBERT_embeddings.append(spaBERT_embedding)
93