File size: 13,655 Bytes
ac736ed
5c91758
bc50d7d
f7a8863
a2d8109
01b8a52
ebf50a4
 
b7111b8
 
3d8cd48
 
83e90d7
ac736ed
d914cbe
dbbbac5
d914cbe
60a8335
d914cbe
1aa7dda
3d81019
60a8335
dc9ff0b
3d81019
 
d914cbe
3d81019
 
60a8335
dc9ff0b
fa29176
d914cbe
3be15aa
fbce538
 
 
3d81019
3be15aa
d914cbe
 
80744c0
 
 
3d81019
60a8335
a74fa0d
3d8cd48
5914cea
18634d6
5914cea
3d8cd48
 
 
bf52bfd
 
18634d6
3d8cd48
fa29176
3d81019
4c18d69
 
 
 
 
 
60a8335
a74fa0d
 
 
 
 
 
 
 
 
 
169e7aa
a74fa0d
 
 
 
 
 
 
 
 
 
857dba3
a74fa0d
 
0cea6d5
a74fa0d
 
4b2b2d4
a74fa0d
857dba3
4b2b2d4
fa29176
4b2b2d4
0cea6d5
4b2b2d4
 
 
a74fa0d
f82dac8
1091141
f82dac8
 
 
 
 
 
 
 
857dba3
 
3d81019
60a8335
cbcad17
7837592
4c18d69
b8aeb00
94e1f37
 
60a8335
b8aeb00
7837592
94e1f37
 
b8aeb00
4c18d69
60a8335
cbcad17
4c18d69
 
 
 
7837592
4c18d69
 
 
 
340e947
a3ad8d6
7837592
cb96047
dfc41ea
94e1f37
b14ffe5
 
bec805a
b6c17e6
7837592
60a8335
4c18d69
f7a8863
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fea8f2e
f7a8863
 
 
b81fd11
f7a8863
d2568a6
b28bcdc
 
9edc447
b28bcdc
9edc447
 
 
 
 
 
b28bcdc
9edc447
b28bcdc
 
 
d2568a6
4c18d69
 
 
 
 
 
 
 
 
 
 
 
18634d6
b28bcdc
 
 
dbbbac5
27b4c27
 
 
 
83e90d7
9822204
3577a57
9822204
 
 
 
3577a57
 
d914cbe
3577a57
9822204
 
3577a57
4c18d69
 
b23060e
4aae01c
 
 
 
 
 
 
 
 
b23060e
4aae01c
b23060e
4aae01c
 
 
0e51129
83e90d7
1f1627a
 
 
 
 
 
5c91758
4b2b2d4
0e51129
 
1f1627a
5f6c4ef
94e1f37
 
 
 
 
 
5fb8341
b28bcdc
4b2b2d4
5c91758
53c7e98
e72b522
b4303dc
1f1627a
b4303dc
 
 
 
 
 
 
 
83e90d7
3577a57
 
4b2b2d4
7837592
 
ffa1ad8
 
7837592
4b2b2d4
aafb64e
 
 
 
4b2b2d4
aafb64e
4b2b2d4
5c91758
b23060e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
import streamlit as st
import spacy
import torch
import torch.nn as nn
import pandas as pd
from transformers import BertTokenizer, BertModel, AutoConfig
from transformers.models.bert.modeling_bert import BertForMaskedLM

from models.spabert.models.spatial_bert_model import SpatialBertConfig, SpatialBertForMaskedLM, SpatialBertModel
from models.spabert.utils.common_utils import load_spatial_bert_pretrained_weights
from models.spabert.datasets.osm_sample_loader import PbfMapDataset
from torch.utils.data import DataLoader
from PIL import Image

device = torch.device('cpu')
dev_mode = False


#Spacy Initialization Section
nlp = spacy.load("./models/en_core_web_sm")


#BERT Initialization Section
bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
bert_model = BertModel.from_pretrained("bert-base-uncased")
bert_model.to(device)
bert_model.eval()


#SpaBERT Initialization Section
data_file_path = 'models/spabert/datasets/SpaBERTPivots.json'    #Sample file otherwise this model will take too long on CPU.
pretrained_model_path = 'models/spabert/datasets/fine-spabert-base-uncased-finetuned-osm-mn.pth'

config = SpatialBertConfig()
config.output_hidden_states = True
spaBERT_model = SpatialBertForMaskedLM(config)

pre_trained_model = torch.load(pretrained_model_path, map_location=torch.device('cpu'))
spaBERT_model.load_state_dict(bert_model.state_dict(), strict = False)
spaBERT_model.load_state_dict(pre_trained_model, strict=False)

spaBERT_model.to(device)
spaBERT_model.eval()


#Load data using SpatialDataset
spatialDataset = PbfMapDataset(data_file_path = data_file_path,
                                        tokenizer = bert_tokenizer,
                                        max_token_len = 256,                          #Originally 300
                                        #max_token_len = max_seq_length,              #Originally 300
                                        distance_norm_factor = 0.0001,
                                        spatial_dist_fill = 20,
                                        with_type = False,
                                        sep_between_neighbors = True,    
                                        label_encoder = None,             
                                        mode = None)                                  #If set to None it will use the full dataset for mlm

data_loader = DataLoader(spatialDataset, batch_size=1, num_workers=0, shuffle=False, pin_memory=False, drop_last=False) 

# Create a dictionary to map entity names to indices
entity_index_dict = {entity['pivot_name']: i for i, entity in enumerate(spatialDataset)}

# Ensure names are stored in lowercase for case-insensitive matching
entity_index_dict = {name.lower(): index for name, index in entity_index_dict.items()}


#Pre-aquire the SpaBERT embeddings for all geo-entities within our dataset
def process_entity(batch, model, device):
    input_ids = batch['masked_input'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    position_list_x = batch['norm_lng_list'].to(device)
    position_list_y = batch['norm_lat_list'].to(device)
    sent_position_ids = batch['sent_position_ids'].to(device)
    pseudo_sentence = batch['pseudo_sentence'].to(device)

    # Convert tensor to list of token IDs, and decode them into a readable sentence
    pseudo_sentence_decoded = bert_tokenizer.decode(pseudo_sentence[0].tolist(), skip_special_tokens=False)

    with torch.no_grad():
        outputs = spaBERT_model(#input_ids=input_ids,
                        input_ids=pseudo_sentence,
                        attention_mask=attention_mask,
                        sent_position_ids=sent_position_ids,
                        position_list_x=position_list_x,
                        position_list_y=position_list_y)
                        #NOTE: we are ommitting the pseudo_sentence here. Verify that this is correct

    spaBERT_embedding = outputs.hidden_states[-1].to(device)

    # Extract the [CLS] token embedding (first token)
    spaBERT_embedding = spaBERT_embedding[:, 0, :].detach()  # [batch_size, hidden_size]

    #return pivot_embeddings.cpu().numpy(), input_ids.cpu().numpy()
    return spaBERT_embedding, input_ids, pseudo_sentence_decoded

spaBERT_embeddings = []
pseudo_sentences = []
for batch in (data_loader):
    spaBERT_embedding, input_ids, pseudo_sentence = process_entity(batch, spaBERT_model, device)
    spaBERT_embeddings.append(spaBERT_embedding)
    pseudo_sentences.append(pseudo_sentence)
    
embedding_cache = {}

#Get BERT Embedding for review
def get_bert_embedding(review_text):
    #tokenize review
    inputs = bert_tokenizer(review_text, return_tensors='pt', padding=True, truncation=True).to(device)
    
    # Forward pass through the BERT model
    with torch.no_grad():
        outputs = bert_model(**inputs)

    # Extract embeddings from the last hidden state
    bert_embedding = outputs.last_hidden_state[:, 0, :].detach()     #CLS Token
    return bert_embedding


#Get SpaBERT Embedding for geo-entity
def get_spaBert_embedding(entity,current_pseudo_sentences):
    entity_index = entity_index_dict.get(entity.lower(), None)
    if entity_index is None:
        if(dev_mode == True):
            st.write("Got Bert embedding for: ", entity)
        return get_bert_embedding(entity)                            #Fallback in-case SpaBERT could not resolve entity to retrieve embedding. Rare-cases only.
    else:
        current_pseudo_sentences.append(pseudo_sentences[entity_index])
        if(dev_mode == True):
            st.write("Got SpaBert embedding for: ", entity)
        return spaBERT_embeddings[entity_index]

        
#Go through each review, identify all geo-entities, then extract their SpaBERT embedings
def processSpatialEntities(review, nlp):
    doc = nlp(review)
    entity_spans = [(ent.start, ent.end, ent.text, ent.label_) for ent in doc.ents]
    token_embeddings = []
    current_pseudo_sentences = []

    # Iterate over each entity span and process only geo entities
    for start, end, text, label in entity_spans:
        if label in ['FAC', 'ORG', 'LOC', 'GPE']:  # Filter to geo-entities
            if(dev_mode == True):
                st.write("Text found:", text)
            spaBert_emb = get_spaBert_embedding(text,current_pseudo_sentences)
            token_embeddings.append(spaBert_emb)
            if(dev_mode == True):
                st.write("Geo-Entity Found in review: ", text)

    token_embeddings = torch.stack(token_embeddings, dim=0)
    processed_embedding = token_embeddings.mean(dim=0)  # Shape: (768)
    #processed_embedding = processed_embedding.unsqueeze(0)  # Shape: (1, 768)
    return processed_embedding,current_pseudo_sentences


#Initialize discriminator module
class Discriminator(nn.Module):
    def __init__(self, input_size=512, hidden_sizes=[512], num_labels=2, dropout_rate=0.1):
        super(Discriminator, self).__init__()
        self.input_dropout = nn.Dropout(p=dropout_rate)
        layers = []
        hidden_sizes = [input_size] + hidden_sizes
        for i in range(len(hidden_sizes)-1):
            layers.extend([nn.Linear(hidden_sizes[i], hidden_sizes[i+1]), nn.LeakyReLU(0.2, inplace=True), nn.Dropout(dropout_rate)])

        self.layers = nn.Sequential(*layers) #per il flatten
        self.logit = nn.Linear(hidden_sizes[-1],num_labels+1) # +1 for the probability of this sample being fake/real.
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, input_rep):
        input_rep = self.input_dropout(input_rep)
        last_rep = self.layers(input_rep)
        logits = self.logit(last_rep)
        probs = self.softmax(logits)
        return last_rep, logits, probs
        
dConfig = AutoConfig.from_pretrained("bert-base-uncased")
hidden_size = int(dConfig.hidden_size)
num_hidden_layers_d = 2; 
hidden_levels_d = [hidden_size for i in range(0, num_hidden_layers_d)]
label_list = ["1", "0"]
label_list.append('UNL')
out_dropout_rate = 0.5; 

discriminator = Discriminator(input_size=hidden_size*2, hidden_sizes=hidden_levels_d, num_labels=len(label_list), dropout_rate=out_dropout_rate).to(device)
discriminator_weights = ('data/datasets/discriminator_weights.pth')
discriminator.load_state_dict(torch.load(discriminator_weights,map_location=torch.device('cpu')))
discriminator.eval()

def get_prediction(embeddings):
    with torch.no_grad():
        # Forward pass through the discriminator to get the logits and probabilities
        last_rep, logits, probs = discriminator(embeddings)
        
        # Filter logits to ignore the last dimension (assuming you only care about the first two)
        filtered_logits = logits[:, 0:-1]

        # Get the predicted labels using the filtered logits
        _, predicted_labels = torch.max(filtered_logits, dim=-1)

    # Convert to numpy array if needed
    predicted_labels = predicted_labels.cpu().numpy()
    return predicted_labels


# Function to read reviews from a text file
def load_reviews_from_file(file_path):
    reviews = {}
    try:
        with open(file_path, 'r', encoding='utf-8') as file:
            for i, line in enumerate(file):
                line = line.strip()
                if line:  # Ensure the line is not empty
                    reviews[f"Review {i + 1}"] = line
    except FileNotFoundError:
        st.error(f"File not found: {file_path}")
    return reviews



#Demo Section
st.title("SpaGAN Demo")
st.write("This demo lets you explore a curated list of sample reviews, containing a real and fake example.")
st.write("Upon selecting a review, any identified geo-entities will be color coded for easy visualization.")
st.write("For each geo-entity found, the model will generate a contextual pseudo-sentnece, highlighting its closest neighbors from our dataset.")
st.write("Finally, the entire review is embedded and enriched with spatial embeddings, enabling the model to determine whether the review is authentic or fake.")

# Define a color map and descriptions for different entity types
COLOR_MAP = {
    'FAC': ('red', 'Facilities (e.g., buildings, airports)'),
    'ORG': ('blue', 'Organizations (e.g., companies, institutions)'),
    'LOC': ('purple', 'Locations (e.g., mountain ranges, water bodies)'),
    'GPE': ('green', 'Geopolitical Entities (e.g., countries, cities)')
}

# Display the color key
st.write("**Color Key:**")
for label, (color, description) in COLOR_MAP.items():
    st.markdown(f"- **{label}**: <span style='color:{color}'>{color}</span> - {description}", unsafe_allow_html=True)

review_file_path = "models/spabert/datasets/SampleReviews.txt"
example_reviews = load_reviews_from_file(review_file_path)

# Define labels
review_labels = {
    "Review 1": "Real",
    "Review 2": "Spam",
}

# Create options with labels for the dropdown
dropdown_options = [f"{key} ({review_labels.get(key, 'Unknown')})" for key in example_reviews.keys()]

# Dropdown for selecting an example review
user_selection = st.selectbox("Select an example review", options=dropdown_options)

# Extract the original review key from the selected option
selected_key = user_selection.split(" (")[0]  # Remove the label part
selected_review = example_reviews[selected_key]
lower_case_review = selected_review.lower()

#Optional textbox for interactivity
user_input_review = st.text_area("Or type your own review here","")

review_to_process = user_input_review if user_input_review.strip() else selected_review
lower_case_review = review_to_process.lower()

# Process the text when the button is clicked
if st.button("Process Review"):
    if lower_case_review.strip():
        bert_embedding = get_bert_embedding(lower_case_review)
        spaBert_embedding, current_pseudo_sentences = processSpatialEntities(review_to_process,nlp)
        combined_embedding = torch.cat((bert_embedding,spaBert_embedding),dim=-1)
        
        if(dev_mode == True):
            st.write("Review Embedding Shape:", bert_embedding.shape)
            st.write("Geo-Entities embedding shape: ", spaBert_embedding.shape)
            st.write("Concatenated Embedding Shape:", combined_embedding.shape)
            st.write("Concatenated Embedding:", combined_embedding)

        prediction = get_prediction(combined_embedding)
                    
        # Process the text using spaCy
        doc = nlp(selected_review)
        
        # Highlight geo-entities with different colors
        highlighted_text = review_to_process
        for ent in reversed(doc.ents):
            if ent.label_ in COLOR_MAP:
                color = COLOR_MAP[ent.label_][0]
                highlighted_text = (
                    highlighted_text[:ent.start_char] +
                    f"<span style='color:{color}; font-weight:bold'>{ent.text}</span>" + 
                    highlighted_text[ent.end_char:]
                )

        # Display the highlighted text with HTML support
        st.markdown(highlighted_text, unsafe_allow_html=True)

        #Display pseudo sentences found
        for sentence in current_pseudo_sentences:
            clean_sentence = sentence.replace("[PAD]", "").strip()
            st.write("Pseudo-Sentence:", clean_sentence)
        
        #Display the models prediction
        if prediction == 0:
            st.markdown("<h3 style='color:green;'>✅ Prediction: Not Spam</h3>", unsafe_allow_html=True)
        elif prediction == 1:
            st.markdown("<h3 style='color:red;'>❌ Prediction: Spam</h3>", unsafe_allow_html=True)
        else:
            st.markdown("<h3 style='color:orange;'>⚠️ Error during prediction</h3>", unsafe_allow_html=True)
            
    else:
        st.error("Please select a review.")