import unittest

from model.base_model import SummModel
from model import SUPPORTED_SUMM_MODELS

from pipeline import assemble_model_pipeline

from evaluation.base_metric import SummMetric
from evaluation import SUPPORTED_EVALUATION_METRICS

from dataset.st_dataset import SummInstance, SummDataset
from dataset import SUPPORTED_SUMM_DATASETS
from dataset.dataset_loaders import ScisummnetDataset, ArxivDataset

from helpers import print_with_color, retrieve_random_test_instances

import random
import time
from typing import List, Union, Tuple
import sys
import re


class IntegrationTests(unittest.TestCase):
    def get_prediction(
        self, model: SummModel, dataset: SummDataset, test_instances: List[SummInstance]
    ) -> Tuple[Union[List[str], List[List[str]]], Union[List[str], List[List[str]]]]:
        """
        Get summary prediction given model and dataset instances.

        :param SummModel `model`: Model for summarization task.
        :param SummDataset `dataset`: Dataset for summarization task.
        :param List[SummInstance] `test_instances`: Instances from `dataset` to summarize.
        :returns Tuple containing summary list of summary predictions and targets corresponding to each instance in `test_instances`.
        """

        src = (
            [ins.source[0] for ins in test_instances]
            if isinstance(dataset, ScisummnetDataset)
            else [ins.source for ins in test_instances]
        )
        tgt = [ins.summary for ins in test_instances]
        query = (
            [ins.query for ins in test_instances] if dataset.is_query_based else None
        )
        prediction = model.summarize(src, query)
        return prediction, tgt

    def get_eval_dict(self, metric: SummMetric, prediction: List[str], tgt: List[str]):
        """
        Run evaluation metric on summary prediction.

        :param SummMetric `metric`: Evaluation metric.
        :param List[str] `prediction`: Summary prediction instances.
        :param List[str] `tgt`: Target prediction instances from dataset.
        """
        score_dict = metric.evaluate(prediction, tgt)
        return score_dict

    def test_all(self):
        """
        Runs integration test on all compatible dataset + model + evaluation metric pipelines supported by SummerTime.
        """

        print_with_color("\nInitializing all evaluation metrics...", "35")
        evaluation_metrics = []
        for eval_cls in SUPPORTED_EVALUATION_METRICS:
            # # TODO: Temporarily skipping Rouge/RougeWE metrics to avoid local bug.
            # if eval_cls in [Rouge, RougeWe]:
            #     continue
            print(eval_cls)
            evaluation_metrics.append(eval_cls())

        print_with_color("\n\nBeginning integration tests...", "35")
        for dataset_cls in SUPPORTED_SUMM_DATASETS:
            # TODO: Temporarily skipping Arxiv (size/time)
            if dataset_cls in [ArxivDataset]:
                continue
            dataset = dataset_cls()
            if dataset.train_set is not None:
                dataset_instances = list(dataset.train_set)
                print(
                    f"\n{dataset.dataset_name} has a training set of {len(dataset_instances)} examples"
                )
                print_with_color(
                    f"Initializing all matching model pipelines for {dataset.dataset_name} dataset...",
                    "35",
                )
                # matching_model_instances = assemble_model_pipeline(dataset_cls, list(filter(lambda m: m != PegasusModel, SUPPORTED_SUMM_MODELS)))
                matching_model_instances = assemble_model_pipeline(
                    dataset_cls, SUPPORTED_SUMM_MODELS
                )
                for model, model_name in matching_model_instances:
                    test_instances = retrieve_random_test_instances(
                        dataset_instances=dataset_instances, num_instances=1
                    )
                    print_with_color(
                        f"{'#' * 20} Testing: {dataset.dataset_name} dataset, {model_name} model {'#' * 20}",
                        "35",
                    )
                    prediction, tgt = self.get_prediction(
                        model, dataset, test_instances
                    )
                    print(f"Prediction: {prediction}\nTarget: {tgt}\n")
                    for metric in evaluation_metrics:
                        print_with_color(f"{metric.metric_name} metric", "35")
                        score_dict = self.get_eval_dict(metric, prediction, tgt)
                        print(score_dict)

                    print_with_color(
                        f"{'#' * 20} Test for {dataset.dataset_name} dataset, {model_name} model COMPLETE {'#' * 20}\n\n",
                        "32",
                    )


if __name__ == "__main__":
    if len(sys.argv) > 2 or (
        len(sys.argv) == 2 and not re.match("^\\d+$", sys.argv[1])
    ):
        print("Usage: python tests/integration_test.py [seed]", file=sys.stderr)
        sys.exit(1)

    seed = int(time.time()) if len(sys.argv) == 1 else int(sys.argv.pop())
    random.seed(seed)
    print_with_color(f"(to reproduce) random seeded with {seed}\n", "32")
    unittest.main()