Spaces:
Running
Running
File size: 2,619 Bytes
9390e2c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
from torch import nn
class Conv(nn.Module):
def __init__(self, inp_dim, out_dim, kernel_size=3, stride=1, bn=False, relu=True):
super(Conv, self).__init__()
self.inp_dim = inp_dim
self.conv = nn.Conv2d(inp_dim, out_dim, kernel_size, stride, padding=(kernel_size - 1) // 2, bias=False)
self.relu = None
self.bn = None
if relu:
self.relu = nn.ReLU()
if bn:
self.bn = nn.BatchNorm2d(out_dim)
def forward(self, x):
assert x.size()[1] == self.inp_dim, "{} {}".format(x.size()[1], self.inp_dim)
x = self.conv(x)
if self.bn is not None:
x = self.bn(x)
if self.relu is not None:
x = self.relu(x)
return x
class Deconv(nn.Module):
def __init__(self, inp_dim, out_dim, kernel_size=3, stride=1, bn=False, relu=True):
super(Deconv, self).__init__()
self.inp_dim = inp_dim
self.deconv = nn.ConvTranspose2d(inp_dim, out_dim, kernel_size=kernel_size, stride=stride, bias=False)
self.relu = None
self.bn = None
if relu:
self.relu = nn.ReLU()
if bn:
self.bn = nn.BatchNorm2d(out_dim)
def forward(self, x):
assert x.size()[1] == self.inp_dim, "{} {}".format(x.size()[1], self.inp_dim)
x = self.deconv(x)
if self.bn is not None:
x = self.bn(x)
if self.relu is not None:
x = self.relu(x)
return x
class Residual(nn.Module):
def __init__(self, inp_dim, out_dim, kernel=3):
super(Residual, self).__init__()
self.relu = nn.ReLU()
self.bn1 = nn.BatchNorm2d(inp_dim)
self.conv1 = Conv(inp_dim, int(out_dim / 2), 1, relu=False)
self.bn2 = nn.BatchNorm2d(int(out_dim / 2))
self.conv2 = Conv(int(out_dim / 2), int(out_dim / 2), kernel, relu=False)
self.bn3 = nn.BatchNorm2d(int(out_dim / 2))
self.conv3 = Conv(int(out_dim / 2), out_dim, 1, relu=False)
self.skip_layer = Conv(inp_dim, out_dim, 1, relu=False)
if inp_dim == out_dim:
self.need_skip = False
else:
self.need_skip = True
def forward(self, x):
if self.need_skip:
residual = self.skip_layer(x)
else:
residual = x
out = self.bn1(x)
out = self.relu(out)
out = self.conv1(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn3(out)
out = self.relu(out)
out = self.conv3(out)
out += residual
return out
|