import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch
import numpy as np
import os, time, random
import argparse
from torch.utils.data import Dataset, DataLoader
from PIL import Image as PILImage
from glob import glob
from tqdm import tqdm
import rawpy
import colour_demosaicing

from .InvISP.model.model import InvISPNet
from .utils.common import Notify
from datasets.noise import (
    camera_params,
    addGStarNoise,
    addPStarNoise,
    addQuantNoise,
    addRowNoise,
    sampleK,
)


class NoiseSimulator:
    def __init__(self, device, ckpt_path="./datasets/InvISP/pretrained/canon.pth"):
        self.device = device

        # load Invertible ISP Network
        self.net = (
            InvISPNet(channel_in=3, channel_out=3, block_num=8).to(self.device).eval()
        )
        self.net.load_state_dict(torch.load(ckpt_path), strict=False)
        print(
            Notify.INFO, "Loaded ISPNet checkpoint: {}".format(ckpt_path), Notify.ENDC
        )

        # white balance parameters
        self.wb = np.array([2020.0, 1024.0, 1458.0, 1024.0])

        # use Canon EOS 5D4 noise parameters provided by ELD
        self.camera_params = camera_params

        # random specify exposure time ratio from 50 to 150
        self.ratio_min = 50
        self.ratio_max = 150
        pass

    # inverse demosaic
    # input: [H, W, 3]
    # output: [H, W]
    def invDemosaic(self, img):
        img_R = img[::2, ::2, 0]
        img_G1 = img[::2, 1::2, 1]
        img_G2 = img[1::2, ::2, 1]
        img_B = img[1::2, 1::2, 2]
        raw_img = np.ones(img.shape[:2])
        raw_img[::2, ::2] = img_R
        raw_img[::2, 1::2] = img_G1
        raw_img[1::2, ::2] = img_G2
        raw_img[1::2, 1::2] = img_B
        return raw_img

    # demosaic - nearest ver
    # input: [H, W]
    # output: [H, W, 3]
    def demosaicNearest(self, img):
        raw = np.ones((img.shape[0], img.shape[1], 3))
        raw[::2, ::2, 0] = img[::2, ::2]
        raw[::2, 1::2, 0] = img[::2, ::2]
        raw[1::2, ::2, 0] = img[::2, ::2]
        raw[1::2, 1::2, 0] = img[::2, ::2]
        raw[::2, ::2, 2] = img[1::2, 1::2]
        raw[::2, 1::2, 2] = img[1::2, 1::2]
        raw[1::2, ::2, 2] = img[1::2, 1::2]
        raw[1::2, 1::2, 2] = img[1::2, 1::2]
        raw[::2, ::2, 1] = img[::2, 1::2]
        raw[::2, 1::2, 1] = img[::2, 1::2]
        raw[1::2, ::2, 1] = img[1::2, ::2]
        raw[1::2, 1::2, 1] = img[1::2, ::2]
        return raw

    # demosaic
    # input: [H, W]
    # output: [H, W, 3]
    def demosaic(self, img):
        return colour_demosaicing.demosaicing_CFA_Bayer_bilinear(img, "RGGB")

    # load rgb image
    def path2rgb(self, path):
        return torch.from_numpy(np.array(PILImage.open(path)) / 255.0)

    # InvISP
    # input: rgb image [H, W, 3]
    # output: raw image [H, W]
    def rgb2raw(self, rgb, batched=False):
        # 1. rgb -> invnet
        if not batched:
            rgb = rgb.unsqueeze(0)

        rgb = rgb.permute(0, 3, 1, 2).float().to(self.device)
        with torch.no_grad():
            reconstruct_raw = self.net(rgb, rev=True)

        pred_raw = reconstruct_raw.detach().permute(0, 2, 3, 1)
        pred_raw = torch.clamp(pred_raw, 0, 1)

        if not batched:
            pred_raw = pred_raw[0, ...]

        pred_raw = pred_raw.cpu().numpy()

        # 2. -> inv gamma
        norm_value = np.power(16383, 1 / 2.2)
        pred_raw *= norm_value
        pred_raw = np.power(pred_raw, 2.2)

        # 3. -> inv white balance
        wb = self.wb / self.wb.max()
        pred_raw = pred_raw / wb[:-1]

        # 4. -> add black level
        pred_raw += self.camera_params["black_level"]

        # 5. -> inv demosaic
        if not batched:
            pred_raw = self.invDemosaic(pred_raw)
        else:
            preds = []
            for i in range(pred_raw.shape[0]):
                preds.append(self.invDemosaic(pred_raw[i]))
            pred_raw = np.stack(preds, axis=0)

        return pred_raw

    def raw2noisyRaw(self, raw, ratio_dec=1, batched=False):
        if not batched:
            ratio = (random.uniform(self.ratio_min, self.ratio_max) - 1) * ratio_dec + 1
            raw = raw.copy() / ratio

            K = sampleK(self.camera_params["Kmin"], self.camera_params["Kmax"])
            q = 1 / (
                self.camera_params["max_value"] - self.camera_params["black_level"]
            )

            raw = addPStarNoise(raw, K)
            raw = addGStarNoise(
                raw,
                K,
                self.camera_params["G_shape"],
                self.camera_params["Profile-1"]["G_scale"],
            )
            raw = addRowNoise(raw, K, self.camera_params["Profile-1"]["R_scale"])
            raw = addQuantNoise(raw, q)
            raw *= ratio
            return raw

        else:
            raw = raw.copy()
            for i in range(raw.shape[0]):
                ratio = random.uniform(self.ratio_min, self.ratio_max)
                raw[i] /= ratio

                K = sampleK(self.camera_params["Kmin"], self.camera_params["Kmax"])
                q = 1 / (
                    self.camera_params["max_value"] - self.camera_params["black_level"]
                )

                raw[i] = addPStarNoise(raw[i], K)
                raw[i] = addGStarNoise(
                    raw[i],
                    K,
                    self.camera_params["G_shape"],
                    self.camera_params["Profile-1"]["G_scale"],
                )
                raw[i] = addRowNoise(
                    raw[i], K, self.camera_params["Profile-1"]["R_scale"]
                )
                raw[i] = addQuantNoise(raw[i], q)
                raw[i] *= ratio
            return raw

    def raw2rgb(self, raw, batched=False):
        # 1. -> demosaic
        if not batched:
            raw = self.demosaic(raw)
        else:
            raws = []
            for i in range(raw.shape[0]):
                raws.append(self.demosaic(raw[i]))
            raw = np.stack(raws, axis=0)

        # 2. -> substract black level
        raw -= self.camera_params["black_level"]
        raw = np.clip(
            raw, 0, self.camera_params["max_value"] - self.camera_params["black_level"]
        )

        # 3. -> white balance
        wb = self.wb / self.wb.max()
        raw = raw * wb[:-1]

        # 4. -> gamma
        norm_value = np.power(16383, 1 / 2.2)
        raw = np.power(raw, 1 / 2.2)
        raw /= norm_value

        # 5. -> ispnet
        if not batched:
            input_raw_img = (
                torch.Tensor(raw)
                .permute(2, 0, 1)
                .float()
                .to(self.device)[np.newaxis, ...]
            )
        else:
            input_raw_img = (
                torch.Tensor(raw).permute(0, 3, 1, 2).float().to(self.device)
            )

        with torch.no_grad():
            reconstruct_rgb = self.net(input_raw_img)
            reconstruct_rgb = torch.clamp(reconstruct_rgb, 0, 1)

        pred_rgb = reconstruct_rgb.detach().permute(0, 2, 3, 1)

        if not batched:
            pred_rgb = pred_rgb[0, ...]
        pred_rgb = pred_rgb.cpu().numpy()

        return pred_rgb

    def raw2packedRaw(self, raw, batched=False):
        # 1. -> substract black level
        raw -= self.camera_params["black_level"]
        raw = np.clip(
            raw, 0, self.camera_params["max_value"] - self.camera_params["black_level"]
        )
        raw /= self.camera_params["max_value"]

        # 2. pack
        if not batched:
            im = np.expand_dims(raw, axis=2)
            img_shape = im.shape
            H = img_shape[0]
            W = img_shape[1]

            out = np.concatenate(
                (
                    im[0:H:2, 0:W:2, :],
                    im[0:H:2, 1:W:2, :],
                    im[1:H:2, 1:W:2, :],
                    im[1:H:2, 0:W:2, :],
                ),
                axis=2,
            )
        else:
            im = np.expand_dims(raw, axis=3)
            img_shape = im.shape
            H = img_shape[1]
            W = img_shape[2]

            out = np.concatenate(
                (
                    im[:, 0:H:2, 0:W:2, :],
                    im[:, 0:H:2, 1:W:2, :],
                    im[:, 1:H:2, 1:W:2, :],
                    im[:, 1:H:2, 0:W:2, :],
                ),
                axis=3,
            )
        return out

    def raw2demosaicRaw(self, raw, batched=False):
        # 1. -> demosaic
        if not batched:
            raw = self.demosaic(raw)
        else:
            raws = []
            for i in range(raw.shape[0]):
                raws.append(self.demosaic(raw[i]))
            raw = np.stack(raws, axis=0)

        # 2. -> substract black level
        raw -= self.camera_params["black_level"]
        raw = np.clip(
            raw, 0, self.camera_params["max_value"] - self.camera_params["black_level"]
        )
        raw /= self.camera_params["max_value"]
        return raw