|
|
|
|
|
|
|
import fvcore.nn.weight_init as weight_init |
|
import torch |
|
|
|
from torch import nn |
|
from torch.nn import functional as F |
|
|
|
from detectron2.config import configurable |
|
from detectron2.layers import Conv2d |
|
|
|
from .model import Aggregator |
|
from cat_seg.third_party import clip |
|
from cat_seg.third_party import imagenet_templates |
|
|
|
import numpy as np |
|
import open_clip |
|
class CATSegPredictor(nn.Module): |
|
@configurable |
|
def __init__( |
|
self, |
|
*, |
|
train_class_json: str, |
|
test_class_json: str, |
|
clip_pretrained: str, |
|
prompt_ensemble_type: str, |
|
text_guidance_dim: int, |
|
text_guidance_proj_dim: int, |
|
appearance_guidance_dim: int, |
|
appearance_guidance_proj_dim: int, |
|
prompt_depth: int, |
|
prompt_length: int, |
|
decoder_dims: list, |
|
decoder_guidance_dims: list, |
|
decoder_guidance_proj_dims: list, |
|
num_heads: int, |
|
num_layers: tuple, |
|
hidden_dims: tuple, |
|
pooling_sizes: tuple, |
|
feature_resolution: tuple, |
|
window_sizes: tuple, |
|
attention_type: str, |
|
): |
|
""" |
|
Args: |
|
|
|
""" |
|
super().__init__() |
|
|
|
import json |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
self.device = device |
|
self.tokenizer = None |
|
if clip_pretrained == "ViT-G" or clip_pretrained == "ViT-H": |
|
|
|
name, pretrain = ('ViT-H-14', 'laion2b_s32b_b79k') if clip_pretrained == 'ViT-H' else ('ViT-bigG-14', 'laion2b_s39b_b160k') |
|
clip_model, _, clip_preprocess = open_clip.create_model_and_transforms( |
|
name, |
|
pretrained=pretrain, |
|
device=device, |
|
force_image_size=336,) |
|
|
|
self.tokenizer = open_clip.get_tokenizer(name) |
|
else: |
|
|
|
clip_model, clip_preprocess = clip.load(clip_pretrained, device=device, jit=False, prompt_depth=prompt_depth, prompt_length=prompt_length) |
|
|
|
self.prompt_ensemble_type = prompt_ensemble_type |
|
|
|
if self.prompt_ensemble_type == "imagenet_select": |
|
prompt_templates = imagenet_templates.IMAGENET_TEMPLATES_SELECT |
|
elif self.prompt_ensemble_type == "imagenet": |
|
prompt_templates = imagenet_templates.IMAGENET_TEMPLATES |
|
elif self.prompt_ensemble_type == "single": |
|
prompt_templates = ['A photo of a {} in the scene',] |
|
else: |
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
|
self.clip_model = clip_model.float() |
|
self.clip_preprocess = clip_preprocess |
|
|
|
transformer = Aggregator( |
|
text_guidance_dim=text_guidance_dim, |
|
text_guidance_proj_dim=text_guidance_proj_dim, |
|
appearance_guidance_dim=appearance_guidance_dim, |
|
appearance_guidance_proj_dim=appearance_guidance_proj_dim, |
|
decoder_dims=decoder_dims, |
|
decoder_guidance_dims=decoder_guidance_dims, |
|
decoder_guidance_proj_dims=decoder_guidance_proj_dims, |
|
num_layers=num_layers, |
|
nheads=num_heads, |
|
hidden_dim=hidden_dims, |
|
pooling_size=pooling_sizes, |
|
feature_resolution=feature_resolution, |
|
window_size=window_sizes, |
|
attention_type=attention_type |
|
) |
|
self.transformer = transformer |
|
|
|
@classmethod |
|
def from_config(cls, cfg): |
|
ret = {} |
|
|
|
ret["train_class_json"] = cfg.MODEL.SEM_SEG_HEAD.TRAIN_CLASS_JSON |
|
ret["test_class_json"] = cfg.MODEL.SEM_SEG_HEAD.TEST_CLASS_JSON |
|
ret["clip_pretrained"] = cfg.MODEL.SEM_SEG_HEAD.CLIP_PRETRAINED |
|
ret["prompt_ensemble_type"] = cfg.MODEL.PROMPT_ENSEMBLE_TYPE |
|
|
|
|
|
ret["text_guidance_dim"] = cfg.MODEL.SEM_SEG_HEAD.TEXT_AFFINITY_DIM |
|
ret["text_guidance_proj_dim"] = cfg.MODEL.SEM_SEG_HEAD.TEXT_AFFINITY_PROJ_DIM |
|
ret["appearance_guidance_dim"] = cfg.MODEL.SEM_SEG_HEAD.APPEARANCE_AFFINITY_DIM |
|
ret["appearance_guidance_proj_dim"] = cfg.MODEL.SEM_SEG_HEAD.APPEARANCE_AFFINITY_PROJ_DIM |
|
|
|
ret["decoder_dims"] = cfg.MODEL.SEM_SEG_HEAD.DECODER_DIMS |
|
ret["decoder_guidance_dims"] = cfg.MODEL.SEM_SEG_HEAD.DECODER_AFFINITY_DIMS |
|
ret["decoder_guidance_proj_dims"] = cfg.MODEL.SEM_SEG_HEAD.DECODER_AFFINITY_PROJ_DIMS |
|
|
|
ret["prompt_depth"] = cfg.MODEL.SEM_SEG_HEAD.PROMPT_DEPTH |
|
ret["prompt_length"] = cfg.MODEL.SEM_SEG_HEAD.PROMPT_LENGTH |
|
|
|
ret["num_layers"] = cfg.MODEL.SEM_SEG_HEAD.NUM_LAYERS |
|
ret["num_heads"] = cfg.MODEL.SEM_SEG_HEAD.NUM_HEADS |
|
ret["hidden_dims"] = cfg.MODEL.SEM_SEG_HEAD.HIDDEN_DIMS |
|
ret["pooling_sizes"] = cfg.MODEL.SEM_SEG_HEAD.POOLING_SIZES |
|
ret["feature_resolution"] = cfg.MODEL.SEM_SEG_HEAD.FEATURE_RESOLUTION |
|
ret["window_sizes"] = cfg.MODEL.SEM_SEG_HEAD.WINDOW_SIZES |
|
ret["attention_type"] = cfg.MODEL.SEM_SEG_HEAD.ATTENTION_TYPE |
|
|
|
return ret |
|
|
|
def forward(self, x, vis_affinity): |
|
vis = [vis_affinity[k] for k in vis_affinity.keys()][::-1] |
|
text = self.text_features if self.training else self.text_features_test |
|
text = text.repeat(x.shape[0], 1, 1, 1) |
|
out = self.transformer(x, text, vis) |
|
return out |
|
|
|
@torch.no_grad() |
|
def class_embeddings(self, classnames, templates, clip_model): |
|
zeroshot_weights = [] |
|
for classname in classnames: |
|
if ', ' in classname: |
|
classname_splits = classname.split(', ') |
|
texts = [] |
|
for template in templates: |
|
for cls_split in classname_splits: |
|
texts.append(template.format(cls_split)) |
|
else: |
|
texts = [template.format(classname) for template in templates] |
|
if self.tokenizer is not None: |
|
texts = self.tokenizer(texts).to(self.device) |
|
else: |
|
texts = clip.tokenize(texts).to(self.device) |
|
class_embeddings = clip_model.encode_text(texts) |
|
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) |
|
if len(templates) != class_embeddings.shape[0]: |
|
class_embeddings = class_embeddings.reshape(len(templates), -1, class_embeddings.shape[-1]).mean(dim=1) |
|
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) |
|
class_embedding = class_embeddings |
|
zeroshot_weights.append(class_embedding) |
|
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(self.device) |
|
return zeroshot_weights |