|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from climategan.blocks import BaseDecoder, Conv2dBlock, InterpolateNearest2d |
|
from climategan.utils import find_target_size |
|
|
|
|
|
def create_depth_decoder(opts, no_init=False, verbose=0): |
|
if opts.gen.d.architecture == "base": |
|
decoder = BaseDepthDecoder(opts) |
|
if "s" in opts.task: |
|
assert opts.gen.s.use_dada is False |
|
if "m" in opts.tasks: |
|
assert opts.gen.m.use_dada is False |
|
else: |
|
decoder = DADADepthDecoder(opts) |
|
|
|
if verbose > 0: |
|
print(f" - Add {decoder.__class__.__name__}") |
|
|
|
return decoder |
|
|
|
|
|
class DADADepthDecoder(nn.Module): |
|
""" |
|
Depth decoder based on depth auxiliary task in DADA paper |
|
""" |
|
|
|
def __init__(self, opts): |
|
super().__init__() |
|
if ( |
|
opts.gen.encoder.architecture == "deeplabv3" |
|
and opts.gen.deeplabv3.backbone == "mobilenet" |
|
): |
|
res_dim = 320 |
|
else: |
|
res_dim = 2048 |
|
|
|
mid_dim = 512 |
|
|
|
self.do_feat_fusion = False |
|
if opts.gen.m.use_dada or ("s" in opts.tasks and opts.gen.s.use_dada): |
|
self.do_feat_fusion = True |
|
self.dec4 = Conv2dBlock( |
|
128, |
|
res_dim, |
|
1, |
|
stride=1, |
|
padding=0, |
|
bias=True, |
|
activation="lrelu", |
|
norm="none", |
|
) |
|
|
|
self.relu = nn.ReLU(inplace=True) |
|
self.enc4_1 = Conv2dBlock( |
|
res_dim, |
|
mid_dim, |
|
1, |
|
stride=1, |
|
padding=0, |
|
bias=False, |
|
activation="lrelu", |
|
pad_type="reflect", |
|
norm="batch", |
|
) |
|
self.enc4_2 = Conv2dBlock( |
|
mid_dim, |
|
mid_dim, |
|
3, |
|
stride=1, |
|
padding=1, |
|
bias=False, |
|
activation="lrelu", |
|
pad_type="reflect", |
|
norm="batch", |
|
) |
|
self.enc4_3 = Conv2dBlock( |
|
mid_dim, |
|
128, |
|
1, |
|
stride=1, |
|
padding=0, |
|
bias=False, |
|
activation="lrelu", |
|
pad_type="reflect", |
|
norm="batch", |
|
) |
|
self.upsample = None |
|
if opts.gen.d.upsample_featuremaps: |
|
self.upsample = nn.Sequential( |
|
*[ |
|
InterpolateNearest2d(), |
|
Conv2dBlock( |
|
128, |
|
32, |
|
3, |
|
stride=1, |
|
padding=1, |
|
bias=False, |
|
activation="lrelu", |
|
pad_type="reflect", |
|
norm="batch", |
|
), |
|
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), |
|
] |
|
) |
|
self._target_size = find_target_size(opts, "d") |
|
print( |
|
" - {}: setting target size to {}".format( |
|
self.__class__.__name__, self._target_size |
|
) |
|
) |
|
|
|
def set_target_size(self, size): |
|
""" |
|
Set final interpolation's target size |
|
|
|
Args: |
|
size (int, list, tuple): target size (h, w). If int, target will be (i, i) |
|
""" |
|
if isinstance(size, (list, tuple)): |
|
self._target_size = size[:2] |
|
else: |
|
self._target_size = (size, size) |
|
|
|
def forward(self, z): |
|
if isinstance(z, (list, tuple)): |
|
z = z[0] |
|
z4_enc = self.enc4_1(z) |
|
z4_enc = self.enc4_2(z4_enc) |
|
z4_enc = self.enc4_3(z4_enc) |
|
|
|
z_depth = None |
|
if self.do_feat_fusion: |
|
z_depth = self.dec4(z4_enc) |
|
|
|
if self.upsample is not None: |
|
z4_enc = self.upsample(z4_enc) |
|
|
|
depth = torch.mean(z4_enc, dim=1, keepdim=True) |
|
if depth.shape[-1] != self._target_size: |
|
depth = F.interpolate( |
|
depth, |
|
size=(384, 384), |
|
mode="bicubic", |
|
align_corners=False, |
|
) |
|
|
|
depth = F.interpolate( |
|
depth, (self._target_size, self._target_size), mode="nearest" |
|
) |
|
|
|
return depth, z_depth |
|
|
|
def __str__(self): |
|
return "DADA Depth Decoder" |
|
|
|
|
|
class BaseDepthDecoder(BaseDecoder): |
|
def __init__(self, opts): |
|
low_level_feats_dim = -1 |
|
use_v3 = opts.gen.encoder.architecture == "deeplabv3" |
|
use_mobile_net = opts.gen.deeplabv3.backbone == "mobilenet" |
|
use_low = opts.gen.d.use_low_level_feats |
|
|
|
if use_v3 and use_mobile_net: |
|
input_dim = 320 |
|
if use_low: |
|
low_level_feats_dim = 24 |
|
elif use_v3: |
|
input_dim = 2048 |
|
if use_low: |
|
low_level_feats_dim = 256 |
|
else: |
|
input_dim = 2048 |
|
|
|
n_upsample = 1 if opts.gen.d.upsample_featuremaps else 0 |
|
output_dim = ( |
|
1 |
|
if not opts.gen.d.classify.enable |
|
else opts.gen.d.classify.linspace.buckets |
|
) |
|
|
|
self._target_size = find_target_size(opts, "d") |
|
print( |
|
" - {}: setting target size to {}".format( |
|
self.__class__.__name__, self._target_size |
|
) |
|
) |
|
|
|
super().__init__( |
|
n_upsample=n_upsample, |
|
n_res=opts.gen.d.n_res, |
|
input_dim=input_dim, |
|
proj_dim=opts.gen.d.proj_dim, |
|
output_dim=output_dim, |
|
norm=opts.gen.d.norm, |
|
activ=opts.gen.d.activ, |
|
pad_type=opts.gen.d.pad_type, |
|
output_activ="none", |
|
low_level_feats_dim=low_level_feats_dim, |
|
) |
|
|
|
def set_target_size(self, size): |
|
""" |
|
Set final interpolation's target size |
|
|
|
Args: |
|
size (int, list, tuple): target size (h, w). If int, target will be (i, i) |
|
""" |
|
if isinstance(size, (list, tuple)): |
|
self._target_size = size[:2] |
|
else: |
|
self._target_size = (size, size) |
|
|
|
def forward(self, z, cond=None): |
|
if self._target_size is None: |
|
error = "self._target_size should be set with self.set_target_size()" |
|
error += "to interpolate depth to the target depth map's size" |
|
raise ValueError(error) |
|
|
|
d = super().forward(z) |
|
|
|
preds = F.interpolate( |
|
d, size=self._target_size, mode="bilinear", align_corners=True |
|
) |
|
|
|
return preds, None |
|
|