SAM-CAT-Seg / cat_seg /modeling /transformer /cat_seg_predictor.py
seokju cho
initial commit
f8f62f3
raw
history blame
7.68 kB
# Copyright (c) Facebook, Inc. and its affiliates.
# Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/detr.py
# Modified by Jian Ding from: https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py
import fvcore.nn.weight_init as weight_init
import torch
from torch import nn
from torch.nn import functional as F
from detectron2.config import configurable
from detectron2.layers import Conv2d
from .model import Aggregator
from cat_seg.third_party import clip
from cat_seg.third_party import imagenet_templates
import numpy as np
import open_clip
class CATSegPredictor(nn.Module):
@configurable
def __init__(
self,
*,
train_class_json: str,
test_class_json: str,
clip_pretrained: str,
prompt_ensemble_type: str,
text_guidance_dim: int,
text_guidance_proj_dim: int,
appearance_guidance_dim: int,
appearance_guidance_proj_dim: int,
prompt_depth: int,
prompt_length: int,
decoder_dims: list,
decoder_guidance_dims: list,
decoder_guidance_proj_dims: list,
num_heads: int,
num_layers: tuple,
hidden_dims: tuple,
pooling_sizes: tuple,
feature_resolution: tuple,
window_sizes: tuple,
attention_type: str,
):
"""
Args:
"""
super().__init__()
import json
# use class_texts in train_forward, and test_class_texts in test_forward
#with open(train_class_json, 'r') as f_in:
# self.class_texts = json.load(f_in)
#with open(test_class_json, 'r') as f_in:
# self.test_class_texts = json.load(f_in)
#assert self.class_texts != None
#if self.test_class_texts == None:
# self.test_class_texts = self.class_texts
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = device
self.tokenizer = None
if clip_pretrained == "ViT-G" or clip_pretrained == "ViT-H":
# for OpenCLIP models
name, pretrain = ('ViT-H-14', 'laion2b_s32b_b79k') if clip_pretrained == 'ViT-H' else ('ViT-bigG-14', 'laion2b_s39b_b160k')
clip_model, _, clip_preprocess = open_clip.create_model_and_transforms(
name,
pretrained=pretrain,
device=device,
force_image_size=336,)
self.tokenizer = open_clip.get_tokenizer(name)
else:
# for OpenAI models
clip_model, clip_preprocess = clip.load(clip_pretrained, device=device, jit=False, prompt_depth=prompt_depth, prompt_length=prompt_length)
self.prompt_ensemble_type = prompt_ensemble_type
if self.prompt_ensemble_type == "imagenet_select":
prompt_templates = imagenet_templates.IMAGENET_TEMPLATES_SELECT
elif self.prompt_ensemble_type == "imagenet":
prompt_templates = imagenet_templates.IMAGENET_TEMPLATES
elif self.prompt_ensemble_type == "single":
prompt_templates = ['A photo of a {} in the scene',]
else:
raise NotImplementedError
#self.text_features = self.class_embeddings(self.class_texts, prompt_templates, clip_model).permute(1, 0, 2).float()
#self.text_features_test = self.class_embeddings(self.test_class_texts, prompt_templates, clip_model).permute(1, 0, 2).float()
self.clip_model = clip_model.float()
self.clip_preprocess = clip_preprocess
transformer = Aggregator(
text_guidance_dim=text_guidance_dim,
text_guidance_proj_dim=text_guidance_proj_dim,
appearance_guidance_dim=appearance_guidance_dim,
appearance_guidance_proj_dim=appearance_guidance_proj_dim,
decoder_dims=decoder_dims,
decoder_guidance_dims=decoder_guidance_dims,
decoder_guidance_proj_dims=decoder_guidance_proj_dims,
num_layers=num_layers,
nheads=num_heads,
hidden_dim=hidden_dims,
pooling_size=pooling_sizes,
feature_resolution=feature_resolution,
window_size=window_sizes,
attention_type=attention_type
)
self.transformer = transformer
@classmethod
def from_config(cls, cfg):#, in_channels, mask_classification):
ret = {}
ret["train_class_json"] = cfg.MODEL.SEM_SEG_HEAD.TRAIN_CLASS_JSON
ret["test_class_json"] = cfg.MODEL.SEM_SEG_HEAD.TEST_CLASS_JSON
ret["clip_pretrained"] = cfg.MODEL.SEM_SEG_HEAD.CLIP_PRETRAINED
ret["prompt_ensemble_type"] = cfg.MODEL.PROMPT_ENSEMBLE_TYPE
# Aggregator parameters:
ret["text_guidance_dim"] = cfg.MODEL.SEM_SEG_HEAD.TEXT_AFFINITY_DIM
ret["text_guidance_proj_dim"] = cfg.MODEL.SEM_SEG_HEAD.TEXT_AFFINITY_PROJ_DIM
ret["appearance_guidance_dim"] = cfg.MODEL.SEM_SEG_HEAD.APPEARANCE_AFFINITY_DIM
ret["appearance_guidance_proj_dim"] = cfg.MODEL.SEM_SEG_HEAD.APPEARANCE_AFFINITY_PROJ_DIM
ret["decoder_dims"] = cfg.MODEL.SEM_SEG_HEAD.DECODER_DIMS
ret["decoder_guidance_dims"] = cfg.MODEL.SEM_SEG_HEAD.DECODER_AFFINITY_DIMS
ret["decoder_guidance_proj_dims"] = cfg.MODEL.SEM_SEG_HEAD.DECODER_AFFINITY_PROJ_DIMS
ret["prompt_depth"] = cfg.MODEL.SEM_SEG_HEAD.PROMPT_DEPTH
ret["prompt_length"] = cfg.MODEL.SEM_SEG_HEAD.PROMPT_LENGTH
ret["num_layers"] = cfg.MODEL.SEM_SEG_HEAD.NUM_LAYERS
ret["num_heads"] = cfg.MODEL.SEM_SEG_HEAD.NUM_HEADS
ret["hidden_dims"] = cfg.MODEL.SEM_SEG_HEAD.HIDDEN_DIMS
ret["pooling_sizes"] = cfg.MODEL.SEM_SEG_HEAD.POOLING_SIZES
ret["feature_resolution"] = cfg.MODEL.SEM_SEG_HEAD.FEATURE_RESOLUTION
ret["window_sizes"] = cfg.MODEL.SEM_SEG_HEAD.WINDOW_SIZES
ret["attention_type"] = cfg.MODEL.SEM_SEG_HEAD.ATTENTION_TYPE
return ret
def forward(self, x, vis_affinity):
vis = [vis_affinity[k] for k in vis_affinity.keys()][::-1]
text = self.text_features if self.training else self.text_features_test
text = text.repeat(x.shape[0], 1, 1, 1)
out = self.transformer(x, text, vis)
return out
@torch.no_grad()
def class_embeddings(self, classnames, templates, clip_model):
zeroshot_weights = []
for classname in classnames:
if ', ' in classname:
classname_splits = classname.split(', ')
texts = []
for template in templates:
for cls_split in classname_splits:
texts.append(template.format(cls_split))
else:
texts = [template.format(classname) for template in templates] # format with class
if self.tokenizer is not None:
texts = self.tokenizer(texts).to(self.device)
else:
texts = clip.tokenize(texts).to(self.device)
class_embeddings = clip_model.encode_text(texts)
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
if len(templates) != class_embeddings.shape[0]:
class_embeddings = class_embeddings.reshape(len(templates), -1, class_embeddings.shape[-1]).mean(dim=1)
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
class_embedding = class_embeddings
zeroshot_weights.append(class_embedding)
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(self.device)
return zeroshot_weights