Realcat
add: mast3r
fb98d2a
raw
history blame
2.65 kB
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
# --------------------------------------------------------
# Heads for downstream tasks
# --------------------------------------------------------
"""
A head is a module where the __init__ defines only the head hyperparameters.
A method setup(croconet) takes a CroCoNet and set all layers according to the head and croconet attributes.
The forward takes the features as well as a dictionary img_info containing the keys 'width' and 'height'
"""
import torch
import torch.nn as nn
from .dpt_block import DPTOutputAdapter
class PixelwiseTaskWithDPT(nn.Module):
""" DPT module for CroCo.
by default, hooks_idx will be equal to:
* for encoder-only: 4 equally spread layers
* for encoder+decoder: last encoder + 3 equally spread layers of the decoder
"""
def __init__(self, *, hooks_idx=None, layer_dims=[96,192,384,768],
output_width_ratio=1, num_channels=1, postprocess=None, **kwargs):
super(PixelwiseTaskWithDPT, self).__init__()
self.return_all_blocks = True # backbone needs to return all layers
self.postprocess = postprocess
self.output_width_ratio = output_width_ratio
self.num_channels = num_channels
self.hooks_idx = hooks_idx
self.layer_dims = layer_dims
def setup(self, croconet):
dpt_args = {'output_width_ratio': self.output_width_ratio, 'num_channels': self.num_channels}
if self.hooks_idx is None:
if hasattr(croconet, 'dec_blocks'): # encoder + decoder
step = {8: 3, 12: 4, 24: 8}[croconet.dec_depth]
hooks_idx = [croconet.dec_depth+croconet.enc_depth-1-i*step for i in range(3,-1,-1)]
else: # encoder only
step = croconet.enc_depth//4
hooks_idx = [croconet.enc_depth-1-i*step for i in range(3,-1,-1)]
self.hooks_idx = hooks_idx
print(f' PixelwiseTaskWithDPT: automatically setting hook_idxs={self.hooks_idx}')
dpt_args['hooks'] = self.hooks_idx
dpt_args['layer_dims'] = self.layer_dims
self.dpt = DPTOutputAdapter(**dpt_args)
dim_tokens = [croconet.enc_embed_dim if hook<croconet.enc_depth else croconet.dec_embed_dim for hook in self.hooks_idx]
dpt_init_args = {'dim_tokens_enc': dim_tokens}
self.dpt.init(**dpt_init_args)
def forward(self, x, img_info):
out = self.dpt(x, image_size=(img_info['height'],img_info['width']))
if self.postprocess: out = self.postprocess(out)
return out