File size: 4,934 Bytes
1086a9c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
# From https://github.com/Koushik0901/Swift-SRGAN/blob/master/swift-srgan/models.py
import torch
from torch import nn
class SeperableConv2d(nn.Module):
def __init__(
self, in_channels, out_channels, kernel_size, stride=1, padding=1, bias=True
):
super(SeperableConv2d, self).__init__()
self.depthwise = nn.Conv2d(
in_channels,
in_channels,
kernel_size=kernel_size,
stride=stride,
groups=in_channels,
bias=bias,
padding=padding,
)
self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=bias)
def forward(self, x):
return self.pointwise(self.depthwise(x))
class ConvBlock(nn.Module):
def __init__(
self,
in_channels,
out_channels,
use_act=True,
use_bn=True,
discriminator=False,
**kwargs,
):
super(ConvBlock, self).__init__()
self.use_act = use_act
self.cnn = SeperableConv2d(in_channels, out_channels, **kwargs, bias=not use_bn)
self.bn = nn.BatchNorm2d(out_channels) if use_bn else nn.Identity()
self.act = (
nn.LeakyReLU(0.2, inplace=True)
if discriminator
else nn.PReLU(num_parameters=out_channels)
)
def forward(self, x):
return self.act(self.bn(self.cnn(x))) if self.use_act else self.bn(self.cnn(x))
class UpsampleBlock(nn.Module):
def __init__(self, in_channels, scale_factor):
super(UpsampleBlock, self).__init__()
self.conv = SeperableConv2d(
in_channels,
in_channels * scale_factor**2,
kernel_size=3,
stride=1,
padding=1,
)
self.ps = nn.PixelShuffle(
scale_factor
) # (in_channels * 4, H, W) -> (in_channels, H*2, W*2)
self.act = nn.PReLU(num_parameters=in_channels)
def forward(self, x):
return self.act(self.ps(self.conv(x)))
class ResidualBlock(nn.Module):
def __init__(self, in_channels):
super(ResidualBlock, self).__init__()
self.block1 = ConvBlock(
in_channels, in_channels, kernel_size=3, stride=1, padding=1
)
self.block2 = ConvBlock(
in_channels, in_channels, kernel_size=3, stride=1, padding=1, use_act=False
)
def forward(self, x):
out = self.block1(x)
out = self.block2(out)
return out + x
class Generator(nn.Module):
"""Swift-SRGAN Generator
Args:
in_channels (int): number of input image channels.
num_channels (int): number of hidden channels.
num_blocks (int): number of residual blocks.
upscale_factor (int): factor to upscale the image [2x, 4x, 8x].
Returns:
torch.Tensor: super resolution image
"""
def __init__(
self,
state_dict,
):
super(Generator, self).__init__()
self.model_arch = "Swift-SRGAN"
self.sub_type = "SR"
self.state = state_dict
if "model" in self.state:
self.state = self.state["model"]
self.in_nc: int = self.state["initial.cnn.depthwise.weight"].shape[0]
self.out_nc: int = self.state["final_conv.pointwise.weight"].shape[0]
self.num_filters: int = self.state["initial.cnn.pointwise.weight"].shape[0]
self.num_blocks = len(
set([x.split(".")[1] for x in self.state.keys() if "residual" in x])
)
self.scale: int = 2 ** len(
set([x.split(".")[1] for x in self.state.keys() if "upsampler" in x])
)
in_channels = self.in_nc
num_channels = self.num_filters
num_blocks = self.num_blocks
upscale_factor = self.scale
self.supports_fp16 = True
self.supports_bfp16 = True
self.min_size_restriction = None
self.initial = ConvBlock(
in_channels, num_channels, kernel_size=9, stride=1, padding=4, use_bn=False
)
self.residual = nn.Sequential(
*[ResidualBlock(num_channels) for _ in range(num_blocks)]
)
self.convblock = ConvBlock(
num_channels,
num_channels,
kernel_size=3,
stride=1,
padding=1,
use_act=False,
)
self.upsampler = nn.Sequential(
*[
UpsampleBlock(num_channels, scale_factor=2)
for _ in range(upscale_factor // 2)
]
)
self.final_conv = SeperableConv2d(
num_channels, in_channels, kernel_size=9, stride=1, padding=4
)
self.load_state_dict(self.state, strict=False)
def forward(self, x):
initial = self.initial(x)
x = self.residual(initial)
x = self.convblock(x) + initial
x = self.upsampler(x)
return (torch.tanh(self.final_conv(x)) + 1) / 2
|