Spaces:
Running
Running
# Copyright 2020 Erik Härkönen. All rights reserved. | |
# This file is licensed to you under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. You may obtain a copy | |
# of the License at http://www.apache.org/licenses/LICENSE-2.0 | |
# Unless required by applicable law or agreed to in writing, software distributed under | |
# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS | |
# OF ANY KIND, either express or implied. See the License for the specific language | |
# governing permissions and limitations under the License. | |
import torch | |
import numpy as np | |
from os import makedirs | |
from PIL import Image | |
import sys | |
from pathlib import Path | |
sys.path.insert(0, str(Path(__file__).parent.parent)) | |
from utils import prettify_name, pad_frames | |
# Apply edit to given latents, return strip of images | |
def create_strip(inst, mode, layer, latents, x_comp, z_comp, act_stdev, lat_stdev, sigma, layer_start, layer_end, num_frames=5): | |
return _create_strip_impl(inst, mode, layer, latents, x_comp, z_comp, act_stdev, | |
lat_stdev, None, None, sigma, layer_start, layer_end, num_frames, center=False) | |
# Strip where the sample is centered along the component before manipulation | |
def create_strip_centered(inst, mode, layer, latents, x_comp, z_comp, act_stdev, lat_stdev, act_mean, lat_mean, sigma, layer_start, layer_end, num_frames=5): | |
return _create_strip_impl(inst, mode, layer, latents, x_comp, z_comp, act_stdev, | |
lat_stdev, act_mean, lat_mean, sigma, layer_start, layer_end, num_frames, center=True) | |
def _create_strip_impl(inst, mode, layer, latents, x_comp, z_comp, act_stdev, lat_stdev, act_mean, lat_mean, sigma, layer_start, layer_end, num_frames, center): | |
if not isinstance(latents, list): | |
latents = list(latents) | |
max_lat = inst.model.get_max_latents() | |
if layer_end < 0 or layer_end > max_lat: | |
layer_end = max_lat | |
layer_start = np.clip(layer_start, 0, layer_end) | |
if len(latents) > num_frames: | |
# Batch over latents | |
return _create_strip_batch_lat(inst, mode, layer, latents, x_comp, z_comp, | |
act_stdev, lat_stdev, act_mean, lat_mean, sigma, layer_start, layer_end, num_frames, center) | |
else: | |
# Batch over strip frames | |
return _create_strip_batch_sigma(inst, mode, layer, latents, x_comp, z_comp, | |
act_stdev, lat_stdev, act_mean, lat_mean, sigma, layer_start, layer_end, num_frames, center) | |
# Batch over frames if there are more frames in strip than latents | |
def _create_strip_batch_sigma(inst, mode, layer, latents, x_comp, z_comp, act_stdev, lat_stdev, act_mean, lat_mean, sigma, layer_start, layer_end, num_frames, center): | |
inst.close() | |
batch_frames = [[] for _ in range(len(latents))] | |
B = min(num_frames, 5) | |
lep_padded = ((num_frames - 1) // B + 1) * B | |
sigma_range = np.linspace(-sigma, sigma, num_frames) | |
sigma_range = np.concatenate([sigma_range, np.zeros((lep_padded - num_frames))]) | |
sigma_range = torch.from_numpy(sigma_range).float().to(inst.model.device) | |
normalize = lambda v : v / torch.sqrt(torch.sum(v**2, dim=-1, keepdim=True) + 1e-8) | |
for i_batch in range(lep_padded // B): | |
sigmas = sigma_range[i_batch*B:(i_batch+1)*B] | |
for i_lat in range(len(latents)): | |
z_single = latents[i_lat] | |
z_batch = z_single.repeat_interleave(B, axis=0) | |
zeroing_offset_act = 0 | |
zeroing_offset_lat = 0 | |
if center: | |
if mode == 'activation': | |
# Center along activation before applying offset | |
inst.retain_layer(layer) | |
_ = inst.model.sample_np(z_single) | |
value = inst.retained_features()[layer].clone() | |
dotp = torch.sum((value - act_mean)*normalize(x_comp), dim=-1, keepdim=True) | |
zeroing_offset_act = normalize(x_comp)*dotp # offset that sets coordinate to zero | |
else: | |
# Shift latent to lie on mean along given component | |
dotp = torch.sum((z_single - lat_mean)*normalize(z_comp), dim=-1, keepdim=True) | |
zeroing_offset_lat = dotp*normalize(z_comp) | |
with torch.no_grad(): | |
z = z_batch | |
if mode in ['latent', 'both']: | |
z = [z]*inst.model.get_max_latents() | |
delta = z_comp * sigmas.reshape([-1] + [1]*(z_comp.ndim - 1)) * lat_stdev | |
for i in range(layer_start, layer_end): | |
z[i] = z[i] - zeroing_offset_lat + delta | |
if mode in ['activation', 'both']: | |
comp_batch = x_comp.repeat_interleave(B, axis=0) | |
delta = comp_batch * sigmas.reshape([-1] + [1]*(comp_batch.ndim - 1)) | |
inst.edit_layer(layer, offset=delta * act_stdev - zeroing_offset_act) | |
img_batch = inst.model.sample_np(z) | |
if img_batch.ndim == 3: | |
img_batch = np.expand_dims(img_batch, axis=0) | |
for j, img in enumerate(img_batch): | |
idx = i_batch*B + j | |
if idx < num_frames: | |
batch_frames[i_lat].append(img) | |
return batch_frames | |
# Batch over latents if there are more latents than frames in strip | |
def _create_strip_batch_lat(inst, mode, layer, latents, x_comp, z_comp, act_stdev, lat_stdev, act_mean, lat_mean, sigma, layer_start, layer_end, num_frames, center): | |
n_lat = len(latents) | |
B = min(n_lat, 5) | |
max_lat = inst.model.get_max_latents() | |
if layer_end < 0 or layer_end > max_lat: | |
layer_end = max_lat | |
layer_start = np.clip(layer_start, 0, layer_end) | |
len_padded = ((n_lat - 1) // B + 1) * B | |
batch_frames = [[] for _ in range(n_lat)] | |
for i_batch in range(len_padded // B): | |
zs = latents[i_batch*B:(i_batch+1)*B] | |
if len(zs) == 0: | |
continue | |
z_batch_single = torch.cat(zs, 0) | |
inst.close() # don't retain, remove edits | |
sigma_range = np.linspace(-sigma, sigma, num_frames, dtype=np.float32) | |
normalize = lambda v : v / torch.sqrt(torch.sum(v**2, dim=-1, keepdim=True) + 1e-8) | |
zeroing_offset_act = 0 | |
zeroing_offset_lat = 0 | |
if center: | |
if mode == 'activation': | |
# Center along activation before applying offset | |
inst.retain_layer(layer) | |
_ = inst.model.sample_np(z_batch_single) | |
value = inst.retained_features()[layer].clone() | |
dotp = torch.sum((value - act_mean)*normalize(x_comp), dim=-1, keepdim=True) | |
zeroing_offset_act = normalize(x_comp)*dotp # offset that sets coordinate to zero | |
else: | |
# Shift latent to lie on mean along given component | |
dotp = torch.sum((z_batch_single - lat_mean)*normalize(z_comp), dim=-1, keepdim=True) | |
zeroing_offset_lat = dotp*normalize(z_comp) | |
for i in range(len(sigma_range)): | |
s = sigma_range[i] | |
with torch.no_grad(): | |
z = [z_batch_single]*inst.model.get_max_latents() # one per layer | |
if mode in ['latent', 'both']: | |
delta = z_comp*s*lat_stdev | |
for i in range(layer_start, layer_end): | |
z[i] = z[i] - zeroing_offset_lat + delta | |
if mode in ['activation', 'both']: | |
act_delta = x_comp*s*act_stdev | |
inst.edit_layer(layer, offset=act_delta - zeroing_offset_act) | |
img_batch = inst.model.sample_np(z) | |
if img_batch.ndim == 3: | |
img_batch = np.expand_dims(img_batch, axis=0) | |
for j, img in enumerate(img_batch): | |
img_idx = i_batch*B + j | |
if img_idx < n_lat: | |
batch_frames[img_idx].append(img) | |
return batch_frames | |
def save_frames(title, model_name, rootdir, frames, strip_width=10): | |
test_name = prettify_name(title) | |
outdir = f'{rootdir}/{model_name}/{test_name}' | |
makedirs(outdir, exist_ok=True) | |
# Limit maximum resolution | |
max_H = 512 | |
real_H = frames[0][0].shape[0] | |
ratio = min(1.0, max_H / real_H) | |
# Combined with first 10 | |
strips = [np.hstack(frames) for frames in frames[:strip_width]] | |
if len(strips) >= strip_width: | |
left_col = np.vstack(strips[0:strip_width//2]) | |
right_col = np.vstack(strips[5:10]) | |
grid = np.hstack([left_col, np.ones_like(left_col[:, :30]), right_col]) | |
im = Image.fromarray((255*grid).astype(np.uint8)) | |
im = im.resize((int(ratio*im.size[0]), int(ratio*im.size[1])), Image.ANTIALIAS) | |
im.save(f'{outdir}/{test_name}_all.png') | |
else: | |
print('Too few strips to create grid, creating just strips!') | |
for ex_num, strip in enumerate(frames[:strip_width]): | |
im = Image.fromarray(np.uint8(255*np.hstack(pad_frames(strip)))) | |
im = im.resize((int(ratio*im.size[0]), int(ratio*im.size[1])), Image.ANTIALIAS) | |
im.save(f'{outdir}/{test_name}_{ex_num}.png') |