Spaces:
Runtime error
Runtime error
from typing import Tuple | |
import torch | |
from torch import nn, Tensor | |
from torch.nn import functional as F | |
class EncoderBlock(nn.Module): | |
def __init__(self, in_channels: int, out_channels: int) -> None: | |
super().__init__() | |
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=(2, 2)) | |
self.bn = nn.BatchNorm2d( | |
num_features=out_channels, | |
track_running_stats=True, | |
eps=1e-3, | |
momentum=0.01, | |
) | |
self.relu = nn.LeakyReLU(negative_slope=0.2) | |
def forward(self, input: Tensor) -> Tuple[Tensor, Tensor]: | |
down = self.conv(F.pad(input, (1, 2, 1, 2), "constant", 0)) | |
return down, self.relu(self.bn(down)) | |
class DecoderBlock(nn.Module): | |
def __init__( | |
self, in_channels: int, out_channels: int, dropout_prob: float = 0.0 | |
) -> None: | |
super().__init__() | |
self.tconv = nn.ConvTranspose2d( | |
in_channels, out_channels, kernel_size=5, stride=2 | |
) | |
self.relu = nn.ReLU() | |
self.bn = nn.BatchNorm2d( | |
out_channels, track_running_stats=True, eps=1e-3, momentum=0.01 | |
) | |
self.dropout = nn.Dropout(dropout_prob) if dropout_prob > 0 else nn.Identity() | |
def forward(self, input: Tensor) -> Tensor: | |
up = self.tconv(input) | |
# reverse padding | |
l, r, t, b = 1, 2, 1, 2 | |
up = up[:, :, l:-r, t:-b] | |
return self.dropout(self.bn(self.relu(up))) | |
class UNet(nn.Module): | |
def __init__( | |
self, | |
n_layers: int = 6, | |
in_channels: int = 1, | |
) -> None: | |
super().__init__() | |
# DownSample layers | |
down_set = [in_channels] + [2 ** (i + 4) for i in range(n_layers)] | |
self.encoder_layers = nn.ModuleList( | |
[ | |
EncoderBlock(in_channels=in_ch, out_channels=out_ch) | |
for in_ch, out_ch in zip(down_set[:-1], down_set[1:]) | |
] | |
) | |
# UpSample layers | |
up_set = [1] + [2 ** (i + 4) for i in range(n_layers)] | |
up_set.reverse() | |
self.decoder_layers = nn.ModuleList( | |
[ | |
DecoderBlock( | |
# doubled for concatenated inputs (skip connections) | |
in_channels=in_ch if i == 0 else in_ch * 2, | |
out_channels=out_ch, | |
# 50 % dropout... first 3 layers only | |
dropout_prob=0.5 if i < 3 else 0, | |
) | |
for i, (in_ch, out_ch) in enumerate(zip(up_set[:-1], up_set[1:])) | |
] | |
) | |
# reconstruct the final mask same as the original channels | |
self.up_final = nn.Conv2d(1, in_channels, kernel_size=4, dilation=2, padding=3) | |
self.sigmoid = nn.Sigmoid() | |
def forward(self, input: Tensor) -> Tensor: | |
encoder_outputs_pre_act = [] | |
x = input | |
for down in self.encoder_layers: | |
conv, x = down(x) | |
encoder_outputs_pre_act.append(conv) | |
for i, up in enumerate(self.decoder_layers): | |
if i == 0: | |
x = up(encoder_outputs_pre_act.pop()) | |
else: | |
# merge skip connection | |
x = up(torch.concat([encoder_outputs_pre_act.pop(), x], axis=1)) | |
mask = self.sigmoid(self.up_final(x)) | |
return mask * input | |