import json
import nltk
import pandas as pd
import re

nltk.download('punkt') 

from dataclasses import asdict, dataclass
from nltk.tokenize import sent_tokenize
from typing import Dict, List, Mapping, Sequence
from utils_report_parser import get_section_from_report

from transformers import (
    AutoModelForTokenClassification,
    AutoTokenizer,
    pipeline,
)


@dataclass
class Report:
    patient_id: str | int
    text: str
    date: str
    summary: str | None = None

    def dict(self):
        return {k: str(v) for k, v in asdict(self).items()}


def clean(s: str) -> str:
    s = s.replace("\n", " ")  # Concatenate into one string
    s = s.replace("_", "")  # Remove long lines and underscores
    s = re.sub(r"\[.*?\]", "", s)  # Remove brackets and parentheses
    s = re.sub(r"\(.*?\)", "", s)
    s = " ".join(s.split())  # Replace multiple white spaces
    return s


def split_into_sentences(text: str):
    # Split into sentences
    return sent_tokenize(text)


def remove_keyword(text: str, keyword: str):
    start_idx = text.lower().find(keyword.lower())
    if start_idx > -1:
        substring_to_replace = text[start_idx: start_idx+len(keyword)]
        text = text.replace(substring_to_replace, " ")
    return text


def format_casemaker_data(
    df: pd.DataFrame, patient_id_column: str, text_column: str, date_column: str
):
    """Take in a pandas dataframe where each row corresponds to one report for a patient,
    and output a dataframe where each row corresponds to a patient, and the "records" column
    contains a list of dictionaries of all their reports sorted by date

    Args:
        df (pd.DataFrame): Input dataframe on report level
        patient_id_column (str): Patient ID
        text_column (str): Text/Report
        date_column (str): Date (will be used to sort)
    """
    df = df.rename(
        columns={
            patient_id_column: "patient_id",
            text_column: "text",
            date_column: "date",
        }
    )
    df = (
        df.sort_values(by=["patient_id", "date"])
        .groupby("patient_id")
        .apply(lambda df: df[["patient_id", "text", "date"]].to_dict("records"))
    )
    reports_by_patient = dict[str, Sequence[Report]]()
    for patient_id, report_list in zip(df.index, df):
        patient_id = str(patient_id)
        report_list = [Report(**report) for report in report_list]
        reports_by_patient[patient_id] = report_list
    return reports_by_patient


class CaseMaker:
    def __init__(self, organ_keywords_dict_path: str = "../assets/terms.json"):
        self.organ_keyword_dict: Mapping[str, list[str]] = json.load(open(organ_keywords_dict_path, "r"))

        self.ner_pipe = pipeline(
            "ner",
            model=AutoModelForTokenClassification.from_pretrained(
                "d4data/biomedical-ner-all"
            ),
            tokenizer=AutoTokenizer.from_pretrained("d4data/biomedical-ner-all"),
            aggregation_strategy="simple",
            device_map="auto",
        )
        # self.summ_pipe = pipeline(
        #     "text2text-generation", model="starmpcc/Asclepius-7B", device_map="auto"
        # )

    def filter_out_irrelevant_sentences(self, lst: list[str]):
        bad_keywords = [
            "date of procedure", "physicians", "report initiated by",
            "reported by", "reported and signed by"
            ]
        return [s for s in lst if not any([s.lower().startswith(k) for k in bad_keywords])]
    
    def remove_header_names(self, s: str):
        headers = [
            "IMPRESSION", "FINDINGS", "RECOMMENDATION",
            "COMPARISON", "INDICATION", "TECHNIQUE", "STUDY", 
            "MEDICATIONS", "TECHNIQUE AND FINDINGS"
            ]
        for header in headers:
            s = remove_keyword(s, f"{header}:")
        return s

    def pick_organ_by_keyword(self, s: str):
        words = s.lower()
        for organ in self.organ_keyword_dict.keys():
            if any(
                [
                    keyword.lower() in words
                    for keyword in [organ] + self.organ_keyword_dict[organ]
                ]
            ):
                return organ
        return None

    def parse_report_by_organ(self, report: str):
        """Take in a text report and output a dictionary of body organs
        and a list of all the sentences corresponding to that organ

        Args:
            report (str): Input report
        """
        report_string_by_organ = dict[str, str]()

        # Split the report into a list of sentences
        sentences = split_into_sentences(report)
        # Filter out irrelevant sentences using rules
        sentences = self.filter_out_irrelevant_sentences(sentences)
        # Collect a list of paragraphs related to each organ
        previous_sentence_organ = "Other"
        for s in sentences:
            # Figure out which organ is being referenced
            selected_organ = self.pick_organ_by_keyword(s)
            if selected_organ is None:
                selected_organ = previous_sentence_organ
            else:
                previous_sentence_organ = selected_organ
            # Concatenate the report to its corresponding organ
            if selected_organ not in report_string_by_organ:
                report_string_by_organ[selected_organ] = s
            else:
                report_string_by_organ[selected_organ] += f" {s}"

        return report_string_by_organ

    def trim_to_relevant_portion(self, report: str):
        # Only keep sentences with symptoms and disease descriptions
        relevant_sentences = list[str]()
        for sentence in sent_tokenize(report):
            if any(
                [
                    ent["entity_group"] in ["Sign_symptom", "Disease_disorder"]
                    for ent in self.ner_pipe(sentence)
                ]
            ):
                relevant_sentences.append(str(sentence))
        return "\n".join(relevant_sentences)

    def summarize_report(self, text: str) -> str:
        """Format text into prompt and summarize clinical text

        Args:
            text (str): Input report

        Returns:
            str: Output
        """

        question = (
            "Can you provide a succinct summary of the key clinical findings "
            "and treatment recommendations outlined in this discharge summary?"
        )

        prompt = """
        You are an intelligent clinical languge model.
        Below is a snippet of patient's discharge summary and a following instruction from healthcare professional.
        Write a response that appropriately completes the instruction.
        The response should provide the accurate answer to the instruction, while being concise.

        [Discharge Summary Begin]
        {note}
        [Discharge Summary End]

        [Instruction Begin]
        {question}
        [Instruction End]
        """.format(
            question=question, note=text
        )

        output = self.summ_pipe(prompt, max_new_tokens=len(text.split()) // 2)[0][
            "generated_text"
        ]
        answer = output.split("[Instruction End]")[-1]
        answer = clean(answer)
        return answer

    def parse_records(
        self,
        reports: Sequence[Report],
    ):
        """Given a list of reports (represented by dictionaries), split each of them
        by body part using parse_report_by_organ, then compile all the text for the same
        organ across different reports
        (i.e. for each body part, have a list of dicts which contain the text from various reports)

        Args:
            records (Sequence[Report]): List of reports represented by dictionaries; each dictionary
            must contain "text" and "date" keys
        """

        # Split the reports by organ
        reports_by_organ = dict[str, Sequence[Report]]()
        for report in reports:
            # Cut the report to the findings
            report_findings = get_section_from_report(report.text, "findings")
            
            # Remove any other keywords
            report_findings = self.remove_header_names(report_findings)
            
            # For each organ, collect a list of relevant records containing the text and date
            report_by_organ = self.parse_report_by_organ(report_findings)
            for organ, report_text in report_by_organ.items():
                organ_level_record = Report(
                    text=report_text, date=report.date, patient_id=report.patient_id
                )
                if organ in reports_by_organ:
                    reports_by_organ[organ].append(organ_level_record)
                else:
                    reports_by_organ[organ] = [organ_level_record]

        # For each organ, then filter only to the relevant reports and summarize them
        summarized_reports_by_organ = dict[str, Sequence[Report]]()
        for organ in reports_by_organ.keys():
            cleaned_reports = list[Report]()
            for report in reports_by_organ[organ]:
                # Trim the report
                report_text = self.trim_to_relevant_portion(report.text)
                if report_text:
                    report.summary = report_text
                    cleaned_reports.append(report)
            summarized_reports_by_organ[organ] = cleaned_reports

        return summarized_reports_by_organ

    def format_reports(self, all_reports: Dict[str, List[Dict]]):
        new_reports = {}
        for organ, organ_reports in all_reports.items():
            new_reports[organ] = "\n\n".join(
                [f"**Report {str(r.date)}**\n\n{str(r.summary)}" for r in organ_reports]
            )
        return new_reports