import streamlit as st
from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForTokenClassification
import pandas as pd
from pprint import pprint 


@st.cache_resource()
def load_trained_model():
    
    tokenizer = AutoTokenizer.from_pretrained("LampOfSocrates/bert-cased-plodcw-sourav")
    model = AutoModelForTokenClassification.from_pretrained("LampOfSocrates/bert-cased-plodcw-sourav")
    # Mapping labels
    id2label = model.config.id2label
    # Print the label mapping
    print(f"Can recognise the following labels {id2label}")

    # Load the NER model and tokenizer from Hugging Face
    #ner_pipeline = pipeline("ner", model="dbmdz/bert-large-cased-finetuned-conll03-english")
    ner_pipeline = pipeline("ner", model=model, tokenizer = tokenizer)
    return ner_pipeline


@st.cache_data()
def load_plod_cw_dataset():
    from datasets import load_dataset
    dataset = load_dataset("surrey-nlp/PLOD-CW")
    return dataset

def load_random_examples(dataset_name, num_examples=5):
    """
    Load random examples from the specified Hugging Face dataset.
    Args:
        dataset_name (str): The name of the dataset to load.
        num_examples (int): The number of random examples to load.
    Returns:
        pd.DataFrame: A DataFrame containing the random examples.
    """
    # Load the dataset
    
    dat = load_plod_cw_dataset()

    # Convert the dataset to a pandas DataFrame
    df = pd.DataFrame(dat['test'])
    
    # Select random examples
    random_examples = df.sample(n=1)

    tokens = random_examples.tokens
    ner_tags = random_examples.ner_tags

    return pd.DataFrame((tokens, ner_tags))


def render_entities(tokens, entities):
    """
    Renders a page with a 2-column table showing the entity corresponding to each token.
    """
    
    # Custom CSS for chilled and cool theme
    st.markdown("""
        <style>
        body {
            font-family: 'Arial', sans-serif;
            background-color: #f0f0f5;
            color: #333333;
        }
        table {
            width: 100%;
            border-collapse: collapse;
        }
        th, td {
            padding: 12px;
            text-align: left;
            border-bottom: 1px solid #dddddd;
        }
        th {
            background-color: #4CAF50;
            color: white;
            width: 16.66%;
        }
        tr:hover {
            background-color: #f5f5f5;
        }
        td {
            width: 16.66%;
        }
        </style>
        """, unsafe_allow_html=True)

    # Title and description
    st.title("Model predicted Token vs Entities Table")
    st.write("This table shows the entity corresponding to each token in a cool and chilled theme.")

    # Create the table
    table_data = {"Token": tokens, "Entity": entities}
    st.table(table_data)

def render_random_examples():
    """
    Render random examples from the PLOD-CW dataset in a Streamlit table.
    """
    # Load random examples
    
    # Custom CSS for chilled and cool theme
    st.markdown("""
        <style>
        body {
            font-family: 'Arial', sans-serif;
            background-color: #f0f0f5;
            color: #333333;
        }
        table {
            width: 100%;
            border-collapse: collapse;
        }
        th, td {
            padding: 12px;
            text-align: left;
            border-bottom: 1px solid #dddddd;
        }
        th {
            background-color: #4CAF50;
            color: white;
            width: 16.66%;
        }
        tr:hover {
            background-color: #f5f5f5;
        }
        td {
            width: 16.66%;
        }
        </style>
        """, unsafe_allow_html=True)

    # Title and description
    st.title("Random Examples from PLOD-CW")
    st.write("This table shows 1 random examples from the PLOD-CW dataset in a cool and chilled theme.")

    # Add a button to select a different set of random samples
    if st.button('Show another set of random examples'):
        st.session_state['random_examples'] = load_random_examples("surrey-nlp/PLOD-CW")

    # Load random examples if not already loaded
    if 'random_examples' not in st.session_state:
        st.session_state['random_examples'] = load_random_examples("surrey-nlp/PLOD-CW")

    # Display the table
    st.table(st.session_state['random_examples'])
def predict_using_trained(sentence):
    model = load_trained_model()

    entities = model(sentence)

    return entities

def prep_page():
    model = load_trained_model()

    # Streamlit app
    # Page configuration
    #st.set_page_config(page_title="NER Token Entities", layout="centered")

    st.title("Named Entity Recognition with BERT on PLOD-CW")
    st.write("Enter a sentence to see the named entities recognized by the model.")

    # Text input
    text = st.text_area("Enter your sentence here:")

    # Perform NER and display results
    if text:
        st.write("Entities recognized:")
        entities = model(text)

        pprint(entities)
    
        # Create a dictionary to map entity labels to colors
        label_colors = {
            'B-LF': 'lightblue',
            'B-O': 'lightgreen',
            'B-AC': 'lightcoral',
            'I-LF': 'lightyellow'
        }
    
        # Prepare the HTML output with styled entities
        def get_entity_html(text, entities):
            html = "<div>"
            last_idx = 0
            for entity in entities:
                start = entity['start']
                end = entity['end']
                label = entity['entity']
                entity_text = text[start:end]
                color = label_colors.get(label, 'lightgray')

                # Append the text before the entity
                html += text[last_idx:start].replace(" ", "<br>")
                # Append the entity with styling
                html += f'<div style="background-color: {color}; padding: 5px; border-radius: 3px; margin: 5px 0;">{entity_text}</div>'
                last_idx = end
                
            # Append any remaining text after the last entity
            html += text[last_idx:].replace(" ", "<br>")
            html += "</div>"
            return html
            
        # Generate and display the styled HTML
        styled_text = get_entity_html(text, entities)
        
        st.markdown(styled_text, unsafe_allow_html=True)

        render_entities(text, entities)

    render_random_examples()



if __name__ == '__main__':

    query_params = st.query_params
    if 'api' in query_params:
        sentence = query_params.get('sentence')
        entities = predict_using_trained(sentence)
        response = {"sentence" : sentence , "entities" : entities}
        pprint(response)

        st.write(response)
    else:
        prep_page()