JasonTPhillipsJr commited on
Commit
18634d6
1 Parent(s): bf52bfd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -3
app.py CHANGED
@@ -40,14 +40,14 @@ spaBERT_model.eval()
40
  #Load data using SpatialDataset
41
  spatialDataset = PbfMapDataset(data_file_path = data_file_path,
42
  tokenizer = bert_tokenizer,
43
- max_token_len = 256, #Originally 300
44
  #max_token_len = max_seq_length, #Originally 300
45
  distance_norm_factor = 0.0001,
46
  spatial_dist_fill = 20,
47
  with_type = False,
48
  sep_between_neighbors = True,
49
  label_encoder = None,
50
- mode = None) #If set to None it will use the full dataset for mlm
51
 
52
  data_loader = DataLoader(spatialDataset, batch_size=1, num_workers=0, shuffle=False, pin_memory=False, drop_last=False) #issue needs to be fixed with num_workers not stopping after finished
53
 
@@ -81,7 +81,7 @@ def process_entity(batch, model, device):
81
  #pivot_embeddings = embeddings[:, :pivot_token_len, :]
82
 
83
  #return pivot_embeddings.cpu().numpy(), input_ids.cpu().numpy()
84
- return embedding.cpu().numpy(), input_ids.cpu().numpy()
85
 
86
  all_embeddings = []
87
  for i, batch in enumerate(data_loader):
@@ -112,6 +112,12 @@ def get_bert_embedding(review_text):
112
 
113
 
114
 
 
 
 
 
 
 
115
  st.title("SpaGAN Demo")
116
  st.write("Enter a text, and the system will highlight the geo-entities within it.")
117
 
@@ -153,9 +159,11 @@ if st.button("Highlight Geo-Entities"):
153
  # Debug: Print the embeddings themselves (optional)
154
  st.write("Embeddings:", bert_embedding)
155
 
 
156
  combined_embedding = torch.cat((bert_embedding,all_embeddings[0]),dim=-1)
157
  st.write("Concatenated Embedding Shape:", concatenated_embedding.shape)
158
  st.write("Concatenated Embedding:", concatenated_embedding)
 
159
  # Process the text using spaCy
160
  doc = nlp(selected_review)
161
 
 
40
  #Load data using SpatialDataset
41
  spatialDataset = PbfMapDataset(data_file_path = data_file_path,
42
  tokenizer = bert_tokenizer,
43
+ max_token_len = 256, #Originally 300
44
  #max_token_len = max_seq_length, #Originally 300
45
  distance_norm_factor = 0.0001,
46
  spatial_dist_fill = 20,
47
  with_type = False,
48
  sep_between_neighbors = True,
49
  label_encoder = None,
50
+ mode = None) #If set to None it will use the full dataset for mlm
51
 
52
  data_loader = DataLoader(spatialDataset, batch_size=1, num_workers=0, shuffle=False, pin_memory=False, drop_last=False) #issue needs to be fixed with num_workers not stopping after finished
53
 
 
81
  #pivot_embeddings = embeddings[:, :pivot_token_len, :]
82
 
83
  #return pivot_embeddings.cpu().numpy(), input_ids.cpu().numpy()
84
+ return embedding, input_ids
85
 
86
  all_embeddings = []
87
  for i, batch in enumerate(data_loader):
 
112
 
113
 
114
 
115
+
116
+
117
+
118
+
119
+
120
+
121
  st.title("SpaGAN Demo")
122
  st.write("Enter a text, and the system will highlight the geo-entities within it.")
123
 
 
159
  # Debug: Print the embeddings themselves (optional)
160
  st.write("Embeddings:", bert_embedding)
161
 
162
+ #combine the embeddings (NOTE: come back and update after testing)
163
  combined_embedding = torch.cat((bert_embedding,all_embeddings[0]),dim=-1)
164
  st.write("Concatenated Embedding Shape:", concatenated_embedding.shape)
165
  st.write("Concatenated Embedding:", concatenated_embedding)
166
+
167
  # Process the text using spaCy
168
  doc = nlp(selected_review)
169