File size: 7,684 Bytes
d617811 aff8d56 d617811 dfe1f0b d617811 aff8d56 d617811 dfe1f0b d617811 dfe1f0b d617811 dfe1f0b d617811 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
# Copyright (c) Facebook, Inc. and its affiliates.
# Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/detr.py
# Modified by Jian Ding from: https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py
import fvcore.nn.weight_init as weight_init
import torch
from torch import nn
from torch.nn import functional as F
from detectron2.config import configurable
from detectron2.layers import Conv2d
from .model import Aggregator
from cat_seg.third_party import clip
from cat_seg.third_party import imagenet_templates
import numpy as np
import open_clip
class CATSegPredictor(nn.Module):
@configurable
def __init__(
self,
*,
train_class_json: str,
test_class_json: str,
clip_pretrained: str,
prompt_ensemble_type: str,
text_guidance_dim: int,
text_guidance_proj_dim: int,
appearance_guidance_dim: int,
appearance_guidance_proj_dim: int,
prompt_depth: int,
prompt_length: int,
decoder_dims: list,
decoder_guidance_dims: list,
decoder_guidance_proj_dims: list,
num_heads: int,
num_layers: tuple,
hidden_dims: tuple,
pooling_sizes: tuple,
feature_resolution: tuple,
window_sizes: tuple,
attention_type: str,
):
"""
Args:
"""
super().__init__()
import json
# use class_texts in train_forward, and test_class_texts in test_forward
#with open(train_class_json, 'r') as f_in:
# self.class_texts = json.load(f_in)
#with open(test_class_json, 'r') as f_in:
# self.test_class_texts = json.load(f_in)
#assert self.class_texts != None
#if self.test_class_texts == None:
# self.test_class_texts = self.class_texts
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = device
self.tokenizer = None
if clip_pretrained == "ViT-G" or clip_pretrained == "ViT-H":
# for OpenCLIP models
name, pretrain = ('ViT-H-14', 'laion2b_s32b_b79k') if clip_pretrained == 'ViT-H' else ('ViT-bigG-14', 'laion2b_s39b_b160k')
clip_model, _, clip_preprocess = open_clip.create_model_and_transforms(
name,
pretrained=pretrain,
device=device,
force_image_size=336,)
self.tokenizer = open_clip.get_tokenizer(name)
else:
# for OpenAI models
clip_model, clip_preprocess = clip.load(clip_pretrained, device=device, jit=False, prompt_depth=prompt_depth, prompt_length=prompt_length)
self.prompt_ensemble_type = prompt_ensemble_type
if self.prompt_ensemble_type == "imagenet_select":
prompt_templates = imagenet_templates.IMAGENET_TEMPLATES_SELECT
elif self.prompt_ensemble_type == "imagenet":
prompt_templates = imagenet_templates.IMAGENET_TEMPLATES
elif self.prompt_ensemble_type == "single":
prompt_templates = ['A photo of a {} in the scene',]
else:
raise NotImplementedError
#self.text_features = self.class_embeddings(self.class_texts, prompt_templates, clip_model).permute(1, 0, 2).float()
#self.text_features_test = self.class_embeddings(self.test_class_texts, prompt_templates, clip_model).permute(1, 0, 2).float()
self.clip_model = clip_model.float()
self.clip_preprocess = clip_preprocess
transformer = Aggregator(
text_guidance_dim=text_guidance_dim,
text_guidance_proj_dim=text_guidance_proj_dim,
appearance_guidance_dim=appearance_guidance_dim,
appearance_guidance_proj_dim=appearance_guidance_proj_dim,
decoder_dims=decoder_dims,
decoder_guidance_dims=decoder_guidance_dims,
decoder_guidance_proj_dims=decoder_guidance_proj_dims,
num_layers=num_layers,
nheads=num_heads,
hidden_dim=hidden_dims,
pooling_size=pooling_sizes,
feature_resolution=feature_resolution,
window_size=window_sizes,
attention_type=attention_type
)
self.transformer = transformer
@classmethod
def from_config(cls, cfg):#, in_channels, mask_classification):
ret = {}
ret["train_class_json"] = cfg.MODEL.SEM_SEG_HEAD.TRAIN_CLASS_JSON
ret["test_class_json"] = cfg.MODEL.SEM_SEG_HEAD.TEST_CLASS_JSON
ret["clip_pretrained"] = cfg.MODEL.SEM_SEG_HEAD.CLIP_PRETRAINED
ret["prompt_ensemble_type"] = cfg.MODEL.PROMPT_ENSEMBLE_TYPE
# Aggregator parameters:
ret["text_guidance_dim"] = cfg.MODEL.SEM_SEG_HEAD.TEXT_AFFINITY_DIM
ret["text_guidance_proj_dim"] = cfg.MODEL.SEM_SEG_HEAD.TEXT_AFFINITY_PROJ_DIM
ret["appearance_guidance_dim"] = cfg.MODEL.SEM_SEG_HEAD.APPEARANCE_AFFINITY_DIM
ret["appearance_guidance_proj_dim"] = cfg.MODEL.SEM_SEG_HEAD.APPEARANCE_AFFINITY_PROJ_DIM
ret["decoder_dims"] = cfg.MODEL.SEM_SEG_HEAD.DECODER_DIMS
ret["decoder_guidance_dims"] = cfg.MODEL.SEM_SEG_HEAD.DECODER_AFFINITY_DIMS
ret["decoder_guidance_proj_dims"] = cfg.MODEL.SEM_SEG_HEAD.DECODER_AFFINITY_PROJ_DIMS
ret["prompt_depth"] = cfg.MODEL.SEM_SEG_HEAD.PROMPT_DEPTH
ret["prompt_length"] = cfg.MODEL.SEM_SEG_HEAD.PROMPT_LENGTH
ret["num_layers"] = cfg.MODEL.SEM_SEG_HEAD.NUM_LAYERS
ret["num_heads"] = cfg.MODEL.SEM_SEG_HEAD.NUM_HEADS
ret["hidden_dims"] = cfg.MODEL.SEM_SEG_HEAD.HIDDEN_DIMS
ret["pooling_sizes"] = cfg.MODEL.SEM_SEG_HEAD.POOLING_SIZES
ret["feature_resolution"] = cfg.MODEL.SEM_SEG_HEAD.FEATURE_RESOLUTION
ret["window_sizes"] = cfg.MODEL.SEM_SEG_HEAD.WINDOW_SIZES
ret["attention_type"] = cfg.MODEL.SEM_SEG_HEAD.ATTENTION_TYPE
return ret
def forward(self, x, vis_affinity):
vis = [vis_affinity[k] for k in vis_affinity.keys()][::-1]
text = self.text_features if self.training else self.text_features_test
text = text.repeat(x.shape[0], 1, 1, 1)
out = self.transformer(x, text, vis)
return out
@torch.no_grad()
def class_embeddings(self, classnames, templates, clip_model):
zeroshot_weights = []
for classname in classnames:
if ', ' in classname:
classname_splits = classname.split(', ')
texts = []
for template in templates:
for cls_split in classname_splits:
texts.append(template.format(cls_split))
else:
texts = [template.format(classname) for template in templates] # format with class
if self.tokenizer is not None:
texts = self.tokenizer(texts).to(self.device)
else:
texts = clip.tokenize(texts).to(self.device)
class_embeddings = clip_model.encode_text(texts)
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
if len(templates) != class_embeddings.shape[0]:
class_embeddings = class_embeddings.reshape(len(templates), -1, class_embeddings.shape[-1]).mean(dim=1)
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
class_embedding = class_embeddings
zeroshot_weights.append(class_embedding)
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(self.device)
return zeroshot_weights |