Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
class AdaIN(nn.Module): | |
def __init__(self): | |
super().__init__() | |
def forward(self, x, y): | |
ch = y.size(1) | |
sigma, mu = torch.split(y.unsqueeze(-1).unsqueeze(-1), [ch // 2, ch // 2], dim=1) | |
x_mu = x.mean(dim=[2, 3], keepdim=True) | |
x_sigma = x.std(dim=[2, 3], keepdim=True) | |
return sigma * ((x - x_mu) / x_sigma) + mu | |