Spaces:
Runtime error
Runtime error
File size: 5,112 Bytes
0cb9530 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
from fastai import *
from fastai.core import *
from fastai.torch_core import *
from fastai.callbacks import hook_outputs
import torchvision.models as models
class FeatureLoss(nn.Module):
def __init__(self, layer_wgts=[20, 70, 10]):
super().__init__()
self.m_feat = models.vgg16_bn(True).features.cuda().eval()
requires_grad(self.m_feat, False)
blocks = [
i - 1
for i, o in enumerate(children(self.m_feat))
if isinstance(o, nn.MaxPool2d)
]
layer_ids = blocks[2:5]
self.loss_features = [self.m_feat[i] for i in layer_ids]
self.hooks = hook_outputs(self.loss_features, detach=False)
self.wgts = layer_wgts
self.metric_names = ['pixel'] + [f'feat_{i}' for i in range(len(layer_ids))]
self.base_loss = F.l1_loss
def _make_features(self, x, clone=False):
self.m_feat(x)
return [(o.clone() if clone else o) for o in self.hooks.stored]
def forward(self, input, target):
out_feat = self._make_features(target, clone=True)
in_feat = self._make_features(input)
self.feat_losses = [self.base_loss(input, target)]
self.feat_losses += [
self.base_loss(f_in, f_out) * w
for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)
]
self.metrics = dict(zip(self.metric_names, self.feat_losses))
return sum(self.feat_losses)
def __del__(self):
self.hooks.remove()
# Refactored code, originally from https://github.com/VinceMarron/style_transfer
class WassFeatureLoss(nn.Module):
def __init__(self, layer_wgts=[5, 15, 2], wass_wgts=[3.0, 0.7, 0.01]):
super().__init__()
self.m_feat = models.vgg16_bn(True).features.cuda().eval()
requires_grad(self.m_feat, False)
blocks = [
i - 1
for i, o in enumerate(children(self.m_feat))
if isinstance(o, nn.MaxPool2d)
]
layer_ids = blocks[2:5]
self.loss_features = [self.m_feat[i] for i in layer_ids]
self.hooks = hook_outputs(self.loss_features, detach=False)
self.wgts = layer_wgts
self.wass_wgts = wass_wgts
self.metric_names = (
['pixel']
+ [f'feat_{i}' for i in range(len(layer_ids))]
+ [f'wass_{i}' for i in range(len(layer_ids))]
)
self.base_loss = F.l1_loss
def _make_features(self, x, clone=False):
self.m_feat(x)
return [(o.clone() if clone else o) for o in self.hooks.stored]
def _calc_2_moments(self, tensor):
chans = tensor.shape[1]
tensor = tensor.view(1, chans, -1)
n = tensor.shape[2]
mu = tensor.mean(2)
tensor = (tensor - mu[:, :, None]).squeeze(0)
# Prevents nasty bug that happens very occassionally- divide by zero. Why such things happen?
if n == 0:
return None, None
cov = torch.mm(tensor, tensor.t()) / float(n)
return mu, cov
def _get_style_vals(self, tensor):
mean, cov = self._calc_2_moments(tensor)
if mean is None:
return None, None, None
eigvals, eigvects = torch.symeig(cov, eigenvectors=True)
eigroot_mat = torch.diag(torch.sqrt(eigvals.clamp(min=0)))
root_cov = torch.mm(torch.mm(eigvects, eigroot_mat), eigvects.t())
tr_cov = eigvals.clamp(min=0).sum()
return mean, tr_cov, root_cov
def _calc_l2wass_dist(
self, mean_stl, tr_cov_stl, root_cov_stl, mean_synth, cov_synth
):
tr_cov_synth = torch.symeig(cov_synth, eigenvectors=True)[0].clamp(min=0).sum()
mean_diff_squared = (mean_stl - mean_synth).pow(2).sum()
cov_prod = torch.mm(torch.mm(root_cov_stl, cov_synth), root_cov_stl)
var_overlap = torch.sqrt(
torch.symeig(cov_prod, eigenvectors=True)[0].clamp(min=0) + 1e-8
).sum()
dist = mean_diff_squared + tr_cov_stl + tr_cov_synth - 2 * var_overlap
return dist
def _single_wass_loss(self, pred, targ):
mean_test, tr_cov_test, root_cov_test = targ
mean_synth, cov_synth = self._calc_2_moments(pred)
loss = self._calc_l2wass_dist(
mean_test, tr_cov_test, root_cov_test, mean_synth, cov_synth
)
return loss
def forward(self, input, target):
out_feat = self._make_features(target, clone=True)
in_feat = self._make_features(input)
self.feat_losses = [self.base_loss(input, target)]
self.feat_losses += [
self.base_loss(f_in, f_out) * w
for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)
]
styles = [self._get_style_vals(i) for i in out_feat]
if styles[0][0] is not None:
self.feat_losses += [
self._single_wass_loss(f_pred, f_targ) * w
for f_pred, f_targ, w in zip(in_feat, styles, self.wass_wgts)
]
self.metrics = dict(zip(self.metric_names, self.feat_losses))
return sum(self.feat_losses)
def __del__(self):
self.hooks.remove()
|