import streamlit as st import spacy import torch import torch.nn as nn from transformers import BertTokenizer, BertModel, AutoConfig from transformers.models.bert.modeling_bert import BertForMaskedLM from models.spabert.models.spatial_bert_model import SpatialBertConfig, SpatialBertForMaskedLM, SpatialBertModel from models.spabert.utils.common_utils import load_spatial_bert_pretrained_weights from models.spabert.datasets.osm_sample_loader import PbfMapDataset from torch.utils.data import DataLoader from PIL import Image device = torch.device('cpu') dev_mode = False #Spacy Initialization Section nlp = spacy.load("./models/en_core_web_sm") #BERT Initialization Section bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") bert_model = BertModel.from_pretrained("bert-base-uncased") bert_model.to(device) bert_model.eval() #SpaBERT Initialization Section data_file_path = 'models/spabert/datasets/SpaBERTPivots.json' #Sample file otherwise this model will take too long on CPU. pretrained_model_path = 'models/spabert/datasets/fine-spabert-base-uncased-finetuned-osm-mn.pth' config = SpatialBertConfig() config.output_hidden_states = True spaBERT_model = SpatialBertForMaskedLM(config) pre_trained_model = torch.load(pretrained_model_path, map_location=torch.device('cpu')) spaBERT_model.load_state_dict(bert_model.state_dict(), strict = False) spaBERT_model.load_state_dict(pre_trained_model, strict=False) spaBERT_model.to(device) spaBERT_model.eval() #Load data using SpatialDataset spatialDataset = PbfMapDataset(data_file_path = data_file_path, tokenizer = bert_tokenizer, max_token_len = 256, #Originally 300 #max_token_len = max_seq_length, #Originally 300 distance_norm_factor = 0.0001, spatial_dist_fill = 20, with_type = False, sep_between_neighbors = True, label_encoder = None, mode = None) #If set to None it will use the full dataset for mlm data_loader = DataLoader(spatialDataset, batch_size=1, num_workers=0, shuffle=False, pin_memory=False, drop_last=False) # Create a dictionary to map entity names to indices entity_index_dict = {entity['pivot_name']: i for i, entity in enumerate(spatialDataset)} # Ensure names are stored in lowercase for case-insensitive matching entity_index_dict = {name.lower(): index for name, index in entity_index_dict.items()} #Pre-aquire the SpaBERT embeddings for all geo-entities within our dataset def process_entity(batch, model, device): input_ids = batch['masked_input'].to(device) attention_mask = batch['attention_mask'].to(device) position_list_x = batch['norm_lng_list'].to(device) position_list_y = batch['norm_lat_list'].to(device) sent_position_ids = batch['sent_position_ids'].to(device) pseudo_sentence = batch['pseudo_sentence'].to(device) # Convert tensor to list of token IDs, and decode them into a readable sentence pseudo_sentence_decoded = bert_tokenizer.decode(pseudo_sentence[0].tolist(), skip_special_tokens=False) with torch.no_grad(): outputs = spaBERT_model(#input_ids=input_ids, input_ids=pseudo_sentence, attention_mask=attention_mask, sent_position_ids=sent_position_ids, position_list_x=position_list_x, position_list_y=position_list_y) #NOTE: we are ommitting the pseudo_sentence here. Verify that this is correct spaBERT_embedding = outputs.hidden_states[-1].to(device) # Extract the [CLS] token embedding (first token) spaBERT_embedding = spaBERT_embedding[:, 0, :].detach() # [batch_size, hidden_size] #return pivot_embeddings.cpu().numpy(), input_ids.cpu().numpy() return spaBERT_embedding, input_ids, pseudo_sentence_decoded spaBERT_embeddings = [] pseudo_sentences = [] for batch in (data_loader): spaBERT_embedding, input_ids, pseudo_sentence = process_entity(batch, spaBERT_model, device) spaBERT_embeddings.append(spaBERT_embedding) pseudo_sentences.append(pseudo_sentence) embedding_cache = {} #Get BERT Embedding for review def get_bert_embedding(review_text): #tokenize review inputs = bert_tokenizer(review_text, return_tensors='pt', padding=True, truncation=True).to(device) # Forward pass through the BERT model with torch.no_grad(): outputs = bert_model(**inputs) # Extract embeddings from the last hidden state bert_embedding = outputs.last_hidden_state[:, 0, :].detach() #CLS Token return bert_embedding #Get SpaBERT Embedding for geo-entity def get_spaBert_embedding(entity,current_pseudo_sentences): entity_index = entity_index_dict.get(entity.lower(), None) if entity_index is None: if(dev_mode == True): st.write("Got Bert embedding for: ", entity) return get_bert_embedding(entity) #Fallback in-case SpaBERT could not resolve entity to retrieve embedding. Rare-cases only. else: current_pseudo_sentences.append(pseudo_sentences[entity_index]) if(dev_mode == True): st.write("Got SpaBert embedding for: ", entity) return spaBERT_embeddings[entity_index] #Go through each review, identify all geo-entities, then extract their SpaBERT embedings def processSpatialEntities(review, nlp): doc = nlp(review) entity_spans = [(ent.start, ent.end, ent.text, ent.label_) for ent in doc.ents] token_embeddings = [] current_pseudo_sentences = [] # Iterate over each entity span and process only geo entities for start, end, text, label in entity_spans: if label in ['FAC', 'ORG', 'LOC', 'GPE']: # Filter to geo-entities spaBert_emb = get_spaBert_embedding(text,current_pseudo_sentences) token_embeddings.append(spaBert_emb) if(dev_mode == True): st.write("Geo-Entity Found in review: ", text) token_embeddings = torch.stack(token_embeddings, dim=0) processed_embedding = token_embeddings.mean(dim=0) # Shape: (768) #processed_embedding = processed_embedding.unsqueeze(0) # Shape: (1, 768) return processed_embedding,current_pseudo_sentences #Initialize discriminator module class Discriminator(nn.Module): def __init__(self, input_size=512, hidden_sizes=[512], num_labels=2, dropout_rate=0.1): super(Discriminator, self).__init__() self.input_dropout = nn.Dropout(p=dropout_rate) layers = [] hidden_sizes = [input_size] + hidden_sizes for i in range(len(hidden_sizes)-1): layers.extend([nn.Linear(hidden_sizes[i], hidden_sizes[i+1]), nn.LeakyReLU(0.2, inplace=True), nn.Dropout(dropout_rate)]) self.layers = nn.Sequential(*layers) #per il flatten self.logit = nn.Linear(hidden_sizes[-1],num_labels+1) # +1 for the probability of this sample being fake/real. self.softmax = nn.Softmax(dim=-1) def forward(self, input_rep): input_rep = self.input_dropout(input_rep) last_rep = self.layers(input_rep) logits = self.logit(last_rep) probs = self.softmax(logits) return last_rep, logits, probs dConfig = AutoConfig.from_pretrained("bert-base-uncased") hidden_size = int(dConfig.hidden_size) num_hidden_layers_d = 2; hidden_levels_d = [hidden_size for i in range(0, num_hidden_layers_d)] label_list = ["1", "0"] label_list.append('UNL') out_dropout_rate = 0.5; discriminator = Discriminator(input_size=hidden_size*2, hidden_sizes=hidden_levels_d, num_labels=len(label_list), dropout_rate=out_dropout_rate).to(device) discriminator_weights = ('data/datasets/discriminator_weights.pth') discriminator.load_state_dict(torch.load(discriminator_weights,map_location=torch.device('cpu'))) discriminator.eval() def get_prediction(embeddings): with torch.no_grad(): # Forward pass through the discriminator to get the logits and probabilities last_rep, logits, probs = discriminator(embeddings) # Filter logits to ignore the last dimension (assuming you only care about the first two) filtered_logits = logits[:, 0:-1] # Get the predicted labels using the filtered logits _, predicted_labels = torch.max(filtered_logits, dim=-1) # Convert to numpy array if needed predicted_labels = predicted_labels.cpu().numpy() return predicted_labels # Function to read reviews from a text file def load_reviews_from_file(file_path): reviews = {} try: with open(file_path, 'r', encoding='utf-8') as file: for i, line in enumerate(file): line = line.strip() if line: # Ensure the line is not empty reviews[f"Review {i + 1}"] = line except FileNotFoundError: st.error(f"File not found: {file_path}") return reviews #Demo Section st.title("SpaGAN Demo: IN PROGRESS") st.write("This demo allows you to select from a list of sample reviews that contrain both real and fake reviews.") st.write("Once selected, the identified geo-entities within the review will be color coded and displayed along with the review.") st.write("The model will construct a pseudo-sentence for each entity within the review, contextualizing each geo-entities closest neighbors from our dataset.") st.write("Finally, the entire review will be embedded and combined with the spatial embeddings and the model will determine whether the review is real or fake.") # Define a color map and descriptions for different entity types COLOR_MAP = { 'FAC': ('red', 'Facilities (e.g., buildings, airports)'), 'ORG': ('blue', 'Organizations (e.g., companies, institutions)'), 'LOC': ('purple', 'Locations (e.g., mountain ranges, water bodies)'), 'GPE': ('green', 'Geopolitical Entities (e.g., countries, cities)') } # Display the color key st.write("**Color Key:**") for label, (color, description) in COLOR_MAP.items(): st.markdown(f"- **{label}**: {color} - {description}", unsafe_allow_html=True) review_file_path = "models/spabert/datasets/SampleReviews.txt" example_reviews = load_reviews_from_file(review_file_path) # Define labels review_labels = { "Review 1": "Real", "Review 2": "Spam", } # Create options with labels for the dropdown dropdown_options = [f"{key} ({review_labels.get(key, 'Unknown')})" for key in example_reviews.keys()] # Dropdown for selecting an example review user_selection = st.selectbox("Select an example review", options=dropdown_options) # Extract the original review key from the selected option selected_key = user_selection.split(" (")[0] # Remove the label part selected_review = example_reviews[selected_key] lower_case_review = selected_review.lower() # Process the text when the button is clicked if st.button("Process Review"): if lower_case_review.strip(): bert_embedding = get_bert_embedding(lower_case_review) spaBert_embedding, current_pseudo_sentences = processSpatialEntities(selected_review,nlp) combined_embedding = torch.cat((bert_embedding,spaBert_embedding),dim=-1) if(dev_mode == True): st.write("Review Embedding Shape:", bert_embedding.shape) st.write("Geo-Entities embedding shape: ", spaBert_embedding.shape) st.write("Concatenated Embedding Shape:", combined_embedding.shape) st.write("Concatenated Embedding:", combined_embedding) prediction = get_prediction(combined_embedding) # Process the text using spaCy doc = nlp(selected_review) # Highlight geo-entities with different colors highlighted_text = selected_review for ent in reversed(doc.ents): if ent.label_ in COLOR_MAP: color = COLOR_MAP[ent.label_][0] highlighted_text = ( highlighted_text[:ent.start_char] + f"{ent.text}" + highlighted_text[ent.end_char:] ) # Display the highlighted text with HTML support st.markdown(highlighted_text, unsafe_allow_html=True) #Display pseudo sentences found for sentence in current_pseudo_sentences: clean_sentence = sentence.replace("[PAD]", "").strip() st.write("Pseudo-Sentence:", clean_sentence) #Display the models prediction if(prediction == 0): st.write("Prediction: Not Spam") elif(prediction == 1): st.write("Prediction: Spam") else: st.write("error during prediction") else: st.error("Please select a review.")