import json
import nltk
import pandas as pd
import re

nltk.download('punkt') 

from dataclasses import dataclass
from nltk.tokenize import sent_tokenize
from typing import Dict, List, Sequence
from utils_report_parser import get_section_from_report

from transformers import (
    AutoModelForTokenClassification,
    AutoTokenizer,
    pipeline,
)

@dataclass
class Report:
    patient_id: str|int
    text: str
    date: str
    summary: str|None = None

def clean(s: str) -> str:
    s = s.replace("\n", " ")  # Concatenate into one string
    s = s.replace("_", "")  # Remove long lines and underscores
    s = re.sub(r"\[.*?\]", "", s)  # Remove brackets and parentheses
    s = re.sub(r"\(.*?\)", "", s)
    s = " ".join(s.split())  # Replace multiple white spaces
    return s


def split_paragraphs(text: str) -> List[str]:
    paragraphs = text.split("\n\n")
    paragraphs = list(map(clean, paragraphs))
    paragraphs = list(filter(lambda s: len(s.split()) > 10, paragraphs))
    return paragraphs


def format_casemaker_data(
    df: pd.DataFrame, patient_id_column: str, text_column: str, date_column: str
):
    """Take in a pandas dataframe where each row corresponds to one report for a patient,
    and output a dataframe where each row corresponds to a patient, and the "records" column
    contains a list of dictionaries of all their reports sorted by date

    Args:
        df (pd.DataFrame): Input dataframe on report level
        patient_id_column (str): Patient ID
        text_column (str): Text/Report
        date_column (str): Date (will be used to sort)
    """
    df = df.rename(
        columns={
            patient_id_column: "patient_id",
            text_column: "text",
            date_column: "date",
        }
    )
    df = (
        df.sort_values(by=["patient_id", "date"])
        .groupby("patient_id")
        .apply(lambda df: df[["patient_id", "text", "date"]].to_dict("records"))
    )
    reports_by_patient = dict[str,Sequence[Report]]()
    for patient_id, report_list in zip(df.index, df):
        patient_id = str(patient_id)
        report_list = [Report(**report) for report in report_list]
        reports_by_patient[patient_id] = report_list
    return reports_by_patient


class CaseMaker:
    def __init__(self, organ_keywords_dict_path: str = "../assets/terms.json"):
        self.organ_keyword_dict = json.load(open(organ_keywords_dict_path, "r"))

        self.ner_pipe = pipeline(
            "ner",
            model=AutoModelForTokenClassification.from_pretrained(
                "d4data/biomedical-ner-all"
            ),
            tokenizer=AutoTokenizer.from_pretrained("d4data/biomedical-ner-all"),
            aggregation_strategy="simple",
            device_map="auto",
        )
        # self.summ_pipe = pipeline(
        #     "text2text-generation", model="starmpcc/Asclepius-7B", device_map="auto"
        # )

    def standardize_organ(self, organ_entity: Dict) -> Dict:
        """Given an entity, map its name to a set of recognized entities provided in
        organ_keyword_dict if it matches any of the keywords; otherwise set it as "Other"

        Args:
            organ_entity (Dict): Dictionary corresponding to entity; should contain "word" key
            which is the entity

        Returns:
            Dict: Same dictionary where the "word" key has been updated to either a set of standard
            body organs or "Other"
        """
        # If the organ matches any of the keys or their synonyms, replace the name and return
        for key in self.organ_keyword_dict:
            if (organ_entity["word"].lower() == key.lower()) or (
                organ_entity["word"].lower() in self.organ_keyword_dict[key]
            ):
                organ_entity["word"] = key
                return organ_entity
        # Otherwise, it's a bad match so set the score to 0 and return other
        organ_entity["word"] = "Other"
        organ_entity["score"] = 0.0

        return organ_entity

    def pick_organ_by_keyword(self, s: str):
        words = s.lower()
        for organ in self.organ_keyword_dict.keys():
            if any(
                [
                    keyword.lower() in words
                    for keyword in [organ] + self.organ_keyword_dict[organ]
                ]
            ):
                return organ
        return "other"

    def parse_report_by_organ(self, report: str):
        """Take in a text report and output a dictionary of body organs
        and a list of all the sentences corresponding to that organ

        Args:
            report (str): Input report
        """
        report_string_by_organ = dict[str, str]()

        # Split the report into a list of paragraphs
        paragraphs = split_paragraphs(report)
        # Collect a list of paragraphs related to each organ
        for p in paragraphs:
            # Figure out which organ is being referenced
            selected_organ = self.pick_organ_by_keyword(p)

            # Concatenate the report to its corresponding organ
            if selected_organ not in report_string_by_organ:
                report_string_by_organ[selected_organ] = p
            else:
                report_string_by_organ[selected_organ] += p

        return report_string_by_organ

    def trim_to_relevant_portion(self, report: str):
        # Cut the report to the findings
        report = get_section_from_report(report, "findings")

        # Only keep sentences with symptoms and disease descriptions
        relevant_sentences = []
        for sentence in sent_tokenize(report):
            if any(
                [
                    ent["entity_group"] in ["Sign_symptom", "Disease_disorder"]
                    for ent in self.ner_pipe(sentence)
                ]
            ):
                relevant_sentences.append(sentence)
        return "\n".join(relevant_sentences)

    def summarize_report(self, text: str) -> str:
        """Format text into prompt and summarize clinical text

        Args:
            text (str): Input report

        Returns:
            str: Output
        """

        question = (
            "Can you provide a succinct summary of the key clinical findings "
            "and treatment recommendations outlined in this discharge summary?"
        )

        prompt = """
        You are an intelligent clinical languge model.
        Below is a snippet of patient's discharge summary and a following instruction from healthcare professional.
        Write a response that appropriately completes the instruction.
        The response should provide the accurate answer to the instruction, while being concise.

        [Discharge Summary Begin]
        {note}
        [Discharge Summary End]

        [Instruction Begin]
        {question}
        [Instruction End]
        """.format(
            question=question, note=text
        )

        output = self.summ_pipe(prompt, max_new_tokens=len(text.split()) // 2)[0][
            "generated_text"
        ]
        answer = output.split("[Instruction End]")[-1]
        answer = clean(answer)
        return answer

    def parse_records(
        self,
        reports: Sequence[Report],
    ):
        """Given a list of reports (represented by dictionaries), split each of them
        by body part using parse_report_by_organ, then compile all the text for the same
        organ across different reports
        (i.e. for each body part, have a list of dicts which contain the text from various reports)

        Args:
            records (Sequence[Report]): List of reports represented by dictionaries; each dictionary
            must contain "text" and "date" keys
        """

        # For each organ, collect a list of relevant records containing the text and date
        reports_by_organ = dict[str, Sequence[Report]]()
        for report in reports:
            report_by_organ = self.parse_report_by_organ(report.text)
            for organ, report_text in report_by_organ.items():
                organ_level_record = Report(text=report_text, date=report.date, patient_id=report.patient_id)
                if organ in reports_by_organ:
                    reports_by_organ[organ].append(organ_level_record)
                else:
                    reports_by_organ[organ] = [organ_level_record]

        # For each organ, then filter only to the relevant reports and summarize them
        summarized_reports_by_organ = dict[str, Sequence[Report]]()
        for organ in reports_by_organ.keys():
            cleaned_reports = list[Report]()
            for report in reports_by_organ[organ]:
                # Trim the report
                report_text = self.trim_to_relevant_portion(report.text)
                if report_text:
                    report.summary = report_text
                    cleaned_reports.append(report)
            summarized_reports_by_organ[organ] = cleaned_reports

        return summarized_reports_by_organ

    def format_reports(self, all_reports: Dict[str, List[Dict]]):
        new_reports = {}
        for organ, organ_reports in all_reports.items():
            new_reports[organ] = "\n\n".join(
                [f"**Report {str(r.date)}**\n\n{str(r.summary)}" for r in organ_reports]
            )
        return new_reports