|
import sys |
|
from typing import Dict |
|
sys.path.insert(0, 'gradio-modified') |
|
|
|
import gradio as gr |
|
import numpy as np |
|
import torch.nn as nn |
|
from PIL import Image |
|
|
|
import torch |
|
|
|
if torch.cuda.is_available(): |
|
t = torch.cuda.get_device_properties(0).total_memory |
|
r = torch.cuda.memory_reserved(0) |
|
a = torch.cuda.memory_allocated(0) |
|
f = t-a |
|
if f < 2**32: |
|
device = 'cpu' |
|
else: |
|
device = 'cuda' |
|
else: |
|
device = 'cpu' |
|
torch._C._jit_set_bailout_depth(0) |
|
|
|
print('Use device:', device) |
|
|
|
|
|
net = torch.jit.load(f'weights/pkp-v1.{device}.jit.pt') |
|
|
|
class BaseColor(nn.Module): |
|
def __init__(self): |
|
super(BaseColor, self).__init__() |
|
|
|
self.l_cent = 50. |
|
self.l_norm = 100. |
|
self.ab_norm = 110. |
|
|
|
def normalize_l(self, in_l): |
|
return (in_l-self.l_cent)/self.l_norm |
|
|
|
def unnormalize_l(self, in_l): |
|
return in_l*self.l_norm + self.l_cent |
|
|
|
def normalize_ab(self, in_ab): |
|
return in_ab/self.ab_norm |
|
|
|
def unnormalize_ab(self, in_ab): |
|
return in_ab*self.ab_norm |
|
|
|
|
|
|
|
class ECCVGenerator(BaseColor): |
|
def __init__(self, norm_layer=nn.BatchNorm2d): |
|
super(ECCVGenerator, self).__init__() |
|
|
|
model1=[nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=True),] |
|
model1+=[nn.ReLU(True),] |
|
model1+=[nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=True),] |
|
model1+=[nn.ReLU(True),] |
|
model1+=[norm_layer(64),] |
|
|
|
model2=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),] |
|
model2+=[nn.ReLU(True),] |
|
model2+=[nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1, bias=True),] |
|
model2+=[nn.ReLU(True),] |
|
model2+=[norm_layer(128),] |
|
|
|
model3=[nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True),] |
|
model3+=[nn.ReLU(True),] |
|
model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),] |
|
model3+=[nn.ReLU(True),] |
|
model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1, bias=True),] |
|
model3+=[nn.ReLU(True),] |
|
model3+=[norm_layer(256),] |
|
|
|
model4=[nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=True),] |
|
model4+=[nn.ReLU(True),] |
|
model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),] |
|
model4+=[nn.ReLU(True),] |
|
model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),] |
|
model4+=[nn.ReLU(True),] |
|
model4+=[norm_layer(512),] |
|
|
|
model5=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),] |
|
model5+=[nn.ReLU(True),] |
|
model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),] |
|
model5+=[nn.ReLU(True),] |
|
model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),] |
|
model5+=[nn.ReLU(True),] |
|
model5+=[norm_layer(512),] |
|
|
|
model6=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),] |
|
model6+=[nn.ReLU(True),] |
|
model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),] |
|
model6+=[nn.ReLU(True),] |
|
model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),] |
|
model6+=[nn.ReLU(True),] |
|
model6+=[norm_layer(512),] |
|
|
|
model7=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),] |
|
model7+=[nn.ReLU(True),] |
|
model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),] |
|
model7+=[nn.ReLU(True),] |
|
model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),] |
|
model7+=[nn.ReLU(True),] |
|
model7+=[norm_layer(512),] |
|
|
|
model8=[nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=True),] |
|
model8+=[nn.ReLU(True),] |
|
model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),] |
|
model8+=[nn.ReLU(True),] |
|
model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),] |
|
model8+=[nn.ReLU(True),] |
|
|
|
model8+=[nn.Conv2d(256, 313, kernel_size=1, stride=1, padding=0, bias=True),] |
|
|
|
self.model1 = nn.Sequential(*model1) |
|
self.model2 = nn.Sequential(*model2) |
|
self.model3 = nn.Sequential(*model3) |
|
self.model4 = nn.Sequential(*model4) |
|
self.model5 = nn.Sequential(*model5) |
|
self.model6 = nn.Sequential(*model6) |
|
self.model7 = nn.Sequential(*model7) |
|
self.model8 = nn.Sequential(*model8) |
|
|
|
self.softmax = nn.Softmax(dim=1) |
|
self.model_out = nn.Conv2d(313, 2, kernel_size=1, padding=0, dilation=1, stride=1, bias=False) |
|
self.upsample4 = nn.Upsample(scale_factor=4, mode='bilinear') |
|
|
|
def forward(self, input_l): |
|
conv1_2 = self.model1(self.normalize_l(input_l)) |
|
conv2_2 = self.model2(conv1_2) |
|
conv3_3 = self.model3(conv2_2) |
|
conv4_3 = self.model4(conv3_3) |
|
conv5_3 = self.model5(conv4_3) |
|
conv6_3 = self.model6(conv5_3) |
|
conv7_3 = self.model7(conv6_3) |
|
conv8_3 = self.model8(conv7_3) |
|
out_reg = self.model_out(self.softmax(conv8_3)) |
|
|
|
return self.unnormalize_ab(self.upsample4(out_reg)) |
|
|
|
|
|
|
|
model = ECCVGenerator() |
|
model_net.load_state_dict(torch.load(f'weights/colorizer.pt')) |
|
|
|
|
|
def resize_original(img: Image.Image): |
|
if img is None: |
|
return img |
|
if isinstance(img, dict): |
|
img = img["image"] |
|
|
|
guide_img = img.convert('L') |
|
w, h = guide_img.size |
|
scale = 256 / min(guide_img.size) |
|
guide_img = guide_img.resize([int(round(s*scale)) for s in guide_img.size], Image.Resampling.LANCZOS) |
|
|
|
guide = np.asarray(guide_img) |
|
h, w = guide.shape[-2:] |
|
rows = int(np.ceil(h/64))*64 |
|
cols = int(np.ceil(w/64))*64 |
|
ph_1 = (rows-h) // 2 |
|
ph_2 = rows-h - (rows-h) // 2 |
|
pw_1 = (cols-w) // 2 |
|
pw_2 = cols-w - (cols-w) // 2 |
|
guide = np.pad(guide, ((ph_1, ph_2), (pw_1, pw_2)), mode='constant', constant_values=255) |
|
guide_img = Image.fromarray(guide) |
|
|
|
return gr.Image.update(value=guide_img.convert('RGBA')), guide_img.convert('RGBA') |
|
|
|
|
|
def colorize(img: Dict[str, Image.Image], guide_img: Image.Image, seed: int, hint_mode: str): |
|
if not isinstance(img, dict): |
|
return gr.update(visible=True) |
|
|
|
if hint_mode == "Roughly Hint": |
|
hint_mode_int = 0 |
|
elif hint_mode == "Precisely Hint": |
|
hint_mode_int = 0 |
|
|
|
guide_img = guide_img.convert('L') |
|
hint_img = img["mask"].convert('RGBA') |
|
|
|
guide = torch.from_numpy(np.asarray(guide_img))[None,None].float().to(device) / 255.0 * 2 - 1 |
|
hint = torch.from_numpy(np.asarray(hint_img)).permute(2,0,1)[None].float().to(device) / 255.0 * 2 - 1 |
|
hint_alpha = (hint[:,-1:] > 0.99).float() |
|
hint = hint[:,:3] * hint_alpha - 2 * (1 - hint_alpha) |
|
|
|
np.random.seed(int(seed)) |
|
b, c, h, w = hint.shape |
|
h //= 8 |
|
w //= 8 |
|
noises = [torch.from_numpy(np.random.randn(b, c, h, w)).float().to(device) for _ in range(16+1)] |
|
|
|
with torch.inference_mode(): |
|
sample = net(noises, guide, hint, hint_mode_int) |
|
out = sample[0].cpu().numpy().transpose([1,2,0]) |
|
out = np.uint8(((out + 1) / 2 * 255).clip(0,255)) |
|
|
|
return Image.fromarray(out).convert('RGB') |
|
|
|
|
|
def colorize2(img: Image.Image, model_option: str): |
|
if not isinstance(img, dict): |
|
return gr.update(visible=True) |
|
|
|
if hint_mode == "Model 1": |
|
model_int = 0 |
|
elif hint_mode == "Model 2": |
|
model_int = 0 |
|
|
|
with torch.inference_mode(): |
|
out2 = model(input) |
|
out = sample[0].cpu().numpy().transpose([1,2,0]) |
|
out = np.uint8(((out + 1) / 2 * 255).clip(0,255)) |
|
|
|
return Image.fromarray(out).convert('RGB') |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown('''<center><h1>Image Colorization With Hint</h1></center> |
|
<h2>Colorize your images/sketches with hint points.</h2> |
|
<br /> |
|
''') |
|
with gr.Row(): |
|
with gr.Column(): |
|
inp = gr.Image( |
|
source="upload", |
|
tool="sketch", |
|
type="pil", |
|
label="Sketch", |
|
interactive=True, |
|
elem_id="sketch-canvas" |
|
) |
|
inp_store = gr.Image( |
|
type="pil", |
|
interactive=False |
|
) |
|
inp_store.visible = False |
|
with gr.Column(): |
|
seed = gr.Slider(1, 2**32, step=1, label="Seed", interactive=True, randomize=True) |
|
hint_mode = gr.Radio(["Roughly Hint", "Precisely Hint"], value="Roughly Hint", label="Hint Mode") |
|
btn = gr.Button("Run") |
|
with gr.Column(): |
|
output = gr.Image(type="pil", label="Output", interactive=False) |
|
with gr.Row(): |
|
with gr.Column(): |
|
inp2 = gr.Image( |
|
source="upload", |
|
type="pil", |
|
label="Sketch", |
|
interactive=True |
|
) |
|
inp_store2 = gr.Image( |
|
type="pil", |
|
interactive=False |
|
) |
|
inp_store2.visible = False |
|
with gr.Column(): |
|
|
|
model_option = gr.Radio(["Model 1", "Model 2"], value="Model 1", label="Model 2") |
|
btn2 = gr.Button("Run Colorization") |
|
with gr.Column(): |
|
output2 = gr.Image(type="pil", label="Output2", interactive=False) |
|
gr.Markdown(''' |
|
Upon uploading an image, kindly give color hints at specific points, and then run the model. Average inference time is about 52 seconds.<br /> |
|
''') |
|
gr.Markdown('''Authors: <a href=\"https://www.linkedin.com/in/chakshu-dhannawat/">Chakshu Dhannawat</a>, <a href=\"https://www.linkedin.com/in/navlika-singh-963120204/">Navlika Singh</a>,<a href=\"https://www.linkedin.com/in/akshat-jain-103550201/"> Akshat Jain</a>''') |
|
inp.upload( |
|
resize_original, |
|
inp, |
|
[inp, inp_store], |
|
) |
|
btn.click( |
|
colorize, |
|
[inp, inp_store, seed, hint_mode], |
|
output |
|
) |
|
btn2.click( |
|
colorize2, |
|
[inp, model_option], |
|
output2 |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|