import sys
import cv2
import utils
import numpy as np

import torch
from PIL import Image

from utils import convert_state_dict
from models import restormer_arch
from data.preprocess.crop_merge_image import stride_integral

sys.path.append("./data/MBD/")
from data.MBD.infer import net1_net2_infer_single_im


def dewarp_prompt(img):
    mask = net1_net2_infer_single_im(img, "data/MBD/checkpoint/mbd.pkl")
    base_coord = utils.getBasecoord(256, 256) / 256
    img[mask == 0] = 0
    mask = cv2.resize(mask, (256, 256)) / 255
    return img, np.concatenate((base_coord, np.expand_dims(mask, -1)), -1)


def deshadow_prompt(img):
    h, w = img.shape[:2]
    # img = cv2.resize(img,(128,128))
    img = cv2.resize(img, (1024, 1024))
    rgb_planes = cv2.split(img)
    result_planes = []
    result_norm_planes = []
    bg_imgs = []
    for plane in rgb_planes:
        dilated_img = cv2.dilate(plane, np.ones((7, 7), np.uint8))
        bg_img = cv2.medianBlur(dilated_img, 21)
        bg_imgs.append(bg_img)
        diff_img = 255 - cv2.absdiff(plane, bg_img)
        norm_img = cv2.normalize(
            diff_img,
            None,
            alpha=0,
            beta=255,
            norm_type=cv2.NORM_MINMAX,
            dtype=cv2.CV_8UC1,
        )
        result_planes.append(diff_img)
        result_norm_planes.append(norm_img)
    bg_imgs = cv2.merge(bg_imgs)
    bg_imgs = cv2.resize(bg_imgs, (w, h))
    # result = cv2.merge(result_planes)
    result_norm = cv2.merge(result_norm_planes)
    result_norm[result_norm == 0] = 1
    shadow_map = np.clip(
        img.astype(float) / result_norm.astype(float) * 255, 0, 255
    ).astype(np.uint8)
    shadow_map = cv2.resize(shadow_map, (w, h))
    shadow_map = cv2.cvtColor(shadow_map, cv2.COLOR_BGR2GRAY)
    shadow_map = cv2.cvtColor(shadow_map, cv2.COLOR_GRAY2BGR)
    # return shadow_map
    return bg_imgs


def deblur_prompt(img):
    x = cv2.Sobel(img, cv2.CV_16S, 1, 0)
    y = cv2.Sobel(img, cv2.CV_16S, 0, 1)
    absX = cv2.convertScaleAbs(x)  # 转回uint8
    absY = cv2.convertScaleAbs(y)
    high_frequency = cv2.addWeighted(absX, 0.5, absY, 0.5, 0)
    high_frequency = cv2.cvtColor(high_frequency, cv2.COLOR_BGR2GRAY)
    high_frequency = cv2.cvtColor(high_frequency, cv2.COLOR_GRAY2BGR)
    return high_frequency


def appearance_prompt(img):
    h, w = img.shape[:2]
    # img = cv2.resize(img,(128,128))
    img = cv2.resize(img, (1024, 1024))
    rgb_planes = cv2.split(img)
    result_planes = []
    result_norm_planes = []
    for plane in rgb_planes:
        dilated_img = cv2.dilate(plane, np.ones((7, 7), np.uint8))
        bg_img = cv2.medianBlur(dilated_img, 21)
        diff_img = 255 - cv2.absdiff(plane, bg_img)
        norm_img = cv2.normalize(
            diff_img,
            None,
            alpha=0,
            beta=255,
            norm_type=cv2.NORM_MINMAX,
            dtype=cv2.CV_8UC1,
        )
        result_planes.append(diff_img)
        result_norm_planes.append(norm_img)
    result_norm = cv2.merge(result_norm_planes)
    result_norm = cv2.resize(result_norm, (w, h))
    return result_norm


def binarization_promptv2(img):
    result, thresh = utils.SauvolaModBinarization(img)
    thresh = thresh.astype(np.uint8)
    result[result > 155] = 255
    result[result <= 155] = 0

    x = cv2.Sobel(img, cv2.CV_16S, 1, 0)
    y = cv2.Sobel(img, cv2.CV_16S, 0, 1)
    absX = cv2.convertScaleAbs(x)  # 转回uint8
    absY = cv2.convertScaleAbs(y)
    high_frequency = cv2.addWeighted(absX, 0.5, absY, 0.5, 0)
    high_frequency = cv2.cvtColor(high_frequency, cv2.COLOR_BGR2GRAY)
    return np.concatenate(
        (
            np.expand_dims(thresh, -1),
            np.expand_dims(high_frequency, -1),
            np.expand_dims(result, -1),
        ),
        -1,
    )


def dewarping(model, im_org, device):
    INPUT_SIZE = 256
    im_masked, prompt_org = dewarp_prompt(im_org.copy())

    h, w = im_masked.shape[:2]
    im_masked = im_masked.copy()
    im_masked = cv2.resize(im_masked, (INPUT_SIZE, INPUT_SIZE))
    im_masked = im_masked / 255.0
    im_masked = torch.from_numpy(im_masked.transpose(2, 0, 1)).unsqueeze(0)
    im_masked = im_masked.float().to(device)

    prompt = torch.from_numpy(prompt_org.transpose(2, 0, 1)).unsqueeze(0)
    prompt = prompt.float().to(device)

    in_im = torch.cat((im_masked, prompt), dim=1)

    # inference
    base_coord = utils.getBasecoord(INPUT_SIZE, INPUT_SIZE) / INPUT_SIZE
    model = model.float()
    with torch.no_grad():
        pred = model(in_im)
        pred = pred[0][:2].permute(1, 2, 0).cpu().numpy()
        pred = pred + base_coord
    ## smooth
    for i in range(15):
        pred = cv2.blur(pred, (3, 3), borderType=cv2.BORDER_REPLICATE)
    pred = cv2.resize(pred, (w, h)) * (w, h)
    pred = pred.astype(np.float32)
    out_im = cv2.remap(im_org, pred[:, :, 0], pred[:, :, 1], cv2.INTER_LINEAR)

    prompt_org = (prompt_org * 255).astype(np.uint8)
    prompt_org = cv2.resize(prompt_org, im_org.shape[:2][::-1])

    return prompt_org[:, :, 0], prompt_org[:, :, 1], prompt_org[:, :, 2], out_im


def appearance(model, im_org, device):
    MAX_SIZE = 1600
    # obtain im and prompt
    h, w = im_org.shape[:2]
    prompt = appearance_prompt(im_org)
    in_im = np.concatenate((im_org, prompt), -1)

    # constrain the max resolution
    if max(w, h) < MAX_SIZE:
        in_im, padding_h, padding_w = stride_integral(in_im, 8)
    else:
        in_im = cv2.resize(in_im, (MAX_SIZE, MAX_SIZE))

    # normalize
    in_im = in_im / 255.0
    in_im = torch.from_numpy(in_im.transpose(2, 0, 1)).unsqueeze(0)

    # inference
    in_im = in_im.half().to(device)
    model = model.half()
    with torch.no_grad():
        pred = model(in_im)
        pred = torch.clamp(pred, 0, 1)
        pred = pred[0].permute(1, 2, 0).cpu().numpy()
        pred = (pred * 255).astype(np.uint8)

        if max(w, h) < MAX_SIZE:
            out_im = pred[padding_h:, padding_w:]
        else:
            pred[pred == 0] = 1
            shadow_map = cv2.resize(im_org, (MAX_SIZE, MAX_SIZE)).astype(
                float
            ) / pred.astype(float)
            shadow_map = cv2.resize(shadow_map, (w, h))
            shadow_map[shadow_map == 0] = 0.00001
            out_im = np.clip(im_org.astype(float) / shadow_map, 0, 255).astype(np.uint8)

    return prompt[:, :, 0], prompt[:, :, 1], prompt[:, :, 2], out_im


def deshadowing(model, im_org, device):
    MAX_SIZE = 1600
    # obtain im and prompt
    h, w = im_org.shape[:2]
    prompt = deshadow_prompt(im_org)
    in_im = np.concatenate((im_org, prompt), -1)

    # constrain the max resolution
    if max(w, h) < MAX_SIZE:
        in_im, padding_h, padding_w = stride_integral(in_im, 8)
    else:
        in_im = cv2.resize(in_im, (MAX_SIZE, MAX_SIZE))

    # normalize
    in_im = in_im / 255.0
    in_im = torch.from_numpy(in_im.transpose(2, 0, 1)).unsqueeze(0)

    # inference
    in_im = in_im.half().to(device)
    model = model.half()
    with torch.no_grad():
        pred = model(in_im)
        pred = torch.clamp(pred, 0, 1)
        pred = pred[0].permute(1, 2, 0).cpu().numpy()
        pred = (pred * 255).astype(np.uint8)

        if max(w, h) < MAX_SIZE:
            out_im = pred[padding_h:, padding_w:]
        else:
            pred[pred == 0] = 1
            shadow_map = cv2.resize(im_org, (MAX_SIZE, MAX_SIZE)).astype(
                float
            ) / pred.astype(float)
            shadow_map = cv2.resize(shadow_map, (w, h))
            shadow_map[shadow_map == 0] = 0.00001
            out_im = np.clip(im_org.astype(float) / shadow_map, 0, 255).astype(np.uint8)

    return prompt[:, :, 0], prompt[:, :, 1], prompt[:, :, 2], out_im


def deblurring(model, im_org, device):
    # setup image
    in_im, padding_h, padding_w = stride_integral(im_org, 8)
    prompt = deblur_prompt(in_im)
    in_im = np.concatenate((in_im, prompt), -1)
    in_im = in_im / 255.0
    in_im = torch.from_numpy(in_im.transpose(2, 0, 1)).unsqueeze(0)
    in_im = in_im.half().to(device)
    # inference
    model.to(device)
    model.eval()
    model = model.half()
    with torch.no_grad():
        pred = model(in_im)
        pred = torch.clamp(pred, 0, 1)
        pred = pred[0].permute(1, 2, 0).cpu().numpy()
        pred = (pred * 255).astype(np.uint8)
        out_im = pred[padding_h:, padding_w:]

    return prompt[:, :, 0], prompt[:, :, 1], prompt[:, :, 2], out_im


def binarization(model, im_org, device):
    im, padding_h, padding_w = stride_integral(im_org, 8)
    prompt = binarization_promptv2(im)
    h, w = im.shape[:2]
    in_im = np.concatenate((im, prompt), -1)

    in_im = in_im / 255.0
    in_im = torch.from_numpy(in_im.transpose(2, 0, 1)).unsqueeze(0)
    in_im = in_im.to(device)
    model = model.half()
    in_im = in_im.half()
    with torch.no_grad():
        pred = model(in_im)
        pred = pred[:, :2, :, :]
        pred = torch.max(torch.softmax(pred, 1), 1)[1]
        pred = pred[0].cpu().numpy()
        pred = (pred * 255).astype(np.uint8)
        pred = cv2.resize(pred, (w, h))
        out_im = pred[padding_h:, padding_w:]

    return prompt[:, :, 0], prompt[:, :, 1], prompt[:, :, 2], out_im


def model_init(model_path, device):
    # prepare model
    model = restormer_arch.Restormer(
        inp_channels=6,
        out_channels=3,
        dim=48,
        num_blocks=[2, 3, 3, 4],
        num_refinement_blocks=4,
        heads=[1, 2, 4, 8],
        ffn_expansion_factor=2.66,
        bias=False,
        LayerNorm_type="WithBias",
        dual_pixel_task=True,
    )

    if device == "cpu":
        state = convert_state_dict(
            torch.load(model_path, map_location="cpu")["model_state"]
        )
    else:
        state = convert_state_dict(
            torch.load(model_path, map_location="cuda:0")["model_state"]
        )
    model.load_state_dict(state)

    model.eval()
    model = model.to(device)
    return model


def resize(image, max_size):
    h, w = image.shape[:2]
    if max(h, w) > max_size:
        if h > w:
            h_new = max_size
            w_new = int(w * h_new / h)
        else:
            w_new = max_size
            h_new = int(h * w_new / w)
        pil_image = Image.fromarray(image)
        pil_image = pil_image.resize((w_new, h_new), Image.Resampling.LANCZOS)
        image = np.array(pil_image)
    return image


def inference_one_image(model, image, tasks, device):
    # image should be in BGR format

    if "dewarping" in tasks:
        *_, image = dewarping(model, image, device)
    
    # if only dewarping return here
    if len(tasks) == 1 and "dewarping" in tasks:
        return image
    
    image = resize(image, 1536)

    if "deshadowing" in tasks:
        *_, image = deshadowing(model, image, device)
    if "appearance" in tasks:
        *_, image = appearance(model, image, device)
    if "deblurring" in tasks:
        *_, image = deblurring(model, image, device)
    if "binarization" in tasks:
        *_, image = binarization(model, image, device)
    
    return image