Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/detr.py | |
# Modified by Jian Ding from: https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py | |
import fvcore.nn.weight_init as weight_init | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from detectron2.config import configurable | |
from detectron2.layers import Conv2d | |
from .model import Aggregator | |
from cat_seg.third_party import clip | |
from cat_seg.third_party import imagenet_templates | |
import numpy as np | |
import open_clip | |
class CATSegPredictor(nn.Module): | |
def __init__( | |
self, | |
*, | |
train_class_json: str, | |
test_class_json: str, | |
clip_pretrained: str, | |
prompt_ensemble_type: str, | |
text_guidance_dim: int, | |
text_guidance_proj_dim: int, | |
appearance_guidance_dim: int, | |
appearance_guidance_proj_dim: int, | |
prompt_depth: int, | |
prompt_length: int, | |
decoder_dims: list, | |
decoder_guidance_dims: list, | |
decoder_guidance_proj_dims: list, | |
num_heads: int, | |
num_layers: tuple, | |
hidden_dims: tuple, | |
pooling_sizes: tuple, | |
feature_resolution: tuple, | |
window_sizes: tuple, | |
attention_type: str, | |
): | |
""" | |
Args: | |
""" | |
super().__init__() | |
import json | |
# use class_texts in train_forward, and test_class_texts in test_forward | |
#with open(train_class_json, 'r') as f_in: | |
# self.class_texts = json.load(f_in) | |
#with open(test_class_json, 'r') as f_in: | |
# self.test_class_texts = json.load(f_in) | |
#assert self.class_texts != None | |
#if self.test_class_texts == None: | |
# self.test_class_texts = self.class_texts | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.device = device | |
self.tokenizer = None | |
if clip_pretrained == "ViT-G" or clip_pretrained == "ViT-H": | |
# for OpenCLIP models | |
name, pretrain = ('ViT-H-14', 'laion2b_s32b_b79k') if clip_pretrained == 'ViT-H' else ('ViT-bigG-14', 'laion2b_s39b_b160k') | |
clip_model, _, clip_preprocess = open_clip.create_model_and_transforms( | |
name, | |
pretrained=pretrain, | |
device=device, | |
force_image_size=336,) | |
self.tokenizer = open_clip.get_tokenizer(name) | |
else: | |
# for OpenAI models | |
clip_model, clip_preprocess = clip.load(clip_pretrained, device=device, jit=False, prompt_depth=prompt_depth, prompt_length=prompt_length) | |
self.prompt_ensemble_type = prompt_ensemble_type | |
if self.prompt_ensemble_type == "imagenet_select": | |
prompt_templates = imagenet_templates.IMAGENET_TEMPLATES_SELECT | |
elif self.prompt_ensemble_type == "imagenet": | |
prompt_templates = imagenet_templates.IMAGENET_TEMPLATES | |
elif self.prompt_ensemble_type == "single": | |
prompt_templates = ['A photo of a {} in the scene',] | |
else: | |
raise NotImplementedError | |
#self.text_features = self.class_embeddings(self.class_texts, prompt_templates, clip_model).permute(1, 0, 2).float() | |
#self.text_features_test = self.class_embeddings(self.test_class_texts, prompt_templates, clip_model).permute(1, 0, 2).float() | |
self.clip_model = clip_model.float() | |
self.clip_preprocess = clip_preprocess | |
transformer = Aggregator( | |
text_guidance_dim=text_guidance_dim, | |
text_guidance_proj_dim=text_guidance_proj_dim, | |
appearance_guidance_dim=appearance_guidance_dim, | |
appearance_guidance_proj_dim=appearance_guidance_proj_dim, | |
decoder_dims=decoder_dims, | |
decoder_guidance_dims=decoder_guidance_dims, | |
decoder_guidance_proj_dims=decoder_guidance_proj_dims, | |
num_layers=num_layers, | |
nheads=num_heads, | |
hidden_dim=hidden_dims, | |
pooling_size=pooling_sizes, | |
feature_resolution=feature_resolution, | |
window_size=window_sizes, | |
attention_type=attention_type | |
) | |
self.transformer = transformer | |
def from_config(cls, cfg):#, in_channels, mask_classification): | |
ret = {} | |
ret["train_class_json"] = cfg.MODEL.SEM_SEG_HEAD.TRAIN_CLASS_JSON | |
ret["test_class_json"] = cfg.MODEL.SEM_SEG_HEAD.TEST_CLASS_JSON | |
ret["clip_pretrained"] = cfg.MODEL.SEM_SEG_HEAD.CLIP_PRETRAINED | |
ret["prompt_ensemble_type"] = cfg.MODEL.PROMPT_ENSEMBLE_TYPE | |
# Aggregator parameters: | |
ret["text_guidance_dim"] = cfg.MODEL.SEM_SEG_HEAD.TEXT_AFFINITY_DIM | |
ret["text_guidance_proj_dim"] = cfg.MODEL.SEM_SEG_HEAD.TEXT_AFFINITY_PROJ_DIM | |
ret["appearance_guidance_dim"] = cfg.MODEL.SEM_SEG_HEAD.APPEARANCE_AFFINITY_DIM | |
ret["appearance_guidance_proj_dim"] = cfg.MODEL.SEM_SEG_HEAD.APPEARANCE_AFFINITY_PROJ_DIM | |
ret["decoder_dims"] = cfg.MODEL.SEM_SEG_HEAD.DECODER_DIMS | |
ret["decoder_guidance_dims"] = cfg.MODEL.SEM_SEG_HEAD.DECODER_AFFINITY_DIMS | |
ret["decoder_guidance_proj_dims"] = cfg.MODEL.SEM_SEG_HEAD.DECODER_AFFINITY_PROJ_DIMS | |
ret["prompt_depth"] = cfg.MODEL.SEM_SEG_HEAD.PROMPT_DEPTH | |
ret["prompt_length"] = cfg.MODEL.SEM_SEG_HEAD.PROMPT_LENGTH | |
ret["num_layers"] = cfg.MODEL.SEM_SEG_HEAD.NUM_LAYERS | |
ret["num_heads"] = cfg.MODEL.SEM_SEG_HEAD.NUM_HEADS | |
ret["hidden_dims"] = cfg.MODEL.SEM_SEG_HEAD.HIDDEN_DIMS | |
ret["pooling_sizes"] = cfg.MODEL.SEM_SEG_HEAD.POOLING_SIZES | |
ret["feature_resolution"] = cfg.MODEL.SEM_SEG_HEAD.FEATURE_RESOLUTION | |
ret["window_sizes"] = cfg.MODEL.SEM_SEG_HEAD.WINDOW_SIZES | |
ret["attention_type"] = cfg.MODEL.SEM_SEG_HEAD.ATTENTION_TYPE | |
return ret | |
def forward(self, x, vis_affinity): | |
vis = [vis_affinity[k] for k in vis_affinity.keys()][::-1] | |
text = self.text_features if self.training else self.text_features_test | |
text = text.repeat(x.shape[0], 1, 1, 1) | |
out = self.transformer(x, text, vis) | |
return out | |
def class_embeddings(self, classnames, templates, clip_model): | |
zeroshot_weights = [] | |
for classname in classnames: | |
if ', ' in classname: | |
classname_splits = classname.split(', ') | |
texts = [] | |
for template in templates: | |
for cls_split in classname_splits: | |
texts.append(template.format(cls_split)) | |
else: | |
texts = [template.format(classname) for template in templates] # format with class | |
if self.tokenizer is not None: | |
texts = self.tokenizer(texts).to(self.device) | |
else: | |
texts = clip.tokenize(texts).to(self.device) | |
class_embeddings = clip_model.encode_text(texts) | |
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) | |
if len(templates) != class_embeddings.shape[0]: | |
class_embeddings = class_embeddings.reshape(len(templates), -1, class_embeddings.shape[-1]).mean(dim=1) | |
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) | |
class_embedding = class_embeddings | |
zeroshot_weights.append(class_embedding) | |
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(self.device) | |
return zeroshot_weights |