|
import torch |
|
import torchvision |
|
from torch import nn |
|
from PIL import Image |
|
import numpy as np |
|
import os |
|
|
|
|
|
|
|
class ResBlock(nn.Module): |
|
def __init__(self, channels): |
|
super(ResBlock, self).__init__() |
|
|
|
self.resblock = nn.Sequential( |
|
nn.ReflectionPad2d(1), |
|
nn.Conv2d(channels, channels, kernel_size=3), |
|
nn.InstanceNorm2d(channels, affine=True), |
|
nn.ReLU(), |
|
nn.ReflectionPad2d(1), |
|
nn.Conv2d(channels, channels, kernel_size=3), |
|
nn.InstanceNorm2d(channels, affine=True), |
|
) |
|
|
|
def forward(self, x): |
|
out = self.resblock(x) |
|
return out + x |
|
|
|
|
|
class Upsample2d(nn.Module): |
|
def __init__(self, scale_factor): |
|
super(Upsample2d, self).__init__() |
|
|
|
self.interp = nn.functional.interpolate |
|
self.scale_factor = scale_factor |
|
|
|
def forward(self, x): |
|
x = self.interp(x, scale_factor=self.scale_factor, mode='nearest') |
|
return x |
|
|
|
|
|
class MicroResNet(nn.Module): |
|
def __init__(self): |
|
super(MicroResNet, self).__init__() |
|
|
|
self.downsampler = nn.Sequential( |
|
nn.ReflectionPad2d(4), |
|
nn.Conv2d(3, 8, kernel_size=9, stride=4), |
|
nn.InstanceNorm2d(8, affine=True), |
|
nn.ReLU(), |
|
nn.ReflectionPad2d(1), |
|
nn.Conv2d(8, 16, kernel_size=3, stride=2), |
|
nn.InstanceNorm2d(16, affine=True), |
|
nn.ReLU(), |
|
nn.ReflectionPad2d(1), |
|
nn.Conv2d(16, 32, kernel_size=3, stride=2), |
|
nn.InstanceNorm2d(32, affine=True), |
|
nn.ReLU(), |
|
) |
|
|
|
self.residual = nn.Sequential( |
|
ResBlock(32), |
|
nn.Conv2d(32, 64, kernel_size=1, bias=False, groups=32), |
|
ResBlock(64), |
|
) |
|
|
|
self.segmentator = nn.Sequential( |
|
nn.ReflectionPad2d(1), |
|
nn.Conv2d(64, 16, kernel_size=3), |
|
nn.InstanceNorm2d(16, affine=True), |
|
nn.ReLU(), |
|
Upsample2d(scale_factor=2), |
|
nn.ReflectionPad2d(4), |
|
nn.Conv2d(16, 1, kernel_size=9), |
|
nn.Sigmoid() |
|
) |
|
|
|
def forward(self, x): |
|
out = self.downsampler(x) |
|
out = self.residual(out) |
|
out = self.segmentator(out) |
|
return out |
|
|