Spaces:
Running
on
A10G
Running
on
A10G
""" | |
Spectral Normalization from https://arxiv.org/abs/1802.05957 | |
""" | |
import torch | |
from torch.nn.functional import normalize | |
class SpectralNorm(object): | |
# Invariant before and after each forward call: | |
# u = normalize(W @ v) | |
# NB: At initialization, this invariant is not enforced | |
_version = 1 | |
# At version 1: | |
# made `W` not a buffer, | |
# added `v` as a buffer, and | |
# made eval mode use `W = u @ W_orig @ v` rather than the stored `W`. | |
def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12): | |
self.name = name | |
self.dim = dim | |
if n_power_iterations <= 0: | |
raise ValueError( | |
'Expected n_power_iterations to be positive, but ' | |
'got n_power_iterations={}'.format(n_power_iterations)) | |
self.n_power_iterations = n_power_iterations | |
self.eps = eps | |
def reshape_weight_to_matrix(self, weight): | |
weight_mat = weight | |
if self.dim != 0: | |
# permute dim to front | |
weight_mat = weight_mat.permute( | |
self.dim, | |
*[d for d in range(weight_mat.dim()) if d != self.dim]) | |
height = weight_mat.size(0) | |
return weight_mat.reshape(height, -1) | |
def compute_weight(self, module, do_power_iteration): | |
# NB: If `do_power_iteration` is set, the `u` and `v` vectors are | |
# updated in power iteration **in-place**. This is very important | |
# because in `DataParallel` forward, the vectors (being buffers) are | |
# broadcast from the parallelized module to each module replica, | |
# which is a new module object created on the fly. And each replica | |
# runs its own spectral norm power iteration. So simply assigning | |
# the updated vectors to the module this function runs on will cause | |
# the update to be lost forever. And the next time the parallelized | |
# module is replicated, the same randomly initialized vectors are | |
# broadcast and used! | |
# | |
# Therefore, to make the change propagate back, we rely on two | |
# important behaviors (also enforced via tests): | |
# 1. `DataParallel` doesn't clone storage if the broadcast tensor | |
# is already on correct device; and it makes sure that the | |
# parallelized module is already on `device[0]`. | |
# 2. If the out tensor in `out=` kwarg has correct shape, it will | |
# just fill in the values. | |
# Therefore, since the same power iteration is performed on all | |
# devices, simply updating the tensors in-place will make sure that | |
# the module replica on `device[0]` will update the _u vector on the | |
# parallized module (by shared storage). | |
# | |
# However, after we update `u` and `v` in-place, we need to **clone** | |
# them before using them to normalize the weight. This is to support | |
# backproping through two forward passes, e.g., the common pattern in | |
# GAN training: loss = D(real) - D(fake). Otherwise, engine will | |
# complain that variables needed to do backward for the first forward | |
# (i.e., the `u` and `v` vectors) are changed in the second forward. | |
weight = getattr(module, self.name + '_orig') | |
u = getattr(module, self.name + '_u') | |
v = getattr(module, self.name + '_v') | |
weight_mat = self.reshape_weight_to_matrix(weight) | |
if do_power_iteration: | |
with torch.no_grad(): | |
for _ in range(self.n_power_iterations): | |
# Spectral norm of weight equals to `u^T W v`, where `u` and `v` | |
# are the first left and right singular vectors. | |
# This power iteration produces approximations of `u` and `v`. | |
v = normalize(torch.mv(weight_mat.t(), u), | |
dim=0, | |
eps=self.eps, | |
out=v) | |
u = normalize(torch.mv(weight_mat, v), | |
dim=0, | |
eps=self.eps, | |
out=u) | |
if self.n_power_iterations > 0: | |
# See above on why we need to clone | |
u = u.clone() | |
v = v.clone() | |
sigma = torch.dot(u, torch.mv(weight_mat, v)) | |
weight = weight / sigma | |
return weight | |
def remove(self, module): | |
with torch.no_grad(): | |
weight = self.compute_weight(module, do_power_iteration=False) | |
delattr(module, self.name) | |
delattr(module, self.name + '_u') | |
delattr(module, self.name + '_v') | |
delattr(module, self.name + '_orig') | |
module.register_parameter(self.name, | |
torch.nn.Parameter(weight.detach())) | |
def __call__(self, module, inputs): | |
setattr( | |
module, self.name, | |
self.compute_weight(module, do_power_iteration=module.training)) | |
def _solve_v_and_rescale(self, weight_mat, u, target_sigma): | |
# Tries to returns a vector `v` s.t. `u = normalize(W @ v)` | |
# (the invariant at top of this class) and `u @ W @ v = sigma`. | |
# This uses pinverse in case W^T W is not invertible. | |
v = torch.chain_matmul(weight_mat.t().mm(weight_mat).pinverse(), | |
weight_mat.t(), u.unsqueeze(1)).squeeze(1) | |
return v.mul_(target_sigma / torch.dot(u, torch.mv(weight_mat, v))) | |
def apply(module, name, n_power_iterations, dim, eps): | |
for k, hook in module._forward_pre_hooks.items(): | |
if isinstance(hook, SpectralNorm) and hook.name == name: | |
raise RuntimeError( | |
"Cannot register two spectral_norm hooks on " | |
"the same parameter {}".format(name)) | |
fn = SpectralNorm(name, n_power_iterations, dim, eps) | |
weight = module._parameters[name] | |
with torch.no_grad(): | |
weight_mat = fn.reshape_weight_to_matrix(weight) | |
h, w = weight_mat.size() | |
# randomly initialize `u` and `v` | |
u = normalize(weight.new_empty(h).normal_(0, 1), dim=0, eps=fn.eps) | |
v = normalize(weight.new_empty(w).normal_(0, 1), dim=0, eps=fn.eps) | |
delattr(module, fn.name) | |
module.register_parameter(fn.name + "_orig", weight) | |
# We still need to assign weight back as fn.name because all sorts of | |
# things may assume that it exists, e.g., when initializing weights. | |
# However, we can't directly assign as it could be an nn.Parameter and | |
# gets added as a parameter. Instead, we register weight.data as a plain | |
# attribute. | |
setattr(module, fn.name, weight.data) | |
module.register_buffer(fn.name + "_u", u) | |
module.register_buffer(fn.name + "_v", v) | |
module.register_forward_pre_hook(fn) | |
module._register_state_dict_hook(SpectralNormStateDictHook(fn)) | |
module._register_load_state_dict_pre_hook( | |
SpectralNormLoadStateDictPreHook(fn)) | |
return fn | |
# This is a top level class because Py2 pickle doesn't like inner class nor an | |
# instancemethod. | |
class SpectralNormLoadStateDictPreHook(object): | |
# See docstring of SpectralNorm._version on the changes to spectral_norm. | |
def __init__(self, fn): | |
self.fn = fn | |
# For state_dict with version None, (assuming that it has gone through at | |
# least one training forward), we have | |
# | |
# u = normalize(W_orig @ v) | |
# W = W_orig / sigma, where sigma = u @ W_orig @ v | |
# | |
# To compute `v`, we solve `W_orig @ x = u`, and let | |
# v = x / (u @ W_orig @ x) * (W / W_orig). | |
def __call__(self, state_dict, prefix, local_metadata, strict, | |
missing_keys, unexpected_keys, error_msgs): | |
fn = self.fn | |
version = local_metadata.get('spectral_norm', | |
{}).get(fn.name + '.version', None) | |
if version is None or version < 1: | |
with torch.no_grad(): | |
weight_orig = state_dict[prefix + fn.name + '_orig'] | |
# weight = state_dict.pop(prefix + fn.name) | |
# sigma = (weight_orig / weight).mean() | |
weight_mat = fn.reshape_weight_to_matrix(weight_orig) | |
u = state_dict[prefix + fn.name + '_u'] | |
# v = fn._solve_v_and_rescale(weight_mat, u, sigma) | |
# state_dict[prefix + fn.name + '_v'] = v | |
# This is a top level class because Py2 pickle doesn't like inner class nor an | |
# instancemethod. | |
class SpectralNormStateDictHook(object): | |
# See docstring of SpectralNorm._version on the changes to spectral_norm. | |
def __init__(self, fn): | |
self.fn = fn | |
def __call__(self, module, state_dict, prefix, local_metadata): | |
if 'spectral_norm' not in local_metadata: | |
local_metadata['spectral_norm'] = {} | |
key = self.fn.name + '.version' | |
if key in local_metadata['spectral_norm']: | |
raise RuntimeError( | |
"Unexpected key in metadata['spectral_norm']: {}".format(key)) | |
local_metadata['spectral_norm'][key] = self.fn._version | |
def spectral_norm(module, | |
name='weight', | |
n_power_iterations=1, | |
eps=1e-12, | |
dim=None): | |
r"""Applies spectral normalization to a parameter in the given module. | |
.. math:: | |
\mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})}, | |
\sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2} | |
Spectral normalization stabilizes the training of discriminators (critics) | |
in Generative Adversarial Networks (GANs) by rescaling the weight tensor | |
with spectral norm :math:`\sigma` of the weight matrix calculated using | |
power iteration method. If the dimension of the weight tensor is greater | |
than 2, it is reshaped to 2D in power iteration method to get spectral | |
norm. This is implemented via a hook that calculates spectral norm and | |
rescales weight before every :meth:`~Module.forward` call. | |
See `Spectral Normalization for Generative Adversarial Networks`_ . | |
.. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957 | |
Args: | |
module (nn.Module): containing module | |
name (str, optional): name of weight parameter | |
n_power_iterations (int, optional): number of power iterations to | |
calculate spectral norm | |
eps (float, optional): epsilon for numerical stability in | |
calculating norms | |
dim (int, optional): dimension corresponding to number of outputs, | |
the default is ``0``, except for modules that are instances of | |
ConvTranspose{1,2,3}d, when it is ``1`` | |
Returns: | |
The original module with the spectral norm hook | |
Example:: | |
>>> m = spectral_norm(nn.Linear(20, 40)) | |
>>> m | |
Linear(in_features=20, out_features=40, bias=True) | |
>>> m.weight_u.size() | |
torch.Size([40]) | |
""" | |
if dim is None: | |
if isinstance(module, | |
(torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, | |
torch.nn.ConvTranspose3d)): | |
dim = 1 | |
else: | |
dim = 0 | |
SpectralNorm.apply(module, name, n_power_iterations, dim, eps) | |
return module | |
def remove_spectral_norm(module, name='weight'): | |
r"""Removes the spectral normalization reparameterization from a module. | |
Args: | |
module (Module): containing module | |
name (str, optional): name of weight parameter | |
Example: | |
>>> m = spectral_norm(nn.Linear(40, 10)) | |
>>> remove_spectral_norm(m) | |
""" | |
for k, hook in module._forward_pre_hooks.items(): | |
if isinstance(hook, SpectralNorm) and hook.name == name: | |
hook.remove(module) | |
del module._forward_pre_hooks[k] | |
return module | |
raise ValueError("spectral_norm of '{}' not found in {}".format( | |
name, module)) | |
def use_spectral_norm(module, use_sn=False): | |
if use_sn: | |
return spectral_norm(module) | |
return module |