# Copyright 2019-present NAVER Corp.
# CC BY-NC-SA 3.0
# Available only for non-commercial use

import pdb
from tqdm import tqdm
from collections import defaultdict

import torch
import torch.nn as nn


class Trainer(nn.Module):
    """Helper class to train a deep network.
        Overload this class `forward_backward` for your actual needs.

    Usage:
        train = Trainer(net, loader, loss, optimizer)
        for epoch in range(n_epochs):
            train()
    """

    def __init__(self, net, loader, loss, optimizer):
        nn.Module.__init__(self)
        self.net = net
        self.loader = loader
        self.loss_func = loss
        self.optimizer = optimizer

    def iscuda(self):
        return next(self.net.parameters()).device != torch.device("cpu")

    def todevice(self, x):
        if isinstance(x, dict):
            return {k: self.todevice(v) for k, v in x.items()}
        if isinstance(x, (tuple, list)):
            return [self.todevice(v) for v in x]

        if self.iscuda():
            return x.contiguous().cuda(non_blocking=True)
        else:
            return x.cpu()

    def __call__(self):
        self.net.train()

        stats = defaultdict(list)

        for iter, inputs in enumerate(tqdm(self.loader)):
            inputs = self.todevice(inputs)

            # compute gradient and do model update
            self.optimizer.zero_grad()

            loss, details = self.forward_backward(inputs)
            if torch.isnan(loss):
                raise RuntimeError("Loss is NaN")

            self.optimizer.step()

            for key, val in details.items():
                stats[key].append(val)

        print(" Summary of losses during this epoch:")
        mean = lambda lis: sum(lis) / len(lis)
        for loss_name, vals in stats.items():
            N = 1 + len(vals) // 10
            print(f"  - {loss_name:20}:", end="")
            print(
                f" {mean(vals[:N]):.3f} --> {mean(vals[-N:]):.3f} (avg: {mean(vals):.3f})"
            )
        return mean(stats["loss"])  # return average loss

    def forward_backward(self, inputs):
        raise NotImplementedError()