Spaces:
Running
on
T4
Running
on
T4
File size: 5,504 Bytes
4e3dd77 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import os
from argparse import Namespace
from tqdm import tqdm
import time
import numpy as np
import torch
from PIL import Image
from torch.utils.data import DataLoader
import sys
sys.path.append(".")
sys.path.append("..")
from configs import data_configs
from datasets.inference_dataset import InferenceDataset
from utils.common import tensor2im, log_input_image
from options.test_options import TestOptions
from models.psp import pSp
def run():
test_opts = TestOptions().parse()
if test_opts.resize_factors is not None:
assert len(
test_opts.resize_factors.split(',')) == 1, "When running inference, provide a single downsampling factor!"
out_path_results = os.path.join(test_opts.exp_dir, 'inference_results',
'downsampling_{}'.format(test_opts.resize_factors))
out_path_coupled = os.path.join(test_opts.exp_dir, 'inference_coupled',
'downsampling_{}'.format(test_opts.resize_factors))
else:
out_path_results = os.path.join(test_opts.exp_dir, 'inference_results')
out_path_coupled = os.path.join(test_opts.exp_dir, 'inference_coupled')
os.makedirs(out_path_results, exist_ok=True)
os.makedirs(out_path_coupled, exist_ok=True)
# update test options with options used during training
ckpt = torch.load(test_opts.checkpoint_path, map_location='cpu')
opts = ckpt['opts']
opts.update(vars(test_opts))
if 'learn_in_w' not in opts:
opts['learn_in_w'] = False
if 'output_size' not in opts:
opts['output_size'] = 1024
opts = Namespace(**opts)
net = pSp(opts)
net.eval()
net.cuda()
print('Loading dataset for {}'.format(opts.dataset_type))
dataset_args = data_configs.DATASETS[opts.dataset_type]
transforms_dict = dataset_args['transforms'](opts).get_transforms()
dataset = InferenceDataset(root=opts.data_path,
transform=transforms_dict['transform_inference'],
opts=opts)
dataloader = DataLoader(dataset,
batch_size=opts.test_batch_size,
shuffle=False,
num_workers=int(opts.test_workers),
drop_last=True)
if opts.n_images is None:
opts.n_images = len(dataset)
global_i = 0
global_time = []
for input_batch in tqdm(dataloader):
if global_i >= opts.n_images:
break
with torch.no_grad():
input_cuda = input_batch.cuda().float()
tic = time.time()
result_batch = run_on_batch(input_cuda, net, opts)
toc = time.time()
global_time.append(toc - tic)
for i in range(opts.test_batch_size):
result = tensor2im(result_batch[i])
im_path = dataset.paths[global_i]
if opts.couple_outputs or global_i % 100 == 0:
input_im = log_input_image(input_batch[i], opts)
resize_amount = (256, 256) if opts.resize_outputs else (opts.output_size, opts.output_size)
if opts.resize_factors is not None:
# for super resolution, save the original, down-sampled, and output
source = Image.open(im_path)
res = np.concatenate([np.array(source.resize(resize_amount)),
np.array(input_im.resize(resize_amount, resample=Image.NEAREST)),
np.array(result.resize(resize_amount))], axis=1)
else:
# otherwise, save the original and output
res = np.concatenate([np.array(input_im.resize(resize_amount)),
np.array(result.resize(resize_amount))], axis=1)
Image.fromarray(res).save(os.path.join(out_path_coupled, os.path.basename(im_path)))
im_save_path = os.path.join(out_path_results, os.path.basename(im_path))
Image.fromarray(np.array(result)).save(im_save_path)
global_i += 1
stats_path = os.path.join(opts.exp_dir, 'stats.txt')
result_str = 'Runtime {:.4f}+-{:.4f}'.format(np.mean(global_time), np.std(global_time))
print(result_str)
with open(stats_path, 'w') as f:
f.write(result_str)
def run_on_batch(inputs, net, opts):
if opts.latent_mask is None:
result_batch = net(inputs, randomize_noise=False, resize=opts.resize_outputs)
else:
latent_mask = [int(l) for l in opts.latent_mask.split(",")]
result_batch = []
for image_idx, input_image in enumerate(inputs):
# get latent vector to inject into our input image
vec_to_inject = np.random.randn(1, 512).astype('float32')
_, latent_to_inject = net(torch.from_numpy(vec_to_inject).to("cuda"),
input_code=True,
return_latents=True)
# get output image with injected style vector
res = net(input_image.unsqueeze(0).to("cuda").float(),
latent_mask=latent_mask,
inject_latent=latent_to_inject,
alpha=opts.mix_alpha,
resize=opts.resize_outputs)
result_batch.append(res)
result_batch = torch.cat(result_batch, dim=0)
return result_batch
if __name__ == '__main__':
run()
|