import streamlit as st
import time
import base64
from annotated_text import annotated_text
from io import StringIO
from transformers import AutoTokenizer, AutoModelForTokenClassification
from text_extractor import *
from text_annotatator import *
from claim_details import *
import os
from streamlit_text_annotation import text_annotation

os.environ['KMP_DUPLICATE_LIB_OK']='True'

import plotly.express as px
from streamlit_option_menu import option_menu

from transformers import pipeline
import pandas as pd

st.set_page_config(layout="wide")

@st.cache(allow_output_mutation = True)
def init_text_summarization_model():
    MODEL = 'facebook/bart-large-cnn'
    pipe = pipeline("summarization", model=MODEL)
    return pipe

@st.cache(allow_output_mutation = True)
def init_zsl_topic_classification():
    MODEL = 'facebook/bart-large-mnli'
    pipe = pipeline("zero-shot-classification", model=MODEL)
    template = "This text is about {}."
    return pipe, template

@st.cache(allow_output_mutation = True)
def init_zsl_topic_classification():
    MODEL = 'facebook/bart-large-mnli'
    pipe = pipeline("zero-shot-classification", model=MODEL)
    template = "This text is about {}."
    return pipe, template

@st.cache(allow_output_mutation = True)
def init_ner_pipeline():
    tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")
    model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")
    pipe = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="simple") # pass device=0 if using gpu
    return pipe

@st.cache(allow_output_mutation = True)
def init_qa_pipeline():
    question_answerer_pipe = pipeline("question-answering", model='deepset/roberta-base-squad2')
    return question_answerer_pipe

def get_formatted_text_for_annotation(output):
    colour_map = {'Coreference': '#29D93B',
    'Severity':'#FCF3CF',
 'Sex': '#E9F7EF',
 'Sign_symptom': '#EAF2F8',
 'Detailed_description': '#078E8B',
 'Date': '#F5EEF8',
 'History': '#FDEDEC',
 'Medication': '#F4F6F6',
 'Therapeutic_procedure': '#A3E4D7',
 'Age': '#85C1E9',
 'Subject': '#D7BDE2',
 'Biological_structure': '#AF7AC5',
 'Activity': '#B2BABB',
 'Lab_value': '#E6B0AA',
 'Family_history': '#2471A3',
 'Diagnostic_procedure': '#CCD1D1',
 'Other_event': '#239B56',
 'Occupation': '#B3B6B7'}
    
    annotated_texts = []
    next_index = 0
    for entity in output:
        if entity['start'] == next_index:
    #         print("found entity")
            extracted_text = text[entity['start']:entity['end']]
    #         print("annotated",annotated_text)
            annotated_texts.append((extracted_text ,entity['entity_group'],colour_map[entity['entity_group']]))
        else:
            unannotated_text = text[next_index:entity['start']-1]
            annotated_texts.append(unannotated_text)
            extracted_text = text[entity['start']:entity['end']]
            annotated_texts.append((extracted_text ,entity['entity_group'],colour_map[entity['entity_group']]))
            next_index =entity['end'] +1
    
    if next_index < len(text):
        annotated_texts.append(text[next_index-1:len(text)-1])
        
    return tuple(annotated_texts)

def displayPDF(file):
    # Opening file from file path
    with open(file, "rb") as f:
        base64_pdf = base64.b64encode(f.read()).decode('utf-8')

    # Embedding PDF in HTML
        pdf_display = F'<iframe src="data:application/pdf;base64,{base64_pdf}" width="700" height="1000" type="application/pdf"></iframe>'


    # Displaying File
    st.markdown(pdf_display, unsafe_allow_html=True)    


# Model initialization    
pipeline_summarization = init_text_summarization_model()
pipeline_zsl, template = init_zsl_topic_classification()
pipeline_ner =init_ner_pipeline()
pipeline_qa = init_qa_pipeline()

st.header("Intelligent Document Automation")



with st.sidebar:
    selected_menu = option_menu("Select Option", 
    ["Upload Document", "Extract Text", "Summarize Document", "Extract Entities","Detected Barriers","Get Answers","Annotation Tool",
    "Claim Status Report"], 
        menu_icon="cast", default_index=0)
    

if selected_menu == "Upload Document":
    uploaded_file = st.file_uploader("Choose a file")        
    if uploaded_file is not None:
        os.makedirs(os.path.join(os.getcwd(),"uploaded_files"),mode = 0o777, exist_ok = True)
        file_path = os.path.join(os.getcwd(),"uploaded_files",uploaded_file.name) 
        
        with open(file_path,"wb") as f: 
            f.write(uploaded_file.getbuffer())  
        displayPDF(file_path)
        
elif selected_menu == "Extract Text":
    with st.spinner("Extracting Text..."):
        time.sleep(6)
        st.write(get_text_from_ocr_engine())
        
elif selected_menu == "Summarize Document":
    paragraphs= get_paragraphs_for_summaries()
    
    with st.spinner("Finding Topics..."):
        tags_found = ["Injury Details", "Past Medical Conditions", "Injury Management Plan", "GP Correspondence"]
        time.sleep(5)
        st.write("This document is about:")
        st.markdown(";".join(["#" + tag + " "  for tag in tags_found]) + "**")
        st.markdown("""---""")
        
    with st.spinner("Summarizing Document..."):
        
        
        for text in paragraphs:
            summary_text = pipeline_summarization(text, max_length=130, min_length=30, do_sample=False)
            # Show output
            st.write(summary_text[0]['summary_text'])
            st.markdown("""---""")
     
        
elif selected_menu == "Extract Entities":
    paragraphs= get_paragraphs_for_entities()
    
    with st.spinner("Extracting Entities..."):
        for text in paragraphs:
            output = pipeline_ner (text)
            entities_text =get_formatted_text_for_annotation(output)
            annotated_text(*entities_text)
            st.markdown("""---""")
            
elif selected_menu == "Detected Barriers":
    #st.subheader('Barriers Detected')
    barriers_to_detect = {"Chronic Pain":"Is the patint experiencing chronic pain?",
                          "Mental Health Issues":"Does he have any mental issues?",
                          "Prior History":"What is prior medical history?",
                          "Smoking":"Does he smoke?",
                          "Drinking":"Does he drink?",
                          "Comorbidities":"Does he have any comorbidities?"}
    
    with st.spinner("Detecting Barriers..."):                                            
        for barrier,question_text in barriers_to_detect.items():
        
            context = get_text_from_ocr_engine()
            if question_text:
                result = pipeline_qa(question=question_text,  context=context)
                st.subheader(barrier)
                #st.text(result)
                if result['score'] < 0.3:
                    st.text("Not Found")
                else:
                    st.text(result['answer']) 

elif selected_menu == "Get Answers":
    st.subheader('Question')
    question_text = st.text_input("Type your question")
    context = get_text_from_ocr_engine()
  
    if question_text:
        with st.spinner("Finding Answer(s)..."):
            result = pipeline_qa(question=question_text,  context=context)
            st.subheader('Answer')
            st.text(result['answer'])
            
elif selected_menu == "Annotation Tool":
   
    display_only_data = get_display_only_data()
    editable_data = get_editable_data()
    
    st.subheader("Display Mode:")
    left, right = st.columns(2)
    with left:
        st.text("Vertical labels:")
        text_annotation(display_only_data )
    with right:
        st.text("Horizontal labels:")
        display_only_data["labelOrientation"] = "horizontal"
        text_annotation(display_only_data )


    st.subheader("Edit Mode:")
    data = text_annotation(editable_data)
    if data:
        "Returned data:", data
elif selected_menu == "Claim Status Report":
    claim_number = st.text_input("Enter the Claim Number")
    
    if claim_number :
        st.subheader("Claim Attributes:")
        claim_attributes = get_claim_details()
        
        for label,value in claim_attributes.items(): 
            st.metric(label, value, delta=None, delta_color="normal")
        
        st.subheader("Injury Details:")
        injury_details = get_injury_details() 
        st.write(injury_details)
        
        
        st.subheader("Injury Severity:")
        injury_severity = get_injury_severity() 
        st.write(injury_severity)
        
        st.subheader("Preexisting Conditions:")
        preexisting_conditions = get_preexisting_conditions() 
        st.write(preexisting_conditions)
        
        st.subheader("Work Capacity:")
        work_capacity = get_work_capacity() 
        st.write(work_capacity)
        
       
        st.subheader("Injury Management Plan:")
        injury_management_plan = get_injury_management_plan() 
        st.write(injury_management_plan)