Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
import logging | |
from copy import deepcopy | |
from typing import Callable, Dict, List, Optional, Tuple, Union | |
from einops import rearrange | |
import fvcore.nn.weight_init as weight_init | |
from torch import nn | |
from torch.nn import functional as F | |
from detectron2.config import configurable | |
from detectron2.layers import Conv2d, ShapeSpec, get_norm | |
from detectron2.modeling import SEM_SEG_HEADS_REGISTRY | |
from ..transformer.cat_seg_predictor import CATSegPredictor | |
class CATSegHead(nn.Module): | |
def __init__( | |
self, | |
input_shape: Dict[str, ShapeSpec], | |
*, | |
num_classes: int, | |
ignore_value: int = -1, | |
# extra parameters | |
feature_resolution: list, | |
transformer_predictor: nn.Module, | |
): | |
""" | |
NOTE: this interface is experimental. | |
Args: | |
input_shape: shapes (channels and stride) of the input features | |
num_classes: number of classes to predict | |
pixel_decoder: the pixel decoder module | |
loss_weight: loss weight | |
ignore_value: category id to be ignored during training. | |
transformer_predictor: the transformer decoder that makes prediction | |
transformer_in_feature: input feature name to the transformer_predictor | |
""" | |
super().__init__() | |
input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) | |
self.in_features = [k for k, v in input_shape] | |
self.ignore_value = ignore_value | |
self.predictor = transformer_predictor | |
self.num_classes = num_classes | |
self.feature_resolution = feature_resolution | |
def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): | |
return { | |
"input_shape": { | |
k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES | |
}, | |
"ignore_value": cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE, | |
"num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES, | |
"feature_resolution": cfg.MODEL.SEM_SEG_HEAD.FEATURE_RESOLUTION, | |
"transformer_predictor": CATSegPredictor( | |
cfg, | |
), | |
} | |
def forward(self, features, guidance_features): | |
""" | |
Arguments: | |
img_feats: (B, C, HW) | |
affinity_features: (B, C, ) | |
""" | |
img_feat = rearrange(features[:, 1:, :], "b (h w) c->b c h w", h=self.feature_resolution[0], w=self.feature_resolution[1]) | |
return self.predictor(img_feat, guidance_features) |