import argparse import torch import os import collections import sys from tqdm import tqdm sys.path.append('./TokenCut/model') sys.path.append('./TokenCut/unsupervised_saliency_detection') import dino# model import object_discovery as tokencut import argparse import utils import bilateral_solver import os from shutil import copyfile import PIL.Image as Image import cv2 import numpy as np from tqdm import tqdm from torchvision import transforms import metric import matplotlib.pyplot as plt import skimage import torch from tokencut_image_dataset import RobustnessDataset basewidth = 224 def mask_color_compose(org, mask, mask_color = [173, 216, 230]) : mask_fg = mask > 0.5 rgb = np.copy(org) rgb[mask_fg] = (rgb[mask_fg] * 0.3 + np.array(mask_color) * 0.7).astype(np.uint8) return Image.fromarray(rgb) # Image transformation applied to all images ToTensor = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),]) def get_tokencut_binary_map(img_pth, backbone,patch_size, tau, resize_size) : I = Image.open(img_pth).convert('RGB') I = I.resize(resize_size) I_resize, w, h, feat_w, feat_h = utils.resize_pil(I, patch_size) feat = backbone(ToTensor(I_resize).unsqueeze(0).cuda())[0] seed, bipartition, eigvec = tokencut.ncut(feat, [feat_h, feat_w], [patch_size, patch_size], [h,w], tau) return bipartition, eigvec parser = argparse.ArgumentParser(description='Generate Seg maps') parser.add_argument('--img_path', metavar='path', help='path to image') parser.add_argument('--out_dir', type=str, help='output directory') parser.add_argument('--vit-arch', type=str, default='base', choices=['base', 'small'], help='which architecture') parser.add_argument('--vit-feat', type=str, default='k', choices=['k', 'q', 'v', 'kqv'], help='which features') parser.add_argument('--patch-size', type=int, default=16, choices=[16, 8], help='patch size') parser.add_argument('--tau', type=float, default=0.2, help='Tau for tresholding graph') parser.add_argument('--sigma-spatial', type=float, default=16, help='sigma spatial in the bilateral solver') parser.add_argument('--sigma-luma', type=float, default=16, help='sigma luma in the bilateral solver') parser.add_argument('--sigma-chroma', type=float, default=8, help='sigma chroma in the bilateral solver') parser.add_argument('--dataset', type=str, default=None, choices=['ECSSD', 'DUTS', 'DUT', None], help='which dataset?') parser.add_argument('--nb-vis', type=int, default=100, choices=[1, 200], help='nb of visualization') ImageItem = collections.namedtuple('ImageItem', ('image_name', 'tag')) if __name__ == '__main__': args = parser.parse_args() url = "https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth" feat_dim = 768 args.patch_size = 16 args.vit_arch = 'base' backbone = dino.ViTFeat(url, feat_dim, args.vit_arch, args.vit_feat, args.patch_size) msg = 'Load {} pre-trained feature...'.format(args.vit_arch) print(msg) backbone.eval() backbone.cuda() with torch.no_grad(): # transforms - start img_pth = args.img_path img = Image.open(img_pth).convert('RGB') bipartition, eigvec = get_tokencut_binary_map(img_pth, backbone, args.patch_size, args.tau, img.size) output_solver, binary_solver = bilateral_solver.bilateral_solver_output(img_pth, bipartition, sigma_spatial=args.sigma_spatial, sigma_luma=args.sigma_luma, sigma_chroma=args.sigma_chroma, resize_size=img.size) mask1 = torch.from_numpy(bipartition).cuda() mask2 = torch.from_numpy(binary_solver).cuda() if metric.IoU(mask1, mask2) < 0.5: binary_solver = binary_solver * -1 #output segmented image img_name = img_pth.split("/")[-1] out_name = os.path.join(args.out_dir, img_name) out_lost = os.path.join(args.out_dir, img_name.replace('.JPEG', '_tokencut.JPEG')) out_bfs = os.path.join(args.out_dir, img_name.replace('.JPEG', '_tokencut_bfs.JPEG')) out_gt = os.path.join(args.out_dir, img_name.replace('.JPEG', '_gt.JPEG')) org = Image.open(img_pth).convert('RGB') # plt.imsave(fname=out_eigvec, arr=eigvec, cmap='cividis') mask_color_compose(org, bipartition).save(out_lost) mask_color_compose(org, binary_solver).save(out_bfs) #mask_color_compose(org, seg_map).save(out_gt) torch.save(bipartition, os.path.join(args.out_dir, img_name.replace('.JPEG', '_tokencut.pt'))) torch.save(binary_solver, os.path.join(args.out_dir, img_name.replace('.JPEG', '_tokencut_bfs.pt')))