import sys
from typing import Dict
sys.path.insert(0, 'gradio-modified')

import gradio as gr
import numpy as np
import torch.nn as nn
from PIL import Image

import torch

if torch.cuda.is_available():
    t = torch.cuda.get_device_properties(0).total_memory
    r = torch.cuda.memory_reserved(0)
    a = torch.cuda.memory_allocated(0)
    f = t-a  # free inside reserved
    if f < 2**32:
        device = 'cpu'
    else:
        device = 'cuda'
else:
    device = 'cpu'
    torch._C._jit_set_bailout_depth(0)

print('Use device:', device)


net = torch.jit.load(f'weights/pkp-v1.{device}.jit.pt')

class BaseColor(nn.Module):
	def __init__(self):
		super(BaseColor, self).__init__()

		self.l_cent = 50.
		self.l_norm = 100.
		self.ab_norm = 110.

	def normalize_l(self, in_l):
		return (in_l-self.l_cent)/self.l_norm

	def unnormalize_l(self, in_l):
		return in_l*self.l_norm + self.l_cent

	def normalize_ab(self, in_ab):
		return in_ab/self.ab_norm

	def unnormalize_ab(self, in_ab):
		return in_ab*self.ab_norm

        

class ECCVGenerator(BaseColor):
    def __init__(self, norm_layer=nn.BatchNorm2d):
        super(ECCVGenerator, self).__init__()

        model1=[nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=True),]
        model1+=[nn.ReLU(True),]
        model1+=[nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=True),]
        model1+=[nn.ReLU(True),]
        model1+=[norm_layer(64),]

        model2=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),]
        model2+=[nn.ReLU(True),]
        model2+=[nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1, bias=True),]
        model2+=[nn.ReLU(True),]
        model2+=[norm_layer(128),]

        model3=[nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True),]
        model3+=[nn.ReLU(True),]
        model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
        model3+=[nn.ReLU(True),]
        model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1, bias=True),]
        model3+=[nn.ReLU(True),]
        model3+=[norm_layer(256),]

        model4=[nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=True),]
        model4+=[nn.ReLU(True),]
        model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
        model4+=[nn.ReLU(True),]
        model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
        model4+=[nn.ReLU(True),]
        model4+=[norm_layer(512),]

        model5=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
        model5+=[nn.ReLU(True),]
        model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
        model5+=[nn.ReLU(True),]
        model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
        model5+=[nn.ReLU(True),]
        model5+=[norm_layer(512),]

        model6=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
        model6+=[nn.ReLU(True),]
        model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
        model6+=[nn.ReLU(True),]
        model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
        model6+=[nn.ReLU(True),]
        model6+=[norm_layer(512),]

        model7=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
        model7+=[nn.ReLU(True),]
        model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
        model7+=[nn.ReLU(True),]
        model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
        model7+=[nn.ReLU(True),]
        model7+=[norm_layer(512),]

        model8=[nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=True),]
        model8+=[nn.ReLU(True),]
        model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
        model8+=[nn.ReLU(True),]
        model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
        model8+=[nn.ReLU(True),]

        model8+=[nn.Conv2d(256, 313, kernel_size=1, stride=1, padding=0, bias=True),]

        self.model1 = nn.Sequential(*model1)
        self.model2 = nn.Sequential(*model2)
        self.model3 = nn.Sequential(*model3)
        self.model4 = nn.Sequential(*model4)
        self.model5 = nn.Sequential(*model5)
        self.model6 = nn.Sequential(*model6)
        self.model7 = nn.Sequential(*model7)
        self.model8 = nn.Sequential(*model8)

        self.softmax = nn.Softmax(dim=1)
        self.model_out = nn.Conv2d(313, 2, kernel_size=1, padding=0, dilation=1, stride=1, bias=False)
        self.upsample4 = nn.Upsample(scale_factor=4, mode='bilinear')

    def forward(self, input_l):
        conv1_2 = self.model1(self.normalize_l(input_l))
        conv2_2 = self.model2(conv1_2)
        conv3_3 = self.model3(conv2_2)
        conv4_3 = self.model4(conv3_3)
        conv5_3 = self.model5(conv4_3)
        conv6_3 = self.model6(conv5_3)
        conv7_3 = self.model7(conv6_3)
        conv8_3 = self.model8(conv7_3)
        out_reg = self.model_out(self.softmax(conv8_3))

        x= self.unnormalize_ab(self.upsample4(out_reg))
        zeros = torch.zeros_like(x[:, :1, :, :])
        x = torch.cat([x, zeros], dim=1)  # concatenate the tensor of zeros with the input tensor along the channel dimension
        return x


# model_net = torch.load(f'weights/colorizer.pt')
model_net = ECCVGenerator()
model_net.load_state_dict(torch.load(f'weights/colorizer (1).pt', map_location=torch.device('cpu')))


def resize_original(img: Image.Image):
    if img is None:
        return img
    if isinstance(img, dict):
        img = img["image"]
    
    guide_img = img.convert('L')
    w, h = guide_img.size
    scale = 256 / min(guide_img.size)
    guide_img = guide_img.resize([int(round(s*scale)) for s in guide_img.size], Image.Resampling.LANCZOS)

    guide = np.asarray(guide_img)
    h, w = guide.shape[-2:]
    rows = int(np.ceil(h/64))*64
    cols = int(np.ceil(w/64))*64
    ph_1 = (rows-h) // 2
    ph_2 = rows-h - (rows-h) // 2
    pw_1 = (cols-w) // 2
    pw_2 = cols-w - (cols-w) // 2
    guide = np.pad(guide, ((ph_1, ph_2), (pw_1, pw_2)), mode='constant', constant_values=255)
    guide_img = Image.fromarray(guide)

    return gr.Image.update(value=guide_img.convert('RGBA')), guide_img.convert('RGBA')


def resize_original2(img: Image.Image):
    if img is None:
        return img
    if isinstance(img, dict):
        img = img["image"]
    
    img = img.resize(256,256)

    return img


def colorize(img: Dict[str, Image.Image], guide_img: Image.Image, seed: int, hint_mode: str):
    if not isinstance(img, dict):
        return gr.update(visible=True)

    if hint_mode == "Roughly Hint":
        hint_mode_int = 0
    elif hint_mode == "Precisely Hint":
        hint_mode_int = 0
    
    guide_img = guide_img.convert('L')
    hint_img = img["mask"].convert('RGBA') # I modified gradio to enable it upload colorful mask

    guide = torch.from_numpy(np.asarray(guide_img))[None,None].float().to(device) / 255.0 * 2 - 1
    hint = torch.from_numpy(np.asarray(hint_img)).permute(2,0,1)[None].float().to(device) / 255.0 * 2 - 1
    hint_alpha = (hint[:,-1:] > 0.99).float()
    hint = hint[:,:3] * hint_alpha - 2 * (1 - hint_alpha)

    np.random.seed(int(seed))
    b, c, h, w = hint.shape
    h //= 8
    w //= 8
    noises = [torch.from_numpy(np.random.randn(b, c, h, w)).float().to(device) for _ in range(16+1)]

    with torch.inference_mode():
        sample = net(noises, guide, hint,  hint_mode_int)
        out = sample[0].cpu().numpy().transpose([1,2,0])
        out = np.uint8(((out + 1) / 2 * 255).clip(0,255))
    
    return Image.fromarray(out).convert('RGB')


def colorize2(img: Image.Image, model_option: str):
    if not isinstance(img, dict):
        return gr.update(visible=True)

    if model_option == "Model 1":
        model_int = 0
    elif model_option == "Model 2":
        model_int = 0
    input = torch.from_numpy(np.asarray(img))[None,None].float().to(device) / 255.0 * 2 - 1
    with torch.inference_mode():
        out2 = model_net(input).squeeze()
        print(out2.shape)
        out2 = sample[0].cpu().numpy().transpose([1,2,0])
        out2 = np.uint8(((out + 1) / 2 * 255).clip(0,255))
    
    return Image.fromarray(out2).convert('RGB')


with gr.Blocks() as demo:
    gr.Markdown('''<center><h1>Image Colorization With Hint</h1></center>
<h2>Colorize your images/sketches with hint points.</h2>
<br />
''')
    with gr.Row():
        with gr.Column():
            inp = gr.Image(
                source="upload", 
                tool="sketch", # tool="color-sketch", # color-sketch upload image mixed with the original
                type="pil", 
                label="Sketch", 
                interactive=True,
                elem_id="sketch-canvas"
            )
            inp_store = gr.Image(
                type="pil", 
                interactive=False
            )
            inp_store.visible = False
        with gr.Column():
            seed = gr.Slider(1, 2**32, step=1, label="Seed", interactive=True, randomize=True)
            hint_mode = gr.Radio(["Roughly Hint", "Precisely Hint"], value="Roughly Hint", label="Hint Mode")
            btn = gr.Button("Run")
        with gr.Column():
            output = gr.Image(type="pil", label="Output", interactive=False)
    with gr.Row():
        with gr.Column():
            inp2 = gr.Image(
                source="upload", 
                type="pil", 
                label="Sketch",
                interactive=True
            )
            inp_store2 = gr.Image(
                type="pil", 
                interactive=False
            )
            inp_store2.visible = False
        with gr.Column():
            # seed = gr.Slider(1, 2**32, step=1, label="Seed", interactive=True, randomize=True)
            model_option = gr.Radio(["Model 1", "Model 2"], value="Model 1", label="Model 2")
            btn2 = gr.Button("Run Colorization")
        with gr.Column():
            output2 = gr.Image(type="pil", label="Output2", interactive=False)
    gr.Markdown('''
Upon uploading an image, kindly give color hints at specific points, and then run the model. Average inference time is about 52 seconds.<br />
''')
    gr.Markdown('''Authors: <a href=\"https://www.linkedin.com/in/chakshu-dhannawat/">Chakshu Dhannawat</a>, <a href=\"https://www.linkedin.com/in/navlika-singh-963120204/">Navlika Singh</a>,<a href=\"https://www.linkedin.com/in/akshat-jain-103550201/"> Akshat Jain</a>''')
    inp.upload(
        resize_original, 
        inp, 
        [inp, inp_store],
    )
    inp2.upload(
        resize_original2,
        inp,
        inp
    )
    btn.click(
        colorize, 
        [inp, inp_store, seed, hint_mode],
        output
    )
    btn2.click(
        colorize2, 
        [inp2, model_option],
        output2
    )

if __name__ == "__main__":
    demo.launch()