Spaces:
Sleeping
Sleeping
JasonTPhillipsJr
commited on
Commit
•
4c18d69
1
Parent(s):
b751a90
Update app.py
Browse files
app.py
CHANGED
@@ -51,6 +51,12 @@ spatialDataset = PbfMapDataset(data_file_path = data_file_path,
|
|
51 |
|
52 |
data_loader = DataLoader(spatialDataset, batch_size=1, num_workers=0, shuffle=False, pin_memory=False, drop_last=False) #issue needs to be fixed with num_workers not stopping after finished
|
53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
#Pre-aquire the SpaBERT embeddings for all geo-entities within our dataset
|
55 |
def process_entity(batch, model, device):
|
56 |
input_ids = batch['masked_input'].to(device)
|
@@ -90,9 +96,9 @@ for i, batch in enumerate(data_loader):
|
|
90 |
spaBERT_embedding, input_ids = process_entity(batch, spaBERT_model, device)
|
91 |
spaBERT_embeddings.append(spaBERT_embedding)
|
92 |
|
93 |
-
st.write("SpaBERT Embedding shape:", spaBERT_embedding[0].shape)
|
94 |
-
st.write("SpaBERT Embedding:", spaBERT_embedding[0])
|
95 |
-
|
96 |
|
97 |
|
98 |
#Get BERT Embedding for review
|
@@ -108,14 +114,34 @@ def get_bert_embedding(review_text):
|
|
108 |
bert_embedding = outputs.last_hidden_state[:, 0, :].detach() #CLS Token
|
109 |
return bert_embedding
|
110 |
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
|
120 |
|
121 |
st.title("SpaGAN Demo")
|
@@ -138,10 +164,12 @@ for label, (color, description) in COLOR_MAP.items():
|
|
138 |
#user_input = st.text_area("Input Text", height=200)
|
139 |
|
140 |
# Define example reviews for testing
|
141 |
-
example_reviews = {
|
142 |
-
"Review 1": "
|
143 |
-
"Review 2": "Google, headquartered in Mountain View, is a leading tech company in the United States.",
|
144 |
-
}
|
|
|
|
|
145 |
|
146 |
# Dropdown for selecting an example review
|
147 |
user_input = st.selectbox("Select an example review", options=list(example_reviews.keys()))
|
|
|
51 |
|
52 |
data_loader = DataLoader(spatialDataset, batch_size=1, num_workers=0, shuffle=False, pin_memory=False, drop_last=False) #issue needs to be fixed with num_workers not stopping after finished
|
53 |
|
54 |
+
# Create a dictionary to map entity names to indices
|
55 |
+
entity_index_dict = {entity['pivot_name']: i for i, entity in enumerate(spatialDataset)}
|
56 |
+
|
57 |
+
# Ensure names are stored in lowercase for case-insensitive matching
|
58 |
+
entity_index_dict = {name.lower(): index for name, index in entity_index_dict.items()}
|
59 |
+
|
60 |
#Pre-aquire the SpaBERT embeddings for all geo-entities within our dataset
|
61 |
def process_entity(batch, model, device):
|
62 |
input_ids = batch['masked_input'].to(device)
|
|
|
96 |
spaBERT_embedding, input_ids = process_entity(batch, spaBERT_model, device)
|
97 |
spaBERT_embeddings.append(spaBERT_embedding)
|
98 |
|
99 |
+
#st.write("SpaBERT Embedding shape:", spaBERT_embedding[0].shape)
|
100 |
+
#st.write("SpaBERT Embedding:", spaBERT_embedding[0])
|
101 |
+
embedding_cache = {}
|
102 |
|
103 |
|
104 |
#Get BERT Embedding for review
|
|
|
114 |
bert_embedding = outputs.last_hidden_state[:, 0, :].detach() #CLS Token
|
115 |
return bert_embedding
|
116 |
|
117 |
+
def get_spaBert_embedding(entity):
|
118 |
+
entity_index = entity_index_dict.get(entity.lower(), None)
|
119 |
+
return spaBERT_embeddings[entity_index]
|
120 |
+
|
121 |
+
def processSpatialEntities(review, nlp):
|
122 |
+
doc = nlp(review)
|
123 |
+
entity_spans = [(ent.start, ent.end, ent.text, ent.label_) for ent in doc.ents]
|
124 |
+
token_embeddings = []
|
125 |
+
|
126 |
+
# Iterate over each entity span and process only geo entities
|
127 |
+
for start, end, text, label in entity_spans:
|
128 |
+
if label in ['FAC', 'ORG', 'LOC', 'GPE']: # Filter to geo-entities
|
129 |
+
spaBert_emb = get_spaBert_embedding(text)
|
130 |
+
token_embeddings.append((text, spaBert_emb))
|
131 |
+
return token_embeddings
|
132 |
+
|
133 |
+
# Function to read reviews from a text file
|
134 |
+
def load_reviews_from_file(file_path):
|
135 |
+
reviews = {}
|
136 |
+
try:
|
137 |
+
with open(file_path, 'r', encoding='utf-8') as file:
|
138 |
+
for i, line in enumerate(file):
|
139 |
+
line = line.strip()
|
140 |
+
if line: # Ensure the line is not empty
|
141 |
+
reviews[f"Review {i + 1}"] = line
|
142 |
+
except FileNotFoundError:
|
143 |
+
st.error(f"File not found: {file_path}")
|
144 |
+
return reviews
|
145 |
|
146 |
|
147 |
st.title("SpaGAN Demo")
|
|
|
164 |
#user_input = st.text_area("Input Text", height=200)
|
165 |
|
166 |
# Define example reviews for testing
|
167 |
+
#example_reviews = {
|
168 |
+
# "Review 1": "Meh. My brother lives near the Italian Market in South Philly. I went for a visit. Luckily for me, my brother and his girlfriend are foodies. I was able to taste many different cuisines in Philly. Coming from San Francisco, there are places I don't go due to the tourist trap aura and the non-authenticity of it all (Fisherman’s Wharf, Chinatown, etc.). But when I was in Philly, I had to have a cheesesteak... and I had to go to the two most famous places, which of course are right across the street from one another, in a big rivalry, and featured on the Food Network! How cheesy, but essential. We split two, both "wit whiz"? (cheese whiz) one from Geno's and one from Pat's. Pat's was much tastier than Geno's. The meat was seasoned, and the bun and cheese had much more flavor... better of the two... it seems.",
|
169 |
+
# "Review 2": "Google, headquartered in Mountain View, is a leading tech company in the United States.",
|
170 |
+
#}
|
171 |
+
review_file_path = "models/spabert/datasets/SampleReviews.txt"
|
172 |
+
example_reviews = load_reviews_from_file(review_file_path)
|
173 |
|
174 |
# Dropdown for selecting an example review
|
175 |
user_input = st.selectbox("Select an example review", options=list(example_reviews.keys()))
|