import argparse import torch import os import shutil import json import glob import cv2 import imageio import numpy as np from tqdm import tqdm from torch import nn from torch.nn import functional as F from torch.utils.data import DataLoader from torchvision import transforms as T from torchvision.transforms.functional import to_pil_image from threading import Thread from dataset import ImagesDataset, ZipDataset from dataset import augmentation as A from model import MattingBase, MattingRefine from inference_utils import HomographicAlignment def preprocess_nersemble(args, data_folder, camera_ids): device = torch.device(args.device) # Load model if args.model_type == 'mattingbase': model = MattingBase(args.model_backbone) if args.model_type == 'mattingrefine': model = MattingRefine( args.model_backbone, args.model_backbone_scale, args.model_refine_mode, args.model_refine_sample_pixels, args.model_refine_threshold, args.model_refine_kernel_size) model = model.to(device).eval() model.load_state_dict(torch.load(args.model_checkpoint, map_location=device), strict=False) fids = sorted(os.listdir(os.path.join(data_folder, 'images'))) for v in range(len(camera_ids)): for fid in tqdm(fids): image_path = os.path.join(data_folder, 'images', fid, 'image_%s.jpg' % camera_ids[v]) background_path = os.path.join(data_folder, 'background', 'image_%s.jpg' % camera_ids[v]) if not os.path.exists(image_path): continue image = imageio.imread(image_path) src = (torch.from_numpy(image).float() / 255).permute(2,0,1)[None].to(device, non_blocking=True) if os.path.exists(background_path): background = imageio.imread(background_path) bgr = (torch.from_numpy(background).float() / 255).permute(2,0,1)[None].to(device, non_blocking=True) else: bgr = src * 0.0 with torch.no_grad(): if args.model_type == 'mattingbase': pha, fgr, err, _ = model(src, bgr) elif args.model_type == 'mattingrefine': pha, fgr, _, _, err, ref = model(src, bgr) mask = (pha[0].repeat([3, 1, 1]) * 255).permute(1,2,0).cpu().numpy().astype(np.uint8) mask_lowres = cv2.resize(mask, (256, 256)) mask_path = os.path.join(data_folder, 'images', fid, 'mask_%s.jpg' % camera_ids[v]) imageio.imsave(mask_path, mask) mask_lowres_path = os.path.join(data_folder, 'images', fid, 'mask_lowres_%s.jpg' % camera_ids[v]) imageio.imsave(mask_lowres_path, mask_lowres) if __name__ == "__main__": parser = argparse.ArgumentParser(description='Inference images') parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda') parser.add_argument('--model-type', type=str, default='mattingrefine', choices=['mattingbase', 'mattingrefine']) parser.add_argument('--model-backbone', type=str, default='resnet101', choices=['resnet101', 'resnet50', 'mobilenetv2']) parser.add_argument('--model-backbone-scale', type=float, default=0.25) parser.add_argument('--model-checkpoint', type=str, default='assets/pytorch_resnet101.pth') parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding']) parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000) parser.add_argument('--model-refine-threshold', type=float, default=0.7) parser.add_argument('--model-refine-kernel-size', type=int, default=3) args = parser.parse_args() DATA_SOURCE = '../NeRSemble' CAMERA_IDS = ['220700191', '221501007', '222200036', '222200037', '222200038', '222200039', '222200040', '222200041', '222200042', '222200043', '222200044', '222200045', '222200046', '222200047', '222200048', '222200049'] ids = sorted(os.listdir(DATA_SOURCE)) for id in ids: data_folder = os.path.join(DATA_SOURCE, id) preprocess_nersemble(args, data_folder, CAMERA_IDS)