JasonTPhillipsJr commited on
Commit
4c18d69
1 Parent(s): b751a90

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -15
app.py CHANGED
@@ -51,6 +51,12 @@ spatialDataset = PbfMapDataset(data_file_path = data_file_path,
51
 
52
  data_loader = DataLoader(spatialDataset, batch_size=1, num_workers=0, shuffle=False, pin_memory=False, drop_last=False) #issue needs to be fixed with num_workers not stopping after finished
53
 
 
 
 
 
 
 
54
  #Pre-aquire the SpaBERT embeddings for all geo-entities within our dataset
55
  def process_entity(batch, model, device):
56
  input_ids = batch['masked_input'].to(device)
@@ -90,9 +96,9 @@ for i, batch in enumerate(data_loader):
90
  spaBERT_embedding, input_ids = process_entity(batch, spaBERT_model, device)
91
  spaBERT_embeddings.append(spaBERT_embedding)
92
 
93
- st.write("SpaBERT Embedding shape:", spaBERT_embedding[0].shape)
94
- st.write("SpaBERT Embedding:", spaBERT_embedding[0])
95
-
96
 
97
 
98
  #Get BERT Embedding for review
@@ -108,14 +114,34 @@ def get_bert_embedding(review_text):
108
  bert_embedding = outputs.last_hidden_state[:, 0, :].detach() #CLS Token
109
  return bert_embedding
110
 
111
-
112
-
113
-
114
-
115
-
116
-
117
-
118
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
 
121
  st.title("SpaGAN Demo")
@@ -138,10 +164,12 @@ for label, (color, description) in COLOR_MAP.items():
138
  #user_input = st.text_area("Input Text", height=200)
139
 
140
  # Define example reviews for testing
141
- example_reviews = {
142
- "Review 1": "I visited the Empire State Building in New York last summer, and it was amazing!",
143
- "Review 2": "Google, headquartered in Mountain View, is a leading tech company in the United States.",
144
- }
 
 
145
 
146
  # Dropdown for selecting an example review
147
  user_input = st.selectbox("Select an example review", options=list(example_reviews.keys()))
 
51
 
52
  data_loader = DataLoader(spatialDataset, batch_size=1, num_workers=0, shuffle=False, pin_memory=False, drop_last=False) #issue needs to be fixed with num_workers not stopping after finished
53
 
54
+ # Create a dictionary to map entity names to indices
55
+ entity_index_dict = {entity['pivot_name']: i for i, entity in enumerate(spatialDataset)}
56
+
57
+ # Ensure names are stored in lowercase for case-insensitive matching
58
+ entity_index_dict = {name.lower(): index for name, index in entity_index_dict.items()}
59
+
60
  #Pre-aquire the SpaBERT embeddings for all geo-entities within our dataset
61
  def process_entity(batch, model, device):
62
  input_ids = batch['masked_input'].to(device)
 
96
  spaBERT_embedding, input_ids = process_entity(batch, spaBERT_model, device)
97
  spaBERT_embeddings.append(spaBERT_embedding)
98
 
99
+ #st.write("SpaBERT Embedding shape:", spaBERT_embedding[0].shape)
100
+ #st.write("SpaBERT Embedding:", spaBERT_embedding[0])
101
+ embedding_cache = {}
102
 
103
 
104
  #Get BERT Embedding for review
 
114
  bert_embedding = outputs.last_hidden_state[:, 0, :].detach() #CLS Token
115
  return bert_embedding
116
 
117
+ def get_spaBert_embedding(entity):
118
+ entity_index = entity_index_dict.get(entity.lower(), None)
119
+ return spaBERT_embeddings[entity_index]
120
+
121
+ def processSpatialEntities(review, nlp):
122
+ doc = nlp(review)
123
+ entity_spans = [(ent.start, ent.end, ent.text, ent.label_) for ent in doc.ents]
124
+ token_embeddings = []
125
+
126
+ # Iterate over each entity span and process only geo entities
127
+ for start, end, text, label in entity_spans:
128
+ if label in ['FAC', 'ORG', 'LOC', 'GPE']: # Filter to geo-entities
129
+ spaBert_emb = get_spaBert_embedding(text)
130
+ token_embeddings.append((text, spaBert_emb))
131
+ return token_embeddings
132
+
133
+ # Function to read reviews from a text file
134
+ def load_reviews_from_file(file_path):
135
+ reviews = {}
136
+ try:
137
+ with open(file_path, 'r', encoding='utf-8') as file:
138
+ for i, line in enumerate(file):
139
+ line = line.strip()
140
+ if line: # Ensure the line is not empty
141
+ reviews[f"Review {i + 1}"] = line
142
+ except FileNotFoundError:
143
+ st.error(f"File not found: {file_path}")
144
+ return reviews
145
 
146
 
147
  st.title("SpaGAN Demo")
 
164
  #user_input = st.text_area("Input Text", height=200)
165
 
166
  # Define example reviews for testing
167
+ #example_reviews = {
168
+ # "Review 1": "Meh. My brother lives near the Italian Market in South Philly. I went for a visit. Luckily for me, my brother and his girlfriend are foodies. I was able to taste many different cuisines in Philly. Coming from San Francisco, there are places I don't go due to the tourist trap aura and the non-authenticity of it all (Fisherman’s Wharf, Chinatown, etc.). But when I was in Philly, I had to have a cheesesteak... and I had to go to the two most famous places, which of course are right across the street from one another, in a big rivalry, and featured on the Food Network! How cheesy, but essential. We split two, both "wit whiz"? (cheese whiz) one from Geno's and one from Pat's. Pat's was much tastier than Geno's. The meat was seasoned, and the bun and cheese had much more flavor... better of the two... it seems.",
169
+ # "Review 2": "Google, headquartered in Mountain View, is a leading tech company in the United States.",
170
+ #}
171
+ review_file_path = "models/spabert/datasets/SampleReviews.txt"
172
+ example_reviews = load_reviews_from_file(review_file_path)
173
 
174
  # Dropdown for selecting an example review
175
  user_input = st.selectbox("Select an example review", options=list(example_reviews.keys()))