Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
from . import encoder, decoder | |
class Generator(nn.Module): | |
def __init__(self, hp, in_channels=1): | |
super().__init__() | |
self.hp = hp | |
_ngf = 64 | |
hidden_dim = _ngf * 4 | |
self.content_encoder = getattr(encoder, self.hp.encoder.content.type)(self.hp, in_channels, hidden_dim) | |
self.style_encoder = getattr(encoder, self.hp.encoder.style.type)(self.hp, in_channels, hidden_dim) | |
self.decoder = getattr(decoder, self.hp.decoder.type)(self.hp, hidden_dim * 2, in_channels) | |
def forward(self, images): | |
content_images, style_images = images | |
content_feature = self.content_encoder(content_images) | |
style_images = style_images * 2 - 1 # pixel value range -1 to 1 | |
style_feature = self.style_encoder(style_images) # K-shot as batch | |
_, _, H, W = content_feature.size() | |
out = self.decoder(torch.cat([content_feature, style_feature.expand(-1, -1, H, W)], dim=1)) | |
return out |