import base64 import json import os import re import time import uuid from io import BytesIO from pathlib import Path import cv2 # For inpainting import numpy as np import pandas as pd import streamlit as st from PIL import Image from streamlit_drawable_canvas import st_canvas import argparse import io import multiprocessing from typing import Union import torch try: torch._C._jit_override_can_fuse_on_cpu(False) torch._C._jit_override_can_fuse_on_gpu(False) torch._C._jit_set_texpr_fuser_enabled(False) torch._C._jit_set_nvfuser_enabled(False) except: pass from src.helper import ( download_model, load_img, norm_img, numpy_to_bytes, pad_img_to_modulo, resize_max_size, ) NUM_THREADS = str(multiprocessing.cpu_count()) os.environ["OMP_NUM_THREADS"] = NUM_THREADS os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS os.environ["MKL_NUM_THREADS"] = NUM_THREADS os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS if os.environ.get("CACHE_DIR"): os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"] #BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "./lama_cleaner/app/build") # For Seam-carving from scipy import ndimage as ndi SEAM_COLOR = np.array([255, 200, 200]) # seam visualization color (BGR) SHOULD_DOWNSIZE = True # if True, downsize image for faster carving DOWNSIZE_WIDTH = 500 # resized image width if SHOULD_DOWNSIZE is True ENERGY_MASK_CONST = 100000.0 # large energy value for protective masking MASK_THRESHOLD = 10 # minimum pixel intensity for binary mask USE_FORWARD_ENERGY = True # if True, use forward energy algorithm device = torch.device("cpu") model_path = "./assets/big-lama.pt" model = torch.jit.load(model_path, map_location="cpu") model = model.to(device) model.eval() ######################################## # UTILITY CODE ######################################## def visualize(im, boolmask=None, rotate=False): vis = im.astype(np.uint8) if boolmask is not None: vis[np.where(boolmask == False)] = SEAM_COLOR if rotate: vis = rotate_image(vis, False) cv2.imshow("visualization", vis) cv2.waitKey(1) return vis def resize(image, width): dim = None h, w = image.shape[:2] dim = (width, int(h * width / float(w))) image = image.astype('float32') return cv2.resize(image, dim) def rotate_image(image, clockwise): k = 1 if clockwise else 3 return np.rot90(image, k) ######################################## # ENERGY FUNCTIONS ######################################## def backward_energy(im): """ Simple gradient magnitude energy map. """ xgrad = ndi.convolve1d(im, np.array([1, 0, -1]), axis=1, mode='wrap') ygrad = ndi.convolve1d(im, np.array([1, 0, -1]), axis=0, mode='wrap') grad_mag = np.sqrt(np.sum(xgrad**2, axis=2) + np.sum(ygrad**2, axis=2)) # vis = visualize(grad_mag) # cv2.imwrite("backward_energy_demo.jpg", vis) return grad_mag def forward_energy(im): """ Forward energy algorithm as described in "Improved Seam Carving for Video Retargeting" by Rubinstein, Shamir, Avidan. Vectorized code adapted from https://github.com/axu2/improved-seam-carving. """ h, w = im.shape[:2] im = cv2.cvtColor(im.astype(np.uint8), cv2.COLOR_BGR2GRAY).astype(np.float64) energy = np.zeros((h, w)) m = np.zeros((h, w)) U = np.roll(im, 1, axis=0) L = np.roll(im, 1, axis=1) R = np.roll(im, -1, axis=1) cU = np.abs(R - L) cL = np.abs(U - L) + cU cR = np.abs(U - R) + cU for i in range(1, h): mU = m[i-1] mL = np.roll(mU, 1) mR = np.roll(mU, -1) mULR = np.array([mU, mL, mR]) cULR = np.array([cU[i], cL[i], cR[i]]) mULR += cULR argmins = np.argmin(mULR, axis=0) m[i] = np.choose(argmins, mULR) energy[i] = np.choose(argmins, cULR) # vis = visualize(energy) # cv2.imwrite("forward_energy_demo.jpg", vis) return energy ######################################## # SEAM HELPER FUNCTIONS ######################################## def add_seam(im, seam_idx): """ Add a vertical seam to a 3-channel color image at the indices provided by averaging the pixels values to the left and right of the seam. Code adapted from https://github.com/vivianhylee/seam-carving. """ h, w = im.shape[:2] output = np.zeros((h, w + 1, 3)) for row in range(h): col = seam_idx[row] for ch in range(3): if col == 0: p = np.mean(im[row, col: col + 2, ch]) output[row, col, ch] = im[row, col, ch] output[row, col + 1, ch] = p output[row, col + 1:, ch] = im[row, col:, ch] else: p = np.mean(im[row, col - 1: col + 1, ch]) output[row, : col, ch] = im[row, : col, ch] output[row, col, ch] = p output[row, col + 1:, ch] = im[row, col:, ch] return output def add_seam_grayscale(im, seam_idx): """ Add a vertical seam to a grayscale image at the indices provided by averaging the pixels values to the left and right of the seam. """ h, w = im.shape[:2] output = np.zeros((h, w + 1)) for row in range(h): col = seam_idx[row] if col == 0: p = np.mean(im[row, col: col + 2]) output[row, col] = im[row, col] output[row, col + 1] = p output[row, col + 1:] = im[row, col:] else: p = np.mean(im[row, col - 1: col + 1]) output[row, : col] = im[row, : col] output[row, col] = p output[row, col + 1:] = im[row, col:] return output def remove_seam(im, boolmask): h, w = im.shape[:2] boolmask3c = np.stack([boolmask] * 3, axis=2) return im[boolmask3c].reshape((h, w - 1, 3)) def remove_seam_grayscale(im, boolmask): h, w = im.shape[:2] return im[boolmask].reshape((h, w - 1)) def get_minimum_seam(im, mask=None, remove_mask=None): """ DP algorithm for finding the seam of minimum energy. Code adapted from https://karthikkaranth.me/blog/implementing-seam-carving-with-python/ """ h, w = im.shape[:2] energyfn = forward_energy if USE_FORWARD_ENERGY else backward_energy M = energyfn(im) if mask is not None: M[np.where(mask > MASK_THRESHOLD)] = ENERGY_MASK_CONST # give removal mask priority over protective mask by using larger negative value if remove_mask is not None: M[np.where(remove_mask > MASK_THRESHOLD)] = -ENERGY_MASK_CONST * 100 seam_idx, boolmask = compute_shortest_path(M, im, h, w) return np.array(seam_idx), boolmask def compute_shortest_path(M, im, h, w): backtrack = np.zeros_like(M, dtype=np.int_) # populate DP matrix for i in range(1, h): for j in range(0, w): if j == 0: idx = np.argmin(M[i - 1, j:j + 2]) backtrack[i, j] = idx + j min_energy = M[i-1, idx + j] else: idx = np.argmin(M[i - 1, j - 1:j + 2]) backtrack[i, j] = idx + j - 1 min_energy = M[i - 1, idx + j - 1] M[i, j] += min_energy # backtrack to find path seam_idx = [] boolmask = np.ones((h, w), dtype=np.bool_) j = np.argmin(M[-1]) for i in range(h-1, -1, -1): boolmask[i, j] = False seam_idx.append(j) j = backtrack[i, j] seam_idx.reverse() return seam_idx, boolmask ######################################## # MAIN ALGORITHM ######################################## def seams_removal(im, num_remove, mask=None, vis=False, rot=False): for _ in range(num_remove): seam_idx, boolmask = get_minimum_seam(im, mask) if vis: visualize(im, boolmask, rotate=rot) im = remove_seam(im, boolmask) if mask is not None: mask = remove_seam_grayscale(mask, boolmask) return im, mask def seams_insertion(im, num_add, mask=None, vis=False, rot=False): seams_record = [] temp_im = im.copy() temp_mask = mask.copy() if mask is not None else None for _ in range(num_add): seam_idx, boolmask = get_minimum_seam(temp_im, temp_mask) if vis: visualize(temp_im, boolmask, rotate=rot) seams_record.append(seam_idx) temp_im = remove_seam(temp_im, boolmask) if temp_mask is not None: temp_mask = remove_seam_grayscale(temp_mask, boolmask) seams_record.reverse() for _ in range(num_add): seam = seams_record.pop() im = add_seam(im, seam) if vis: visualize(im, rotate=rot) if mask is not None: mask = add_seam_grayscale(mask, seam) # update the remaining seam indices for remaining_seam in seams_record: remaining_seam[np.where(remaining_seam >= seam)] += 2 return im, mask ######################################## # MAIN DRIVER FUNCTIONS ######################################## def seam_carve(im, dy, dx, mask=None, vis=False): im = im.astype(np.float64) h, w = im.shape[:2] assert h + dy > 0 and w + dx > 0 and dy <= h and dx <= w if mask is not None: mask = mask.astype(np.float64) output = im if dx < 0: output, mask = seams_removal(output, -dx, mask, vis) elif dx > 0: output, mask = seams_insertion(output, dx, mask, vis) if dy < 0: output = rotate_image(output, True) if mask is not None: mask = rotate_image(mask, True) output, mask = seams_removal(output, -dy, mask, vis, rot=True) output = rotate_image(output, False) elif dy > 0: output = rotate_image(output, True) if mask is not None: mask = rotate_image(mask, True) output, mask = seams_insertion(output, dy, mask, vis, rot=True) output = rotate_image(output, False) return output def object_removal(im, rmask, mask=None, vis=False, horizontal_removal=False): im = im.astype(np.float64) rmask = rmask.astype(np.float64) if mask is not None: mask = mask.astype(np.float64) output = im h, w = im.shape[:2] if horizontal_removal: output = rotate_image(output, True) rmask = rotate_image(rmask, True) if mask is not None: mask = rotate_image(mask, True) while len(np.where(rmask > MASK_THRESHOLD)[0]) > 0: seam_idx, boolmask = get_minimum_seam(output, mask, rmask) if vis: visualize(output, boolmask, rotate=horizontal_removal) output = remove_seam(output, boolmask) rmask = remove_seam_grayscale(rmask, boolmask) if mask is not None: mask = remove_seam_grayscale(mask, boolmask) num_add = (h if horizontal_removal else w) - output.shape[1] output, mask = seams_insertion(output, num_add, mask, vis, rot=horizontal_removal) if horizontal_removal: output = rotate_image(output, False) return output def s_image(im,mask,vs,hs,mode="resize"): im = cv2.cvtColor(im, cv2.COLOR_RGBA2RGB) mask = 255-mask[:,:,3] h, w = im.shape[:2] if SHOULD_DOWNSIZE and w > DOWNSIZE_WIDTH: im = resize(im, width=DOWNSIZE_WIDTH) if mask is not None: mask = resize(mask, width=DOWNSIZE_WIDTH) # image resize mode if mode=="resize": dy = hs#reverse dx = vs#reverse assert dy is not None and dx is not None output = seam_carve(im, dy, dx, mask, False) # object removal mode elif mode=="remove": assert mask is not None output = object_removal(im, mask, None, False, True) return output ##### Inpainting helper code def run(image, mask): """ image: [C, H, W] mask: [1, H, W] return: BGR IMAGE """ origin_height, origin_width = image.shape[1:] image = pad_img_to_modulo(image, mod=8) mask = pad_img_to_modulo(mask, mod=8) mask = (mask > 0) * 1 image = torch.from_numpy(image).unsqueeze(0).to(device) mask = torch.from_numpy(mask).unsqueeze(0).to(device) start = time.time() with torch.no_grad(): inpainted_image = model(image, mask) print(f"process time: {(time.time() - start)*1000}ms") cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy() cur_res = cur_res[0:origin_height, 0:origin_width, :] cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8") cur_res = cv2.cvtColor(cur_res, cv2.COLOR_BGR2RGB) return cur_res def get_args_parser(): parser = argparse.ArgumentParser() parser.add_argument("--port", default=8080, type=int) parser.add_argument("--device", default="cuda", type=str) parser.add_argument("--debug", action="store_true") return parser.parse_args() def process_inpaint(image, mask): image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) original_shape = image.shape interpolation = cv2.INTER_CUBIC #size_limit: Union[int, str] = request.form.get("sizeLimit", "1080") #if size_limit == "Original": size_limit = max(image.shape) #else: # size_limit = int(size_limit) print(f"Origin image shape: {original_shape}") image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation) print(f"Resized image shape: {image.shape}") image = norm_img(image) mask = 255-mask[:,:,3] mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation) mask = norm_img(mask) res_np_img = run(image, mask) return cv2.cvtColor(res_np_img, cv2.COLOR_BGR2RGB)