File size: 5,150 Bytes
7754b29 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import argparse
import torch
import os
import collections
import sys
from tqdm import tqdm
sys.path.append('./TokenCut/model')
sys.path.append('./TokenCut/unsupervised_saliency_detection')
import dino# model
import object_discovery as tokencut
import argparse
import utils
import bilateral_solver
import os
from shutil import copyfile
import PIL.Image as Image
import cv2
import numpy as np
from tqdm import tqdm
from torchvision import transforms
import metric
import matplotlib.pyplot as plt
import skimage
import torch
from tokencut_image_dataset import RobustnessDataset
basewidth = 224
def mask_color_compose(org, mask, mask_color = [173, 216, 230]) :
mask_fg = mask > 0.5
rgb = np.copy(org)
rgb[mask_fg] = (rgb[mask_fg] * 0.3 + np.array(mask_color) * 0.7).astype(np.uint8)
return Image.fromarray(rgb)
# Image transformation applied to all images
ToTensor = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225)),])
def get_tokencut_binary_map(img_pth, backbone,patch_size, tau, resize_size) :
I = Image.open(img_pth).convert('RGB')
I = I.resize(resize_size)
I_resize, w, h, feat_w, feat_h = utils.resize_pil(I, patch_size)
feat = backbone(ToTensor(I_resize).unsqueeze(0).cuda())[0]
seed, bipartition, eigvec = tokencut.ncut(feat, [feat_h, feat_w], [patch_size, patch_size], [h,w], tau)
return bipartition, eigvec
parser = argparse.ArgumentParser(description='Generate Seg maps')
parser.add_argument('--img_path', metavar='path',
help='path to image')
parser.add_argument('--out_dir', type=str, help='output directory')
parser.add_argument('--vit-arch', type=str, default='base', choices=['base', 'small'], help='which architecture')
parser.add_argument('--vit-feat', type=str, default='k', choices=['k', 'q', 'v', 'kqv'], help='which features')
parser.add_argument('--patch-size', type=int, default=16, choices=[16, 8], help='patch size')
parser.add_argument('--tau', type=float, default=0.2, help='Tau for tresholding graph')
parser.add_argument('--sigma-spatial', type=float, default=16, help='sigma spatial in the bilateral solver')
parser.add_argument('--sigma-luma', type=float, default=16, help='sigma luma in the bilateral solver')
parser.add_argument('--sigma-chroma', type=float, default=8, help='sigma chroma in the bilateral solver')
parser.add_argument('--dataset', type=str, default=None, choices=['ECSSD', 'DUTS', 'DUT', None], help='which dataset?')
parser.add_argument('--nb-vis', type=int, default=100, choices=[1, 200], help='nb of visualization')
ImageItem = collections.namedtuple('ImageItem', ('image_name', 'tag'))
if __name__ == '__main__':
args = parser.parse_args()
url = "https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
feat_dim = 768
args.patch_size = 16
args.vit_arch = 'base'
backbone = dino.ViTFeat(url, feat_dim, args.vit_arch, args.vit_feat, args.patch_size)
msg = 'Load {} pre-trained feature...'.format(args.vit_arch)
print(msg)
backbone.eval()
backbone.cuda()
with torch.no_grad():
# transforms - start
img_pth = args.img_path
img = Image.open(img_pth).convert('RGB')
bipartition, eigvec = get_tokencut_binary_map(img_pth, backbone, args.patch_size, args.tau, img.size)
output_solver, binary_solver = bilateral_solver.bilateral_solver_output(img_pth, bipartition,
sigma_spatial=args.sigma_spatial,
sigma_luma=args.sigma_luma,
sigma_chroma=args.sigma_chroma,
resize_size=img.size)
mask1 = torch.from_numpy(bipartition).cuda()
mask2 = torch.from_numpy(binary_solver).cuda()
if metric.IoU(mask1, mask2) < 0.5:
binary_solver = binary_solver * -1
#output segmented image
img_name = img_pth.split("/")[-1]
out_name = os.path.join(args.out_dir, img_name)
out_lost = os.path.join(args.out_dir, img_name.replace('.JPEG', '_tokencut.JPEG'))
out_bfs = os.path.join(args.out_dir, img_name.replace('.JPEG', '_tokencut_bfs.JPEG'))
out_gt = os.path.join(args.out_dir, img_name.replace('.JPEG', '_gt.JPEG'))
org = Image.open(img_pth).convert('RGB')
# plt.imsave(fname=out_eigvec, arr=eigvec, cmap='cividis')
mask_color_compose(org, bipartition).save(out_lost)
mask_color_compose(org, binary_solver).save(out_bfs)
#mask_color_compose(org, seg_map).save(out_gt)
torch.save(bipartition, os.path.join(args.out_dir, img_name.replace('.JPEG', '_tokencut.pt')))
torch.save(binary_solver, os.path.join(args.out_dir, img_name.replace('.JPEG', '_tokencut_bfs.pt')))
|