import streamlit as st
import spacy
import torch
import torch.nn as nn
import pandas as pd
from transformers import BertTokenizer, BertModel, AutoConfig
from transformers.models.bert.modeling_bert import BertForMaskedLM

from models.spabert.models.spatial_bert_model import SpatialBertConfig, SpatialBertForMaskedLM, SpatialBertModel
from models.spabert.utils.common_utils import load_spatial_bert_pretrained_weights
from models.spabert.datasets.osm_sample_loader import PbfMapDataset
from import DataLoader
device = torch.device('cpu')
#Spacy Initialization Section
nlp = spacy.load("./models/en_core_web_sm")

#BERT Initialization Section
bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
#SpaBERT Initialization Section
data_file_path = 'models/spabert/datasets/SpaBERTPivots.json'    #Sample file otherwise this model will take too long on CPU.
config = SpatialBertConfig()
config.output_hidden_states = True
pre_trained_model = torch.load(pretrained_model_path, map_location=torch.device('cpu'))
spaBERT_model.load_state_dict(bert_model.state_dict(), strict = False)
#Load data using SpatialDataset
spatialDataset = PbfMapDataset(data_file_path = data_file_path,
                                        tokenizer = bert_tokenizer,
                                        max_token_len = 256,                          #Originally 300
                                        #max_token_len = max_seq_length,              #Originally 300
                                        distance_norm_factor = 0.0001,
                                        spatial_dist_fill = 20,
                                        with_type = False,
                                        sep_between_neighbors = True,    
                                        label_encoder = None,             
# Create a dictionary to map entity names to indices
entity_index_dict = {entity['pivot_name']: i for i, entity in enumerate(spatialDataset)}

# Ensure names are stored in lowercase for case-insensitive matching
entity_index_dict = {name.lower(): index for name, index in entity_index_dict.items()}

#Pre-aquire the SpaBERT embeddings for all geo-entities within our dataset
def process_entity(batch, model, device):
    input_ids = batch['masked_input'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    position_list_x = batch['norm_lng_list'].to(device)
    position_list_y = batch['norm_lat_list'].to(device)
    sent_position_ids = batch['sent_position_ids'].to(device)
    # Convert tensor to list of token IDs, and decode them into a readable sentence
    pseudo_sentence_decoded = bert_tokenizer.decode(pseudo_sentence[0].tolist(), skip_special_tokens=False)

    with torch.no_grad():
        outputs = spaBERT_model(#input_ids=input_ids,
                        #NOTE: we are ommitting the pseudo_sentence here. Verify that this is correct

    spaBERT_embedding = outputs.hidden_states[-1].to(device)

    # Extract the [CLS] token embedding (first token)
    spaBERT_embedding = spaBERT_embedding[:, 0, :].detach()  # [batch_size, hidden_size]

    #return pivot_embeddings.cpu().numpy(), input_ids.cpu().numpy()
spaBERT_embeddings = []
pseudo_sentences = []
for batch in (data_loader):
    spaBERT_embedding, input_ids, pseudo_sentence = process_entity(batch, spaBERT_model, device)
#Get BERT Embedding for review
def get_bert_embedding(review_text):
    #tokenize review
    inputs = bert_tokenizer(review_text, return_tensors='pt', padding=True, truncation=True).to(device)
    # Forward pass through the BERT model
    with torch.no_grad():
    # Extract embeddings from the last hidden state
    bert_embedding = outputs.last_hidden_state[:, 0, :].detach()     #CLS Token
#Get SpaBERT Embedding for geo-entity
def get_spaBert_embedding(entity,current_pseudo_sentences):
    entity_index = entity_index_dict.get(entity.lower(), None)
    if entity_index is None:
        if(dev_mode == True):
            st.write("Got Bert embedding for: ", entity)
        return get_bert_embedding(entity)                            #Fallback in-case SpaBERT could not resolve entity to retrieve embedding. Rare-cases only.
        if(dev_mode == True):
            st.write("Got SpaBert embedding for: ", entity)
#Go through each review, identify all geo-entities, then extract their SpaBERT embedings
def processSpatialEntities(review, nlp):
    doc = nlp(review)
    entity_spans = [(ent.start, ent.end, ent.text, ent.label_) for ent in doc.ents]
    token_embeddings = []
    # Iterate over each entity span and process only geo entities
    for start, end, text, label in entity_spans:
        if label in ['FAC', 'ORG', 'LOC', 'GPE']:  # Filter to geo-entities
            if(dev_mode == True):
                st.write("Text found:", text)
            spaBert_emb = get_spaBert_embedding(text,current_pseudo_sentences)
            if(dev_mode == True):
                st.write("Geo-Entity Found in review: ", text)
    # Handle the case where no geo-entities are found
    if not token_embeddings:
        st.warning("No geo-entities found in the review. Please include one from the list.")
        # Return a zero vector as a fallback if no entities are found
        return torch.zeros(bert_model.config.hidden_size), []
    token_embeddings = torch.stack(token_embeddings, dim=0)
    processed_embedding = token_embeddings.mean(dim=0)  # Shape: (768)
    #processed_embedding = processed_embedding.unsqueeze(0)  # Shape: (1, 768)
#Initialize discriminator module
class Discriminator(nn.Module):
    def __init__(self, input_size=512, hidden_sizes=[512], num_labels=2, dropout_rate=0.1):
        super(Discriminator, self).__init__()
        self.input_dropout = nn.Dropout(p=dropout_rate)
        layers = []
        hidden_sizes = [input_size] + hidden_sizes
        for i in range(len(hidden_sizes)-1):
        self.layers = nn.Sequential(*layers) #per il flatten
        self.logit = nn.Linear(hidden_sizes[-1],num_labels+1) # +1 for the probability of this sample being fake/real.
    def forward(self, input_rep):
        input_rep = self.input_dropout(input_rep)
        last_rep = self.layers(input_rep)
        logits = self.logit(last_rep)
        probs = self.softmax(logits)
        return last_rep, logits, probs
dConfig = AutoConfig.from_pretrained("bert-base-uncased")
hidden_size = int(dConfig.hidden_size)
num_hidden_layers_d = 2; 
hidden_levels_d = [hidden_size for i in range(0, num_hidden_layers_d)]
label_list = ["1", "0"]
discriminator = Discriminator(input_size=hidden_size*2, hidden_sizes=hidden_levels_d, num_labels=len(label_list), dropout_rate=out_dropout_rate).to(device)
def get_prediction(embeddings):
    with torch.no_grad():
        # Forward pass through the discriminator to get the logits and probabilities
        last_rep, logits, probs = discriminator(embeddings)
        # Filter logits to ignore the last dimension (assuming you only care about the first two)
        # Get the predicted labels using the filtered logits
        _, predicted_labels = torch.max(filtered_logits, dim=-1)

    # Convert to numpy array if needed
    predicted_labels = predicted_labels.cpu().numpy()
# Function to read reviews from a text file
def load_reviews_from_file(file_path):
    reviews = {}
        try:
            for i, line in enumerate(file):
                line = line.strip()
                if line:  # Ensure the line is not empty
                    reviews[f"Review {i + 1}"] = line
    except FileNotFoundError:
        st.error(f"File not found: {file_path}")
#Demo Section
st.title("SpaGAN Demo")
st.write("This demo lets you explore a curated list of sample reviews, containing a real and fake example.")
st.write("Upon selecting a review, any identified geo-entities will be color coded for easy visualization.")
st.write("For each geo-entity found, the model will generate a contextual pseudo-sentnece, highlighting its closest neighbors from our dataset.")
# Define a color map and descriptions for different entity types
    'FAC': ('red', 'Facilities (e.g., buildings, airports)'),
    'ORG': ('blue', 'Organizations (e.g., companies, institutions)'),
    'LOC': ('purple', 'Locations (e.g., mountain ranges, water bodies)'),
# Display the color key
st.write("**Color Key:**")
for label, (color, description) in COLOR_MAP.items():
review_file_path = "models/spabert/datasets/SampleReviews.txt"
# Define labels
review_labels = {
    "Review 1": "Real",
# Create options with labels for the dropdown
# Dropdown for selecting an example review
# Extract the original review key from the selected option
selected_key = user_selection.split(" (")[0]  # Remove the label part
selected_review = example_reviews[selected_key]
lower_case_review = selected_review.lower()

#Optional textbox for interactivity
california_entities = [
    "Los Angeles", "San Francisco", "San Diego", "Yosemite National Park", 
    "Lake Tahoe", "Napa Valley", "Big Sur", "Death Valley", "Monterey Bay",
    "Joshua Tree National Park", "Santa Monica", "Mammoth Lakes", "Palm Springs",
    "Sequoia National Park", "Redwood National Park", "San Jose", "Anaheim",
    "Berkeley", "Fresno", "Glendale", "Laguna Beach", "Pasadena",
user_input_review = st.text_area("Or type your own review here","")
st.info(f

review_to_process = user_input_review if user_input_review.strip() else selected_review
#st.write("Selected Review: ", review_to_process)
# Process the text when the button is clicked
if st.button("Process Review"):
    if lower_case_review.strip():
        bert_embedding = get_bert_embedding(lower_case_review)
        spaBert_embedding, current_pseudo_sentences = processSpatialEntities(review_to_process,nlp)
        # Check if SpaBERT embedding is valid
        if spaBert_embedding is None or spaBert_embedding.sum() == 0:
            st.error("Unable to process the review. Please include at least one valid geo-entity.")
            else:
            combined_embedding =,spaBert_embedding),dim=-1)
            if(dev_mode == True):
                st.write("Review Embedding Shape:", bert_embedding.shape)
                st.write("Geo-Entities embedding shape: ", spaBert_embedding.shape)
                st.write("Concatenated Embedding Shape:", combined_embedding.shape)
                st.write("Concatenated Embedding:", combined_embedding)
            prediction = get_prediction(combined_embedding)
            # Process the text using spaCy
            doc = nlp(review_to_process)
            # Highlight geo-entities with different colors
            highlighted_text = review_to_process
            for ent in reversed(doc.ents):
                if ent.label_ in COLOR_MAP:
                    color = COLOR_MAP[ent.label_][0]
                    highlighted_text = (
                        highlighted_text[:ent.start_char] +
                        f"<span style='color:{color}; font-weight:bold'>{ent.text}</span>" + 
            # Display the highlighted text with HTML support
            st.markdown(highlighted_text, unsafe_allow_html=True)
            #Display pseudo sentences found
            for sentence in current_pseudo_sentences:
                clean_sentence = sentence.replace("[PAD]", "").strip()
                st.write("Pseudo-Sentence:", clean_sentence)
            #Display the models prediction
            if prediction == 0:
                st.markdown("<h3 style='color:green;'>✅ Prediction: Not Spam</h3>", unsafe_allow_html=True)
            elif prediction == 1:
                st.markdown("<h3 style='color:red;'>❌ Prediction: Spam</h3>", unsafe_allow_html=True)
                else:
                st.markdown("<h3 style='color:orange;'>⚠️ Error during prediction</h3>", unsafe_allow_html=True)
    else:
        st.error("Please select a review.")