|
import argparse |
|
|
|
import torch |
|
import os |
|
import collections |
|
|
|
import sys |
|
from tqdm import tqdm |
|
|
|
sys.path.append('./TokenCut/model') |
|
sys.path.append('./TokenCut/unsupervised_saliency_detection') |
|
import dino |
|
|
|
import object_discovery as tokencut |
|
import argparse |
|
import utils |
|
import bilateral_solver |
|
import os |
|
|
|
from shutil import copyfile |
|
import PIL.Image as Image |
|
import cv2 |
|
import numpy as np |
|
from tqdm import tqdm |
|
|
|
from torchvision import transforms |
|
import metric |
|
import matplotlib.pyplot as plt |
|
import skimage |
|
import torch |
|
|
|
from tokencut_image_dataset import RobustnessDataset |
|
|
|
basewidth = 224 |
|
|
|
|
|
def mask_color_compose(org, mask, mask_color = [173, 216, 230]) : |
|
|
|
mask_fg = mask > 0.5 |
|
rgb = np.copy(org) |
|
rgb[mask_fg] = (rgb[mask_fg] * 0.3 + np.array(mask_color) * 0.7).astype(np.uint8) |
|
|
|
return Image.fromarray(rgb) |
|
|
|
|
|
ToTensor = transforms.Compose([ |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.485, 0.456, 0.406), |
|
(0.229, 0.224, 0.225)),]) |
|
|
|
def get_tokencut_binary_map(img_pth, backbone,patch_size, tau, resize_size) : |
|
|
|
|
|
I = Image.open(img_pth).convert('RGB') |
|
I = I.resize(resize_size) |
|
|
|
I_resize, w, h, feat_w, feat_h = utils.resize_pil(I, patch_size) |
|
|
|
feat = backbone(ToTensor(I_resize).unsqueeze(0).cuda())[0] |
|
|
|
seed, bipartition, eigvec = tokencut.ncut(feat, [feat_h, feat_w], [patch_size, patch_size], [h,w], tau) |
|
return bipartition, eigvec |
|
|
|
parser = argparse.ArgumentParser(description='Generate Seg maps') |
|
parser.add_argument('--img_path', metavar='path', |
|
help='path to image') |
|
|
|
parser.add_argument('--out_dir', type=str, help='output directory') |
|
|
|
parser.add_argument('--vit-arch', type=str, default='base', choices=['base', 'small'], help='which architecture') |
|
|
|
parser.add_argument('--vit-feat', type=str, default='k', choices=['k', 'q', 'v', 'kqv'], help='which features') |
|
|
|
parser.add_argument('--patch-size', type=int, default=16, choices=[16, 8], help='patch size') |
|
|
|
parser.add_argument('--tau', type=float, default=0.2, help='Tau for tresholding graph') |
|
|
|
parser.add_argument('--sigma-spatial', type=float, default=16, help='sigma spatial in the bilateral solver') |
|
|
|
parser.add_argument('--sigma-luma', type=float, default=16, help='sigma luma in the bilateral solver') |
|
|
|
parser.add_argument('--sigma-chroma', type=float, default=8, help='sigma chroma in the bilateral solver') |
|
|
|
|
|
parser.add_argument('--dataset', type=str, default=None, choices=['ECSSD', 'DUTS', 'DUT', None], help='which dataset?') |
|
|
|
parser.add_argument('--nb-vis', type=int, default=100, choices=[1, 200], help='nb of visualization') |
|
|
|
|
|
|
|
|
|
ImageItem = collections.namedtuple('ImageItem', ('image_name', 'tag')) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
url = "https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth" |
|
feat_dim = 768 |
|
args.patch_size = 16 |
|
args.vit_arch = 'base' |
|
|
|
backbone = dino.ViTFeat(url, feat_dim, args.vit_arch, args.vit_feat, args.patch_size) |
|
msg = 'Load {} pre-trained feature...'.format(args.vit_arch) |
|
print(msg) |
|
backbone.eval() |
|
backbone.cuda() |
|
|
|
with torch.no_grad(): |
|
|
|
img_pth = args.img_path |
|
img = Image.open(img_pth).convert('RGB') |
|
|
|
|
|
bipartition, eigvec = get_tokencut_binary_map(img_pth, backbone, args.patch_size, args.tau, img.size) |
|
output_solver, binary_solver = bilateral_solver.bilateral_solver_output(img_pth, bipartition, |
|
sigma_spatial=args.sigma_spatial, |
|
sigma_luma=args.sigma_luma, |
|
sigma_chroma=args.sigma_chroma, |
|
resize_size=img.size) |
|
mask1 = torch.from_numpy(bipartition).cuda() |
|
mask2 = torch.from_numpy(binary_solver).cuda() |
|
if metric.IoU(mask1, mask2) < 0.5: |
|
binary_solver = binary_solver * -1 |
|
|
|
|
|
|
|
|
|
|
|
img_name = img_pth.split("/")[-1] |
|
out_name = os.path.join(args.out_dir, img_name) |
|
out_lost = os.path.join(args.out_dir, img_name.replace('.JPEG', '_tokencut.JPEG')) |
|
out_bfs = os.path.join(args.out_dir, img_name.replace('.JPEG', '_tokencut_bfs.JPEG')) |
|
out_gt = os.path.join(args.out_dir, img_name.replace('.JPEG', '_gt.JPEG')) |
|
|
|
org = Image.open(img_pth).convert('RGB') |
|
|
|
mask_color_compose(org, bipartition).save(out_lost) |
|
mask_color_compose(org, binary_solver).save(out_bfs) |
|
|
|
|
|
|
|
torch.save(bipartition, os.path.join(args.out_dir, img_name.replace('.JPEG', '_tokencut.pt'))) |
|
torch.save(binary_solver, os.path.join(args.out_dir, img_name.replace('.JPEG', '_tokencut_bfs.pt'))) |
|
|