Spaces:
Sleeping
Sleeping
import streamlit as st | |
import spacy | |
import torch | |
import torch.nn as nn | |
import pandas as pd | |
from transformers import BertTokenizer, BertModel, AutoConfig | |
from transformers.models.bert.modeling_bert import BertForMaskedLM | |
from models.spabert.models.spatial_bert_model import SpatialBertConfig, SpatialBertForMaskedLM, SpatialBertModel | |
from models.spabert.utils.common_utils import load_spatial_bert_pretrained_weights | |
from models.spabert.datasets.osm_sample_loader import PbfMapDataset | |
from torch.utils.data import DataLoader | |
from PIL import Image | |
device = torch.device('cpu') | |
dev_mode = True | |
#Spacy Initialization Section | |
nlp = spacy.load("./models/en_core_web_sm") | |
#BERT Initialization Section | |
bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") | |
bert_model = BertModel.from_pretrained("bert-base-uncased") | |
bert_model.to(device) | |
bert_model.eval() | |
#SpaBERT Initialization Section | |
data_file_path = 'models/spabert/datasets/SpaBERTPivots.json' #Sample file otherwise this model will take too long on CPU. | |
pretrained_model_path = 'models/spabert/datasets/fine-spabert-base-uncased-finetuned-osm-mn.pth' | |
config = SpatialBertConfig() | |
config.output_hidden_states = True | |
spaBERT_model = SpatialBertForMaskedLM(config) | |
pre_trained_model = torch.load(pretrained_model_path, map_location=torch.device('cpu')) | |
spaBERT_model.load_state_dict(bert_model.state_dict(), strict = False) | |
spaBERT_model.load_state_dict(pre_trained_model, strict=False) | |
spaBERT_model.to(device) | |
spaBERT_model.eval() | |
#Load data using SpatialDataset | |
spatialDataset = PbfMapDataset(data_file_path = data_file_path, | |
tokenizer = bert_tokenizer, | |
max_token_len = 256, #Originally 300 | |
#max_token_len = max_seq_length, #Originally 300 | |
distance_norm_factor = 0.0001, | |
spatial_dist_fill = 20, | |
with_type = False, | |
sep_between_neighbors = True, | |
label_encoder = None, | |
mode = None) #If set to None it will use the full dataset for mlm | |
data_loader = DataLoader(spatialDataset, batch_size=1, num_workers=0, shuffle=False, pin_memory=False, drop_last=False) | |
# Create a dictionary to map entity names to indices | |
entity_index_dict = {entity['pivot_name']: i for i, entity in enumerate(spatialDataset)} | |
# Ensure names are stored in lowercase for case-insensitive matching | |
entity_index_dict = {name.lower(): index for name, index in entity_index_dict.items()} | |
#Pre-aquire the SpaBERT embeddings for all geo-entities within our dataset | |
def process_entity(batch, model, device): | |
input_ids = batch['masked_input'].to(device) | |
attention_mask = batch['attention_mask'].to(device) | |
position_list_x = batch['norm_lng_list'].to(device) | |
position_list_y = batch['norm_lat_list'].to(device) | |
sent_position_ids = batch['sent_position_ids'].to(device) | |
pseudo_sentence = batch['pseudo_sentence'].to(device) | |
# Convert tensor to list of token IDs, and decode them into a readable sentence | |
pseudo_sentence_decoded = bert_tokenizer.decode(pseudo_sentence[0].tolist(), skip_special_tokens=False) | |
with torch.no_grad(): | |
outputs = spaBERT_model(#input_ids=input_ids, | |
input_ids=pseudo_sentence, | |
attention_mask=attention_mask, | |
sent_position_ids=sent_position_ids, | |
position_list_x=position_list_x, | |
position_list_y=position_list_y) | |
#NOTE: we are ommitting the pseudo_sentence here. Verify that this is correct | |
spaBERT_embedding = outputs.hidden_states[-1].to(device) | |
# Extract the [CLS] token embedding (first token) | |
spaBERT_embedding = spaBERT_embedding[:, 0, :].detach() # [batch_size, hidden_size] | |
#return pivot_embeddings.cpu().numpy(), input_ids.cpu().numpy() | |
return spaBERT_embedding, input_ids, pseudo_sentence_decoded | |
spaBERT_embeddings = [] | |
pseudo_sentences = [] | |
for batch in (data_loader): | |
spaBERT_embedding, input_ids, pseudo_sentence = process_entity(batch, spaBERT_model, device) | |
spaBERT_embeddings.append(spaBERT_embedding) | |
pseudo_sentences.append(pseudo_sentence) | |
embedding_cache = {} | |
#Get BERT Embedding for review | |
def get_bert_embedding(review_text): | |
#tokenize review | |
inputs = bert_tokenizer(review_text, return_tensors='pt', padding=True, truncation=True).to(device) | |
# Forward pass through the BERT model | |
with torch.no_grad(): | |
outputs = bert_model(**inputs) | |
# Extract embeddings from the last hidden state | |
bert_embedding = outputs.last_hidden_state[:, 0, :].detach() #CLS Token | |
return bert_embedding | |
#Get SpaBERT Embedding for geo-entity | |
def get_spaBert_embedding(entity,current_pseudo_sentences): | |
entity_index = entity_index_dict.get(entity.lower(), None) | |
if entity_index is None: | |
if(dev_mode == True): | |
st.write("Got Bert embedding for: ", entity) | |
return get_bert_embedding(entity) #Fallback in-case SpaBERT could not resolve entity to retrieve embedding. Rare-cases only. | |
else: | |
current_pseudo_sentences.append(pseudo_sentences[entity_index]) | |
if(dev_mode == True): | |
st.write("Got SpaBert embedding for: ", entity) | |
return spaBERT_embeddings[entity_index] | |
#Go through each review, identify all geo-entities, then extract their SpaBERT embedings | |
def processSpatialEntities(review, nlp): | |
doc = nlp(review) | |
entity_spans = [(ent.start, ent.end, ent.text, ent.label_) for ent in doc.ents] | |
token_embeddings = [] | |
current_pseudo_sentences = [] | |
# Iterate over each entity span and process only geo entities | |
for start, end, text, label in entity_spans: | |
if label in ['FAC', 'ORG', 'LOC', 'GPE']: # Filter to geo-entities | |
if(dev_mode == True): | |
st.write("Text found:", text) | |
spaBert_emb = get_spaBert_embedding(text,current_pseudo_sentences) | |
token_embeddings.append(spaBert_emb) | |
if(dev_mode == True): | |
st.write("Geo-Entity Found in review: ", text) | |
token_embeddings = torch.stack(token_embeddings, dim=0) | |
processed_embedding = token_embeddings.mean(dim=0) # Shape: (768) | |
#processed_embedding = processed_embedding.unsqueeze(0) # Shape: (1, 768) | |
return processed_embedding,current_pseudo_sentences | |
#Initialize discriminator module | |
class Discriminator(nn.Module): | |
def __init__(self, input_size=512, hidden_sizes=[512], num_labels=2, dropout_rate=0.1): | |
super(Discriminator, self).__init__() | |
self.input_dropout = nn.Dropout(p=dropout_rate) | |
layers = [] | |
hidden_sizes = [input_size] + hidden_sizes | |
for i in range(len(hidden_sizes)-1): | |
layers.extend([nn.Linear(hidden_sizes[i], hidden_sizes[i+1]), nn.LeakyReLU(0.2, inplace=True), nn.Dropout(dropout_rate)]) | |
self.layers = nn.Sequential(*layers) #per il flatten | |
self.logit = nn.Linear(hidden_sizes[-1],num_labels+1) # +1 for the probability of this sample being fake/real. | |
self.softmax = nn.Softmax(dim=-1) | |
def forward(self, input_rep): | |
input_rep = self.input_dropout(input_rep) | |
last_rep = self.layers(input_rep) | |
logits = self.logit(last_rep) | |
probs = self.softmax(logits) | |
return last_rep, logits, probs | |
dConfig = AutoConfig.from_pretrained("bert-base-uncased") | |
hidden_size = int(dConfig.hidden_size) | |
num_hidden_layers_d = 2; | |
hidden_levels_d = [hidden_size for i in range(0, num_hidden_layers_d)] | |
label_list = ["1", "0"] | |
label_list.append('UNL') | |
out_dropout_rate = 0.5; | |
discriminator = Discriminator(input_size=hidden_size*2, hidden_sizes=hidden_levels_d, num_labels=len(label_list), dropout_rate=out_dropout_rate).to(device) | |
discriminator_weights = ('data/datasets/discriminator_weights.pth') | |
discriminator.load_state_dict(torch.load(discriminator_weights,map_location=torch.device('cpu'))) | |
discriminator.eval() | |
def get_prediction(embeddings): | |
with torch.no_grad(): | |
# Forward pass through the discriminator to get the logits and probabilities | |
last_rep, logits, probs = discriminator(embeddings) | |
# Filter logits to ignore the last dimension (assuming you only care about the first two) | |
filtered_logits = logits[:, 0:-1] | |
# Get the predicted labels using the filtered logits | |
_, predicted_labels = torch.max(filtered_logits, dim=-1) | |
# Convert to numpy array if needed | |
predicted_labels = predicted_labels.cpu().numpy() | |
return predicted_labels | |
# Function to read reviews from a text file | |
def load_reviews_from_file(file_path): | |
reviews = {} | |
try: | |
with open(file_path, 'r', encoding='utf-8') as file: | |
for i, line in enumerate(file): | |
line = line.strip() | |
if line: # Ensure the line is not empty | |
reviews[f"Review {i + 1}"] = line | |
except FileNotFoundError: | |
st.error(f"File not found: {file_path}") | |
return reviews | |
#Demo Section | |
st.title("SpaGAN Demo: IN PROGRESS") | |
st.write("This demo allows you to select from a list of sample reviews that contrain both real and fake reviews.") | |
st.write("Once selected, the identified geo-entities within the review will be color coded and displayed along with the review.") | |
st.write("The model will construct a pseudo-sentence for each entity within the review, contextualizing each geo-entities closest neighbors from our dataset.") | |
st.write("Finally, the entire review will be embedded and combined with the spatial embeddings and the model will determine whether the review is real or fake.") | |
# Define a color map and descriptions for different entity types | |
COLOR_MAP = { | |
'FAC': ('red', 'Facilities (e.g., buildings, airports)'), | |
'ORG': ('blue', 'Organizations (e.g., companies, institutions)'), | |
'LOC': ('purple', 'Locations (e.g., mountain ranges, water bodies)'), | |
'GPE': ('green', 'Geopolitical Entities (e.g., countries, cities)') | |
} | |
# Display the color key | |
st.write("**Color Key:**") | |
for label, (color, description) in COLOR_MAP.items(): | |
st.markdown(f"- **{label}**: <span style='color:{color}'>{color}</span> - {description}", unsafe_allow_html=True) | |
review_file_path = "models/spabert/datasets/SampleReviews.txt" | |
example_reviews = load_reviews_from_file(review_file_path) | |
# Define labels | |
review_labels = { | |
"Review 1": "Real", | |
"Review 2": "Spam", | |
} | |
# Create options with labels for the dropdown | |
dropdown_options = [f"{key} ({review_labels.get(key, 'Unknown')})" for key in example_reviews.keys()] | |
# Dropdown for selecting an example review | |
user_selection = st.selectbox("Select an example review", options=dropdown_options) | |
# Extract the original review key from the selected option | |
selected_key = user_selection.split(" (")[0] # Remove the label part | |
selected_review = example_reviews[selected_key] | |
lower_case_review = selected_review.lower() | |
# Process the text when the button is clicked | |
if st.button("Process Review"): | |
if lower_case_review.strip(): | |
bert_embedding = get_bert_embedding(lower_case_review) | |
spaBert_embedding, current_pseudo_sentences = processSpatialEntities(selected_review,nlp) | |
combined_embedding = torch.cat((bert_embedding,spaBert_embedding),dim=-1) | |
if(dev_mode == True): | |
st.write("Review Embedding Shape:", bert_embedding.shape) | |
st.write("Geo-Entities embedding shape: ", spaBert_embedding.shape) | |
st.write("Concatenated Embedding Shape:", combined_embedding.shape) | |
st.write("Concatenated Embedding:", combined_embedding) | |
prediction = get_prediction(combined_embedding) | |
# Process the text using spaCy | |
doc = nlp(selected_review) | |
# Highlight geo-entities with different colors | |
highlighted_text = selected_review | |
for ent in reversed(doc.ents): | |
if ent.label_ in COLOR_MAP: | |
color = COLOR_MAP[ent.label_][0] | |
highlighted_text = ( | |
highlighted_text[:ent.start_char] + | |
f"<span style='color:{color}; font-weight:bold'>{ent.text}</span>" + | |
highlighted_text[ent.end_char:] | |
) | |
# Display the highlighted text with HTML support | |
st.markdown(highlighted_text, unsafe_allow_html=True) | |
#Display pseudo sentences found | |
for sentence in current_pseudo_sentences: | |
clean_sentence = sentence.replace("[PAD]", "").strip() | |
st.write("Pseudo-Sentence:", clean_sentence) | |
#Display the models prediction | |
if(prediction == 0): | |
st.write("Prediction: Not Spam") | |
elif(prediction == 1): | |
st.write("Prediction: Spam") | |
else: | |
st.write("error during prediction") | |
else: | |
st.error("Please select a review.") |