import sys from typing import Dict sys.path.insert(0, 'gradio-modified') import gradio as gr import numpy as np from PIL import Image import torch if torch.cuda.is_available(): t = torch.cuda.get_device_properties(0).total_memory r = torch.cuda.memory_reserved(0) a = torch.cuda.memory_allocated(0) f = t-a # free inside reserved if f < 2**32: device = 'cpu' else: device = 'cuda' else: device = 'cpu' torch._C._jit_set_bailout_depth(0) print('Use device:', device) net = torch.jit.load(f'weights/pkp-v1.{device}.jit.pt') def resize_original(img: Image.Image): if img is None: return img if isinstance(img, dict): img = img["image"] guide_img = img.convert('L') w, h = guide_img.size scale = 256 / min(guide_img.size) guide_img = guide_img.resize([int(round(s*scale)) for s in guide_img.size], Image.Resampling.LANCZOS) guide = np.asarray(guide_img) h, w = guide.shape[-2:] rows = int(np.ceil(h/64))*64 cols = int(np.ceil(w/64))*64 ph_1 = (rows-h) // 2 ph_2 = rows-h - (rows-h) // 2 pw_1 = (cols-w) // 2 pw_2 = cols-w - (cols-w) // 2 guide = np.pad(guide, ((ph_1, ph_2), (pw_1, pw_2)), mode='constant', constant_values=255) guide_img = Image.fromarray(guide) return gr.Image.update(value=guide_img.convert('RGBA')), guide_img.convert('RGBA') def colorize(img: Dict[str, Image.Image], guide_img: Image.Image, seed: int, hint_mode: str): if not isinstance(img, dict): return gr.update(visible=True) if hint_mode == "Roughly Hint": hint_mode_int = 0 elif hint_mode == "Precisely Hint": hint_mode_int = 1 guide_img = guide_img.convert('L') hint_img = img["mask"].convert('RGBA') # I modified gradio to enable it upload colorful mask guide = torch.from_numpy(np.asarray(guide_img))[None,None].float().to(device) / 255.0 * 2 - 1 hint = torch.from_numpy(np.asarray(hint_img)).permute(2,0,1)[None].float().to(device) / 255.0 * 2 - 1 hint_alpha = (hint[:,-1:] > 0.99).float() hint = hint[:,:3] * hint_alpha - 2 * (1 - hint_alpha) np.random.seed(int(seed)) b, c, h, w = hint.shape h //= 8 w //= 8 noises = [torch.from_numpy(np.random.randn(b, c, h, w)).float().to(device) for _ in range(16+1)] with torch.inference_mode(): sample = net(noises, guide, hint, hint_mode_int) out = sample[0].cpu().numpy().transpose([1,2,0]) out = np.uint8(((out + 1) / 2 * 255).clip(0,255)) return Image.fromarray(out).convert('RGB') with gr.Blocks() as demo: gr.Markdown('''

Image Colorization With Hint

Colorize your images/sketches with hint points.


''') with gr.Row(): with gr.Column(): inp = gr.Image( source="upload", tool="sketch", # tool="color-sketch", # color-sketch upload image mixed with the original type="pil", label="Sketch", interactive=True, elem_id="sketch-canvas" ) inp_store = gr.Image( type="pil", interactive=False ) inp_store.visible = False with gr.Column(): seed = gr.Slider(1, 2**32, step=1, label="Seed", interactive=True, randomize=True) hint_mode = gr.Radio(["Roughly Hint", "Precisely Hint"], value="Roughly Hint", label="Hint Mode") btn = gr.Button("Run") with gr.Column(): output = gr.Image(type="pil", label="Output", interactive=False) gr.Markdown(''' Upon uploading an image, kindly give color hints at specific points, and then run the model. Average inference time is about 52 seconds. ''') inp.upload( resize_original, inp, [inp, inp_store], ) btn.click( colorize, [inp, inp_store, seed, hint_mode], output ) if __name__ == "__main__": demo.launch()