ArchitSharma's picture
Upload 16 files
c716076
raw
history blame
1.27 kB
from fastai.core import *
from fastai.torch_core import *
from fastai.vision import *
from fastai.vision.gan import AdaptiveLoss, accuracy_thresh_expand
_conv_args = dict(leaky=0.2, norm_type=NormType.Spectral)
def _conv(ni: int, nf: int, ks: int = 3, stride: int = 1, **kwargs):
return conv_layer(ni, nf, ks=ks, stride=stride, **_conv_args, **kwargs)
def custom_gan_critic(
n_channels: int = 3, nf: int = 256, n_blocks: int = 3, p: int = 0.15
):
"Critic to train a `GAN`."
layers = [_conv(n_channels, nf, ks=4, stride=2), nn.Dropout2d(p / 2)]
for i in range(n_blocks):
layers += [
_conv(nf, nf, ks=3, stride=1),
nn.Dropout2d(p),
_conv(nf, nf * 2, ks=4, stride=2, self_attention=(i == 0)),
]
nf *= 2
layers += [
_conv(nf, nf, ks=3, stride=1),
_conv(nf, 1, ks=4, bias=False, padding=0, use_activ=False),
Flatten(),
]
return nn.Sequential(*layers)
def colorize_crit_learner(
data: ImageDataBunch,
loss_critic=AdaptiveLoss(nn.BCEWithLogitsLoss()),
nf: int = 256,
) -> Learner:
return Learner(
data,
custom_gan_critic(nf=nf),
metrics=accuracy_thresh_expand,
loss_func=loss_critic,
wd=1e-3,
)