Spaces:
Sleeping
Sleeping
""" | |
Source url: https://github.com/Karel911/TRACER | |
Author: Min Seok Lee and Wooseok Shin | |
Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. | |
License: Apache License 2.0 | |
Changes: | |
- Refactored code | |
- Removed unused code | |
- Added comments | |
""" | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from typing import List, Optional, Tuple | |
from torch import Tensor | |
from carvekit.ml.arch.tracerb7.efficientnet import EfficientEncoderB7 | |
from carvekit.ml.arch.tracerb7.att_modules import ( | |
RFB_Block, | |
aggregation, | |
ObjectAttention, | |
) | |
class TracerDecoder(nn.Module): | |
"""Tracer Decoder""" | |
def __init__( | |
self, | |
encoder: EfficientEncoderB7, | |
features_channels: Optional[List[int]] = None, | |
rfb_channel: Optional[List[int]] = None, | |
): | |
""" | |
Initialize the tracer decoder. | |
Args: | |
encoder: The encoder to use. | |
features_channels: The channels of the backbone features at different stages. default: [48, 80, 224, 640] | |
rfb_channel: The channels of the RFB features. default: [32, 64, 128] | |
""" | |
super().__init__() | |
if rfb_channel is None: | |
rfb_channel = [32, 64, 128] | |
if features_channels is None: | |
features_channels = [48, 80, 224, 640] | |
self.encoder = encoder | |
self.features_channels = features_channels | |
# Receptive Field Blocks | |
features_channels = rfb_channel | |
self.rfb2 = RFB_Block(self.features_channels[1], features_channels[0]) | |
self.rfb3 = RFB_Block(self.features_channels[2], features_channels[1]) | |
self.rfb4 = RFB_Block(self.features_channels[3], features_channels[2]) | |
# Multi-level aggregation | |
self.agg = aggregation(features_channels) | |
# Object Attention | |
self.ObjectAttention2 = ObjectAttention( | |
channel=self.features_channels[1], kernel_size=3 | |
) | |
self.ObjectAttention1 = ObjectAttention( | |
channel=self.features_channels[0], kernel_size=3 | |
) | |
def forward(self, inputs: torch.Tensor) -> Tensor: | |
""" | |
Forward pass of the tracer decoder. | |
Args: | |
inputs: Preprocessed images. | |
Returns: | |
Tensors of segmentation masks and mask of object edges. | |
""" | |
features = self.encoder(inputs) | |
x3_rfb = self.rfb2(features[1]) | |
x4_rfb = self.rfb3(features[2]) | |
x5_rfb = self.rfb4(features[3]) | |
D_0 = self.agg(x5_rfb, x4_rfb, x3_rfb) | |
ds_map0 = F.interpolate(D_0, scale_factor=8, mode="bilinear") | |
D_1 = self.ObjectAttention2(D_0, features[1]) | |
ds_map1 = F.interpolate(D_1, scale_factor=8, mode="bilinear") | |
ds_map = F.interpolate(D_1, scale_factor=2, mode="bilinear") | |
D_2 = self.ObjectAttention1(ds_map, features[0]) | |
ds_map2 = F.interpolate(D_2, scale_factor=4, mode="bilinear") | |
final_map = (ds_map2 + ds_map1 + ds_map0) / 3 | |
return torch.sigmoid(final_map) | |