import json
import logging
from typing import List, Any
import copy

import torch
from torch.utils.data import Dataset
from transformers import AutoTokenizer, AutoModelForTokenClassification, Trainer

from util.process_data import Sample, Entity, EntityType, EntityTypeSet, SampleList, Token, Relation
from util.configuration import InferenceConfiguration

valid_relations = { # head : [tail, ...]
    "StatedKeyFigure": ["StatedKeyFigure", "Condition", "StatedExpression", "DeclarativeExpression"],
    "DeclarativeKeyFigure": ["DeclarativeKeyFigure", "Condition", "StatedExpression", "DeclarativeExpression"],
    "StatedExpression": ["Unit", "Factor", "Range", "Condition"],
    "DeclarativeExpression": ["DeclarativeExpression", "Unit", "Factor", "Range", "Condition"],
    "Condition": ["Condition", "StatedExpression", "DeclarativeExpression"],
    "Range": ["Range"]
}

class TokenClassificationDataset(Dataset):
    """ Pytorch Dataset """

    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)


class TransformersInference():

    def __init__(self, config: InferenceConfiguration):
        super().__init__()
        self.__logger = logging.getLogger(self.__class__.__name__)
        self.__logger.info(f"Load Configuration: {config.dict()}")

        with open(f"classification.json", mode='r', encoding="utf-8") as f:
            self.__entity_type_set = EntityTypeSet.parse_obj(json.load(f))
        self.__entity_type_label_to_id_mapping = {x.label: x.idx for x in self.__entity_type_set.all_types()}
        self.__entity_type_id_to_label_mapping = {x.idx: x.label for x in self.__entity_type_set.all_types()}

        self.__logger.info("Load Model: " + config.model_path_keyfigure)
        self.__tokenizer = AutoTokenizer.from_pretrained(config.transformer_model,
                padding="max_length", max_length=512, truncation=True) 
        
        self.__model = AutoModelForTokenClassification.from_pretrained(config.model_path_keyfigure, num_labels=(
            len(self.__entity_type_set)))

        self.__trainer = Trainer(model=self.__model)
        self.__merge_entities = config.merge_entities
        self.__split_len = config.split_len
        self.__extract_relations = config.extract_relations

        # add special tokens
        entity_groups = self.__entity_type_set.groups
        num_entity_groups = len(entity_groups)

        lst_special_tokens = ["[REL]", "[SUB]", "[/SUB]", "[OBJ]", "[/OBJ]"]
        for grp_idx, grp in enumerate(entity_groups):
            lst_special_tokens.append(f"[GRP-{grp_idx:02d}]")
            lst_special_tokens.extend([f"[ENT-{ent:02d}]" for ent in grp if ent != self.__entity_type_set.id_of_non_entity])
            lst_special_tokens.extend([f"[/ENT-{ent:02d}]" for ent in grp if ent != self.__entity_type_set.id_of_non_entity])

        lst_special_tokens = sorted(list(set(lst_special_tokens)))
        special_tokens_dict = {'additional_special_tokens': lst_special_tokens }
        num_added_toks = self.__tokenizer.add_special_tokens(special_tokens_dict)
        self.__logger.info(f"Added {num_added_toks} new special tokens. All special tokens: '{self.__tokenizer.all_special_tokens}'")

        self.__logger.info("Initialization completed.")



    def run_inference(self, sample_list: SampleList):
        group_predictions = []
        group_entity_ids = []
        self.__logger.info("Predict Entities ...")
        for grp_idx, grp in enumerate(self.__entity_type_set.groups):
            token_lists = [[x.text for x in sample.tokens] for sample in sample_list.samples]
            predictions = self.__get_predictions(token_lists, f"[GRP-{grp_idx:02d}]")
            group_entity_ids_ = []
            for sample, prediction_per_tokens in zip(sample_list.samples, predictions):
                group_entity_ids_.append(self.generate_response_entities(sample, prediction_per_tokens, grp_idx))
            group_predictions.append(predictions)
            group_entity_ids.append(group_entity_ids_)

        if self.__extract_relations:
            self.__logger.info("Predict Relations ...")
            self.__do_extract_relations(sample_list, group_predictions, group_entity_ids)


    def __do_extract_relations(self, sample_list, group_predictions, group_entity_ids):
        id_of_non_entity = self.__entity_type_set.id_of_non_entity

        for sample_idx, sample in enumerate(sample_list.samples):
            masked_tokens = []
            masked_tokens_align = []
            # create SUB-Mask for every entity that can be a head
            head_entities = [entity_ for entity_ in sample.entities if entity_.ent_type.label in list(valid_relations.keys())]
            for entity_ in head_entities:
                ent_masked_tokens = []
                ent_masked_tokens_align = []
                last_preds = [id_of_non_entity for group in group_predictions]
                last_ent_ids = [-1 for group in group_entity_ids]
                for token_idx, token in enumerate(sample.tokens):
                    for group, ent_ids, last_pred, last_ent_id in zip(group_predictions, group_entity_ids, last_preds, last_ent_ids):
                        pred = group[sample_idx][token_idx]
                        ent_id = ent_ids[sample_idx][token_idx]
                        if last_pred != pred and last_pred != id_of_non_entity:
                            mask = "[/SUB]" if last_ent_id == entity_.id else "[/OBJ]"
                            ent_masked_tokens.extend([f"[/ENT-{last_pred:02d}]", mask])
                            ent_masked_tokens_align.extend([str(last_ent_id), str(last_ent_id)])

                    for group, ent_ids, last_pred, last_ent_id in zip(group_predictions, group_entity_ids, last_preds, last_ent_ids):
                        pred = group[sample_idx][token_idx]
                        ent_id = ent_ids[sample_idx][token_idx]
                        if last_pred != pred and pred != id_of_non_entity:
                            mask = "[SUB]" if ent_id == entity_.id else "[OBJ]"
                            ent_masked_tokens.extend([mask, f"[ENT-{pred:02d}]"])
                            ent_masked_tokens_align.extend([str(ent_id), str(ent_id)])

                    ent_masked_tokens.append(token.text)
                    ent_masked_tokens_align.append(token.text)
                    for idx, group in enumerate(group_predictions):
                        last_preds[idx] = group[sample_idx][token_idx]
                    for idx, group in enumerate(group_entity_ids):
                        last_ent_ids[idx] = group[sample_idx][token_idx]

                for group, ent_ids, last_pred, last_ent_id in zip(group_predictions, group_entity_ids, last_preds, last_ent_ids):
                    pred = group[sample_idx][token_idx]
                    ent_id = ent_ids[sample_idx][token_idx]
                    if last_pred != id_of_non_entity:
                        mask = "[/SUB]" if last_ent_id == entity_.id else "[/OBJ]"
                        ent_masked_tokens.extend([f"[/ENT-{last_pred:02d}]", mask])
                        ent_masked_tokens_align.extend([str(last_ent_id), str(last_ent_id)])

                masked_tokens.append(ent_masked_tokens)
                masked_tokens_align.append(ent_masked_tokens_align)

            rel_predictions = self.__get_predictions(masked_tokens, "[REL]")
            self.generate_response_relations(sample, head_entities, masked_tokens_align, rel_predictions)


    def generate_response_entities(self, sample: Sample, predictions_per_tokens: List[int], grp_idx: int):
        entities = []
        entity_ids = []
        id_of_non_entity = self.__entity_type_set.id_of_non_entity
        idx = grp_idx * 1000
        for token, prediction in zip(sample.tokens, predictions_per_tokens):
            if id_of_non_entity == prediction:
                entity_ids.append(-1)
                continue
            idx += 1
            entities.append(self.__build_entity(idx, prediction, token))
            entity_ids.append(idx)

        if self.__merge_entities:
            entities = self.__do_merge_entities(copy.deepcopy(entities))
            prev_pred = id_of_non_entity
            for idx, pred in enumerate(predictions_per_tokens):
                if prev_pred == pred and idx > 0:
                    entity_ids[idx] = entity_ids[idx-1]
                prev_pred = pred

        sample.entities += entities
 
        tags = sample.tags if len(sample.tags) > 0 else [self.__entity_type_set.id_of_non_entity] * len(sample.tokens)
        for tag_id, tok in enumerate(sample.tokens):
            for ent in entities:
                if tok.start >= ent.start and tok.start < ent.end:
                    tags[tag_id] = ent.ent_type.idx
        logging.info(tags)
        sample.tags = tags

        return entity_ids


    def generate_response_relations(self, sample: Sample, head_entities: List[Entity], masked_tokens_align: List[List[str]], rel_predictions: List[List[int]]):
        relations = []
        id_of_non_entity = self.__entity_type_set.id_of_non_entity
        idx = 0
        for entity_, align_per_ent, prediction_per_ent in zip(head_entities, masked_tokens_align, rel_predictions):
            for token, prediction in zip(align_per_ent, prediction_per_ent):
                if id_of_non_entity == prediction:
                    continue
                try:
                    tail = int(token)
                except:
                    continue
                if not self.__validate_relation(sample.entities, entity_.id, tail, prediction):
                    continue
                idx += 1
                relations.append(self.__build_relation(idx, entity_.id, tail, prediction))

        sample.relations = relations


    def __validate_relation(self, entities: List[Entity], head: int, tail: int, prediction: int):
        if head == tail: return False
        head_ents = [ent.ent_type.label for ent in entities if ent.id==head]
        tail_ents = [ent.ent_type.label for ent in entities if ent.id==tail]

        if len(head_ents) > 0:
            head_ent = head_ents[0]
        else:
            return False

        if len(tail_ents) > 0:
            tail_ent = tail_ents[0]
        else:
            return False

        return tail_ent in valid_relations[head_ent]


    def __build_entity(self, idx: int, prediction: int, token: Token) -> Entity:
        return Entity(
            id=idx,
            text=token.text,
            start=token.start,
            end=token.end,
            ent_type=EntityType(
                idx=prediction, 
                label=self.__entity_type_id_to_label_mapping[prediction]
                )
        )

    def __build_relation(self, idx: int, head: int, tail: int, prediction: int) -> Relation:
        return Relation(
            id=idx,
            head=head,
            tail=tail,
            rel_type=EntityType(
                idx=prediction, 
                label=self.__entity_type_id_to_label_mapping[prediction]
                )
        )

    def __do_merge_entities(self, input_ents_):
        out_ents = list()
        current_ent = None

        for ent in input_ents_:
            if current_ent is None:
                current_ent = ent
            else:
                idx_diff = ent.start - current_ent.end
                if ent.ent_type.idx == current_ent.ent_type.idx and idx_diff <= 1:
                    current_ent.end = ent.end
                    current_ent.text += (" " if idx_diff == 1 else "") + ent.text
                else:
                    out_ents.append(current_ent)
                    current_ent = ent
        
        if current_ent is not None:
            out_ents.append(current_ent)

        return out_ents


    def __get_predictions(self, token_lists: List[List[str]], trigger: str) -> List[List[int]]:
        """ Get predictions of Transformer Sequence Labeling model """
        if self.__split_len > 0:
            token_lists_split = self.__do_split_sentences(token_lists, self.__split_len)
            predictions = []
            for sample_token_lists in token_lists_split:
                sample_token_lists_trigger = [[trigger]+sample for sample in sample_token_lists]
                val_encodings = self.__tokenizer(sample_token_lists_trigger, is_split_into_words=True, padding='max_length', truncation=True)  # return_tensors="pt"
                val_labels = []
                for i in range(len(sample_token_lists_trigger)):
                    word_ids = val_encodings.word_ids(batch_index=i)
                    label_ids = [0 for _ in word_ids]
                    val_labels.append(label_ids)

                val_dataset = TokenClassificationDataset(val_encodings, val_labels)

                predictions_raw, _, _ = self.__trainer.predict(val_dataset)

                predictions_align = self.__align_predictions(predictions_raw, val_encodings)
                confidence = [[max(token) for token in sample] for sample in predictions_align]
                predictions_sample = [[token.index(max(token)) for token in sample][1:] for sample in predictions_align]
                predictions_part = []
                for tok, pred in zip(sample_token_lists_trigger, predictions_sample):
                    if trigger == "[REL]" and "[SUB]" not in tok:
                        predictions_part += [self.__entity_type_set.id_of_non_entity] * len(pred)
                    else:
                        predictions_part += pred
                predictions.append(predictions_part)
                # predictions.append([j for i in predictions_sample for j in i]))
        else:
            token_lists_trigger = [[trigger]+sample for sample in token_lists]
            val_encodings = self.__tokenizer(token_lists_trigger, is_split_into_words=True, padding='max_length', truncation=True)  # return_tensors="pt"
            val_labels = []
            for i in range(len(token_lists_trigger)):
                word_ids = val_encodings.word_ids(batch_index=i)
                label_ids = [0 for _ in word_ids]
                val_labels.append(label_ids)

            val_dataset = TokenClassificationDataset(val_encodings, val_labels)

            predictions_raw, _, _ = self.__trainer.predict(val_dataset)

            predictions_align = self.__align_predictions(predictions_raw, val_encodings)
            confidence = [[max(token) for token in sample] for sample in predictions_align]
            predictions = [[token.index(max(token)) for token in sample][1:] for sample in predictions_align]

        return predictions

    def __do_split_sentences(self, tokens_: List[List[str]], split_len_ = 200) -> List[List[List[str]]]:
        # split token lists into shorter lists
        res_tokens = []

        for tok_lst in tokens_:
            res_tokens_sample = []
            length = len(tok_lst)
            if length > split_len_:
                num_lists = length // split_len_ + (1 if (length % split_len_) > 0 else 0)
                new_length = int(length / num_lists) + 1
                self.__logger.info(f"Splitting a list of {length} elements into {num_lists} lists of length {new_length}..")
                start_idx = 0
                for i in range(num_lists):
                    end_idx = min(start_idx + new_length, length)
                    if "\n" in tok_lst[start_idx]: tok_lst[start_idx] = "."
                    if "\n" in tok_lst[end_idx-1]: tok_lst[end_idx-1] = "."
                    res_tokens_sample.append(tok_lst[start_idx:end_idx])
                    start_idx = end_idx

                res_tokens.append(res_tokens_sample)
            else:
                res_tokens.append([tok_lst])

        return res_tokens
    

    def __align_predictions(self, predictions, tokenized_inputs, sum_all_tokens=False) -> List[List[List[float]]]:
        """ Align predicted labels from Transformer Tokenizer """
        confidence = []
        id_of_non_entity = self.__entity_type_set.id_of_non_entity
        for i, tagset in enumerate(predictions):

            word_ids = tokenized_inputs.word_ids(batch_index=i)

            previous_word_idx = None
            token_confidence = []
            for k, word_idx in enumerate(word_ids):
                try:
                    tok_conf = [value for value in tagset[k]]
                except TypeError:
                    # use the object itself it if's not iterable
                    tok_conf = tagset[k]

                if word_idx is not None:
                    # add nonentity tokens if there is a gap in word ids (usually caused by a newline token)
                    if previous_word_idx is not None:
                        diff = word_idx - previous_word_idx
                        for i in range(diff - 1):
                            tmp = [0 for _ in tok_conf]
                            tmp[id_of_non_entity] = 1.0
                            token_confidence.append(tmp)

                    # add confidence value if this is the first token of the word
                    if word_idx != previous_word_idx:
                        token_confidence.append(tok_conf)
                    else:
                        # if sum_all_tokens=True the confidence for all tokens of one word will be summarized
                        if sum_all_tokens:
                            token_confidence[-1] = [a + b for a, b in zip(token_confidence[-1], tok_conf)]

                previous_word_idx = word_idx

            confidence.append(token_confidence)

        return confidence