Spaces:
Runtime error
Runtime error
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. | |
# | |
# This work is licensed under the Creative Commons Attribution-NonCommercial | |
# 4.0 International License. To view a copy of this license, visit | |
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to | |
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. | |
import os | |
import time | |
import re | |
import bisect | |
from collections import OrderedDict | |
import numpy as np | |
import tensorflow as tf | |
import scipy.ndimage | |
import scipy.misc | |
import config | |
import misc | |
import tfutil | |
import train | |
import dataset | |
#---------------------------------------------------------------------------- | |
# Generate random images or image grids using a previously trained network. | |
# To run, uncomment the appropriate line in config.py and launch train.py. | |
def generate_fake_images(run_id, snapshot=None, grid_size=[1,1], num_pngs=1, image_shrink=1, png_prefix=None, random_seed=1000, minibatch_size=8): | |
network_pkl = misc.locate_network_pkl(run_id, snapshot) | |
if png_prefix is None: | |
png_prefix = misc.get_id_string_for_network_pkl(network_pkl) + '-' | |
random_state = np.random.RandomState(random_seed) | |
print('Loading network from "%s"...' % network_pkl) | |
G, D, Gs = misc.load_network_pkl(run_id, snapshot) | |
result_subdir = misc.create_result_subdir(config.result_dir, config.desc) | |
for png_idx in range(num_pngs): | |
print('Generating png %d / %d...' % (png_idx, num_pngs)) | |
latents = misc.random_latents(np.prod(grid_size), Gs, random_state=random_state) | |
labels = np.zeros([latents.shape[0], 0], np.float32) | |
images = Gs.run(latents, labels, minibatch_size=minibatch_size, num_gpus=config.num_gpus, out_mul=127.5, out_add=127.5, out_shrink=image_shrink, out_dtype=np.uint8) | |
misc.save_image_grid(images, os.path.join(result_subdir, '%s%06d.png' % (png_prefix, png_idx)), [0,255], grid_size) | |
open(os.path.join(result_subdir, '_done.txt'), 'wt').close() | |
#---------------------------------------------------------------------------- | |
# Generate MP4 video of random interpolations using a previously trained network. | |
# To run, uncomment the appropriate line in config.py and launch train.py. | |
def generate_interpolation_video(run_id, snapshot=None, grid_size=[1,1], image_shrink=1, image_zoom=1, duration_sec=60.0, smoothing_sec=1.0, mp4=None, mp4_fps=30, mp4_codec='libx265', mp4_bitrate='16M', random_seed=1000, minibatch_size=8): | |
network_pkl = misc.locate_network_pkl(run_id, snapshot) | |
if mp4 is None: | |
mp4 = misc.get_id_string_for_network_pkl(network_pkl) + '-lerp.mp4' | |
num_frames = int(np.rint(duration_sec * mp4_fps)) | |
random_state = np.random.RandomState(random_seed) | |
print('Loading network from "%s"...' % network_pkl) | |
G, D, Gs = misc.load_network_pkl(run_id, snapshot) | |
print('Generating latent vectors...') | |
shape = [num_frames, np.prod(grid_size)] + Gs.input_shape[1:] # [frame, image, channel, component] | |
all_latents = random_state.randn(*shape).astype(np.float32) | |
all_latents = scipy.ndimage.gaussian_filter(all_latents, [smoothing_sec * mp4_fps] + [0] * len(Gs.input_shape), mode='wrap') | |
all_latents /= np.sqrt(np.mean(np.square(all_latents))) | |
# Frame generation func for moviepy. | |
def make_frame(t): | |
frame_idx = int(np.clip(np.round(t * mp4_fps), 0, num_frames - 1)) | |
latents = all_latents[frame_idx] | |
labels = np.zeros([latents.shape[0], 0], np.float32) | |
images = Gs.run(latents, labels, minibatch_size=minibatch_size, num_gpus=config.num_gpus, out_mul=127.5, out_add=127.5, out_shrink=image_shrink, out_dtype=np.uint8) | |
grid = misc.create_image_grid(images, grid_size).transpose(1, 2, 0) # HWC | |
if image_zoom > 1: | |
grid = scipy.ndimage.zoom(grid, [image_zoom, image_zoom, 1], order=0) | |
if grid.shape[2] == 1: | |
grid = grid.repeat(3, 2) # grayscale => RGB | |
return grid | |
# Generate video. | |
import moviepy.editor # pip install moviepy | |
result_subdir = misc.create_result_subdir(config.result_dir, config.desc) | |
moviepy.editor.VideoClip(make_frame, duration=duration_sec).write_videofile(os.path.join(result_subdir, mp4), fps=mp4_fps, codec='libx264', bitrate=mp4_bitrate) | |
open(os.path.join(result_subdir, '_done.txt'), 'wt').close() | |
#---------------------------------------------------------------------------- | |
# Generate MP4 video of training progress for a previous training run. | |
# To run, uncomment the appropriate line in config.py and launch train.py. | |
def generate_training_video(run_id, duration_sec=20.0, time_warp=1.5, mp4=None, mp4_fps=30, mp4_codec='libx265', mp4_bitrate='16M'): | |
src_result_subdir = misc.locate_result_subdir(run_id) | |
if mp4 is None: | |
mp4 = os.path.basename(src_result_subdir) + '-train.mp4' | |
# Parse log. | |
times = [] | |
snaps = [] # [(png, kimg, lod), ...] | |
with open(os.path.join(src_result_subdir, 'log.txt'), 'rt') as log: | |
for line in log: | |
k = re.search(r'kimg ([\d\.]+) ', line) | |
l = re.search(r'lod ([\d\.]+) ', line) | |
t = re.search(r'time (\d+d)? *(\d+h)? *(\d+m)? *(\d+s)? ', line) | |
if k and l and t: | |
k = float(k.group(1)) | |
l = float(l.group(1)) | |
t = [int(t.group(i)[:-1]) if t.group(i) else 0 for i in range(1, 5)] | |
t = t[0] * 24*60*60 + t[1] * 60*60 + t[2] * 60 + t[3] | |
png = os.path.join(src_result_subdir, 'fakes%06d.png' % int(np.floor(k))) | |
if os.path.isfile(png): | |
times.append(t) | |
snaps.append((png, k, l)) | |
assert len(times) | |
# Frame generation func for moviepy. | |
png_cache = [None, None] # [png, img] | |
def make_frame(t): | |
wallclock = ((t / duration_sec) ** time_warp) * times[-1] | |
png, kimg, lod = snaps[max(bisect.bisect(times, wallclock) - 1, 0)] | |
if png_cache[0] == png: | |
img = png_cache[1] | |
else: | |
img = scipy.misc.imread(png) | |
while img.shape[1] > 1920 or img.shape[0] > 1080: | |
img = img.astype(np.float32).reshape(img.shape[0]//2, 2, img.shape[1]//2, 2, -1).mean(axis=(1,3)) | |
png_cache[:] = [png, img] | |
img = misc.draw_text_label(img, 'lod %.2f' % lod, 16, img.shape[0]-4, alignx=0.0, aligny=1.0) | |
img = misc.draw_text_label(img, misc.format_time(int(np.rint(wallclock))), img.shape[1]//2, img.shape[0]-4, alignx=0.5, aligny=1.0) | |
img = misc.draw_text_label(img, '%.0f kimg' % kimg, img.shape[1]-16, img.shape[0]-4, alignx=1.0, aligny=1.0) | |
return img | |
# Generate video. | |
import moviepy.editor # pip install moviepy | |
result_subdir = misc.create_result_subdir(config.result_dir, config.desc) | |
moviepy.editor.VideoClip(make_frame, duration=duration_sec).write_videofile(os.path.join(result_subdir, mp4), fps=mp4_fps, codec='libx264', bitrate=mp4_bitrate) | |
open(os.path.join(result_subdir, '_done.txt'), 'wt').close() | |
#---------------------------------------------------------------------------- | |
# Evaluate one or more metrics for a previous training run. | |
# To run, uncomment one of the appropriate lines in config.py and launch train.py. | |
def evaluate_metrics(run_id, log, metrics, num_images, real_passes, minibatch_size=None): | |
metric_class_names = { | |
'swd': 'metrics.sliced_wasserstein.API', | |
'fid': 'metrics.frechet_inception_distance.API', | |
'is': 'metrics.inception_score.API', | |
'msssim': 'metrics.ms_ssim.API', | |
} | |
# Locate training run and initialize logging. | |
result_subdir = misc.locate_result_subdir(run_id) | |
snapshot_pkls = misc.list_network_pkls(result_subdir, include_final=False) | |
assert len(snapshot_pkls) >= 1 | |
log_file = os.path.join(result_subdir, log) | |
print('Logging output to', log_file) | |
misc.set_output_log_file(log_file) | |
# Initialize dataset and select minibatch size. | |
dataset_obj, mirror_augment = misc.load_dataset_for_previous_run(result_subdir, verbose=True, shuffle_mb=0) | |
if minibatch_size is None: | |
minibatch_size = np.clip(8192 // dataset_obj.shape[1], 4, 256) | |
# Initialize metrics. | |
metric_objs = [] | |
for name in metrics: | |
class_name = metric_class_names.get(name, name) | |
print('Initializing %s...' % class_name) | |
class_def = tfutil.import_obj(class_name) | |
image_shape = [3] + dataset_obj.shape[1:] | |
obj = class_def(num_images=num_images, image_shape=image_shape, image_dtype=np.uint8, minibatch_size=minibatch_size) | |
tfutil.init_uninited_vars() | |
mode = 'warmup' | |
obj.begin(mode) | |
for idx in range(10): | |
obj.feed(mode, np.random.randint(0, 256, size=[minibatch_size]+image_shape, dtype=np.uint8)) | |
obj.end(mode) | |
metric_objs.append(obj) | |
# Print table header. | |
print() | |
print('%-10s%-12s' % ('Snapshot', 'Time_eval'), end='') | |
for obj in metric_objs: | |
for name, fmt in zip(obj.get_metric_names(), obj.get_metric_formatting()): | |
print('%-*s' % (len(fmt % 0), name), end='') | |
print() | |
print('%-10s%-12s' % ('---', '---'), end='') | |
for obj in metric_objs: | |
for fmt in obj.get_metric_formatting(): | |
print('%-*s' % (len(fmt % 0), '---'), end='') | |
print() | |
# Feed in reals. | |
for title, mode in [('Reals', 'reals'), ('Reals2', 'fakes')][:real_passes]: | |
print('%-10s' % title, end='') | |
time_begin = time.time() | |
labels = np.zeros([num_images, dataset_obj.label_size], dtype=np.float32) | |
[obj.begin(mode) for obj in metric_objs] | |
for begin in range(0, num_images, minibatch_size): | |
end = min(begin + minibatch_size, num_images) | |
images, labels[begin:end] = dataset_obj.get_minibatch_np(end - begin) | |
if mirror_augment: | |
images = misc.apply_mirror_augment(images) | |
if images.shape[1] == 1: | |
images = np.tile(images, [1, 3, 1, 1]) # grayscale => RGB | |
[obj.feed(mode, images) for obj in metric_objs] | |
results = [obj.end(mode) for obj in metric_objs] | |
print('%-12s' % misc.format_time(time.time() - time_begin), end='') | |
for obj, vals in zip(metric_objs, results): | |
for val, fmt in zip(vals, obj.get_metric_formatting()): | |
print(fmt % val, end='') | |
print() | |
# Evaluate each network snapshot. | |
for snapshot_idx, snapshot_pkl in enumerate(reversed(snapshot_pkls)): | |
prefix = 'network-snapshot-'; postfix = '.pkl' | |
snapshot_name = os.path.basename(snapshot_pkl) | |
assert snapshot_name.startswith(prefix) and snapshot_name.endswith(postfix) | |
snapshot_kimg = int(snapshot_name[len(prefix) : -len(postfix)]) | |
print('%-10d' % snapshot_kimg, end='') | |
mode ='fakes' | |
[obj.begin(mode) for obj in metric_objs] | |
time_begin = time.time() | |
with tf.Graph().as_default(), tfutil.create_session(config.tf_config).as_default(): | |
G, D, Gs = misc.load_pkl(snapshot_pkl) | |
for begin in range(0, num_images, minibatch_size): | |
end = min(begin + minibatch_size, num_images) | |
latents = misc.random_latents(end - begin, Gs) | |
images = Gs.run(latents, labels[begin:end], num_gpus=config.num_gpus, out_mul=127.5, out_add=127.5, out_dtype=np.uint8) | |
if images.shape[1] == 1: | |
images = np.tile(images, [1, 3, 1, 1]) # grayscale => RGB | |
[obj.feed(mode, images) for obj in metric_objs] | |
results = [obj.end(mode) for obj in metric_objs] | |
print('%-12s' % misc.format_time(time.time() - time_begin), end='') | |
for obj, vals in zip(metric_objs, results): | |
for val, fmt in zip(vals, obj.get_metric_formatting()): | |
print(fmt % val, end='') | |
print() | |
print() | |
#---------------------------------------------------------------------------- | |