"""Module that contain iterator used for dynamic data."""
import torch
from itertools import cycle
from onmt.constants import CorpusTask
from onmt.inputters.text_corpus import get_corpora, build_corpora_iters
from onmt.inputters.text_utils import (
    text_sort_key,
    process,
    numericalize,
    tensorify,
    _addcopykeys,
)
from onmt.transforms import make_transforms
from onmt.utils.logging import init_logger, logger
from onmt.utils.misc import RandomShuffler
from torch.utils.data import DataLoader


class MixingStrategy(object):
    """Mixing strategy that should be used in Data Iterator."""

    def __init__(self, iterables, weights):
        """Initilize neccessary attr."""
        self._valid_iterable(iterables, weights)
        self.iterables = iterables
        self.weights = weights

    def _valid_iterable(self, iterables, weights):
        iter_keys = iterables.keys()
        weight_keys = weights.keys()
        if iter_keys != weight_keys:
            raise ValueError(f"keys in {iterables} & {weights} should be equal.")

    def __iter__(self):
        raise NotImplementedError


class SequentialMixer(MixingStrategy):
    """Generate data sequentially from `iterables` which is exhaustible."""

    def _iter_datasets(self):
        for ds_name, ds_weight in self.weights.items():
            for _ in range(ds_weight):
                yield ds_name

    def __iter__(self):
        for ds_name in self._iter_datasets():
            iterable = self.iterables[ds_name]
            yield from iterable


class WeightedMixer(MixingStrategy):
    """A mixing strategy that mix data weightedly and iterate infinitely."""

    def __init__(self, iterables, weights):
        super().__init__(iterables, weights)
        self._iterators = {}
        self._counts = {}
        for ds_name in self.iterables.keys():
            self._reset_iter(ds_name)

    def _logging(self):
        """Report corpora loading statistics."""
        msgs = []
        # patch to log stdout spawned processes of dataloader
        logger = init_logger()
        for ds_name, ds_count in self._counts.items():
            msgs.append(f"\t\t\t* {ds_name}: {ds_count}")
        logger.info("Weighted corpora loaded so far:\n" + "\n".join(msgs))

    def _reset_iter(self, ds_name):
        self._iterators[ds_name] = iter(self.iterables[ds_name])
        self._counts[ds_name] = self._counts.get(ds_name, 0) + 1
        self._logging()

    def _iter_datasets(self):
        for ds_name, ds_weight in self.weights.items():
            for _ in range(ds_weight):
                yield ds_name

    def __iter__(self):
        for ds_name in cycle(self._iter_datasets()):
            iterator = self._iterators[ds_name]
            try:
                item = next(iterator)
            except StopIteration:
                self._reset_iter(ds_name)
                iterator = self._iterators[ds_name]
                item = next(iterator)
            finally:
                yield item


class DynamicDatasetIter(torch.utils.data.IterableDataset):
    """Yield batch from (multiple) plain text corpus.

    Args:
        corpora (dict[str, ParallelCorpus]): collections of corpora to iterate;
        corpora_info (dict[str, dict]): corpora infos correspond to corpora;
        transforms (dict[str, Transform]): transforms may be used by corpora;
        vocabs (dict[str, Vocab]): vocab dict for convert corpora into Tensor;
        task (str): CorpusTask.TRAIN/VALID/INFER;
        batch_type (str): batching type to count on, choices=[tokens, sents];
        batch_size (int): numbers of examples in a batch;
        batch_size_multiple (int): make batch size multiply of this;
        data_type (str): input data type, currently only text;
        bucket_size (int): accum this number of examples in a dynamic dataset;
        bucket_size_init (int): initialize the bucket with this
        amount of examples;
        bucket_size_increment (int): increment the bucket
        size with this amount of examples;
        copy (Bool): if True, will add specific items for copy_attn
        skip_empty_level (str): security level when encouter empty line;
        stride (int): iterate data files with this stride;
        offset (int): iterate data files with this offset.

    Attributes:
        sort_key (function): functions define how to sort examples;
        mixer (MixingStrategy): the strategy to iterate corpora.
    """

    def __init__(
        self,
        corpora,
        corpora_info,
        transforms,
        vocabs,
        task,
        batch_type,
        batch_size,
        batch_size_multiple,
        data_type="text",
        bucket_size=2048,
        bucket_size_init=-1,
        bucket_size_increment=0,
        copy=False,
        skip_empty_level="warning",
        stride=1,
        offset=0,
    ):
        super(DynamicDatasetIter).__init__()
        self.corpora = corpora
        self.transforms = transforms
        self.vocabs = vocabs
        self.corpora_info = corpora_info
        self.task = task
        self.init_iterators = False
        self.batch_size = batch_size
        self.batch_type = batch_type
        self.batch_size_multiple = batch_size_multiple
        self.device = "cpu"
        self.sort_key = text_sort_key
        self.bucket_size = bucket_size
        self.bucket_size_init = bucket_size_init
        self.bucket_size_increment = bucket_size_increment
        self.copy = copy
        if stride <= 0:
            raise ValueError(f"Invalid argument for stride={stride}.")
        self.stride = stride
        self.offset = offset
        if skip_empty_level not in ["silent", "warning", "error"]:
            raise ValueError(f"Invalid argument skip_empty_level={skip_empty_level}")
        self.skip_empty_level = skip_empty_level
        self.random_shuffler = RandomShuffler()

    @classmethod
    def from_opt(cls, corpora, transforms, vocabs, opt, task, copy, stride=1, offset=0):
        """Initilize `DynamicDatasetIter` with options parsed from `opt`."""
        corpora_info = {}
        batch_size = (
            opt.valid_batch_size if (task == CorpusTask.VALID) else opt.batch_size
        )
        if task != CorpusTask.INFER:
            if opt.batch_size_multiple is not None:
                batch_size_multiple = opt.batch_size_multiple
            else:
                batch_size_multiple = 8 if opt.model_dtype == "fp16" else 1
            corpora_info = opt.data
            bucket_size = opt.bucket_size
            bucket_size_init = opt.bucket_size_init
            bucket_size_increment = opt.bucket_size_increment
            skip_empty_level = opt.skip_empty_level
            
        else:
            batch_size_multiple = 1
            corpora_info[CorpusTask.INFER] = {"transforms": opt.transforms}
            corpora_info[CorpusTask.INFER]["weight"] = 1
            # bucket_size = batch_size
            bucket_size = 16384
            bucket_size_init = -1
            bucket_size_increment = 0
            skip_empty_level = "warning"
        return cls(
            corpora,
            corpora_info,
            transforms,
            vocabs,
            task,
            opt.batch_type,
            batch_size,
            batch_size_multiple,
            data_type=opt.data_type,
            bucket_size=bucket_size,
            bucket_size_init=bucket_size_init,
            bucket_size_increment=bucket_size_increment,
            copy=copy,
            skip_empty_level=skip_empty_level,
            stride=stride,
            offset=offset,
        )

    def _init_datasets(self, worker_id):
        if self.num_workers > 0:
            stride = self.stride * self.num_workers
            offset = self.offset * self.num_workers + worker_id
        else:
            stride = self.stride
            offset = self.offset
        datasets_iterables = build_corpora_iters(
            self.corpora,
            self.transforms,
            self.corpora_info,
            skip_empty_level=self.skip_empty_level,
            stride=stride,
            offset=offset,
        )
        # import pdb
        # pdb.set_trace()
        datasets_weights = {
            ds_name: int(self.corpora_info[ds_name]["weight"])
            for ds_name in datasets_iterables.keys()
        }
        if self.task == CorpusTask.TRAIN:
            self.mixer = WeightedMixer(datasets_iterables, datasets_weights)
        else:
            self.mixer = SequentialMixer(datasets_iterables, datasets_weights)
        self.init_iterators = True

    def _tuple_to_json_with_tokIDs(self, tuple_bucket):
        bucket = []
        tuple_bucket = process(self.task, tuple_bucket)
        for example in tuple_bucket:
            if example is not None:
                if self.copy:
                    example = _addcopykeys(self.vocabs, example)
                bucket.append(numericalize(self.vocabs, example))
        return bucket

    def _bucketing(self):
        """
        Add up to bucket_size examples from the mixed corpora according
        to the above strategy. example tuple is converted to json and
        tokens numericalized.
        """
        bucket = []
        if self.bucket_size_init > 0:
            _bucket_size = self.bucket_size_init
        else:
            _bucket_size = self.bucket_size
        for ex in self.mixer:
            bucket.append(ex)
            if len(bucket) == _bucket_size:
                yield self._tuple_to_json_with_tokIDs(bucket)
                bucket = []
                if _bucket_size < self.bucket_size:
                    _bucket_size += self.bucket_size_increment
                else:
                    _bucket_size = self.bucket_size
        if bucket:
            yield self._tuple_to_json_with_tokIDs(bucket)

    def batch_iter(self, data, batch_size, batch_type="sents", batch_size_multiple=1):
        """Yield elements from data in chunks of batch_size,
        where each chunk size is a multiple of batch_size_multiple.
        """

        def batch_size_fn(nbsents, maxlen):
            if batch_type == "sents":
                return nbsents
            elif batch_type == "tokens":
                return nbsents * maxlen
            else:
                raise ValueError(f"Invalid argument batch_type={batch_type}")

        def max_src_tgt(ex):
            """return the max tokens btw src and tgt in the sequence."""
            if ex["tgt"]:
                return max(len(ex["src"]["src_ids"]), len(ex["tgt"]["tgt_ids"]))
            return len(ex["src"]["src_ids"])

        minibatch, maxlen, size_so_far, seen = [], 0, 0, set()
        for ex in data:
            src = ex["src"]["src"]
            if src not in seen or (self.task != CorpusTask.TRAIN):
                seen.add(src)
                minibatch.append(ex)
                nbsents = len(minibatch)
                maxlen = max(max_src_tgt(ex), maxlen)
                size_so_far = batch_size_fn(nbsents, maxlen)
                if size_so_far >= batch_size:
                    overflowed = 1 if size_so_far > batch_size else 0
                    if batch_size_multiple > 1:
                        overflowed += (nbsents - overflowed) % batch_size_multiple
                    if overflowed == 0:
                        yield minibatch
                        minibatch, maxlen, size_so_far, seen = [], 0, 0, set()
                    else:
                        if overflowed == nbsents:
                            logger.warning(
                                "The batch will be filled until we reach"
                                " %d, its size may exceed %d tokens"
                                % (batch_size_multiple, batch_size)
                            )
                        else:
                            yield minibatch[:-overflowed]
                            minibatch = minibatch[-overflowed:]
                            maxlen = max([max_src_tgt(ex) for ex in minibatch])
                            size_so_far = batch_size_fn(len(minibatch), maxlen)
                            seen = set()

        if minibatch:
            yield minibatch

    def __iter__(self):
        for bucket in self._bucketing():
            # For TRAIN we need to group examples by length
            # for faster performance, but otherwise, sequential.
            if self.task == CorpusTask.TRAIN:
                bucket = sorted(bucket, key=self.sort_key)
            p_batch = list(
                self.batch_iter(
                    bucket,
                    self.batch_size,
                    batch_type=self.batch_type,
                    batch_size_multiple=self.batch_size_multiple,
                )
            )
            # For TRAIN we shuffle batches within the bucket
            # otherwise sequential
            if self.task == CorpusTask.TRAIN:
                p_batch = self.random_shuffler(p_batch)
            for minibatch in p_batch:
                # for specific case of rnn_packed need to be sorted
                # within the batch
                minibatch.sort(key=self.sort_key, reverse=True)
                tensor_batch = tensorify(self.vocabs, minibatch)
                yield tensor_batch


def build_dynamic_dataset_iter(
    opt,
    transforms_cls,
    vocabs,
    copy=False,
    task=CorpusTask.TRAIN,
    stride=1,
    offset=0,
    src=None,
    tgt=None,
    align=None,
):
    """
    Build `DynamicDatasetIter` from opt.
    if src, tgt,align are passed then dataset is built from those lists
    instead of opt.[src, tgt, align]
    Typically this function is called for CorpusTask.[TRAIN,VALID,INFER]
    from the main tain / translate scripts
    We disable automatic batching in the DataLoader.
    The custom optimized batching is performed by the
    custom class DynamicDatasetIter inherited from IterableDataset
    (and not by a custom collate function).
    We load opt.bucket_size examples, sort them and yield
    mini-batchs of size opt.batch_size.
    The bucket_size must be large enough to ensure homogeneous batches.
    Each worker will load opt.prefetch_factor mini-batches in
    advance to avoid the GPU waiting during the refilling of the bucket.
    """
    transforms = make_transforms(opt, transforms_cls, vocabs)
    corpora = get_corpora(opt, task, src=src, tgt=tgt, align=align)
    if corpora is None:
        assert task != CorpusTask.TRAIN, "only valid corpus is ignorable."
        return None
    data_iter = DynamicDatasetIter.from_opt(
        corpora, transforms, vocabs, opt, task, copy=copy, stride=stride, offset=offset
    )
    data_iter.num_workers = opt.num_workers if hasattr(opt, "num_workers") else 0
    if data_iter.num_workers == 0 or task == CorpusTask.INFER:
        data_iter._init_datasets(0)  # when workers=0 init_fn not called
        data_loader = data_iter
    else:
        data_loader = DataLoader(
            data_iter,
            batch_size=None,
            pin_memory=True,
            multiprocessing_context="spawn",
            num_workers=data_iter.num_workers,
            worker_init_fn=data_iter._init_datasets,
            prefetch_factor=opt.prefetch_factor,
        )
    return data_loader