|
|
|
from typing import Tuple |
|
|
|
import torch |
|
from torch import nn |
|
from torch.nn import functional as F |
|
|
|
from detectron2.config import configurable |
|
from detectron2.data import MetadataCatalog |
|
from detectron2.modeling import META_ARCH_REGISTRY, build_backbone, build_sem_seg_head |
|
from detectron2.modeling.backbone import Backbone |
|
from detectron2.modeling.postprocessing import sem_seg_postprocess |
|
from detectron2.structures import ImageList |
|
from detectron2.utils.memory import _ignore_torch_cuda_oom |
|
|
|
from einops import rearrange |
|
|
|
@META_ARCH_REGISTRY.register() |
|
class CATSeg(nn.Module): |
|
@configurable |
|
def __init__( |
|
self, |
|
*, |
|
backbone: Backbone, |
|
sem_seg_head: nn.Module, |
|
size_divisibility: int, |
|
pixel_mean: Tuple[float], |
|
pixel_std: Tuple[float], |
|
clip_pixel_mean: Tuple[float], |
|
clip_pixel_std: Tuple[float], |
|
train_class_json: str, |
|
test_class_json: str, |
|
sliding_window: bool, |
|
clip_finetune: str, |
|
backbone_multiplier: float, |
|
clip_pretrained: str, |
|
): |
|
""" |
|
Args: |
|
backbone: a backbone module, must follow detectron2's backbone interface |
|
sem_seg_head: a module that predicts semantic segmentation from backbone features |
|
""" |
|
super().__init__() |
|
self.backbone = backbone |
|
self.sem_seg_head = sem_seg_head |
|
if size_divisibility < 0: |
|
size_divisibility = self.backbone.size_divisibility |
|
self.size_divisibility = size_divisibility |
|
|
|
self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) |
|
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) |
|
self.register_buffer("clip_pixel_mean", torch.Tensor(clip_pixel_mean).view(-1, 1, 1), False) |
|
self.register_buffer("clip_pixel_std", torch.Tensor(clip_pixel_std).view(-1, 1, 1), False) |
|
|
|
self.train_class_json = train_class_json |
|
self.test_class_json = test_class_json |
|
|
|
self.clip_finetune = clip_finetune |
|
for name, params in self.sem_seg_head.predictor.clip_model.named_parameters(): |
|
if "visual" in name: |
|
if clip_finetune == "prompt": |
|
params.requires_grad = True if "prompt" in name else False |
|
elif clip_finetune == "attention": |
|
params.requires_grad = True if "attn" in name or "position" in name else False |
|
elif clip_finetune == "full": |
|
params.requires_grad = True |
|
else: |
|
params.requires_grad = False |
|
else: |
|
params.requires_grad = False |
|
|
|
finetune_backbone = backbone_multiplier > 0. |
|
for name, params in self.backbone.named_parameters(): |
|
if "norm0" in name: |
|
params.requires_grad = False |
|
else: |
|
params.requires_grad = finetune_backbone |
|
|
|
self.sliding_window = sliding_window |
|
self.clip_resolution = (384, 384) if clip_pretrained == "ViT-B/16" else (336, 336) |
|
self.sequential = False |
|
|
|
@classmethod |
|
def from_config(cls, cfg): |
|
backbone = build_backbone(cfg) |
|
sem_seg_head = build_sem_seg_head(cfg, backbone.output_shape()) |
|
|
|
return { |
|
"backbone": backbone, |
|
"sem_seg_head": sem_seg_head, |
|
"size_divisibility": cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY, |
|
"pixel_mean": cfg.MODEL.PIXEL_MEAN, |
|
"pixel_std": cfg.MODEL.PIXEL_STD, |
|
"clip_pixel_mean": cfg.MODEL.CLIP_PIXEL_MEAN, |
|
"clip_pixel_std": cfg.MODEL.CLIP_PIXEL_STD, |
|
"train_class_json": cfg.MODEL.SEM_SEG_HEAD.TRAIN_CLASS_JSON, |
|
"test_class_json": cfg.MODEL.SEM_SEG_HEAD.TEST_CLASS_JSON, |
|
"sliding_window": cfg.TEST.SLIDING_WINDOW, |
|
"clip_finetune": cfg.MODEL.SEM_SEG_HEAD.CLIP_FINETUNE, |
|
"backbone_multiplier": cfg.SOLVER.BACKBONE_MULTIPLIER, |
|
"clip_pretrained": cfg.MODEL.SEM_SEG_HEAD.CLIP_PRETRAINED, |
|
} |
|
|
|
@property |
|
def device(self): |
|
return self.pixel_mean.device |
|
|
|
def forward(self, batched_inputs): |
|
""" |
|
Args: |
|
batched_inputs: a list, batched outputs of :class:`DatasetMapper`. |
|
Each item in the list contains the inputs for one image. |
|
For now, each item in the list is a dict that contains: |
|
* "image": Tensor, image in (C, H, W) format. |
|
* "instances": per-region ground truth |
|
* Other information that's included in the original dicts, such as: |
|
"height", "width" (int): the output resolution of the model (may be different |
|
from input resolution), used in inference. |
|
Returns: |
|
list[dict]: |
|
each dict has the results for one image. The dict contains the following keys: |
|
|
|
* "sem_seg": |
|
A Tensor that represents the |
|
per-pixel segmentation prediced by the head. |
|
The prediction has shape KxHxW that represents the logits of |
|
each class for each pixel. |
|
""" |
|
images = [x["image"].to(self.device) for x in batched_inputs] |
|
if not self.training and self.sliding_window: |
|
if not self.sequential: |
|
with _ignore_torch_cuda_oom(): |
|
return self.inference_sliding_window(batched_inputs) |
|
self.sequential = True |
|
return self.inference_sliding_window(batched_inputs) |
|
|
|
clip_images = [(x - self.clip_pixel_mean) / self.clip_pixel_std for x in images] |
|
clip_images = ImageList.from_tensors(clip_images, self.size_divisibility) |
|
|
|
images = [(x - self.pixel_mean) / self.pixel_std for x in images] |
|
images = ImageList.from_tensors(images, self.size_divisibility) |
|
|
|
clip_images = F.interpolate(clip_images.tensor, size=self.clip_resolution, mode='bilinear', align_corners=False, ) |
|
clip_features = self.sem_seg_head.predictor.clip_model.encode_image(clip_images, dense=True) |
|
|
|
images_resized = F.interpolate(images.tensor, size=(384, 384), mode='bilinear', align_corners=False,) |
|
features = self.backbone(images_resized) |
|
|
|
outputs = self.sem_seg_head(clip_features, features) |
|
if self.training: |
|
targets = torch.stack([x["sem_seg"].to(self.device) for x in batched_inputs], dim=0) |
|
outputs = F.interpolate(outputs, size=(targets.shape[-2], targets.shape[-1]), mode="bilinear", align_corners=False) |
|
|
|
num_classes = outputs.shape[1] |
|
mask = targets != self.sem_seg_head.ignore_value |
|
|
|
outputs = outputs.permute(0,2,3,1) |
|
_targets = torch.zeros(outputs.shape, device=self.device) |
|
_onehot = F.one_hot(targets[mask], num_classes=num_classes).float() |
|
_targets[mask] = _onehot |
|
|
|
loss = F.binary_cross_entropy_with_logits(outputs, _targets) |
|
losses = {"loss_sem_seg" : loss} |
|
return losses |
|
else: |
|
outputs = outputs.sigmoid() |
|
image_size = images.image_sizes[0] |
|
height = batched_inputs[0].get("height", image_size[0]) |
|
width = batched_inputs[0].get("width", image_size[1]) |
|
|
|
output = sem_seg_postprocess(outputs[0], image_size, height, width) |
|
processed_results = [{'sem_seg': output}] |
|
return processed_results |
|
|
|
|
|
@torch.no_grad() |
|
def inference_sliding_window(self, batched_inputs, kernel=384, overlap=0.333, out_res=[640, 640]): |
|
images = [x["image"].to(self.device, dtype=torch.float32) for x in batched_inputs] |
|
stride = int(kernel * (1 - overlap)) |
|
unfold = nn.Unfold(kernel_size=kernel, stride=stride) |
|
fold = nn.Fold(out_res, kernel_size=kernel, stride=stride) |
|
|
|
image = F.interpolate(images[0].unsqueeze(0), size=out_res, mode='bilinear', align_corners=False).squeeze() |
|
image = rearrange(unfold(image), "(C H W) L-> L C H W", C=3, H=kernel) |
|
global_image = F.interpolate(images[0].unsqueeze(0), size=(kernel, kernel), mode='bilinear', align_corners=False) |
|
image = torch.cat((image, global_image), dim=0) |
|
|
|
images = (image - self.pixel_mean) / self.pixel_std |
|
clip_images = (image - self.clip_pixel_mean) / self.clip_pixel_std |
|
clip_images = F.interpolate(clip_images, size=self.clip_resolution, mode='bilinear', align_corners=False, ) |
|
clip_features = self.sem_seg_head.predictor.clip_model.encode_image(clip_images, dense=True) |
|
|
|
if self.sequential: |
|
outputs = [] |
|
for clip_feat, image in zip(clip_features, images): |
|
feature = self.backbone(image.unsqueeze(0)) |
|
output = self.sem_seg_head(clip_feat.unsqueeze(0), feature) |
|
outputs.append(output[0]) |
|
outputs = torch.stack(outputs, dim=0) |
|
else: |
|
features = self.backbone(images) |
|
outputs = self.sem_seg_head(clip_features, features) |
|
|
|
outputs = F.interpolate(outputs, size=kernel, mode="bilinear", align_corners=False) |
|
outputs = outputs.sigmoid() |
|
|
|
global_output = outputs[-1:] |
|
global_output = F.interpolate(global_output, size=out_res, mode='bilinear', align_corners=False,) |
|
outputs = outputs[:-1] |
|
outputs = fold(outputs.flatten(1).T) / fold(unfold(torch.ones([1] + out_res, device=self.device))) |
|
outputs = (outputs + global_output) / 2. |
|
|
|
height = batched_inputs[0].get("height", out_res[0]) |
|
width = batched_inputs[0].get("width", out_res[1]) |
|
output = sem_seg_postprocess(outputs, out_res, height, width) |
|
return [{'sem_seg': output}] |