full_gaussian_avatar / GHA /lib /module /SuperResolutionModule.py
pengc02's picture
all
ec9a6bc
import torch
from torch import nn
from einops import rearrange
from GHA.lib.network.Upsampler import Upsampler
class SuperResolutionModule(nn.Module):
def __init__(self, cfg):
super(SuperResolutionModule, self).__init__()
self.upsampler = Upsampler(cfg.input_dim, cfg.output_dim, cfg.network_capacity)
def forward(self, input):
output = self.upsampler(input)
return output