Spaces:
Running
Running
import numpy as np | |
import torch | |
import cv2 | |
def dt(a): | |
return cv2.distanceTransform((a * 255).astype(np.uint8), cv2.DIST_L2, 0) | |
def trimap_transform(trimap, L=320): | |
clicks = [] | |
for k in range(2): | |
dt_mask = -dt(1 - trimap[:, :, k]) ** 2 | |
clicks.append(np.exp(dt_mask / (2 * ((0.02 * L) ** 2)))) | |
clicks.append(np.exp(dt_mask / (2 * ((0.08 * L) ** 2)))) | |
clicks.append(np.exp(dt_mask / (2 * ((0.16 * L) ** 2)))) | |
clicks = np.array(clicks) | |
return clicks | |
# For RGB ! | |
imagenet_norm_std = torch.from_numpy(np.array([0.229, 0.224, 0.225])).float().cpu()[None, :, None, None] | |
imagenet_norm_mean = torch.from_numpy(np.array([0.485, 0.456, 0.406])).float().cpu()[None, :, None, None] | |
def normalise_image(image, mean=imagenet_norm_mean, std=imagenet_norm_std): | |
return (image - mean) / std | |