#!/usr/bin/env python3
import torch
from torch import nn

from .inference import make_seg_postprocessor
from .loss import make_seg_loss_evaluator
import time


def conv3x3(in_planes, out_planes, stride=1, has_bias=False):
    "3x3 convolution with padding"
    return nn.Conv2d(
        in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=has_bias
    )


def conv3x3_bn_relu(in_planes, out_planes, stride=1, has_bias=False):
    return nn.Sequential(
        conv3x3(in_planes, out_planes, stride),
        nn.BatchNorm2d(out_planes),
        nn.ReLU(inplace=True),
    )


class SEGHead(nn.Module):
    """
    Adds a simple SEG Head with pixel-level prediction
    """

    def __init__(self, in_channels, cfg):
        """
        Arguments:
            in_channels (int): number of channels of the input feature
        """
        super(SEGHead, self).__init__()
        self.cfg = cfg
        ndim = 256
        self.fpn_out5 = nn.Sequential(
            conv3x3(ndim, 64), nn.Upsample(scale_factor=8, mode="nearest")
        )
        self.fpn_out4 = nn.Sequential(
            conv3x3(ndim, 64), nn.Upsample(scale_factor=4, mode="nearest")
        )
        self.fpn_out3 = nn.Sequential(
            conv3x3(ndim, 64), nn.Upsample(scale_factor=2, mode="nearest")
        )
        self.fpn_out2 = conv3x3(ndim, 64)
        self.seg_out = nn.Sequential(
            conv3x3_bn_relu(in_channels, 64, 1),
            nn.ConvTranspose2d(64, 64, 2, 2),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 1, 2, 2),
            nn.Sigmoid(),
        )
        if self.cfg.MODEL.SEG.USE_PPM:
            # PPM Module
            pool_scales=(2, 4, 8)
            fc_dim = 256
            self.ppm_pooling = []
            self.ppm_conv = []
            for scale in pool_scales:
                self.ppm_pooling.append(nn.AdaptiveAvgPool2d(scale))
                self.ppm_conv.append(nn.Sequential(
                    nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False),
                    nn.BatchNorm2d(512),
                    nn.ReLU(inplace=True)
                ))
            self.ppm_pooling = nn.ModuleList(self.ppm_pooling)
            self.ppm_conv = nn.ModuleList(self.ppm_conv)
            self.ppm_last_conv = conv3x3_bn_relu(fc_dim + len(pool_scales)*512, ndim, 1)
            self.ppm_conv.apply(self.weights_init)
            self.ppm_last_conv.apply(self.weights_init)
        self.fpn_out5.apply(self.weights_init)
        self.fpn_out4.apply(self.weights_init)
        self.fpn_out3.apply(self.weights_init)
        self.fpn_out2.apply(self.weights_init)
        self.seg_out.apply(self.weights_init)

    def forward(self, x):
        if self.cfg.MODEL.SEG.USE_PPM:
            conv5 = x[-2]
            input_size = conv5.size()
            ppm_out = [conv5]
            for pool_scale, pool_conv in zip(self.ppm_pooling, self.ppm_conv):
                ppm_out.append(pool_conv(nn.functional.interpolate(
                    pool_scale(conv5),
                    (input_size[2], input_size[3]),
                    mode='bilinear', align_corners=False)))
            ppm_out = torch.cat(ppm_out, 1)
            f = self.ppm_last_conv(ppm_out)
        else:
            f = x[-2]
        # p5 = self.fpn_out5(x[-2])
        p5 = self.fpn_out5(f)
        p4 = self.fpn_out4(x[-3])
        p3 = self.fpn_out3(x[-4])
        p2 = self.fpn_out2(x[-5])
        fuse = torch.cat((p5, p4, p3, p2), 1)
        out = self.seg_out(fuse)
        return out, fuse

    def weights_init(self, m):
        classname = m.__class__.__name__
        if classname.find("Conv") != -1:
            nn.init.kaiming_normal_(m.weight.data)
        elif classname.find("BatchNorm") != -1:
            m.weight.data.fill_(1.0)
            m.bias.data.fill_(1e-4)


class SEGModule(torch.nn.Module):
    """
    Module for RPN computation. Takes feature maps from the backbone and RPN
    proposals and losses. Works for both FPN and non-FPN.
    """

    def __init__(self, cfg):
        super(SEGModule, self).__init__()

        self.cfg = cfg.clone()

        in_channels = cfg.MODEL.BACKBONE.OUT_CHANNELS
        head = SEGHead(in_channels, cfg)

        box_selector_train = make_seg_postprocessor(cfg, is_train=True)
        box_selector_test = make_seg_postprocessor(cfg, is_train=False)

        loss_evaluator = make_seg_loss_evaluator(cfg)

        # self.anchor_generator = anchor_generator
        self.head = head
        self.box_selector_train = box_selector_train
        self.box_selector_test = box_selector_test
        self.loss_evaluator = loss_evaluator

    def forward(self, images, features, targets=None):
        """
        Arguments:
            images (ImageList): images for which we want to compute the predictions
            features (Tensor): fused feature from FPN
            targets (Tensor): segmentaion gt map

        Returns:
            boxes (list[BoxList]): the predicted boxes from the RPN, one BoxList per
                image.
            losses (dict[Tensor]): the losses for the model during training. During
                testing, it is an empty dict.
        """
        preds, fuse_feature = self.head(features)
        # anchors = self.anchor_generator(images, features)
        image_shapes = images.get_sizes()
        if self.training:
            return self._forward_train(preds, targets, image_shapes), [fuse_feature]
        else:
            return self._forward_test(preds, image_shapes), [fuse_feature]

    def _forward_train(self, preds, targets, image_shapes):
        # Segmentation map must be transformed into boxes for detection.
        # sampled into a training batch.
        with torch.no_grad():
            boxes = self.box_selector_train(preds, image_shapes, targets)
        loss_seg = self.loss_evaluator(preds, targets)
        losses = {"loss_seg": loss_seg}
        return boxes, losses

    def _forward_test(self, preds, image_shapes):
        # torch.cuda.synchronize()
        # start_time = time.time()
        boxes, rotated_boxes, polygons, scores = self.box_selector_test(preds, image_shapes)
        # torch.cuda.synchronize()
        # end_time = time.time()
        # print('post time:', end_time - start_time)
        seg_results = {'rotated_boxes': rotated_boxes, 'polygons': polygons, 'preds': preds, 'scores': scores}
        return boxes, seg_results


def build_segmentation(cfg):
    """
    This gives the gist of it. Not super important because it doesn't change as much
    """
    return SEGModule(cfg)