""" Pytorch Distributed utils
    This piece of code was heavily inspired by the equivalent of Fairseq-py
    https://github.com/pytorch/fairseq
"""
import os
import signal
import math
import pickle
import torch.distributed
from datetime import timedelta
from onmt.translate.translator import build_translator
from onmt.transforms import get_transforms_cls
from onmt.constants import CorpusTask
from onmt.utils.logging import init_logger, logger
from onmt.inputters.dynamic_iterator import build_dynamic_dataset_iter
from onmt.inputters.inputter import IterOnDevice


def is_master(opt, device_id):
    return opt.gpu_ranks[device_id] == 0


def multi_init(opt, device_id):
    dist_init_method = "tcp://{master_ip}:{master_port}".format(
        master_ip=opt.master_ip, master_port=opt.master_port
    )
    dist_world_size = opt.world_size
    torch.distributed.init_process_group(
        backend=opt.gpu_backend,
        init_method=dist_init_method,
        world_size=dist_world_size,
        rank=opt.gpu_ranks[device_id],
        timeout=timedelta(seconds=60),
    )
    gpu_rank = torch.distributed.get_rank()
    if not is_master(opt, device_id):
        logger.disabled = True

    return gpu_rank


def all_reduce_and_rescale_tensors(tensors, rescale_denom, buffer_size=104857600):
    """All-reduce and rescale tensors in chunks of the specified size.

    Args:
        tensors: list of Tensors to all-reduce
        rescale_denom: denominator for rescaling summed Tensors
        buffer_size: all-reduce chunk size in bytes
    """
    # buffer size in bytes, determine equiv. # of elements based on data type
    buffer_t = (
        tensors[0].new(math.ceil(buffer_size / tensors[0].element_size())).zero_()
    )
    buffer = []

    def all_reduce_buffer():
        # copy tensors into buffer_t
        offset = 0
        for t in buffer:
            numel = t.numel()
            buffer_t[offset : offset + numel].copy_(t.view(-1))
            offset += numel

        # all-reduce and rescale
        torch.distributed.all_reduce(buffer_t[:offset], async_op=False)
        buffer_t.div_(rescale_denom)

        # copy all-reduced buffer back into tensors
        offset = 0
        for t in buffer:
            numel = t.numel()
            t.view(-1).copy_(buffer_t[offset : offset + numel])
            offset += numel

    filled = 0
    for t in tensors:
        sz = t.numel() * t.element_size()
        # print(filled, sz)
        if sz > buffer_size:
            # tensor is bigger than buffer, all-reduce and rescale directly
            torch.distributed.all_reduce(t, async_op=False)
            t.div_(rescale_denom)
        elif filled + sz > buffer_size:
            # buffer is full, all-reduce and replace buffer with grad
            all_reduce_buffer()
            buffer = [t]
            filled = sz
        else:
            # add tensor to buffer
            buffer.append(t)
            filled += sz

    if len(buffer) > 0:
        all_reduce_buffer()


def all_gather_list(data, max_size=4096):
    """Gathers arbitrary data from all nodes into a list."""
    world_size = torch.distributed.get_world_size()
    if (
        not hasattr(all_gather_list, "_in_buffer")
        or max_size != all_gather_list._in_buffer.size()
    ):
        all_gather_list._in_buffer = torch.cuda.ByteTensor(max_size)
        all_gather_list._out_buffers = [
            torch.cuda.ByteTensor(max_size) for i in range(world_size)
        ]
    in_buffer = all_gather_list._in_buffer
    out_buffers = all_gather_list._out_buffers

    enc = pickle.dumps(data)
    enc_size = len(enc)
    if enc_size + 2 > max_size:
        raise ValueError("encoded data exceeds max_size: {}".format(enc_size + 2))
    assert max_size < 255 * 256
    in_buffer[0] = enc_size // 255  # this encoding works for max_size < 65k
    in_buffer[1] = enc_size % 255
    in_buffer[2 : enc_size + 2] = torch.ByteTensor(list(enc))

    torch.distributed.all_gather(out_buffers, in_buffer.cuda())

    results = []
    for i in range(world_size):
        out_buffer = out_buffers[i]
        size = (255 * out_buffer[0].item()) + out_buffer[1].item()

        bytes_list = bytes(out_buffer[2 : size + 2].tolist())
        result = pickle.loads(bytes_list)
        results.append(result)
    return results


class ErrorHandler(object):
    """A class that listens for exceptions in children processes and propagates
    the tracebacks to the parent process."""

    def __init__(self, error_queue):
        """init error handler"""
        import signal
        import threading

        self.error_queue = error_queue
        self.children_pids = []
        self.error_thread = threading.Thread(target=self.error_listener, daemon=True)
        self.error_thread.start()
        signal.signal(signal.SIGUSR1, self.signal_handler)

    def add_child(self, pid):
        """error handler"""
        self.children_pids.append(pid)

    def error_listener(self):
        """error listener"""
        (rank, original_trace) = self.error_queue.get()
        self.error_queue.put((rank, original_trace))
        os.kill(os.getpid(), signal.SIGUSR1)

    def signal_handler(self, signalnum, stackframe):
        """signal handler"""
        for pid in self.children_pids:
            os.kill(pid, signal.SIGINT)  # kill children processes
        (rank, original_trace) = self.error_queue.get()
        msg = """\n\n-- Tracebacks above this line can probably
                 be ignored --\n\n"""
        msg += original_trace
        raise Exception(msg)


def spawned_train(process_fn, opt, device_id, error_queue):  # noqa: E501
    """Run `process_fn` on `device_id` with data from `batch_queue`."""
    try:
        gpu_rank = multi_init(opt, device_id)
        if gpu_rank != opt.gpu_ranks[device_id]:
            raise AssertionError(
                "An error occurred in \
                  Distributed initialization"
            )
        process_fn(opt, device_id=device_id)
    except KeyboardInterrupt:
        pass  # killed by parent, do nothing
    except Exception:
        # propagate exception to parent process, keeping original traceback
        import traceback

        error_queue.put((opt.gpu_ranks[device_id], traceback.format_exc()))


def spawned_infer(opt, device_id, error_queue, queue_instruct, queue_result):
    """Run various functions for translation in spawned process on `device_id`."""
    try:
        gpu_rank = multi_init(opt, device_id)
        if gpu_rank != opt.gpu_ranks[device_id]:
            raise AssertionError(
                "An error occurred in \
                  Distributed initialization"
            )
        torch.cuda.set_device(device_id)
        init_logger(opt.log_file)
        translator = build_translator(opt, device_id, logger=logger, report_score=True)
        transforms_cls = get_transforms_cls(opt._all_transform)
        print("Device_id: ", device_id, " translator built")
        while True:
            instruction = queue_instruct.get()
            if instruction[0] == "stop":
                break
            elif instruction[0] == "infer_list":
                src = instruction[1]
                infer_iter = build_dynamic_dataset_iter(
                    opt,
                    transforms_cls,
                    translator.vocabs,
                    task=CorpusTask.INFER,
                    src=src,
                )
                infer_iter = IterOnDevice(infer_iter, device_id)
                scores, preds = translator._translate(
                    infer_iter, infer_iter.transform, opt.attn_debug, opt.align_debug
                )
                queue_result.put(scores)
                queue_result.put(preds)
            elif instruction[0] == "infer_file":
                infer_iter = build_dynamic_dataset_iter(
                    opt, transforms_cls, translator.vocabs, task=CorpusTask.INFER
                )
                infer_iter = IterOnDevice(infer_iter, device_id)
                scores, preds = translator._translate(
                    infer_iter, infer_iter.transform, opt.attn_debug, opt.align_debug
                )
                queue_result.put(scores)
                queue_result.put(preds)

    except KeyboardInterrupt:
        pass  # killed by parent, do nothing
    except Exception:
        # propagate exception to parent process, keeping original traceback
        import traceback

        error_queue.put((opt.gpu_ranks[device_id], traceback.format_exc()))