# -*- coding: utf-8 -*-
import os
import math
import codecs
import torch
import pyonmttok
from onmt.constants import DefaultTokens
from onmt.transforms import TransformPipe


class IterOnDevice(torch.utils.data.IterableDataset):
    """Sent items from `iterable` on `device_id` and yield."""

    def __init__(self, iterable, device_id):
        super(IterOnDevice).__init__()
        self.iterable = iterable
        self.device_id = device_id
        # temporary as long as translation_server and scoring_preparator still use lists
        if hasattr(iterable, "transforms"):
            self.transform = TransformPipe.build_from(
                [iterable.transforms[name] for name in iterable.transforms]
            )

    @staticmethod
    def batch_to_device(tensor_batch, device_id):
        """Move `batch` to `device_id`, cpu if `device_id` < 0."""
        device = torch.device(device_id) if device_id >= 0 else torch.device("cpu")
        for key in tensor_batch.keys():
            if key != "src_ex_vocab":
                tensor_batch[key] = tensor_batch[key].to(device)

    def __iter__(self):
        for tensor_batch in self.iterable:
            self.batch_to_device(tensor_batch, self.device_id)
            yield tensor_batch


def build_vocab(opt, specials):
    """Build vocabs dict to be stored in the checkpoint
        based on vocab files having each line [token, count]
    Args:
        opt: src_vocab, tgt_vocab, n_src_feats
    Return:
        vocabs: {'src': pyonmttok.Vocab, 'tgt': pyonmttok.Vocab,
                 'src_feats' : [pyonmttok.Vocab, ...]},
                 'data_task': seq2seq or lm
                 'decoder_start_token': DefaultTokens.BOS
                }
    """

    def _pad_vocab_to_multiple(vocab, multiple):
        vocab_size = len(vocab)
        if vocab_size % multiple == 0:
            return vocab
        target_size = int(math.ceil(vocab_size / multiple)) * multiple
        for i in range(target_size - vocab_size):
            vocab.add_token(DefaultTokens.VOCAB_PAD + str(i))
        return vocab

    default_specials = opt.default_specials
    vocabs = {}
    src_vocab = _read_vocab_file(opt.src_vocab, opt.src_words_min_frequency)

    src_specials = [
        item for item in (default_specials + specials["src"]) if item not in src_vocab
    ]

    if DefaultTokens.SEP in src_specials and (
        "<0x0A>" in src_vocab or "Ċ" in src_vocab
    ):
        # this is hack: if the special separator ⦅newline⦆is returned because of the
        # "docify" transform.get_specials we don't add it if the corresponding newline code
        # is already included in the sentencepiece or BPE-with-gpt2-pretok.
        src_specials.remove(DefaultTokens.SEP)

    src_vocab = pyonmttok.build_vocab_from_tokens(
        src_vocab, maximum_size=opt.src_vocab_size, special_tokens=src_specials
    )
    src_vocab.default_id = src_vocab[DefaultTokens.UNK]

    if opt.vocab_size_multiple > 1:
        src_vocab = _pad_vocab_to_multiple(src_vocab, opt.vocab_size_multiple)
    vocabs["src"] = src_vocab
    if opt.share_vocab:
        vocabs["tgt"] = src_vocab
    else:
        tgt_vocab = _read_vocab_file(opt.tgt_vocab, opt.tgt_words_min_frequency)
        tgt_specials = [
            item
            for item in (default_specials + specials["tgt"])
            if item not in tgt_vocab
        ]
        if DefaultTokens.SEP in tgt_specials and (
            "<0x0A>" in tgt_vocab or "Ċ" in src_vocab
        ):
            tgt_specials.remove(DefaultTokens.SEP)
        tgt_vocab = pyonmttok.build_vocab_from_tokens(
            tgt_vocab, maximum_size=opt.tgt_vocab_size, special_tokens=tgt_specials
        )
        tgt_vocab.default_id = tgt_vocab[DefaultTokens.UNK]
        if opt.vocab_size_multiple > 1:
            tgt_vocab = _pad_vocab_to_multiple(tgt_vocab, opt.vocab_size_multiple)
        vocabs["tgt"] = tgt_vocab

    if opt.n_src_feats > 0:
        src_feats_vocabs = []
        for i in range(opt.n_src_feats):
            src_f_vocab = _read_vocab_file(f"{opt.src_vocab}_feat{i}", 1)
            src_f_vocab = pyonmttok.build_vocab_from_tokens(
                src_f_vocab,
                maximum_size=0,
                minimum_frequency=1,
                special_tokens=default_specials,
            )
            src_f_vocab.default_id = src_f_vocab[DefaultTokens.UNK]
            if opt.vocab_size_multiple > 1:
                src_f_vocab = _pad_vocab_to_multiple(
                    src_f_vocab, opt.vocab_size_multiple
                )
            src_feats_vocabs.append(src_f_vocab)
        vocabs["src_feats"] = src_feats_vocabs

    vocabs["data_task"] = opt.data_task
    vocabs["decoder_start_token"] = opt.decoder_start_token

    return vocabs


def _read_vocab_file(vocab_path, min_count):
    """Loads a vocabulary from the given path.

    Args:
        vocab_path (str): Path to utf-8 text file containing vocabulary.
            Each token should be on a line, may followed with a count number
            seperate by space if `with_count`. No extra whitespace is allowed.
        min_count (int): retains only tokens with min_count frequency.
    """

    if not os.path.exists(vocab_path):
        raise RuntimeError("Vocabulary not found at {}".format(vocab_path))
    else:
        with codecs.open(vocab_path, "rb", "utf-8") as f:
            lines = [line.strip("\n") for line in f if line.strip("\n")]
            first_line = lines[0].split(None, 1)
            has_count = len(first_line) == 2 and first_line[-1].isdigit()
            if has_count:
                vocab = []
                for line in lines:
                    if int(line.split(None, 1)[1]) >= min_count:
                        vocab.append(line.split(None, 1)[0])
            else:
                vocab = lines
            return vocab


def vocabs_to_dict(vocabs):
    """
    Convert a dict of pyonmttok vocabs
    into a plain text dict to be saved in the checkpoint
    """
    vocabs_dict = {}
    vocabs_dict["src"] = vocabs["src"].ids_to_tokens
    vocabs_dict["tgt"] = vocabs["tgt"].ids_to_tokens
    if "src_feats" in vocabs.keys():
        vocabs_dict["src_feats"] = [
            feat_vocab.ids_to_tokens for feat_vocab in vocabs["src_feats"]
        ]
    vocabs_dict["data_task"] = vocabs["data_task"]
    if "decoder_start_token" in vocabs.keys():
        vocabs_dict["decoder_start_token"] = vocabs["decoder_start_token"]
    else:
        vocabs_dict["decoder_start_token"] = DefaultTokens.BOS
    return vocabs_dict


def dict_to_vocabs(vocabs_dict):
    """
    Convert a dict formatted vocabs (as stored in a checkpoint)
    into a dict of pyonmttok vocabs objects.
    """
    vocabs = {}
    vocabs["data_task"] = vocabs_dict["data_task"]
    if "decoder_start_token" in vocabs_dict.keys():
        vocabs["decoder_start_token"] = vocabs_dict["decoder_start_token"]
    else:
        vocabs["decoder_start_token"] = DefaultTokens.BOS
    vocabs["src"] = pyonmttok.build_vocab_from_tokens(vocabs_dict["src"])
    if vocabs_dict["src"] == vocabs_dict["tgt"]:
        vocabs["tgt"] = vocabs["src"]
    else:
        vocabs["tgt"] = pyonmttok.build_vocab_from_tokens(vocabs_dict["tgt"])
    if "src_feats" in vocabs_dict.keys():
        vocabs["src_feats"] = []
        for feat_vocab in vocabs_dict["src_feats"]:
            vocabs["src_feats"].append(pyonmttok.build_vocab_from_tokens(feat_vocab))
    return vocabs