import streamlit as st import spacy import torch import torch.nn as nn import pandas as pd from transformers import BertTokenizer, BertModel, AutoConfig from transformers.models.bert.modeling_bert import BertForMaskedLM from models.spabert.models.spatial_bert_model import SpatialBertConfig, SpatialBertForMaskedLM, SpatialBertModel from models.spabert.utils.common_utils import load_spatial_bert_pretrained_weights from models.spabert.datasets.osm_sample_loader import PbfMapDataset from torch.utils.data import DataLoader from PIL import Image device = torch.device('cpu') dev_mode = False #Spacy Initialization Section nlp = spacy.load("./models/en_core_web_sm") #BERT Initialization Section bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") bert_model = BertModel.from_pretrained("bert-base-uncased") bert_model.to(device) bert_model.eval() #SpaBERT Initialization Section data_file_path = 'models/spabert/datasets/SpaBERTPivots.json' #Sample file otherwise this model will take too long on CPU. pretrained_model_path = 'models/spabert/datasets/fine-spabert-base-uncased-finetuned-osm-mn.pth' config = SpatialBertConfig() config.output_hidden_states = True spaBERT_model = SpatialBertForMaskedLM(config) pre_trained_model = torch.load(pretrained_model_path, map_location=torch.device('cpu')) spaBERT_model.load_state_dict(bert_model.state_dict(), strict = False) spaBERT_model.load_state_dict(pre_trained_model, strict=False) spaBERT_model.to(device) spaBERT_model.eval() #Load data using SpatialDataset spatialDataset = PbfMapDataset(data_file_path = data_file_path, tokenizer = bert_tokenizer, max_token_len = 256, #Originally 300 #max_token_len = max_seq_length, #Originally 300 distance_norm_factor = 0.0001, spatial_dist_fill = 20, with_type = False, sep_between_neighbors = True, label_encoder = None, mode = None) #If set to None it will use the full dataset for mlm data_loader = DataLoader(spatialDataset, batch_size=1, num_workers=0, shuffle=False, pin_memory=False, drop_last=False) # Create a dictionary to map entity names to indices entity_index_dict = {entity['pivot_name']: i for i, entity in enumerate(spatialDataset)} # Ensure names are stored in lowercase for case-insensitive matching entity_index_dict = {name.lower(): index for name, index in entity_index_dict.items()} #Pre-aquire the SpaBERT embeddings for all geo-entities within our dataset def process_entity(batch, model, device): input_ids = batch['masked_input'].to(device) attention_mask = batch['attention_mask'].to(device) position_list_x = batch['norm_lng_list'].to(device) position_list_y = batch['norm_lat_list'].to(device) sent_position_ids = batch['sent_position_ids'].to(device) pseudo_sentence = batch['pseudo_sentence'].to(device) # Convert tensor to list of token IDs, and decode them into a readable sentence pseudo_sentence_decoded = bert_tokenizer.decode(pseudo_sentence[0].tolist(), skip_special_tokens=False) with torch.no_grad(): outputs = spaBERT_model(#input_ids=input_ids, input_ids=pseudo_sentence, attention_mask=attention_mask, sent_position_ids=sent_position_ids, position_list_x=position_list_x, position_list_y=position_list_y) #NOTE: we are ommitting the pseudo_sentence here. Verify that this is correct spaBERT_embedding = outputs.hidden_states[-1].to(device) # Extract the [CLS] token embedding (first token) spaBERT_embedding = spaBERT_embedding[:, 0, :].detach() # [batch_size, hidden_size] #return pivot_embeddings.cpu().numpy(), input_ids.cpu().numpy() return spaBERT_embedding, input_ids, pseudo_sentence_decoded spaBERT_embeddings = [] pseudo_sentences = [] for batch in (data_loader): spaBERT_embedding, input_ids, pseudo_sentence = process_entity(batch, spaBERT_model, device) spaBERT_embeddings.append(spaBERT_embedding) pseudo_sentences.append(pseudo_sentence) embedding_cache = {} #Get BERT Embedding for review def get_bert_embedding(review_text): #tokenize review inputs = bert_tokenizer(review_text, return_tensors='pt', padding=True, truncation=True).to(device) # Forward pass through the BERT model with torch.no_grad(): outputs = bert_model(**inputs) # Extract embeddings from the last hidden state bert_embedding = outputs.last_hidden_state[:, 0, :].detach() #CLS Token return bert_embedding #Get SpaBERT Embedding for geo-entity def get_spaBert_embedding(entity,current_pseudo_sentences): entity_index = entity_index_dict.get(entity.lower(), None) if entity_index is None: if(dev_mode == True): st.write("Got Bert embedding for: ", entity) return get_bert_embedding(entity) #Fallback in-case SpaBERT could not resolve entity to retrieve embedding. Rare-cases only. else: current_pseudo_sentences.append(pseudo_sentences[entity_index]) if(dev_mode == True): st.write("Got SpaBert embedding for: ", entity) return spaBERT_embeddings[entity_index] #Go through each review, identify all geo-entities, then extract their SpaBERT embedings def processSpatialEntities(review, nlp): doc = nlp(review) entity_spans = [(ent.start, ent.end, ent.text, ent.label_) for ent in doc.ents] token_embeddings = [] current_pseudo_sentences = [] # Iterate over each entity span and process only geo entities for start, end, text, label in entity_spans: if label in ['FAC', 'ORG', 'LOC', 'GPE']: # Filter to geo-entities if(dev_mode == True): st.write("Text found:", text) spaBert_emb = get_spaBert_embedding(text,current_pseudo_sentences) token_embeddings.append(spaBert_emb) if(dev_mode == True): st.write("Geo-Entity Found in review: ", text) token_embeddings = torch.stack(token_embeddings, dim=0) processed_embedding = token_embeddings.mean(dim=0) # Shape: (768) #processed_embedding = processed_embedding.unsqueeze(0) # Shape: (1, 768) return processed_embedding,current_pseudo_sentences #Initialize discriminator module class Discriminator(nn.Module): def __init__(self, input_size=512, hidden_sizes=[512], num_labels=2, dropout_rate=0.1): super(Discriminator, self).__init__() self.input_dropout = nn.Dropout(p=dropout_rate) layers = [] hidden_sizes = [input_size] + hidden_sizes for i in range(len(hidden_sizes)-1): layers.extend([nn.Linear(hidden_sizes[i], hidden_sizes[i+1]), nn.LeakyReLU(0.2, inplace=True), nn.Dropout(dropout_rate)]) self.layers = nn.Sequential(*layers) #per il flatten self.logit = nn.Linear(hidden_sizes[-1],num_labels+1) # +1 for the probability of this sample being fake/real. self.softmax = nn.Softmax(dim=-1) def forward(self, input_rep): input_rep = self.input_dropout(input_rep) last_rep = self.layers(input_rep) logits = self.logit(last_rep) probs = self.softmax(logits) return last_rep, logits, probs dConfig = AutoConfig.from_pretrained("bert-base-uncased") hidden_size = int(dConfig.hidden_size) num_hidden_layers_d = 2; hidden_levels_d = [hidden_size for i in range(0, num_hidden_layers_d)] label_list = ["1", "0"] label_list.append('UNL') out_dropout_rate = 0.5; discriminator = Discriminator(input_size=hidden_size*2, hidden_sizes=hidden_levels_d, num_labels=len(label_list), dropout_rate=out_dropout_rate).to(device) discriminator_weights = ('data/datasets/discriminator_weights.pth') discriminator.load_state_dict(torch.load(discriminator_weights,map_location=torch.device('cpu'))) discriminator.eval() def get_prediction(embeddings): with torch.no_grad(): # Forward pass through the discriminator to get the logits and probabilities last_rep, logits, probs = discriminator(embeddings) # Filter logits to ignore the last dimension (assuming you only care about the first two) filtered_logits = logits[:, 0:-1] # Get the predicted labels using the filtered logits _, predicted_labels = torch.max(filtered_logits, dim=-1) # Convert to numpy array if needed predicted_labels = predicted_labels.cpu().numpy() return predicted_labels # Function to read reviews from a text file def load_reviews_from_file(file_path): reviews = {} try: with open(file_path, 'r', encoding='utf-8') as file: for i, line in enumerate(file): line = line.strip() if line: # Ensure the line is not empty reviews[f"Review {i + 1}"] = line except FileNotFoundError: st.error(f"File not found: {file_path}") return reviews #Demo Section st.title("SpaGAN Demo") st.write("This demo lets you explore a curated list of sample reviews, containing a real and fake example.") st.write("Upon selecting a review, any identified geo-entities will be color coded for easy visualization.") st.write("For each geo-entity found, the model will generate a contextual pseudo-sentnece, highlighting its closest neighbors from our dataset.") st.write("Finally, the entire review is embedded and enriched with spatial embeddings, enabling the model to determine whether the review is authentic or fake.") # Define a color map and descriptions for different entity types COLOR_MAP = { 'FAC': ('red', 'Facilities (e.g., buildings, airports)'), 'ORG': ('blue', 'Organizations (e.g., companies, institutions)'), 'LOC': ('purple', 'Locations (e.g., mountain ranges, water bodies)'), 'GPE': ('green', 'Geopolitical Entities (e.g., countries, cities)') } # Display the color key st.write("**Color Key:**") for label, (color, description) in COLOR_MAP.items(): st.markdown(f"- **{label}**: {color} - {description}", unsafe_allow_html=True) review_file_path = "models/spabert/datasets/SampleReviews.txt" example_reviews = load_reviews_from_file(review_file_path) # Define labels review_labels = { "Review 1": "Real", "Review 2": "Spam", } # Create options with labels for the dropdown dropdown_options = [f"{key} ({review_labels.get(key, 'Unknown')})" for key in example_reviews.keys()] # Dropdown for selecting an example review user_selection = st.selectbox("Select an example review", options=dropdown_options) # Extract the original review key from the selected option selected_key = user_selection.split(" (")[0] # Remove the label part selected_review = example_reviews[selected_key] lower_case_review = selected_review.lower() #Optional textbox for interactivity user_input_review = st.text_area("Or type your own review here","") review_to_process = user_input_review if user_input_review.strip() else selected_review lower_case_review = review_to_process.lower() # Process the text when the button is clicked if st.button("Process Review"): if lower_case_review.strip(): bert_embedding = get_bert_embedding(lower_case_review) spaBert_embedding, current_pseudo_sentences = processSpatialEntities(review_to_process,nlp) combined_embedding = torch.cat((bert_embedding,spaBert_embedding),dim=-1) if(dev_mode == True): st.write("Review Embedding Shape:", bert_embedding.shape) st.write("Geo-Entities embedding shape: ", spaBert_embedding.shape) st.write("Concatenated Embedding Shape:", combined_embedding.shape) st.write("Concatenated Embedding:", combined_embedding) prediction = get_prediction(combined_embedding) # Process the text using spaCy doc = nlp(selected_review) # Highlight geo-entities with different colors highlighted_text = review_to_process for ent in reversed(doc.ents): if ent.label_ in COLOR_MAP: color = COLOR_MAP[ent.label_][0] highlighted_text = ( highlighted_text[:ent.start_char] + f"{ent.text}" + highlighted_text[ent.end_char:] ) # Display the highlighted text with HTML support st.markdown(highlighted_text, unsafe_allow_html=True) #Display pseudo sentences found for sentence in current_pseudo_sentences: clean_sentence = sentence.replace("[PAD]", "").strip() st.write("Pseudo-Sentence:", clean_sentence) #Display the models prediction if prediction == 0: st.markdown("

✅ Prediction: Not Spam

", unsafe_allow_html=True) elif prediction == 1: st.markdown("

❌ Prediction: Spam

", unsafe_allow_html=True) else: st.markdown("

⚠️ Error during prediction

", unsafe_allow_html=True) else: st.error("Please select a review.")