#!/usr/bin/env python
"""Training on a single process."""
import torch
import sys

from onmt.utils.logging import init_logger, logger
from onmt.utils.parse import ArgumentParser
from onmt.constants import CorpusTask
from onmt.transforms import (
    make_transforms,
    save_transforms,
    get_specials,
    get_transforms_cls,
)
from onmt.inputters import build_vocab, IterOnDevice
from onmt.inputters.inputter import dict_to_vocabs, vocabs_to_dict
from onmt.inputters.dynamic_iterator import build_dynamic_dataset_iter
from onmt.inputters.text_corpus import save_transformed_sample
from onmt.model_builder import build_model
from onmt.models.model_saver import load_checkpoint
from onmt.utils.optimizers import Optimizer
from onmt.utils.misc import set_random_seed
from onmt.trainer import build_trainer
from onmt.models import build_model_saver
from onmt.modules.embeddings import prepare_pretrained_embeddings


def prepare_transforms_vocabs(opt, transforms_cls):
    """Prepare or dump transforms before training."""
    # if transform + options set in 'valid' we need to copy in main
    # transform / options for scoring considered as inference
    validset_transforms = opt.data.get("valid", {}).get("transforms", None)
    if validset_transforms:
        opt.transforms = validset_transforms
        if opt.data.get("valid", {}).get("tgt_prefix", None):
            opt.tgt_prefix = opt.data.get("valid", {}).get("tgt_prefix", None)
        if opt.data.get("valid", {}).get("src_prefix", None):
            opt.src_prefix = opt.data.get("valid", {}).get("src_prefix", None)
        if opt.data.get("valid", {}).get("tgt_suffix", None):
            opt.tgt_suffix = opt.data.get("valid", {}).get("tgt_suffix", None)
        if opt.data.get("valid", {}).get("src_suffix", None):
            opt.src_suffix = opt.data.get("valid", {}).get("src_suffix", None)
    specials = get_specials(opt, transforms_cls)

    vocabs = build_vocab(opt, specials)

    # maybe prepare pretrained embeddings, if any
    prepare_pretrained_embeddings(opt, vocabs)

    if opt.dump_transforms or opt.n_sample != 0:
        transforms = make_transforms(opt, transforms_cls, vocabs)
    if opt.dump_transforms:
        save_transforms(transforms, opt.save_data, overwrite=opt.overwrite)
    if opt.n_sample != 0:
        logger.warning(
            "`-n_sample` != 0: Training will not be started. "
            f"Stop after saving {opt.n_sample} samples/corpus."
        )
        save_transformed_sample(opt, transforms, n_sample=opt.n_sample)
        logger.info("Sample saved, please check it before restart training.")
        sys.exit()
    logger.info(
        "The first 10 tokens of the vocabs are:"
        f"{vocabs_to_dict(vocabs)['src'][0:10]}"
    )
    logger.info(f"The decoder start token is: {opt.decoder_start_token}")
    return vocabs


def _init_train(opt):
    """Common initilization stuff for all training process.
    We need to build or rebuild the vocab in 3 cases:
    - training from scratch (train_from is false)
    - resume training but transforms have changed
    - resume training but vocab file has been modified
    """
    ArgumentParser.validate_prepare_opts(opt)
    transforms_cls = get_transforms_cls(opt._all_transform)
    if opt.train_from:
        # Load checkpoint if we resume from a previous training.
        checkpoint = load_checkpoint(ckpt_path=opt.train_from)
        vocabs = dict_to_vocabs(checkpoint["vocab"])
        if (
            hasattr(checkpoint["opt"], "_all_transform")
            and len(
                opt._all_transform.symmetric_difference(
                    checkpoint["opt"]._all_transform
                )
            )
            != 0
        ):
            _msg = "configured transforms is different from checkpoint:"
            new_transf = opt._all_transform.difference(checkpoint["opt"]._all_transform)
            old_transf = checkpoint["opt"]._all_transform.difference(opt._all_transform)
            if len(new_transf) != 0:
                _msg += f" +{new_transf}"
            if len(old_transf) != 0:
                _msg += f" -{old_transf}."
            logger.warning(_msg)
            vocabs = prepare_transforms_vocabs(opt, transforms_cls)
        if opt.update_vocab:
            logger.info("Updating checkpoint vocabulary with new vocabulary")
            vocabs = prepare_transforms_vocabs(opt, transforms_cls)
    else:
        checkpoint = None
        vocabs = prepare_transforms_vocabs(opt, transforms_cls)

    return checkpoint, vocabs, transforms_cls


def configure_process(opt, device_id):
    if device_id >= 0:
        torch.cuda.set_device(device_id)
    set_random_seed(opt.seed, device_id >= 0)


def _get_model_opts(opt, checkpoint=None):
    """Get `model_opt` to build model, may load from `checkpoint` if any."""
    if checkpoint is not None:
        model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
        if opt.override_opts:
            logger.info("Over-ride model option set to true - use with care")
            args = list(opt.__dict__.keys())
            model_args = list(model_opt.__dict__.keys())
            for arg in args:
                if arg in model_args and getattr(opt, arg) != getattr(model_opt, arg):
                    logger.info(
                        "Option: %s , value: %s overriding model: %s"
                        % (arg, getattr(opt, arg), getattr(model_opt, arg))
                    )
            model_opt = opt
        else:
            model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
            ArgumentParser.update_model_opts(model_opt)
            ArgumentParser.validate_model_opts(model_opt)
            if opt.tensorboard_log_dir == model_opt.tensorboard_log_dir and hasattr(
                model_opt, "tensorboard_log_dir_dated"
            ):
                # ensure tensorboard output is written in the directory
                # of previous checkpoints
                opt.tensorboard_log_dir_dated = (
                    model_opt.tensorboard_log_dir_dated
                )  # noqa: E501
            # Override checkpoint's update_embeddings as it defaults to false
            model_opt.update_vocab = opt.update_vocab
            # Override checkpoint's freezing settings as it defaults to false
            model_opt.freeze_encoder = opt.freeze_encoder
            model_opt.freeze_decoder = opt.freeze_decoder
    else:
        model_opt = opt
    return model_opt


def main(opt, device_id):
    """Start training on `device_id`."""
    # NOTE: It's important that ``opt`` has been validated and updated
    # at this point.

    configure_process(opt, device_id)
    init_logger(opt.log_file)
    checkpoint, vocabs, transforms_cls = _init_train(opt)
    model_opt = _get_model_opts(opt, checkpoint=checkpoint)

    # Build model.
    model = build_model(model_opt, opt, vocabs, checkpoint, device_id)

    model.count_parameters(log=logger.info)
    trainable = {
        "torch.float32": 0,
        "torch.float16": 0,
        "torch.uint8": 0,
        "torch.int8": 0,
    }
    non_trainable = {
        "torch.float32": 0,
        "torch.float16": 0,
        "torch.uint8": 0,
        "torch.int8": 0,
    }
    for n, p in model.named_parameters():
        if p.requires_grad:
            trainable[str(p.dtype)] += p.numel()
        else:
            non_trainable[str(p.dtype)] += p.numel()
    logger.info("Trainable parameters = %s" % str(trainable))
    logger.info("Non trainable parameters = %s" % str(non_trainable))
    logger.info(" * src vocab size = %d" % len(vocabs["src"]))
    logger.info(" * tgt vocab size = %d" % len(vocabs["tgt"]))
    if "src_feats" in vocabs:
        for i, feat_vocab in enumerate(vocabs["src_feats"]):
            logger.info(f"* src_feat {i} vocab size = {len(feat_vocab)}")

    # Build optimizer.
    optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

    del checkpoint

    # Build model saver
    model_saver = build_model_saver(model_opt, opt, model, vocabs, optim, device_id)

    trainer = build_trainer(
        opt, device_id, model, vocabs, optim, model_saver=model_saver
    )

    offset = max(0, device_id) if opt.parallel_mode == "data_parallel" else 0
    stride = max(1, len(opt.gpu_ranks)) if opt.parallel_mode == "data_parallel" else 1

    _train_iter = build_dynamic_dataset_iter(
        opt,
        transforms_cls,
        vocabs,
        task=CorpusTask.TRAIN,
        copy=opt.copy_attn,
        stride=stride,
        offset=offset,
    )
    train_iter = IterOnDevice(_train_iter, device_id)

    valid_iter = build_dynamic_dataset_iter(
        opt, transforms_cls, vocabs, task=CorpusTask.VALID, copy=opt.copy_attn
    )

    if valid_iter is not None:
        valid_iter = IterOnDevice(valid_iter, device_id)

    if len(opt.gpu_ranks):
        logger.info("Starting training on GPU: %s" % opt.gpu_ranks)
    else:
        logger.info("Starting training on CPU, could be very slow")
    train_steps = opt.train_steps
    if opt.single_pass and train_steps > 0:
        logger.warning("Option single_pass is enabled, ignoring train_steps.")
        train_steps = 0

    trainer.train(
        train_iter,
        train_steps,
        save_checkpoint_steps=opt.save_checkpoint_steps,
        valid_iter=valid_iter,
        valid_steps=opt.valid_steps,
    )

    if trainer.report_manager.tensorboard_writer is not None:
        trainer.report_manager.tensorboard_writer.close()