File size: 3,639 Bytes
4e3dd77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import os
from argparse import Namespace

from tqdm import tqdm
import numpy as np
from PIL import Image
import torch
from torch.utils.data import DataLoader
import sys

sys.path.append(".")
sys.path.append("..")

from configs import data_configs
from datasets.inference_dataset import InferenceDataset
from utils.common import tensor2im, log_input_image
from options.test_options import TestOptions
from models.psp import pSp


def run():
	test_opts = TestOptions().parse()

	if test_opts.resize_factors is not None:
		factors = test_opts.resize_factors.split(',')
		assert len(factors) == 1, "When running inference, please provide a single downsampling factor!"
		mixed_path_results = os.path.join(test_opts.exp_dir, 'style_mixing',
		                                  'downsampling_{}'.format(test_opts.resize_factors))
	else:
		mixed_path_results = os.path.join(test_opts.exp_dir, 'style_mixing')
	os.makedirs(mixed_path_results, exist_ok=True)

	# update test options with options used during training
	ckpt = torch.load(test_opts.checkpoint_path, map_location='cpu')
	opts = ckpt['opts']
	opts.update(vars(test_opts))
	if 'learn_in_w' not in opts:
		opts['learn_in_w'] = False
	if 'output_size' not in opts:
		opts['output_size'] = 1024
	opts = Namespace(**opts)

	net = pSp(opts)
	net.eval()
	net.cuda()

	print('Loading dataset for {}'.format(opts.dataset_type))
	dataset_args = data_configs.DATASETS[opts.dataset_type]
	transforms_dict = dataset_args['transforms'](opts).get_transforms()
	dataset = InferenceDataset(root=opts.data_path,
	                           transform=transforms_dict['transform_inference'],
	                           opts=opts)
	dataloader = DataLoader(dataset,
	                        batch_size=opts.test_batch_size,
	                        shuffle=False,
	                        num_workers=int(opts.test_workers),
	                        drop_last=True)

	latent_mask = [int(l) for l in opts.latent_mask.split(",")]
	if opts.n_images is None:
		opts.n_images = len(dataset)

	global_i = 0
	for input_batch in tqdm(dataloader):
		if global_i >= opts.n_images:
			break
		with torch.no_grad():
			input_batch = input_batch.cuda()
			for image_idx, input_image in enumerate(input_batch):
				# generate random vectors to inject into input image
				vecs_to_inject = np.random.randn(opts.n_outputs_to_generate, 512).astype('float32')
				multi_modal_outputs = []
				for vec_to_inject in vecs_to_inject:
					cur_vec = torch.from_numpy(vec_to_inject).unsqueeze(0).to("cuda")
					# get latent vector to inject into our input image
					_, latent_to_inject = net(cur_vec,
					                          input_code=True,
					                          return_latents=True)
					# get output image with injected style vector
					res = net(input_image.unsqueeze(0).to("cuda").float(),
					          latent_mask=latent_mask,
					          inject_latent=latent_to_inject,
					          alpha=opts.mix_alpha,
							  resize=opts.resize_outputs)
					multi_modal_outputs.append(res[0])

				# visualize multi modal outputs
				input_im_path = dataset.paths[global_i]
				image = input_batch[image_idx]
				input_image = log_input_image(image, opts)
				resize_amount = (256, 256) if opts.resize_outputs else (opts.output_size, opts.output_size)
				res = np.array(input_image.resize(resize_amount))
				for output in multi_modal_outputs:
					output = tensor2im(output)
					res = np.concatenate([res, np.array(output.resize(resize_amount))], axis=1)
				Image.fromarray(res).save(os.path.join(mixed_path_results, os.path.basename(input_im_path)))
				global_i += 1


if __name__ == '__main__':
	run()