diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..1a66a3a9a1405f07aea398f222adbf6f449d3146 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,36 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text +*.jpeg filter=lfs diff=lfs merge=lfs -text +*.gif filter=lfs diff=lfs merge=lfs -text +*.png filter=lfs diff=lfs merge=lfs -text diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000000000000000000000000000000000000..94a25f7f4cb416c083d265558da75d457237d671 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..43184f07307c95e7d3fd796e1372c4882026899c --- /dev/null +++ b/README.md @@ -0,0 +1,14 @@ +--- +title: Ov Seg +emoji: 📊 +colorFrom: red +colorTo: pink +sdk: gradio +sdk_version: 3.8.2 +app_file: app.py +pinned: false +license: cc-by-nc-4.0 +duplicated_from: facebook/ov-seg +--- + +Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference \ No newline at end of file diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..cd9b5260e98a6bbdce7c0dcdb36bd3780587b4d2 --- /dev/null +++ b/app.py @@ -0,0 +1,96 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +import multiprocessing as mp + +import numpy as np +from PIL import Image + + +try: + import detectron2 +except: + import os + os.system('pip install git+https://github.com/facebookresearch/detectron2.git') + +from detectron2.config import get_cfg + +from detectron2.projects.deeplab import add_deeplab_config +from detectron2.data.detection_utils import read_image +from open_vocab_seg import add_ovseg_config +from open_vocab_seg.utils import VisualizationDemo, SAMVisualizationDemo + +import gradio as gr + +import gdown + +# ckpt_url = 'https://drive.google.com/uc?id=1cn-ohxgXDrDfkzC1QdO-fi8IjbjXmgKy' +# output = './ovseg_swinbase_vitL14_ft_mpt.pth' +# gdown.download(ckpt_url, output, quiet=False) + +def setup_cfg(config_file): + # load config from file and command-line arguments + cfg = get_cfg() + add_deeplab_config(cfg) + add_ovseg_config(cfg) + cfg.merge_from_file(config_file) + cfg.freeze() + return cfg + + +def inference(class_names, proposal_gen, granularity, input_img): + mp.set_start_method("spawn", force=True) + config_file = './ovseg_swinB_vitL_demo.yaml' + cfg = setup_cfg(config_file) + if proposal_gen == 'MaskFormer': + demo = VisualizationDemo(cfg) + elif proposal_gen == 'Segment_Anything': + demo = SAMVisualizationDemo(cfg, granularity, './sam_vit_l_0b3195.pth', './ovseg_clip_l_9a1909.pth') + class_names = class_names.split(',') + img = read_image(input_img, format="BGR") + _, visualized_output = demo.run_on_image(img, class_names) + + return Image.fromarray(np.uint8(visualized_output.get_image())).convert('RGB') + + +examples = [['Saturn V, toys, desk, wall, sunflowers, white roses, chrysanthemums, carnations, green dianthus', 'Segment_Anything', 0.8, './resources/demo_samples/sample_01.jpeg'], + ['red bench, yellow bench, blue bench, brown bench, green bench, blue chair, yellow chair, green chair, brown chair, yellow square painting, barrel, buddha statue', 'Segment_Anything', 0.8, './resources/demo_samples/sample_04.png'], + ['pillow, pipe, sweater, shirt, jeans jacket, shoes, cabinet, handbag, photo frame', 'Segment_Anything', 0.8, './resources/demo_samples/sample_05.png'], + ['Saturn V, toys, blossom', 'MaskFormer', 1.0, './resources/demo_samples/sample_01.jpeg'], + ['Oculus, Ukulele', 'MaskFormer', 1.0, './resources/demo_samples/sample_03.jpeg'], + ['Golden gate, yacht', 'MaskFormer', 1.0, './resources/demo_samples/sample_02.jpeg'],] +output_labels = ['segmentation map'] + +title = 'OVSeg (+ Segment_Anything)' + +description = """ +[NEW!] We incorperate OVSeg CLIP w/ Segment_Anything, enabling SAM's text prompts. +Gradio Demo for Open-Vocabulary Semantic Segmentation with Mask-adapted CLIP. \n +OVSeg could perform open vocabulary segmentation, you may input more classes (seperate by comma). You may click on of the examples or upload your own image. \n +It might take some time to process. Cheers! +

(Colab only supports MaskFormer proposal generator) Don't want to wait in queue? Open In Colab

+""" + +article = """ +

+ +Open-Vocabulary Semantic Segmentation with Mask-adapted CLIP + +| +Github Repo

+""" + +gr.Interface( + inference, + inputs=[ + gr.Textbox( + lines=1, placeholder=None, default='', label='class names'), + gr.Radio(["Segment_Anything", "MaskFormer"], label="Proposal generator", default="Segment_Anything"), + gr.Slider(0, 1.0, 0.8, label="For Segment_Anything only, granularity of masks from 0 (most coarse) to 1 (most precise)"), + gr.Image(type='filepath'), + ], + outputs=gr.outputs.Image(label='segmentation map'), + title=title, + description=description, + article=article, + examples=examples).launch(enable_queue=True) diff --git a/open_vocab_seg/.DS_Store b/open_vocab_seg/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..571ecdfd99ba3951eec8fd0206519a409e86c2b1 Binary files /dev/null and b/open_vocab_seg/.DS_Store differ diff --git a/open_vocab_seg/__init__.py b/open_vocab_seg/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b54fce14b8a029f1355bc8b74c20884e880ee9c4 --- /dev/null +++ b/open_vocab_seg/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +from . import data +from . import modeling +from .config import add_ovseg_config + +from .test_time_augmentation import SemanticSegmentorWithTTA +from .ovseg_model import OVSeg, OVSegDEMO diff --git a/open_vocab_seg/config.py b/open_vocab_seg/config.py new file mode 100644 index 0000000000000000000000000000000000000000..400e9a05d4995e3f3401b34a22ae687b2c9c90e0 --- /dev/null +++ b/open_vocab_seg/config.py @@ -0,0 +1,133 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +from detectron2.config import CfgNode as CN + + +def add_mask_former_default_config(cfg): + # data config + # select the dataset mapper + cfg.INPUT.DATASET_MAPPER_NAME = "mask_former_semantic" + # Color augmentation + cfg.INPUT.COLOR_AUG_SSD = False + # We retry random cropping until no single category in semantic segmentation GT occupies more + # than `SINGLE_CATEGORY_MAX_AREA` part of the crop. + cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA = 1.0 + # Pad image and segmentation GT in dataset mapper. + cfg.INPUT.SIZE_DIVISIBILITY = -1 + + # solver config + # test batch size + cfg.SOLVER.TEST_IMS_PER_BATCH = 1 + # weight decay on embedding + cfg.SOLVER.WEIGHT_DECAY_EMBED = 0.0 + # optimizer + cfg.SOLVER.OPTIMIZER = "ADAMW" + cfg.SOLVER.BACKBONE_MULTIPLIER = 0.1 + + # mask_former model config + cfg.MODEL.MASK_FORMER = CN() + + # loss + cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION = True + cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT = 0.1 + cfg.MODEL.MASK_FORMER.DICE_WEIGHT = 1.0 + cfg.MODEL.MASK_FORMER.MASK_WEIGHT = 20.0 + + # transformer config + cfg.MODEL.MASK_FORMER.NHEADS = 8 + cfg.MODEL.MASK_FORMER.DROPOUT = 0.1 + cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD = 2048 + cfg.MODEL.MASK_FORMER.ENC_LAYERS = 0 + cfg.MODEL.MASK_FORMER.DEC_LAYERS = 6 + cfg.MODEL.MASK_FORMER.PRE_NORM = False + + cfg.MODEL.MASK_FORMER.HIDDEN_DIM = 256 + cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES = 100 + + cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE = "res5" + cfg.MODEL.MASK_FORMER.ENFORCE_INPUT_PROJ = False + + # mask_former inference config + cfg.MODEL.MASK_FORMER.TEST = CN() + cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON = False + cfg.MODEL.MASK_FORMER.TEST.OBJECT_MASK_THRESHOLD = 0.0 + cfg.MODEL.MASK_FORMER.TEST.OVERLAP_THRESHOLD = 0.0 + cfg.MODEL.MASK_FORMER.TEST.SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE = False + + # Sometimes `backbone.size_divisibility` is set to 0 for some backbone (e.g. ResNet) + # you can use this config to override + cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY = 32 + + # pixel decoder config + cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256 + # adding transformer in pixel decoder + cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = 0 + # pixel decoder + cfg.MODEL.SEM_SEG_HEAD.PIXEL_DECODER_NAME = "BasePixelDecoder" + + # swin transformer backbone + cfg.MODEL.SWIN = CN() + cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE = 224 + cfg.MODEL.SWIN.PATCH_SIZE = 4 + cfg.MODEL.SWIN.EMBED_DIM = 96 + cfg.MODEL.SWIN.DEPTHS = [2, 2, 6, 2] + cfg.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24] + cfg.MODEL.SWIN.WINDOW_SIZE = 7 + cfg.MODEL.SWIN.MLP_RATIO = 4.0 + cfg.MODEL.SWIN.QKV_BIAS = True + cfg.MODEL.SWIN.QK_SCALE = None + cfg.MODEL.SWIN.NORM_INDICES = None + cfg.MODEL.SWIN.PROJECTION = False + cfg.MODEL.SWIN.PROJECT_DIM = 256 + cfg.MODEL.SWIN.DROP_RATE = 0.0 + cfg.MODEL.SWIN.ATTN_DROP_RATE = 0.0 + cfg.MODEL.SWIN.DROP_PATH_RATE = 0.3 + cfg.MODEL.SWIN.APE = False + cfg.MODEL.SWIN.PATCH_NORM = True + cfg.MODEL.SWIN.OUT_FEATURES = ["res2", "res3", "res4", "res5"] + + +def add_our_config(cfg): + cfg.TEST.SLIDING_WINDOW = False + cfg.TEST.SLIDING_TILE_SIZE = 224 + cfg.TEST.SLIDING_OVERLAP = 2 / 3.0 + # whether to use dense crf + cfg.TEST.DENSE_CRF = False + cfg.DATASETS.SAMPLE_PER_CLASS = -1 + cfg.DATASETS.SAMPLE_SEED = 0 + # embedding head + cfg.MODEL.SEM_SEG_HEAD.EMBEDDING_DIM = 512 + cfg.MODEL.SEM_SEG_HEAD.EMBED_HIDDEN_DIM = 1024 + cfg.MODEL.SEM_SEG_HEAD.EMBED_LAYERS = 2 + # clip_adapter + cfg.MODEL.CLIP_ADAPTER = CN() + cfg.MODEL.CLIP_ADAPTER.TEXT_TEMPLATES = "vild" + # for predefined + cfg.MODEL.CLIP_ADAPTER.PREDEFINED_PROMPT_TEMPLATES = ["a photo of a {}."] + # for learnable prompt + cfg.MODEL.CLIP_ADAPTER.PROMPT_CHECKPOINT = "" + cfg.MODEL.CLIP_ADAPTER.CLIP_MODEL_NAME = "ViT-B/16" + cfg.MODEL.CLIP_ADAPTER.MASK_FILL = "mean" + cfg.MODEL.CLIP_ADAPTER.MASK_EXPAND_RATIO = 1.0 + cfg.MODEL.CLIP_ADAPTER.MASK_THR = 0.4 + cfg.MODEL.CLIP_ADAPTER.MASK_MATTING = False + cfg.MODEL.CLIP_ADAPTER.REGION_RESIZED = True + cfg.MODEL.CLIP_ADAPTER.CLIP_ENSEMBLE = True + cfg.MODEL.CLIP_ADAPTER.CLIP_ENSEMBLE_WEIGHT = 0.7 + # for mask prompt + cfg.MODEL.CLIP_ADAPTER.MASK_PROMPT_DEPTH = 3 + cfg.MODEL.CLIP_ADAPTER.MASK_PROMPT_FWD = False + + # wandb + cfg.WANDB = CN() + cfg.WANDB.PROJECT = "open_vocab_seg" + cfg.WANDB.NAME = None + + +def add_ovseg_config(cfg): + """ + Add config for open_vocab_seg. + """ + add_mask_former_default_config(cfg) + add_our_config(cfg) diff --git a/open_vocab_seg/data/.DS_Store b/open_vocab_seg/data/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..e3548df3784b4020dc5b9d6383241cbe099cb0df Binary files /dev/null and b/open_vocab_seg/data/.DS_Store differ diff --git a/open_vocab_seg/data/__init__.py b/open_vocab_seg/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..970e2c8ce7f90afab089bf84e249af5ee7124951 --- /dev/null +++ b/open_vocab_seg/data/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +from .dataset_mappers import * +from . import datasets +from .build import ( + build_detection_train_loader, + build_detection_test_loader, +) diff --git a/open_vocab_seg/data/augmentations.py b/open_vocab_seg/data/augmentations.py new file mode 100644 index 0000000000000000000000000000000000000000..44e4906d4827812fa707f50e703f253a64ab6e43 --- /dev/null +++ b/open_vocab_seg/data/augmentations.py @@ -0,0 +1,202 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +import math +import numbers +import numpy as np +from detectron2.data.transforms.augmentation import Augmentation +from detectron2.data.transforms.transform import ( + CropTransform, + ResizeTransform, + TransformList, +) +from PIL import Image +from fvcore.transforms.transform import PadTransform + + +def mask2box(mask: np.ndarray): + # use naive way + row = np.nonzero(mask.sum(axis=0))[0] + if len(row) == 0: + return None + x1 = row.min() + x2 = row.max() + col = np.nonzero(mask.sum(axis=1))[0] + y1 = col.min() + y2 = col.max() + return x1, y1, x2 + 1 - x1, y2 + 1 - y1 + + +def expand_box(x, y, w, h, expand_ratio=1.0, max_h=None, max_w=None): + cx = x + 0.5 * w + cy = y + 0.5 * h + w = w * expand_ratio + h = h * expand_ratio + box = [cx - 0.5 * w, cy - 0.5 * h, cx + 0.5 * w, cy + 0.5 * h] + if max_h is not None: + box[1] = max(0, box[1]) + box[3] = min(max_h - 1, box[3]) + if max_w is not None: + box[0] = max(0, box[0]) + box[2] = min(max_w - 1, box[2]) + box[2] = box[2] - box[0] + box[3] = box[3] - box[1] + + return [int(b) for b in box] + + +class CropImageWithMask(Augmentation): + def __init__(self, expand_ratio=1.0, mode="choice"): + if isinstance(expand_ratio, numbers.Number): + expand_ratio = (expand_ratio, expand_ratio) + self.mode = mode + self.expand_ratio = expand_ratio + if self.mode == "range": + assert len(expand_ratio) == 2 and expand_ratio[0] < expand_ratio[1] + + def get_transform(self, image, sem_seg, category_id): + input_size = image.shape[:2] + bin_mask = sem_seg == category_id + x, y, w, h = mask2box(bin_mask) + if self.mode == "choice": + expand_ratio = np.random.choice(self.expand_ratio) + else: + expand_ratio = np.random.uniform(self.expand_ratio[0], self.expand_ratio[1]) + x, y, w, h = expand_box(x, y, w, h, expand_ratio, *input_size) + w = max(w, 1) + h = max(h, 1) + return CropTransform(x, y, w, h, input_size[1], input_size[0]) + + +class CropImageWithBox(Augmentation): + def __init__(self, expand_ratio=1.0, mode="choice"): + if isinstance(expand_ratio, numbers.Number): + expand_ratio = (expand_ratio, expand_ratio) + self.mode = mode + self.expand_ratio = expand_ratio + if self.mode == "range": + assert len(expand_ratio) == 2 and expand_ratio[0] < expand_ratio[1] + + def get_transform(self, image, boxes): + input_size = image.shape[:2] + x, y, x2, y2 = boxes[0] + w = x2 - x + 1 + h = y2 - y + 1 + if self.mode == "choice": + expand_ratio = np.random.choice(self.expand_ratio) + else: + expand_ratio = np.random.uniform(self.expand_ratio[0], self.expand_ratio[1]) + x, y, w, h = expand_box(x, y, w, h, expand_ratio, *input_size) + w = max(w, 1) + h = max(h, 1) + return CropTransform(x, y, w, h, input_size[1], input_size[0]) + + +class RandomResizedCrop(Augmentation): + def __init__( + self, + size, + scale=(0.08, 1.0), + ratio=(3.0 / 4.0, 4.0 / 3.0), + interpolation=Image.BILINEAR, + ): + if isinstance(size, int): + size = (size, size) + else: + assert isinstance(size, (tuple, list)) and len(size) == 2 + + self.size = size + + self.scale = scale + self.ratio = ratio + self.interpolation = interpolation + + def get_transform(self, image): + height, width = image.shape[:2] + area = height * width + + log_ratio = np.log(np.array(self.ratio)) + is_success = False + for _ in range(10): + target_area = area * np.random.uniform(self.scale[0], self.scale[1]) + aspect_ratio = np.exp(np.random.uniform(log_ratio[0], log_ratio[1])) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if 0 < w <= width and 0 < h <= height: + i = np.random.randint(0, width - w + 1) + j = np.random.randint(0, height - h + 1) + + is_success = True + break + + if not is_success: + # Fallback to central crop + in_ratio = float(width) / float(height) + if in_ratio < min(self.ratio): + w = width + h = int(round(w / min(self.ratio))) + elif in_ratio > max(self.ratio): + h = height + w = int(round(h * max(self.ratio))) + else: # whole image + w = width + h = height + i = (width - w) // 2 + j = (height - h) // 2 + return TransformList( + [ + CropTransform(i, j, w, h, width, height), + ResizeTransform( + h, w, self.size[1], self.size[0], interp=self.interpolation + ), + ] + ) + + +class CenterCrop(Augmentation): + def __init__(self, size, seg_ignore_label): + if isinstance(size, numbers.Number): + size = (int(size), int(size)) + elif isinstance(size, (tuple, list)) and len(size) == 1: + size = (size[0], size[0]) + self.size = size + self.seg_ignore_label = seg_ignore_label + + def get_transform(self, image): + + image_height, image_width = image.shape[:2] + crop_height, crop_width = self.size + + transforms = [] + if crop_width > image_width or crop_height > image_height: + padding_ltrb = [ + (crop_width - image_width) // 2 if crop_width > image_width else 0, + (crop_height - image_height) // 2 if crop_height > image_height else 0, + (crop_width - image_width + 1) // 2 if crop_width > image_width else 0, + (crop_height - image_height + 1) // 2 + if crop_height > image_height + else 0, + ] + transforms.append( + PadTransform( + *padding_ltrb, + orig_w=image_width, + orig_h=image_height, + seg_pad_value=self.seg_ignore_label + ) + ) + image_width, image_height = ( + image_width + padding_ltrb[0] + padding_ltrb[2], + image_height + padding_ltrb[1] + padding_ltrb[3], + ) + + crop_top = int(round((image_height - crop_height) / 2.0)) + crop_left = int(round((image_width - crop_width) / 2.0)) + transforms.append( + CropTransform( + crop_left, crop_top, crop_width, crop_height, image_width, image_height + ) + ) + return TransformList(transforms) diff --git a/open_vocab_seg/data/build.py b/open_vocab_seg/data/build.py new file mode 100644 index 0000000000000000000000000000000000000000..bcd3b9dcebb86c319b91a632c25bcf7827292c3f --- /dev/null +++ b/open_vocab_seg/data/build.py @@ -0,0 +1,344 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +import itertools +import logging +import numpy as np +from collections import Counter +import torch.utils.data +from tabulate import tabulate +from termcolor import colored + +from detectron2.utils.logger import _log_api_usage, log_first_n +from detectron2.data.catalog import DatasetCatalog, MetadataCatalog +import torch.utils.data +from detectron2.config import configurable +from detectron2.data.build import ( + build_batch_data_loader, + trivial_batch_collator, + load_proposals_into_dataset, + filter_images_with_only_crowd_annotations, + filter_images_with_few_keypoints, + print_instances_class_histogram, +) + +from detectron2.data.common import DatasetFromList, MapDataset +from detectron2.data.dataset_mapper import DatasetMapper +from detectron2.data.detection_utils import check_metadata_consistency +from detectron2.data.samplers import ( + InferenceSampler, + RandomSubsetTrainingSampler, + RepeatFactorTrainingSampler, + TrainingSampler, +) + +""" +This file contains the default logic to build a dataloader for training or testing. +""" + +__all__ = [ + "build_detection_train_loader", + "build_detection_test_loader", +] + + +def print_classification_instances_class_histogram(dataset_dicts, class_names): + """ + Args: + dataset_dicts (list[dict]): list of dataset dicts. + class_names (list[str]): list of class names (zero-indexed). + """ + num_classes = len(class_names) + hist_bins = np.arange(num_classes + 1) + histogram = np.zeros((num_classes,), dtype=np.int) + for entry in dataset_dicts: + classes = np.asarray([entry["category_id"]], dtype=np.int) + if len(classes): + assert classes.min() >= 0, f"Got an invalid category_id={classes.min()}" + assert ( + classes.max() < num_classes + ), f"Got an invalid category_id={classes.max()} for a dataset of {num_classes} classes" + histogram += np.histogram(classes, bins=hist_bins)[0] + + N_COLS = min(6, len(class_names) * 2) + + def short_name(x): + # make long class names shorter. useful for lvis + if len(x) > 13: + return x[:11] + ".." + return x + + data = list( + itertools.chain( + *[[short_name(class_names[i]), int(v)] for i, v in enumerate(histogram)] + ) + ) + total_num_instances = sum(data[1::2]) + data.extend([None] * (N_COLS - (len(data) % N_COLS))) + if num_classes > 1: + data.extend(["total", total_num_instances]) + data = itertools.zip_longest(*[data[i::N_COLS] for i in range(N_COLS)]) + table = tabulate( + data, + headers=["category", "#instances"] * (N_COLS // 2), + tablefmt="pipe", + numalign="left", + stralign="center", + ) + log_first_n( + logging.INFO, + "Distribution of instances among all {} categories:\n".format(num_classes) + + colored(table, "cyan"), + key="message", + ) + + +def wrap_metas(dataset_dict, **kwargs): + def _assign_attr(data_dict: dict, **kwargs): + assert not any( + [key in data_dict for key in kwargs] + ), "Assigned attributes should not exist in the original sample." + data_dict.update(kwargs) + return data_dict + + return [_assign_attr(sample, meta=kwargs) for sample in dataset_dict] + + +def get_detection_dataset_dicts( + names, filter_empty=True, min_keypoints=0, proposal_files=None +): + """ + Load and prepare dataset dicts for instance detection/segmentation and semantic segmentation. + + Args: + names (str or list[str]): a dataset name or a list of dataset names + filter_empty (bool): whether to filter out images without instance annotations + min_keypoints (int): filter out images with fewer keypoints than + `min_keypoints`. Set to 0 to do nothing. + proposal_files (list[str]): if given, a list of object proposal files + that match each dataset in `names`. + + Returns: + list[dict]: a list of dicts following the standard dataset dict format. + """ + if isinstance(names, str): + names = [names] + assert len(names), names + dataset_dicts = [ + wrap_metas(DatasetCatalog.get(dataset_name), dataset_name=dataset_name) + for dataset_name in names + ] + for dataset_name, dicts in zip(names, dataset_dicts): + assert len(dicts), "Dataset '{}' is empty!".format(dataset_name) + + if proposal_files is not None: + assert len(names) == len(proposal_files) + # load precomputed proposals from proposal files + dataset_dicts = [ + load_proposals_into_dataset(dataset_i_dicts, proposal_file) + for dataset_i_dicts, proposal_file in zip(dataset_dicts, proposal_files) + ] + + dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts)) + + has_instances = "annotations" in dataset_dicts[0] + if filter_empty and has_instances: + dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts) + if min_keypoints > 0 and has_instances: + dataset_dicts = filter_images_with_few_keypoints(dataset_dicts, min_keypoints) + + if has_instances: + try: + class_names = MetadataCatalog.get(names[0]).thing_classes + check_metadata_consistency("thing_classes", names) + print_instances_class_histogram(dataset_dicts, class_names) + except AttributeError: # class names are not available for this dataset + pass + + assert len(dataset_dicts), "No valid data found in {}.".format(",".join(names)) + return dataset_dicts + + +def _train_loader_from_config(cfg, mapper=None, *, dataset=None, sampler=None): + if dataset is None: + dataset = get_detection_dataset_dicts( + cfg.DATASETS.TRAIN, + filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS, + min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE + if cfg.MODEL.KEYPOINT_ON + else 0, + proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN + if cfg.MODEL.LOAD_PROPOSALS + else None, + ) + _log_api_usage("dataset." + cfg.DATASETS.TRAIN[0]) + + if mapper is None: + mapper = DatasetMapper(cfg, True) + + if sampler is None: + sampler_name = cfg.DATALOADER.SAMPLER_TRAIN + logger = logging.getLogger(__name__) + logger.info("Using training sampler {}".format(sampler_name)) + if sampler_name == "TrainingSampler": + sampler = TrainingSampler(len(dataset)) + elif sampler_name == "RepeatFactorTrainingSampler": + repeat_factors = ( + RepeatFactorTrainingSampler.repeat_factors_from_category_frequency( + dataset, cfg.DATALOADER.REPEAT_THRESHOLD + ) + ) + sampler = RepeatFactorTrainingSampler(repeat_factors) + elif sampler_name == "RandomSubsetTrainingSampler": + sampler = RandomSubsetTrainingSampler( + len(dataset), cfg.DATALOADER.RANDOM_SUBSET_RATIO + ) + else: + raise ValueError("Unknown training sampler: {}".format(sampler_name)) + + return { + "dataset": dataset, + "sampler": sampler, + "mapper": mapper, + "total_batch_size": cfg.SOLVER.IMS_PER_BATCH, + "aspect_ratio_grouping": cfg.DATALOADER.ASPECT_RATIO_GROUPING, + "num_workers": cfg.DATALOADER.NUM_WORKERS, + } + + +# TODO can allow dataset as an iterable or IterableDataset to make this function more general +@configurable(from_config=_train_loader_from_config) +def build_detection_train_loader( + dataset, + *, + mapper, + sampler=None, + total_batch_size, + aspect_ratio_grouping=True, + num_workers=0, +): + """ + Build a dataloader for object detection with some default features. + This interface is experimental. + + Args: + dataset (list or torch.utils.data.Dataset): a list of dataset dicts, + or a map-style pytorch dataset. They can be obtained by using + :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`. + mapper (callable): a callable which takes a sample (dict) from dataset and + returns the format to be consumed by the model. + When using cfg, the default choice is ``DatasetMapper(cfg, is_train=True)``. + sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces + indices to be applied on ``dataset``. Default to :class:`TrainingSampler`, + which coordinates an infinite random shuffle sequence across all workers. + total_batch_size (int): total batch size across all workers. Batching + simply puts data into a list. + aspect_ratio_grouping (bool): whether to group images with similar + aspect ratio for efficiency. When enabled, it requires each + element in dataset be a dict with keys "width" and "height". + num_workers (int): number of parallel data loading workers + + Returns: + torch.utils.data.DataLoader: + a dataloader. Each output from it is a ``list[mapped_element]`` of length + ``total_batch_size / num_workers``, where ``mapped_element`` is produced + by the ``mapper``. + """ + if isinstance(dataset, list): + dataset = DatasetFromList(dataset, copy=False) + if mapper is not None: + dataset = MapDataset(dataset, mapper) + if sampler is None: + sampler = TrainingSampler(len(dataset)) + assert isinstance(sampler, torch.utils.data.sampler.Sampler) + return build_batch_data_loader( + dataset, + sampler, + total_batch_size, + aspect_ratio_grouping=aspect_ratio_grouping, + num_workers=num_workers, + ) + + +def _test_loader_from_config(cfg, dataset_name, mapper=None): + """ + Uses the given `dataset_name` argument (instead of the names in cfg), because the + standard practice is to evaluate each test set individually (not combining them). + """ + if isinstance(dataset_name, str): + dataset_name = [dataset_name] + + dataset = get_detection_dataset_dicts( + dataset_name, + filter_empty=False, + proposal_files=[ + cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(x)] + for x in dataset_name + ] + if cfg.MODEL.LOAD_PROPOSALS + else None, + ) + if mapper is None: + mapper = DatasetMapper(cfg, False) + return { + "dataset": dataset, + "mapper": mapper, + "num_workers": 0, + "samples_per_gpu": cfg.SOLVER.TEST_IMS_PER_BATCH, + } + + +@configurable(from_config=_test_loader_from_config) +def build_detection_test_loader( + dataset, *, mapper, sampler=None, num_workers=0, samples_per_gpu=1 +): + """ + Similar to `build_detection_train_loader`, but uses a batch size of 1, + and :class:`InferenceSampler`. This sampler coordinates all workers to + produce the exact set of all samples. + This interface is experimental. + + Args: + dataset (list or torch.utils.data.Dataset): a list of dataset dicts, + or a map-style pytorch dataset. They can be obtained by using + :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`. + mapper (callable): a callable which takes a sample (dict) from dataset + and returns the format to be consumed by the model. + When using cfg, the default choice is ``DatasetMapper(cfg, is_train=False)``. + sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces + indices to be applied on ``dataset``. Default to :class:`InferenceSampler`, + which splits the dataset across all workers. + num_workers (int): number of parallel data loading workers + + Returns: + DataLoader: a torch DataLoader, that loads the given detection + dataset, with test-time transformation and batching. + + Examples: + :: + data_loader = build_detection_test_loader( + DatasetRegistry.get("my_test"), + mapper=DatasetMapper(...)) + + # or, instantiate with a CfgNode: + data_loader = build_detection_test_loader(cfg, "my_test") + """ + if isinstance(dataset, list): + dataset = DatasetFromList(dataset, copy=False) + if mapper is not None: + dataset = MapDataset(dataset, mapper) + if sampler is None: + sampler = InferenceSampler(len(dataset)) + # Always use 1 image per worker during inference since this is the + # standard when reporting inference time in papers. + batch_sampler = torch.utils.data.sampler.BatchSampler( + sampler, samples_per_gpu, drop_last=False + ) + data_loader = torch.utils.data.DataLoader( + dataset, + num_workers=num_workers, + batch_sampler=batch_sampler, + collate_fn=trivial_batch_collator, + ) + return data_loader + diff --git a/open_vocab_seg/data/dataset_mappers/__init__.py b/open_vocab_seg/data/dataset_mappers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f63cd5c034fcb60af8c78431205ae9b410f33250 --- /dev/null +++ b/open_vocab_seg/data/dataset_mappers/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +from .mask_former_semantic_dataset_mapper import MaskFormerSemanticDatasetMapper diff --git a/open_vocab_seg/data/dataset_mappers/mask_former_semantic_dataset_mapper.py b/open_vocab_seg/data/dataset_mappers/mask_former_semantic_dataset_mapper.py new file mode 100644 index 0000000000000000000000000000000000000000..2836579942cf91c726cb34cbbd2d137c975bee37 --- /dev/null +++ b/open_vocab_seg/data/dataset_mappers/mask_former_semantic_dataset_mapper.py @@ -0,0 +1,208 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +import copy +import logging + +import numpy as np +import torch +from torch.nn import functional as F + +from detectron2.config import configurable +from detectron2.data import MetadataCatalog +from detectron2.data import detection_utils as utils +from detectron2.data import transforms as T +from detectron2.projects.point_rend import ColorAugSSDTransform +from detectron2.structures import BitMasks, Instances + +__all__ = ["MaskFormerSemanticDatasetMapper"] + + +class MaskFormerSemanticDatasetMapper: + """ + A callable which takes a dataset dict in Detectron2 Dataset format, + and map it into a format used by MaskFormer for semantic segmentation. + + The callable currently does the following: + + 1. Read the image from "file_name" + 2. Applies geometric transforms to the image and annotation + 3. Find and applies suitable cropping to the image and annotation + 4. Prepare image and annotation to Tensors + """ + + @configurable + def __init__( + self, + is_train=True, + *, + augmentations, + image_format, + ignore_label, + size_divisibility, + ): + """ + NOTE: this interface is experimental. + Args: + is_train: for training or inference + augmentations: a list of augmentations or deterministic transforms to apply + image_format: an image format supported by :func:`detection_utils.read_image`. + ignore_label: the label that is ignored to evaluation + size_divisibility: pad image size to be divisible by this value + """ + self.is_train = is_train + self.tfm_gens = augmentations + self.img_format = image_format + self.ignore_label = ignore_label + self.size_divisibility = size_divisibility + + logger = logging.getLogger(__name__) + mode = "training" if is_train else "inference" + logger.info( + f"[{self.__class__.__name__}] Augmentations used in {mode}: {augmentations}" + ) + + @classmethod + def from_config(cls, cfg, is_train=True): + # Build augmentation + if is_train: + augs = [ + T.ResizeShortestEdge( + cfg.INPUT.MIN_SIZE_TRAIN, + cfg.INPUT.MAX_SIZE_TRAIN, + cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING, + ) + ] + if cfg.INPUT.CROP.ENABLED: + augs.append( + T.RandomCrop_CategoryAreaConstraint( + cfg.INPUT.CROP.TYPE, + cfg.INPUT.CROP.SIZE, + cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA, + cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE, + ) + ) + if cfg.INPUT.COLOR_AUG_SSD: + augs.append(ColorAugSSDTransform(img_format=cfg.INPUT.FORMAT)) + augs.append(T.RandomFlip()) + + # Assume always applies to the training set. + dataset_names = cfg.DATASETS.TRAIN + else: + min_size = cfg.INPUT.MIN_SIZE_TEST + max_size = cfg.INPUT.MAX_SIZE_TEST + sample_style = "choice" + augs = [T.ResizeShortestEdge(min_size, max_size, sample_style)] + dataset_names = cfg.DATASETS.TEST + meta = MetadataCatalog.get(dataset_names[0]) + ignore_label = meta.ignore_label + + ret = { + "is_train": is_train, + "augmentations": augs, + "image_format": cfg.INPUT.FORMAT, + "ignore_label": ignore_label, + "size_divisibility": cfg.INPUT.SIZE_DIVISIBILITY if is_train else -1, + } + return ret + + def __call__(self, dataset_dict): + """ + Args: + dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. + + Returns: + dict: a format that builtin models in detectron2 accept + """ + # assert self.is_train, "MaskFormerSemanticDatasetMapper should only be used for training!" + + dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below + image = utils.read_image(dataset_dict["file_name"], format=self.img_format) + utils.check_image_size(dataset_dict, image) + + if "sem_seg_file_name" in dataset_dict: + # PyTorch transformation not implemented for uint16, so converting it to double first + sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name")).astype( + "double" + ) + else: + sem_seg_gt = None + + if sem_seg_gt is None: + raise ValueError( + "Cannot find 'sem_seg_file_name' for semantic segmentation dataset {}.".format( + dataset_dict["file_name"] + ) + ) + + aug_input = T.AugInput(image, sem_seg=sem_seg_gt) + aug_input, transforms = T.apply_transform_gens(self.tfm_gens, aug_input) + image = aug_input.image + sem_seg_gt = aug_input.sem_seg + + # Pad image and segmentation label here! + image = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) + if sem_seg_gt is not None: + sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long")) + + if self.size_divisibility > 0: + image_size = (image.shape[-2], image.shape[-1]) + padding_size = [ + 0, + self.size_divisibility - image_size[1], + 0, + self.size_divisibility - image_size[0], + ] + image = F.pad(image, padding_size, value=128).contiguous() + if sem_seg_gt is not None: + sem_seg_gt = F.pad( + sem_seg_gt, padding_size, value=self.ignore_label + ).contiguous() + + image_shape = (image.shape[-2], image.shape[-1]) # h, w + + # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, + # but not efficient on large generic data structures due to the use of pickle & mp.Queue. + # Therefore it's important to use torch.Tensor. + dataset_dict["image"] = image + + if sem_seg_gt is not None: + dataset_dict["sem_seg"] = sem_seg_gt.long() + + if "annotations" in dataset_dict: + raise ValueError( + "Semantic segmentation dataset should not have 'annotations'." + ) + + # Prepare per-category binary masks + if sem_seg_gt is not None: + sem_seg_gt = sem_seg_gt.numpy() + instances = Instances(image_shape) + classes = np.unique(sem_seg_gt) + # remove ignored region + classes = classes[classes != self.ignore_label] + instances.gt_classes = torch.tensor(classes, dtype=torch.int64) + + masks = [] + for class_id in classes: + masks.append(sem_seg_gt == class_id) + + if len(masks) == 0: + # Some image does not have annotation (all ignored) + instances.gt_masks = torch.zeros( + (0, sem_seg_gt.shape[-2], sem_seg_gt.shape[-1]) + ) + else: + masks = BitMasks( + torch.stack( + [ + torch.from_numpy(np.ascontiguousarray(x.copy())) + for x in masks + ] + ) + ) + instances.gt_masks = masks.tensor + + dataset_dict["instances"] = instances + + return dataset_dict diff --git a/open_vocab_seg/data/datasets/__init__.py b/open_vocab_seg/data/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..352792b6fcdbffefa229d5d67a5c7375769fa345 --- /dev/null +++ b/open_vocab_seg/data/datasets/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from . import register_coco_stuff, register_voc_seg +from . import register_cc3m +from . import register_ade20k_full +from . import register_pascal_context \ No newline at end of file diff --git a/open_vocab_seg/data/datasets/csv_data.py b/open_vocab_seg/data/datasets/csv_data.py new file mode 100644 index 0000000000000000000000000000000000000000..3a4c9e52b0b792d49c48fe8bc2693be5ea879581 --- /dev/null +++ b/open_vocab_seg/data/datasets/csv_data.py @@ -0,0 +1,459 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +import ast +import json +import logging +import math +import os +import random +import sys +import time +from dataclasses import dataclass +from multiprocessing import Value + +import braceexpand +import numpy as np +import pandas as pd +import torch +import torchvision.datasets as datasets +import webdataset as wds +from PIL import Image +from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, IterableDataset, get_worker_info +from torch.utils.data.distributed import DistributedSampler +from webdataset.filters import _shuffle +from webdataset.tariterators import base_plus_ext, url_opener, tar_file_expander, valid_sample + +try: + import horovod.torch as hvd +except ImportError: + hvd = None + +from clip import tokenize + + +class CsvDataset(Dataset): + def __init__(self, input_filename, transforms, img_key, caption_key, sep="\t"): + logging.debug(f'Loading csv data from {input_filename}.') + df = pd.read_csv(input_filename, sep=sep) + + self.images = df[img_key].tolist() + self.captions = df[caption_key].tolist() + self.transforms = transforms + logging.debug('Done loading data.') + + def __len__(self): + return len(self.captions) + + def __getitem__(self, idx): + images = self.transforms(Image.open(str(self.images[idx]))) + texts = tokenize([str(self.captions[idx])])[0] + return images, texts + + +class SharedEpoch: + def __init__(self, epoch: int = 0): + self.shared_epoch = Value('i', epoch) + + def set_value(self, epoch): + self.shared_epoch.value = epoch + + def get_value(self): + return self.shared_epoch.value + + +@dataclass +class DataInfo: + dataloader: DataLoader + sampler: DistributedSampler = None + shared_epoch: SharedEpoch = None + + def set_epoch(self, epoch): + if self.shared_epoch is not None: + self.shared_epoch.set_value(epoch) + if self.sampler is not None and isinstance(self.sampler, DistributedSampler): + self.sampler.set_epoch(epoch) + + +def preprocess_txt(text): + return tokenize([str(text)])[0] + + +def get_dataset_size(shards): + shards_list = list(braceexpand.braceexpand(shards)) + dir_path = os.path.dirname(shards) + sizes_filename = os.path.join(dir_path, 'sizes.json') + len_filename = os.path.join(dir_path, '__len__') + if os.path.exists(sizes_filename): + sizes = json.load(open(sizes_filename, 'r')) + total_size = sum([int(sizes[os.path.basename(shard)]) for shard in shards_list]) + elif os.path.exists(len_filename): + # FIXME this used to be eval(open(...)) but that seemed rather unsafe + total_size = ast.literal_eval(open(len_filename, 'r').read()) + else: + total_size = None # num samples undefined + # some common dataset sizes (at time of authors last download) + # CC3M (train): 2905954 + # CC12M: 10968539 + # LAION-400M: 407332084 + # LAION-2B (english): 2170337258 + num_shards = len(shards_list) + return total_size, num_shards + + +def get_imagenet(args, preprocess_fns, split): + assert split in ["train", "val", "v2"] + is_train = split == "train" + preprocess_train, preprocess_val = preprocess_fns + + if split == "v2": + from imagenetv2_pytorch import ImageNetV2Dataset + dataset = ImageNetV2Dataset(location=args.imagenet_v2, transform=preprocess_val) + else: + if is_train: + data_path = args.imagenet_train + preprocess_fn = preprocess_train + else: + data_path = args.imagenet_val + preprocess_fn = preprocess_val + assert data_path + + dataset = datasets.ImageFolder(data_path, transform=preprocess_fn) + + if is_train: + idxs = np.zeros(len(dataset.targets)) + target_array = np.array(dataset.targets) + k = 50 + for c in range(1000): + m = target_array == c + n = len(idxs[m]) + arr = np.zeros(n) + arr[:k] = 1 + np.random.shuffle(arr) + idxs[m] = arr + + idxs = idxs.astype('int') + sampler = SubsetRandomSampler(np.where(idxs)[0]) + else: + sampler = None + + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=args.batch_size, + num_workers=args.workers, + sampler=sampler, + ) + + return DataInfo(dataloader=dataloader, sampler=sampler) + + +def count_samples(dataloader): + os.environ["WDS_EPOCH"] = "0" + n_elements, n_batches = 0, 0 + for images, texts in dataloader: + n_batches += 1 + n_elements += len(images) + assert len(images) == len(texts) + return n_elements, n_batches + + +def filter_no_caption(sample): + return 'txt' in sample + + +def log_and_continue(exn): + """Call in an exception handler to ignore any exception, isssue a warning, and continue.""" + logging.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.') + return True + + +def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None): + """Return function over iterator that groups key, value pairs into samples. + + :param keys: function that splits the key into key and extension (base_plus_ext) + :param lcase: convert suffixes to lower case (Default value = True) + """ + current_sample = None + for filesample in data: + assert isinstance(filesample, dict) + fname, value = filesample["fname"], filesample["data"] + prefix, suffix = keys(fname) + if prefix is None: + continue + if lcase: + suffix = suffix.lower() + # FIXME webdataset version throws if suffix in current_sample, but we have a potential for + # this happening in the current LAION400m dataset if a tar ends with same prefix as the next + # begins, rare, but can happen since prefix aren't unique across tar files in that dataset + if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample: + if valid_sample(current_sample): + yield current_sample + current_sample = dict(__key__=prefix, __url__=filesample["__url__"]) + if suffixes is None or suffix in suffixes: + current_sample[suffix] = value + if valid_sample(current_sample): + yield current_sample + + +def tarfile_to_samples_nothrow(src, handler=log_and_continue): + # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw + streams = url_opener(src, handler=handler) + files = tar_file_expander(streams, handler=handler) + samples = group_by_keys_nothrow(files, handler=handler) + return samples + + +def pytorch_worker_seed(): + """get dataloader worker seed from pytorch""" + worker_info = get_worker_info() + if worker_info is not None: + # favour the seed already created for pytorch dataloader workers if it exists + return worker_info.seed + # fallback to wds rank based seed + return wds.utils.pytorch_worker_seed() + + +_SHARD_SHUFFLE_SIZE = 2000 +_SHARD_SHUFFLE_INITIAL = 500 +_SAMPLE_SHUFFLE_SIZE = 5000 +_SAMPLE_SHUFFLE_INITIAL = 1000 + + +class detshuffle2(wds.PipelineStage): + def __init__( + self, + bufsize=1000, + initial=100, + seed=0, + epoch=-1, + ): + self.bufsize = bufsize + self.initial = initial + self.seed = seed + self.epoch = epoch + + def run(self, src): + if isinstance(self.epoch, SharedEpoch): + epoch = self.epoch.get_value() + else: + # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train) + # situation as different workers may wrap at different times (or not at all). + self.epoch += 1 + epoch = self.epoch + rng = random.Random() + if self.seed < 0: + seed = pytorch_worker_seed() + epoch + else: + seed = self.seed + epoch + rng.seed(seed) + return _shuffle(src, self.bufsize, self.initial, rng) + + +class ResampledShards2(IterableDataset): + """An iterable dataset yielding a list of urls.""" + + def __init__( + self, + urls, + nshards=sys.maxsize, + worker_seed=None, + deterministic=False, + epoch=-1, + ): + """Sample shards from the shard list with replacement. + + :param urls: a list of URLs as a Python list or brace notation string + """ + super().__init__() + urls = wds.shardlists.expand_urls(urls) + self.urls = urls + assert isinstance(self.urls[0], str) + self.nshards = nshards + self.rng = random.Random() + self.worker_seed = pytorch_worker_seed if worker_seed is None else worker_seed + self.deterministic = deterministic + self.epoch = epoch + + def __iter__(self): + """Return an iterator over the shards.""" + if isinstance(self.epoch, SharedEpoch): + epoch = self.epoch.get_value() + else: + # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train) + # situation as different workers may wrap at different times (or not at all). + self.epoch += 1 + epoch = self.epoch + if self.deterministic: + # reset seed w/ epoch if deterministic, worker seed should be deterministic due to arg.seed + self.rng.seed(self.worker_seed() + epoch) + for _ in range(self.nshards): + yield dict(url=self.rng.choice(self.urls)) + + +def get_wds_dataset(args, preprocess_img, is_train, epoch=0, floor=False): + input_shards = args.train_data if is_train else args.val_data + assert input_shards is not None + resampled = getattr(args, 'dataset_resampled', False) and is_train + + num_samples, num_shards = get_dataset_size(input_shards) + if not num_samples: + if is_train: + num_samples = args.train_num_samples + if not num_samples: + raise RuntimeError( + 'Currently, number of dataset samples must be specified for training dataset. ' + 'Please specify via `--train-num-samples` if no dataset length info present.') + else: + num_samples = args.val_num_samples or 0 # eval will just exhaust the iterator if not specified + + shared_epoch = SharedEpoch(epoch=epoch) # create a shared epoch store to sync epoch to dataloader worker proc + if resampled: + pipeline = [ResampledShards2(input_shards, deterministic=True, epoch=shared_epoch)] + else: + pipeline = [wds.SimpleShardList(input_shards)] + + # at this point we have an iterator over all the shards + if is_train: + if not resampled: + pipeline.extend([ + detshuffle2( + bufsize=_SHARD_SHUFFLE_SIZE, + initial=_SHARD_SHUFFLE_INITIAL, + seed=args.seed, + epoch=shared_epoch, + ), + wds.split_by_node, + wds.split_by_worker, + ]) + pipeline.extend([ + # at this point, we have an iterator over the shards assigned to each worker at each node + tarfile_to_samples_nothrow, # wds.tarfile_to_samples(handler=log_and_continue), + wds.shuffle( + bufsize=_SAMPLE_SHUFFLE_SIZE, + initial=_SAMPLE_SHUFFLE_INITIAL, + ), + ]) + else: + pipeline.extend([ + wds.split_by_worker, + # at this point, we have an iterator over the shards assigned to each worker + wds.tarfile_to_samples(handler=log_and_continue), + ]) + pipeline.extend([ + wds.select(filter_no_caption), + wds.decode("pilrgb", handler=log_and_continue), + wds.rename(image="jpg;png", text="txt"), + wds.map_dict(image=preprocess_img, text=preprocess_txt), + wds.to_tuple("image", "text"), + wds.batched(args.batch_size, partial=not is_train), + ]) + + dataset = wds.DataPipeline(*pipeline) + if is_train: + if not resampled: + assert num_shards >= args.workers * args.world_size, 'number of shards must be >= total workers' + # roll over and repeat a few samples to get same number of full batches on each node + round_fn = math.floor if floor else math.ceil + global_batch_size = args.batch_size * args.world_size + num_batches = round_fn(num_samples / global_batch_size) + num_workers = max(1, args.workers) + num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker + num_batches = num_worker_batches * num_workers + num_samples = num_batches * global_batch_size + dataset = dataset.with_epoch(num_worker_batches) # each worker is iterating over this + else: + # last batches are partial, eval is done on single (master) node + num_batches = math.ceil(num_samples / args.batch_size) + + dataloader = wds.WebLoader( + dataset, + batch_size=None, + shuffle=False, + num_workers=args.workers, + persistent_workers=True, + ) + + # FIXME not clear which approach is better, with_epoch before vs after dataloader? + # hoping to resolve via https://github.com/webdataset/webdataset/issues/169 + # if is_train: + # # roll over and repeat a few samples to get same number of full batches on each node + # global_batch_size = args.batch_size * args.world_size + # num_batches = math.ceil(num_samples / global_batch_size) + # num_workers = max(1, args.workers) + # num_batches = math.ceil(num_batches / num_workers) * num_workers + # num_samples = num_batches * global_batch_size + # dataloader = dataloader.with_epoch(num_batches) + # else: + # # last batches are partial, eval is done on single (master) node + # num_batches = math.ceil(num_samples / args.batch_size) + + # add meta-data to dataloader instance for convenience + dataloader.num_batches = num_batches + dataloader.num_samples = num_samples + + return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch) + + +def get_csv_dataset(args, preprocess_fn, is_train, epoch=0): + input_filename = args.train_data if is_train else args.val_data + assert input_filename + dataset = CsvDataset( + input_filename, + preprocess_fn, + img_key=args.csv_img_key, + caption_key=args.csv_caption_key, + sep=args.csv_separator) + num_samples = len(dataset) + sampler = DistributedSampler(dataset) if args.distributed and is_train else None + shuffle = is_train and sampler is None + + dataloader = DataLoader( + dataset, + batch_size=args.batch_size, + shuffle=shuffle, + num_workers=args.workers, + pin_memory=True, + sampler=sampler, + drop_last=is_train, + ) + dataloader.num_samples = num_samples + dataloader.num_batches = len(dataloader) + + return DataInfo(dataloader, sampler) + + +def get_dataset_fn(data_path, dataset_type): + if dataset_type == "webdataset": + return get_wds_dataset + elif dataset_type == "csv": + return get_csv_dataset + elif dataset_type == "auto": + ext = data_path.split('.')[-1] + if ext in ['csv', 'tsv']: + return get_csv_dataset + elif ext in ['tar']: + return get_wds_dataset + else: + raise ValueError( + f"Tried to figure out dataset type, but failed for extention {ext}.") + else: + raise ValueError(f"Unsupported dataset type: {dataset_type}") + + +def get_data(args, preprocess_fns, epoch=0): + preprocess_train, preprocess_val = preprocess_fns + data = {} + + if args.train_data: + data["train"] = get_dataset_fn(args.train_data, args.dataset_type)( + args, preprocess_train, is_train=True, epoch=epoch) + + if args.val_data: + data["val"] = get_dataset_fn(args.val_data, args.dataset_type)( + args, preprocess_val, is_train=False) + + if args.imagenet_val is not None: + data["imagenet-val"] = get_imagenet(args, preprocess_fns, "val") + + if args.imagenet_v2 is not None: + data["imagenet-v2"] = get_imagenet(args, preprocess_fns, "v2") + + return data diff --git a/open_vocab_seg/data/datasets/register_ade20k_full.py b/open_vocab_seg/data/datasets/register_ade20k_full.py new file mode 100644 index 0000000000000000000000000000000000000000..7ba35274c8ba7f03cbe92621f944c8368794497f --- /dev/null +++ b/open_vocab_seg/data/datasets/register_ade20k_full.py @@ -0,0 +1,995 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import os + +from detectron2.data import DatasetCatalog, MetadataCatalog +from detectron2.data.datasets import load_sem_seg + +ADE20K_SEM_SEG_FULL_CATEGORIES = [ + {"name": "wall", "id": 2978, "trainId": 0}, + {"name": "building, edifice", "id": 312, "trainId": 1}, + {"name": "sky", "id": 2420, "trainId": 2}, + {"name": "tree", "id": 2855, "trainId": 3}, + {"name": "road, route", "id": 2131, "trainId": 4}, + {"name": "floor, flooring", "id": 976, "trainId": 5}, + {"name": "ceiling", "id": 447, "trainId": 6}, + {"name": "bed", "id": 165, "trainId": 7}, + {"name": "sidewalk, pavement", "id": 2377, "trainId": 8}, + {"name": "earth, ground", "id": 838, "trainId": 9}, + {"name": "cabinet", "id": 350, "trainId": 10}, + { + "name": "person, individual, someone, somebody, mortal, soul", + "id": 1831, + "trainId": 11, + }, + {"name": "grass", "id": 1125, "trainId": 12}, + {"name": "windowpane, window", "id": 3055, "trainId": 13}, + {"name": "car, auto, automobile, machine, motorcar", "id": 401, "trainId": 14}, + {"name": "mountain, mount", "id": 1610, "trainId": 15}, + {"name": "plant, flora, plant life", "id": 1910, "trainId": 16}, + {"name": "table", "id": 2684, "trainId": 17}, + {"name": "chair", "id": 471, "trainId": 18}, + {"name": "curtain, drape, drapery, mantle, pall", "id": 687, "trainId": 19}, + {"name": "door", "id": 774, "trainId": 20}, + {"name": "sofa, couch, lounge", "id": 2473, "trainId": 21}, + {"name": "sea", "id": 2264, "trainId": 22}, + {"name": "painting, picture", "id": 1735, "trainId": 23}, + {"name": "water", "id": 2994, "trainId": 24}, + {"name": "mirror", "id": 1564, "trainId": 25}, + {"name": "house", "id": 1276, "trainId": 26}, + {"name": "rug, carpet, carpeting", "id": 2178, "trainId": 27}, + {"name": "shelf", "id": 2329, "trainId": 28}, + {"name": "armchair", "id": 57, "trainId": 29}, + {"name": "fence, fencing", "id": 907, "trainId": 30}, + {"name": "field", "id": 913, "trainId": 31}, + {"name": "lamp", "id": 1395, "trainId": 32}, + {"name": "rock, stone", "id": 2138, "trainId": 33}, + {"name": "seat", "id": 2272, "trainId": 34}, + {"name": "river", "id": 2128, "trainId": 35}, + {"name": "desk", "id": 724, "trainId": 36}, + {"name": "bathtub, bathing tub, bath, tub", "id": 155, "trainId": 37}, + {"name": "railing, rail", "id": 2053, "trainId": 38}, + {"name": "signboard, sign", "id": 2380, "trainId": 39}, + {"name": "cushion", "id": 689, "trainId": 40}, + {"name": "path", "id": 1788, "trainId": 41}, + {"name": "work surface", "id": 3087, "trainId": 42}, + {"name": "stairs, steps", "id": 2530, "trainId": 43}, + {"name": "column, pillar", "id": 581, "trainId": 44}, + {"name": "sink", "id": 2388, "trainId": 45}, + {"name": "wardrobe, closet, press", "id": 2985, "trainId": 46}, + {"name": "snow", "id": 2454, "trainId": 47}, + {"name": "refrigerator, icebox", "id": 2096, "trainId": 48}, + {"name": "base, pedestal, stand", "id": 137, "trainId": 49}, + {"name": "bridge, span", "id": 294, "trainId": 50}, + {"name": "blind, screen", "id": 212, "trainId": 51}, + {"name": "runway", "id": 2185, "trainId": 52}, + {"name": "cliff, drop, drop-off", "id": 524, "trainId": 53}, + {"name": "sand", "id": 2212, "trainId": 54}, + {"name": "fireplace, hearth, open fireplace", "id": 943, "trainId": 55}, + {"name": "pillow", "id": 1869, "trainId": 56}, + {"name": "screen door, screen", "id": 2251, "trainId": 57}, + { + "name": "toilet, can, commode, crapper, pot, potty, stool, throne", + "id": 2793, + "trainId": 58, + }, + {"name": "skyscraper", "id": 2423, "trainId": 59}, + {"name": "grandstand, covered stand", "id": 1121, "trainId": 60}, + {"name": "box", "id": 266, "trainId": 61}, + {"name": "pool table, billiard table, snooker table", "id": 1948, "trainId": 62}, + {"name": "palm, palm tree", "id": 1744, "trainId": 63}, + {"name": "double door", "id": 783, "trainId": 64}, + {"name": "coffee table, cocktail table", "id": 571, "trainId": 65}, + {"name": "counter", "id": 627, "trainId": 66}, + {"name": "countertop", "id": 629, "trainId": 67}, + {"name": "chest of drawers, chest, bureau, dresser", "id": 491, "trainId": 68}, + {"name": "kitchen island", "id": 1374, "trainId": 69}, + {"name": "boat", "id": 223, "trainId": 70}, + {"name": "waterfall, falls", "id": 3016, "trainId": 71}, + { + "name": "stove, kitchen stove, range, kitchen range, cooking stove", + "id": 2598, + "trainId": 72, + }, + {"name": "flower", "id": 978, "trainId": 73}, + {"name": "bookcase", "id": 239, "trainId": 74}, + {"name": "controls", "id": 608, "trainId": 75}, + {"name": "book", "id": 236, "trainId": 76}, + {"name": "stairway, staircase", "id": 2531, "trainId": 77}, + {"name": "streetlight, street lamp", "id": 2616, "trainId": 78}, + { + "name": "computer, computing machine, computing device, data processor, electronic computer, information processing system", + "id": 591, + "trainId": 79, + }, + { + "name": "bus, autobus, coach, charabanc, double-decker, jitney, motorbus, motorcoach, omnibus, passenger vehicle", + "id": 327, + "trainId": 80, + }, + {"name": "swivel chair", "id": 2679, "trainId": 81}, + {"name": "light, light source", "id": 1451, "trainId": 82}, + {"name": "bench", "id": 181, "trainId": 83}, + {"name": "case, display case, showcase, vitrine", "id": 420, "trainId": 84}, + {"name": "towel", "id": 2821, "trainId": 85}, + {"name": "fountain", "id": 1023, "trainId": 86}, + {"name": "embankment", "id": 855, "trainId": 87}, + { + "name": "television receiver, television, television set, tv, tv set, idiot box, boob tube, telly, goggle box", + "id": 2733, + "trainId": 88, + }, + {"name": "van", "id": 2928, "trainId": 89}, + {"name": "hill", "id": 1240, "trainId": 90}, + {"name": "awning, sunshade, sunblind", "id": 77, "trainId": 91}, + {"name": "poster, posting, placard, notice, bill, card", "id": 1969, "trainId": 92}, + {"name": "truck, motortruck", "id": 2880, "trainId": 93}, + {"name": "airplane, aeroplane, plane", "id": 14, "trainId": 94}, + {"name": "pole", "id": 1936, "trainId": 95}, + {"name": "tower", "id": 2828, "trainId": 96}, + {"name": "court", "id": 631, "trainId": 97}, + {"name": "ball", "id": 103, "trainId": 98}, + { + "name": "aircraft carrier, carrier, flattop, attack aircraft carrier", + "id": 3144, + "trainId": 99, + }, + {"name": "buffet, counter, sideboard", "id": 308, "trainId": 100}, + {"name": "hovel, hut, hutch, shack, shanty", "id": 1282, "trainId": 101}, + {"name": "apparel, wearing apparel, dress, clothes", "id": 38, "trainId": 102}, + {"name": "minibike, motorbike", "id": 1563, "trainId": 103}, + { + "name": "animal, animate being, beast, brute, creature, fauna", + "id": 29, + "trainId": 104, + }, + {"name": "chandelier, pendant, pendent", "id": 480, "trainId": 105}, + {"name": "step, stair", "id": 2569, "trainId": 106}, + {"name": "booth, cubicle, stall, kiosk", "id": 247, "trainId": 107}, + {"name": "bicycle, bike, wheel, cycle", "id": 187, "trainId": 108}, + {"name": "doorframe, doorcase", "id": 778, "trainId": 109}, + {"name": "sconce", "id": 2243, "trainId": 110}, + {"name": "pond", "id": 1941, "trainId": 111}, + {"name": "trade name, brand name, brand, marque", "id": 2833, "trainId": 112}, + { + "name": "bannister, banister, balustrade, balusters, handrail", + "id": 120, + "trainId": 113, + }, + {"name": "bag", "id": 95, "trainId": 114}, + {"name": "traffic light, traffic signal, stoplight", "id": 2836, "trainId": 115}, + {"name": "gazebo", "id": 1087, "trainId": 116}, + {"name": "escalator, moving staircase, moving stairway", "id": 868, "trainId": 117}, + {"name": "land, ground, soil", "id": 1401, "trainId": 118}, + {"name": "board, plank", "id": 220, "trainId": 119}, + {"name": "arcade machine", "id": 47, "trainId": 120}, + {"name": "eiderdown, duvet, continental quilt", "id": 843, "trainId": 121}, + {"name": "bar", "id": 123, "trainId": 122}, + {"name": "stall, stand, sales booth", "id": 2537, "trainId": 123}, + {"name": "playground", "id": 1927, "trainId": 124}, + {"name": "ship", "id": 2337, "trainId": 125}, + {"name": "ottoman, pouf, pouffe, puff, hassock", "id": 1702, "trainId": 126}, + { + "name": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin", + "id": 64, + "trainId": 127, + }, + {"name": "bottle", "id": 249, "trainId": 128}, + {"name": "cradle", "id": 642, "trainId": 129}, + {"name": "pot, flowerpot", "id": 1981, "trainId": 130}, + { + "name": "conveyer belt, conveyor belt, conveyer, conveyor, transporter", + "id": 609, + "trainId": 131, + }, + {"name": "train, railroad train", "id": 2840, "trainId": 132}, + {"name": "stool", "id": 2586, "trainId": 133}, + {"name": "lake", "id": 1393, "trainId": 134}, + {"name": "tank, storage tank", "id": 2704, "trainId": 135}, + {"name": "ice, water ice", "id": 1304, "trainId": 136}, + {"name": "basket, handbasket", "id": 146, "trainId": 137}, + {"name": "manhole", "id": 1494, "trainId": 138}, + {"name": "tent, collapsible shelter", "id": 2739, "trainId": 139}, + {"name": "canopy", "id": 389, "trainId": 140}, + {"name": "microwave, microwave oven", "id": 1551, "trainId": 141}, + {"name": "barrel, cask", "id": 131, "trainId": 142}, + {"name": "dirt track", "id": 738, "trainId": 143}, + {"name": "beam", "id": 161, "trainId": 144}, + {"name": "dishwasher, dish washer, dishwashing machine", "id": 747, "trainId": 145}, + {"name": "plate", "id": 1919, "trainId": 146}, + {"name": "screen, crt screen", "id": 3109, "trainId": 147}, + {"name": "ruins", "id": 2179, "trainId": 148}, + {"name": "washer, automatic washer, washing machine", "id": 2989, "trainId": 149}, + {"name": "blanket, cover", "id": 206, "trainId": 150}, + {"name": "plaything, toy", "id": 1930, "trainId": 151}, + {"name": "food, solid food", "id": 1002, "trainId": 152}, + {"name": "screen, silver screen, projection screen", "id": 2254, "trainId": 153}, + {"name": "oven", "id": 1708, "trainId": 154}, + {"name": "stage", "id": 2526, "trainId": 155}, + {"name": "beacon, lighthouse, beacon light, pharos", "id": 160, "trainId": 156}, + {"name": "umbrella", "id": 2901, "trainId": 157}, + {"name": "sculpture", "id": 2262, "trainId": 158}, + {"name": "aqueduct", "id": 44, "trainId": 159}, + {"name": "container", "id": 597, "trainId": 160}, + {"name": "scaffolding, staging", "id": 2235, "trainId": 161}, + {"name": "hood, exhaust hood", "id": 1260, "trainId": 162}, + {"name": "curb, curbing, kerb", "id": 682, "trainId": 163}, + {"name": "roller coaster", "id": 2151, "trainId": 164}, + {"name": "horse, equus caballus", "id": 3107, "trainId": 165}, + {"name": "catwalk", "id": 432, "trainId": 166}, + {"name": "glass, drinking glass", "id": 1098, "trainId": 167}, + {"name": "vase", "id": 2932, "trainId": 168}, + {"name": "central reservation", "id": 461, "trainId": 169}, + {"name": "carousel", "id": 410, "trainId": 170}, + {"name": "radiator", "id": 2046, "trainId": 171}, + {"name": "closet", "id": 533, "trainId": 172}, + {"name": "machine", "id": 1481, "trainId": 173}, + {"name": "pier, wharf, wharfage, dock", "id": 1858, "trainId": 174}, + {"name": "fan", "id": 894, "trainId": 175}, + {"name": "inflatable bounce game", "id": 1322, "trainId": 176}, + {"name": "pitch", "id": 1891, "trainId": 177}, + {"name": "paper", "id": 1756, "trainId": 178}, + {"name": "arcade, colonnade", "id": 49, "trainId": 179}, + {"name": "hot tub", "id": 1272, "trainId": 180}, + {"name": "helicopter", "id": 1229, "trainId": 181}, + {"name": "tray", "id": 2850, "trainId": 182}, + {"name": "partition, divider", "id": 1784, "trainId": 183}, + {"name": "vineyard", "id": 2962, "trainId": 184}, + {"name": "bowl", "id": 259, "trainId": 185}, + {"name": "bullring", "id": 319, "trainId": 186}, + {"name": "flag", "id": 954, "trainId": 187}, + {"name": "pot", "id": 1974, "trainId": 188}, + {"name": "footbridge, overcrossing, pedestrian bridge", "id": 1013, "trainId": 189}, + {"name": "shower", "id": 2356, "trainId": 190}, + { + "name": "bag, traveling bag, travelling bag, grip, suitcase", + "id": 97, + "trainId": 191, + }, + {"name": "bulletin board, notice board", "id": 318, "trainId": 192}, + {"name": "confessional booth", "id": 592, "trainId": 193}, + {"name": "trunk, tree trunk, bole", "id": 2885, "trainId": 194}, + {"name": "forest", "id": 1017, "trainId": 195}, + {"name": "elevator door", "id": 851, "trainId": 196}, + {"name": "laptop, laptop computer", "id": 1407, "trainId": 197}, + {"name": "instrument panel", "id": 1332, "trainId": 198}, + {"name": "bucket, pail", "id": 303, "trainId": 199}, + {"name": "tapestry, tapis", "id": 2714, "trainId": 200}, + {"name": "platform", "id": 1924, "trainId": 201}, + {"name": "jacket", "id": 1346, "trainId": 202}, + {"name": "gate", "id": 1081, "trainId": 203}, + {"name": "monitor, monitoring device", "id": 1583, "trainId": 204}, + { + "name": "telephone booth, phone booth, call box, telephone box, telephone kiosk", + "id": 2727, + "trainId": 205, + }, + {"name": "spotlight, spot", "id": 2509, "trainId": 206}, + {"name": "ring", "id": 2123, "trainId": 207}, + {"name": "control panel", "id": 602, "trainId": 208}, + {"name": "blackboard, chalkboard", "id": 202, "trainId": 209}, + {"name": "air conditioner, air conditioning", "id": 10, "trainId": 210}, + {"name": "chest", "id": 490, "trainId": 211}, + {"name": "clock", "id": 530, "trainId": 212}, + {"name": "sand dune", "id": 2213, "trainId": 213}, + {"name": "pipe, pipage, piping", "id": 1884, "trainId": 214}, + {"name": "vault", "id": 2934, "trainId": 215}, + {"name": "table football", "id": 2687, "trainId": 216}, + {"name": "cannon", "id": 387, "trainId": 217}, + {"name": "swimming pool, swimming bath, natatorium", "id": 2668, "trainId": 218}, + {"name": "fluorescent, fluorescent fixture", "id": 982, "trainId": 219}, + {"name": "statue", "id": 2547, "trainId": 220}, + { + "name": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system", + "id": 1474, + "trainId": 221, + }, + {"name": "exhibitor", "id": 877, "trainId": 222}, + {"name": "ladder", "id": 1391, "trainId": 223}, + {"name": "carport", "id": 414, "trainId": 224}, + {"name": "dam", "id": 698, "trainId": 225}, + {"name": "pulpit", "id": 2019, "trainId": 226}, + {"name": "skylight, fanlight", "id": 2422, "trainId": 227}, + {"name": "water tower", "id": 3010, "trainId": 228}, + {"name": "grill, grille, grillwork", "id": 1139, "trainId": 229}, + {"name": "display board", "id": 753, "trainId": 230}, + {"name": "pane, pane of glass, window glass", "id": 1747, "trainId": 231}, + {"name": "rubbish, trash, scrap", "id": 2175, "trainId": 232}, + {"name": "ice rink", "id": 1301, "trainId": 233}, + {"name": "fruit", "id": 1033, "trainId": 234}, + {"name": "patio", "id": 1789, "trainId": 235}, + {"name": "vending machine", "id": 2939, "trainId": 236}, + {"name": "telephone, phone, telephone set", "id": 2730, "trainId": 237}, + {"name": "net", "id": 1652, "trainId": 238}, + { + "name": "backpack, back pack, knapsack, packsack, rucksack, haversack", + "id": 90, + "trainId": 239, + }, + {"name": "jar", "id": 1349, "trainId": 240}, + {"name": "track", "id": 2830, "trainId": 241}, + {"name": "magazine", "id": 1485, "trainId": 242}, + {"name": "shutter", "id": 2370, "trainId": 243}, + {"name": "roof", "id": 2155, "trainId": 244}, + {"name": "banner, streamer", "id": 118, "trainId": 245}, + {"name": "landfill", "id": 1402, "trainId": 246}, + {"name": "post", "id": 1957, "trainId": 247}, + {"name": "altarpiece, reredos", "id": 3130, "trainId": 248}, + {"name": "hat, chapeau, lid", "id": 1197, "trainId": 249}, + {"name": "arch, archway", "id": 52, "trainId": 250}, + {"name": "table game", "id": 2688, "trainId": 251}, + {"name": "bag, handbag, pocketbook, purse", "id": 96, "trainId": 252}, + {"name": "document, written document, papers", "id": 762, "trainId": 253}, + {"name": "dome", "id": 772, "trainId": 254}, + {"name": "pier", "id": 1857, "trainId": 255}, + {"name": "shanties", "id": 2315, "trainId": 256}, + {"name": "forecourt", "id": 1016, "trainId": 257}, + {"name": "crane", "id": 643, "trainId": 258}, + {"name": "dog, domestic dog, canis familiaris", "id": 3105, "trainId": 259}, + {"name": "piano, pianoforte, forte-piano", "id": 1849, "trainId": 260}, + {"name": "drawing", "id": 791, "trainId": 261}, + {"name": "cabin", "id": 349, "trainId": 262}, + { + "name": "ad, advertisement, advertizement, advertising, advertizing, advert", + "id": 6, + "trainId": 263, + }, + {"name": "amphitheater, amphitheatre, coliseum", "id": 3114, "trainId": 264}, + {"name": "monument", "id": 1587, "trainId": 265}, + {"name": "henhouse", "id": 1233, "trainId": 266}, + {"name": "cockpit", "id": 559, "trainId": 267}, + {"name": "heater, warmer", "id": 1223, "trainId": 268}, + {"name": "windmill, aerogenerator, wind generator", "id": 3049, "trainId": 269}, + {"name": "pool", "id": 1943, "trainId": 270}, + {"name": "elevator, lift", "id": 853, "trainId": 271}, + {"name": "decoration, ornament, ornamentation", "id": 709, "trainId": 272}, + {"name": "labyrinth", "id": 1390, "trainId": 273}, + {"name": "text, textual matter", "id": 2748, "trainId": 274}, + {"name": "printer", "id": 2007, "trainId": 275}, + {"name": "mezzanine, first balcony", "id": 1546, "trainId": 276}, + {"name": "mattress", "id": 1513, "trainId": 277}, + {"name": "straw", "id": 2600, "trainId": 278}, + {"name": "stalls", "id": 2538, "trainId": 279}, + {"name": "patio, terrace", "id": 1790, "trainId": 280}, + {"name": "billboard, hoarding", "id": 194, "trainId": 281}, + {"name": "bus stop", "id": 326, "trainId": 282}, + {"name": "trouser, pant", "id": 2877, "trainId": 283}, + {"name": "console table, console", "id": 594, "trainId": 284}, + {"name": "rack", "id": 2036, "trainId": 285}, + {"name": "notebook", "id": 1662, "trainId": 286}, + {"name": "shrine", "id": 2366, "trainId": 287}, + {"name": "pantry", "id": 1754, "trainId": 288}, + {"name": "cart", "id": 418, "trainId": 289}, + {"name": "steam shovel", "id": 2553, "trainId": 290}, + {"name": "porch", "id": 1951, "trainId": 291}, + {"name": "postbox, mailbox, letter box", "id": 1963, "trainId": 292}, + {"name": "figurine, statuette", "id": 918, "trainId": 293}, + {"name": "recycling bin", "id": 2086, "trainId": 294}, + {"name": "folding screen", "id": 997, "trainId": 295}, + {"name": "telescope", "id": 2731, "trainId": 296}, + {"name": "deck chair, beach chair", "id": 704, "trainId": 297}, + {"name": "kennel", "id": 1365, "trainId": 298}, + {"name": "coffee maker", "id": 569, "trainId": 299}, + {"name": "altar, communion table, lord's table", "id": 3108, "trainId": 300}, + {"name": "fish", "id": 948, "trainId": 301}, + {"name": "easel", "id": 839, "trainId": 302}, + {"name": "artificial golf green", "id": 63, "trainId": 303}, + {"name": "iceberg", "id": 1305, "trainId": 304}, + {"name": "candlestick, candle holder", "id": 378, "trainId": 305}, + {"name": "shower stall, shower bath", "id": 2362, "trainId": 306}, + {"name": "television stand", "id": 2734, "trainId": 307}, + { + "name": "wall socket, wall plug, electric outlet, electrical outlet, outlet, electric receptacle", + "id": 2982, + "trainId": 308, + }, + {"name": "skeleton", "id": 2398, "trainId": 309}, + {"name": "grand piano, grand", "id": 1119, "trainId": 310}, + {"name": "candy, confect", "id": 382, "trainId": 311}, + {"name": "grille door", "id": 1141, "trainId": 312}, + {"name": "pedestal, plinth, footstall", "id": 1805, "trainId": 313}, + {"name": "jersey, t-shirt, tee shirt", "id": 3102, "trainId": 314}, + {"name": "shoe", "id": 2341, "trainId": 315}, + {"name": "gravestone, headstone, tombstone", "id": 1131, "trainId": 316}, + {"name": "shanty", "id": 2316, "trainId": 317}, + {"name": "structure", "id": 2626, "trainId": 318}, + {"name": "rocking chair, rocker", "id": 3104, "trainId": 319}, + {"name": "bird", "id": 198, "trainId": 320}, + {"name": "place mat", "id": 1896, "trainId": 321}, + {"name": "tomb", "id": 2800, "trainId": 322}, + {"name": "big top", "id": 190, "trainId": 323}, + { + "name": "gas pump, gasoline pump, petrol pump, island dispenser", + "id": 3131, + "trainId": 324, + }, + {"name": "lockers", "id": 1463, "trainId": 325}, + {"name": "cage", "id": 357, "trainId": 326}, + {"name": "finger", "id": 929, "trainId": 327}, + {"name": "bleachers", "id": 209, "trainId": 328}, + {"name": "ferris wheel", "id": 912, "trainId": 329}, + {"name": "hairdresser chair", "id": 1164, "trainId": 330}, + {"name": "mat", "id": 1509, "trainId": 331}, + {"name": "stands", "id": 2539, "trainId": 332}, + {"name": "aquarium, fish tank, marine museum", "id": 3116, "trainId": 333}, + { + "name": "streetcar, tram, tramcar, trolley, trolley car", + "id": 2615, + "trainId": 334, + }, + {"name": "napkin, table napkin, serviette", "id": 1644, "trainId": 335}, + {"name": "dummy", "id": 818, "trainId": 336}, + {"name": "booklet, brochure, folder, leaflet, pamphlet", "id": 242, "trainId": 337}, + {"name": "sand trap", "id": 2217, "trainId": 338}, + {"name": "shop, store", "id": 2347, "trainId": 339}, + {"name": "table cloth", "id": 2686, "trainId": 340}, + {"name": "service station", "id": 2300, "trainId": 341}, + {"name": "coffin", "id": 572, "trainId": 342}, + {"name": "drawer", "id": 789, "trainId": 343}, + {"name": "cages", "id": 358, "trainId": 344}, + {"name": "slot machine, coin machine", "id": 2443, "trainId": 345}, + {"name": "balcony", "id": 101, "trainId": 346}, + {"name": "volleyball court", "id": 2969, "trainId": 347}, + {"name": "table tennis", "id": 2692, "trainId": 348}, + {"name": "control table", "id": 606, "trainId": 349}, + {"name": "shirt", "id": 2339, "trainId": 350}, + {"name": "merchandise, ware, product", "id": 1533, "trainId": 351}, + {"name": "railway", "id": 2060, "trainId": 352}, + {"name": "parterre", "id": 1782, "trainId": 353}, + {"name": "chimney", "id": 495, "trainId": 354}, + {"name": "can, tin, tin can", "id": 371, "trainId": 355}, + {"name": "tanks", "id": 2707, "trainId": 356}, + {"name": "fabric, cloth, material, textile", "id": 889, "trainId": 357}, + {"name": "alga, algae", "id": 3156, "trainId": 358}, + {"name": "system", "id": 2683, "trainId": 359}, + {"name": "map", "id": 1499, "trainId": 360}, + {"name": "greenhouse", "id": 1135, "trainId": 361}, + {"name": "mug", "id": 1619, "trainId": 362}, + {"name": "barbecue", "id": 125, "trainId": 363}, + {"name": "trailer", "id": 2838, "trainId": 364}, + { + "name": "toilet tissue, toilet paper, bathroom tissue", + "id": 2792, + "trainId": 365, + }, + {"name": "organ", "id": 1695, "trainId": 366}, + {"name": "dishrag, dishcloth", "id": 746, "trainId": 367}, + {"name": "island", "id": 1343, "trainId": 368}, + {"name": "keyboard", "id": 1370, "trainId": 369}, + {"name": "trench", "id": 2858, "trainId": 370}, + {"name": "basket, basketball hoop, hoop", "id": 145, "trainId": 371}, + {"name": "steering wheel, wheel", "id": 2565, "trainId": 372}, + {"name": "pitcher, ewer", "id": 1892, "trainId": 373}, + {"name": "goal", "id": 1103, "trainId": 374}, + {"name": "bread, breadstuff, staff of life", "id": 286, "trainId": 375}, + {"name": "beds", "id": 170, "trainId": 376}, + {"name": "wood", "id": 3073, "trainId": 377}, + {"name": "file cabinet", "id": 922, "trainId": 378}, + {"name": "newspaper, paper", "id": 1655, "trainId": 379}, + {"name": "motorboat", "id": 1602, "trainId": 380}, + {"name": "rope", "id": 2160, "trainId": 381}, + {"name": "guitar", "id": 1151, "trainId": 382}, + {"name": "rubble", "id": 2176, "trainId": 383}, + {"name": "scarf", "id": 2239, "trainId": 384}, + {"name": "barrels", "id": 132, "trainId": 385}, + {"name": "cap", "id": 394, "trainId": 386}, + {"name": "leaves", "id": 1424, "trainId": 387}, + {"name": "control tower", "id": 607, "trainId": 388}, + {"name": "dashboard", "id": 700, "trainId": 389}, + {"name": "bandstand", "id": 116, "trainId": 390}, + {"name": "lectern", "id": 1425, "trainId": 391}, + {"name": "switch, electric switch, electrical switch", "id": 2676, "trainId": 392}, + {"name": "baseboard, mopboard, skirting board", "id": 141, "trainId": 393}, + {"name": "shower room", "id": 2360, "trainId": 394}, + {"name": "smoke", "id": 2449, "trainId": 395}, + {"name": "faucet, spigot", "id": 897, "trainId": 396}, + {"name": "bulldozer", "id": 317, "trainId": 397}, + {"name": "saucepan", "id": 2228, "trainId": 398}, + {"name": "shops", "id": 2351, "trainId": 399}, + {"name": "meter", "id": 1543, "trainId": 400}, + {"name": "crevasse", "id": 656, "trainId": 401}, + {"name": "gear", "id": 1088, "trainId": 402}, + {"name": "candelabrum, candelabra", "id": 373, "trainId": 403}, + {"name": "sofa bed", "id": 2472, "trainId": 404}, + {"name": "tunnel", "id": 2892, "trainId": 405}, + {"name": "pallet", "id": 1740, "trainId": 406}, + {"name": "wire, conducting wire", "id": 3067, "trainId": 407}, + {"name": "kettle, boiler", "id": 1367, "trainId": 408}, + {"name": "bidet", "id": 188, "trainId": 409}, + { + "name": "baby buggy, baby carriage, carriage, perambulator, pram, stroller, go-cart, pushchair, pusher", + "id": 79, + "trainId": 410, + }, + {"name": "music stand", "id": 1633, "trainId": 411}, + {"name": "pipe, tube", "id": 1885, "trainId": 412}, + {"name": "cup", "id": 677, "trainId": 413}, + {"name": "parking meter", "id": 1779, "trainId": 414}, + {"name": "ice hockey rink", "id": 1297, "trainId": 415}, + {"name": "shelter", "id": 2334, "trainId": 416}, + {"name": "weeds", "id": 3027, "trainId": 417}, + {"name": "temple", "id": 2735, "trainId": 418}, + {"name": "patty, cake", "id": 1791, "trainId": 419}, + {"name": "ski slope", "id": 2405, "trainId": 420}, + {"name": "panel", "id": 1748, "trainId": 421}, + {"name": "wallet", "id": 2983, "trainId": 422}, + {"name": "wheel", "id": 3035, "trainId": 423}, + {"name": "towel rack, towel horse", "id": 2824, "trainId": 424}, + {"name": "roundabout", "id": 2168, "trainId": 425}, + {"name": "canister, cannister, tin", "id": 385, "trainId": 426}, + {"name": "rod", "id": 2148, "trainId": 427}, + {"name": "soap dispenser", "id": 2465, "trainId": 428}, + {"name": "bell", "id": 175, "trainId": 429}, + {"name": "canvas", "id": 390, "trainId": 430}, + {"name": "box office, ticket office, ticket booth", "id": 268, "trainId": 431}, + {"name": "teacup", "id": 2722, "trainId": 432}, + {"name": "trellis", "id": 2857, "trainId": 433}, + {"name": "workbench", "id": 3088, "trainId": 434}, + {"name": "valley, vale", "id": 2926, "trainId": 435}, + {"name": "toaster", "id": 2782, "trainId": 436}, + {"name": "knife", "id": 1378, "trainId": 437}, + {"name": "podium", "id": 1934, "trainId": 438}, + {"name": "ramp", "id": 2072, "trainId": 439}, + {"name": "tumble dryer", "id": 2889, "trainId": 440}, + {"name": "fireplug, fire hydrant, plug", "id": 944, "trainId": 441}, + {"name": "gym shoe, sneaker, tennis shoe", "id": 1158, "trainId": 442}, + {"name": "lab bench", "id": 1383, "trainId": 443}, + {"name": "equipment", "id": 867, "trainId": 444}, + {"name": "rocky formation", "id": 2145, "trainId": 445}, + {"name": "plastic", "id": 1915, "trainId": 446}, + {"name": "calendar", "id": 361, "trainId": 447}, + {"name": "caravan", "id": 402, "trainId": 448}, + {"name": "check-in-desk", "id": 482, "trainId": 449}, + {"name": "ticket counter", "id": 2761, "trainId": 450}, + {"name": "brush", "id": 300, "trainId": 451}, + {"name": "mill", "id": 1554, "trainId": 452}, + {"name": "covered bridge", "id": 636, "trainId": 453}, + {"name": "bowling alley", "id": 260, "trainId": 454}, + {"name": "hanger", "id": 1186, "trainId": 455}, + {"name": "excavator", "id": 871, "trainId": 456}, + {"name": "trestle", "id": 2859, "trainId": 457}, + {"name": "revolving door", "id": 2103, "trainId": 458}, + {"name": "blast furnace", "id": 208, "trainId": 459}, + {"name": "scale, weighing machine", "id": 2236, "trainId": 460}, + {"name": "projector", "id": 2012, "trainId": 461}, + {"name": "soap", "id": 2462, "trainId": 462}, + {"name": "locker", "id": 1462, "trainId": 463}, + {"name": "tractor", "id": 2832, "trainId": 464}, + {"name": "stretcher", "id": 2617, "trainId": 465}, + {"name": "frame", "id": 1024, "trainId": 466}, + {"name": "grating", "id": 1129, "trainId": 467}, + {"name": "alembic", "id": 18, "trainId": 468}, + {"name": "candle, taper, wax light", "id": 376, "trainId": 469}, + {"name": "barrier", "id": 134, "trainId": 470}, + {"name": "cardboard", "id": 407, "trainId": 471}, + {"name": "cave", "id": 434, "trainId": 472}, + {"name": "puddle", "id": 2017, "trainId": 473}, + {"name": "tarp", "id": 2717, "trainId": 474}, + {"name": "price tag", "id": 2005, "trainId": 475}, + {"name": "watchtower", "id": 2993, "trainId": 476}, + {"name": "meters", "id": 1545, "trainId": 477}, + { + "name": "light bulb, lightbulb, bulb, incandescent lamp, electric light, electric-light bulb", + "id": 1445, + "trainId": 478, + }, + {"name": "tracks", "id": 2831, "trainId": 479}, + {"name": "hair dryer", "id": 1161, "trainId": 480}, + {"name": "skirt", "id": 2411, "trainId": 481}, + {"name": "viaduct", "id": 2949, "trainId": 482}, + {"name": "paper towel", "id": 1769, "trainId": 483}, + {"name": "coat", "id": 552, "trainId": 484}, + {"name": "sheet", "id": 2327, "trainId": 485}, + {"name": "fire extinguisher, extinguisher, asphyxiator", "id": 939, "trainId": 486}, + {"name": "water wheel", "id": 3013, "trainId": 487}, + {"name": "pottery, clayware", "id": 1986, "trainId": 488}, + {"name": "magazine rack", "id": 1486, "trainId": 489}, + {"name": "teapot", "id": 2723, "trainId": 490}, + {"name": "microphone, mike", "id": 1549, "trainId": 491}, + {"name": "support", "id": 2649, "trainId": 492}, + {"name": "forklift", "id": 1020, "trainId": 493}, + {"name": "canyon", "id": 392, "trainId": 494}, + {"name": "cash register, register", "id": 422, "trainId": 495}, + {"name": "leaf, leafage, foliage", "id": 1419, "trainId": 496}, + {"name": "remote control, remote", "id": 2099, "trainId": 497}, + {"name": "soap dish", "id": 2464, "trainId": 498}, + {"name": "windshield, windscreen", "id": 3058, "trainId": 499}, + {"name": "cat", "id": 430, "trainId": 500}, + {"name": "cue, cue stick, pool cue, pool stick", "id": 675, "trainId": 501}, + {"name": "vent, venthole, vent-hole, blowhole", "id": 2941, "trainId": 502}, + {"name": "videos", "id": 2955, "trainId": 503}, + {"name": "shovel", "id": 2355, "trainId": 504}, + {"name": "eaves", "id": 840, "trainId": 505}, + {"name": "antenna, aerial, transmitting aerial", "id": 32, "trainId": 506}, + {"name": "shipyard", "id": 2338, "trainId": 507}, + {"name": "hen, biddy", "id": 1232, "trainId": 508}, + {"name": "traffic cone", "id": 2834, "trainId": 509}, + {"name": "washing machines", "id": 2991, "trainId": 510}, + {"name": "truck crane", "id": 2879, "trainId": 511}, + {"name": "cds", "id": 444, "trainId": 512}, + {"name": "niche", "id": 1657, "trainId": 513}, + {"name": "scoreboard", "id": 2246, "trainId": 514}, + {"name": "briefcase", "id": 296, "trainId": 515}, + {"name": "boot", "id": 245, "trainId": 516}, + {"name": "sweater, jumper", "id": 2661, "trainId": 517}, + {"name": "hay", "id": 1202, "trainId": 518}, + {"name": "pack", "id": 1714, "trainId": 519}, + {"name": "bottle rack", "id": 251, "trainId": 520}, + {"name": "glacier", "id": 1095, "trainId": 521}, + {"name": "pergola", "id": 1828, "trainId": 522}, + {"name": "building materials", "id": 311, "trainId": 523}, + {"name": "television camera", "id": 2732, "trainId": 524}, + {"name": "first floor", "id": 947, "trainId": 525}, + {"name": "rifle", "id": 2115, "trainId": 526}, + {"name": "tennis table", "id": 2738, "trainId": 527}, + {"name": "stadium", "id": 2525, "trainId": 528}, + {"name": "safety belt", "id": 2194, "trainId": 529}, + {"name": "cover", "id": 634, "trainId": 530}, + {"name": "dish rack", "id": 740, "trainId": 531}, + {"name": "synthesizer", "id": 2682, "trainId": 532}, + {"name": "pumpkin", "id": 2020, "trainId": 533}, + {"name": "gutter", "id": 1156, "trainId": 534}, + {"name": "fruit stand", "id": 1036, "trainId": 535}, + {"name": "ice floe, floe", "id": 1295, "trainId": 536}, + {"name": "handle, grip, handgrip, hold", "id": 1181, "trainId": 537}, + {"name": "wheelchair", "id": 3037, "trainId": 538}, + {"name": "mousepad, mouse mat", "id": 1614, "trainId": 539}, + {"name": "diploma", "id": 736, "trainId": 540}, + {"name": "fairground ride", "id": 893, "trainId": 541}, + {"name": "radio", "id": 2047, "trainId": 542}, + {"name": "hotplate", "id": 1274, "trainId": 543}, + {"name": "junk", "id": 1361, "trainId": 544}, + {"name": "wheelbarrow", "id": 3036, "trainId": 545}, + {"name": "stream", "id": 2606, "trainId": 546}, + {"name": "toll plaza", "id": 2797, "trainId": 547}, + {"name": "punching bag", "id": 2022, "trainId": 548}, + {"name": "trough", "id": 2876, "trainId": 549}, + {"name": "throne", "id": 2758, "trainId": 550}, + {"name": "chair desk", "id": 472, "trainId": 551}, + {"name": "weighbridge", "id": 3028, "trainId": 552}, + {"name": "extractor fan", "id": 882, "trainId": 553}, + {"name": "hanging clothes", "id": 1189, "trainId": 554}, + {"name": "dish, dish aerial, dish antenna, saucer", "id": 743, "trainId": 555}, + {"name": "alarm clock, alarm", "id": 3122, "trainId": 556}, + {"name": "ski lift", "id": 2401, "trainId": 557}, + {"name": "chain", "id": 468, "trainId": 558}, + {"name": "garage", "id": 1061, "trainId": 559}, + {"name": "mechanical shovel", "id": 1523, "trainId": 560}, + {"name": "wine rack", "id": 3059, "trainId": 561}, + {"name": "tramway", "id": 2843, "trainId": 562}, + {"name": "treadmill", "id": 2853, "trainId": 563}, + {"name": "menu", "id": 1529, "trainId": 564}, + {"name": "block", "id": 214, "trainId": 565}, + {"name": "well", "id": 3032, "trainId": 566}, + {"name": "witness stand", "id": 3071, "trainId": 567}, + {"name": "branch", "id": 277, "trainId": 568}, + {"name": "duck", "id": 813, "trainId": 569}, + {"name": "casserole", "id": 426, "trainId": 570}, + {"name": "frying pan", "id": 1039, "trainId": 571}, + {"name": "desk organizer", "id": 727, "trainId": 572}, + {"name": "mast", "id": 1508, "trainId": 573}, + {"name": "spectacles, specs, eyeglasses, glasses", "id": 2490, "trainId": 574}, + {"name": "service elevator", "id": 2299, "trainId": 575}, + {"name": "dollhouse", "id": 768, "trainId": 576}, + {"name": "hammock", "id": 1172, "trainId": 577}, + {"name": "clothes hanging", "id": 537, "trainId": 578}, + {"name": "photocopier", "id": 1847, "trainId": 579}, + {"name": "notepad", "id": 1664, "trainId": 580}, + {"name": "golf cart", "id": 1110, "trainId": 581}, + {"name": "footpath", "id": 1014, "trainId": 582}, + {"name": "cross", "id": 662, "trainId": 583}, + {"name": "baptismal font", "id": 121, "trainId": 584}, + {"name": "boiler", "id": 227, "trainId": 585}, + {"name": "skip", "id": 2410, "trainId": 586}, + {"name": "rotisserie", "id": 2165, "trainId": 587}, + {"name": "tables", "id": 2696, "trainId": 588}, + {"name": "water mill", "id": 3005, "trainId": 589}, + {"name": "helmet", "id": 1231, "trainId": 590}, + {"name": "cover curtain", "id": 635, "trainId": 591}, + {"name": "brick", "id": 292, "trainId": 592}, + {"name": "table runner", "id": 2690, "trainId": 593}, + {"name": "ashtray", "id": 65, "trainId": 594}, + {"name": "street box", "id": 2607, "trainId": 595}, + {"name": "stick", "id": 2574, "trainId": 596}, + {"name": "hangers", "id": 1188, "trainId": 597}, + {"name": "cells", "id": 456, "trainId": 598}, + {"name": "urinal", "id": 2913, "trainId": 599}, + {"name": "centerpiece", "id": 459, "trainId": 600}, + {"name": "portable fridge", "id": 1955, "trainId": 601}, + {"name": "dvds", "id": 827, "trainId": 602}, + {"name": "golf club", "id": 1111, "trainId": 603}, + {"name": "skirting board", "id": 2412, "trainId": 604}, + {"name": "water cooler", "id": 2997, "trainId": 605}, + {"name": "clipboard", "id": 528, "trainId": 606}, + {"name": "camera, photographic camera", "id": 366, "trainId": 607}, + {"name": "pigeonhole", "id": 1863, "trainId": 608}, + {"name": "chips", "id": 500, "trainId": 609}, + {"name": "food processor", "id": 1001, "trainId": 610}, + {"name": "post box", "id": 1958, "trainId": 611}, + {"name": "lid", "id": 1441, "trainId": 612}, + {"name": "drum", "id": 809, "trainId": 613}, + {"name": "blender", "id": 210, "trainId": 614}, + {"name": "cave entrance", "id": 435, "trainId": 615}, + {"name": "dental chair", "id": 718, "trainId": 616}, + {"name": "obelisk", "id": 1674, "trainId": 617}, + {"name": "canoe", "id": 388, "trainId": 618}, + {"name": "mobile", "id": 1572, "trainId": 619}, + {"name": "monitors", "id": 1584, "trainId": 620}, + {"name": "pool ball", "id": 1944, "trainId": 621}, + {"name": "cue rack", "id": 674, "trainId": 622}, + {"name": "baggage carts", "id": 99, "trainId": 623}, + {"name": "shore", "id": 2352, "trainId": 624}, + {"name": "fork", "id": 1019, "trainId": 625}, + {"name": "paper filer", "id": 1763, "trainId": 626}, + {"name": "bicycle rack", "id": 185, "trainId": 627}, + {"name": "coat rack", "id": 554, "trainId": 628}, + {"name": "garland", "id": 1066, "trainId": 629}, + {"name": "sports bag", "id": 2508, "trainId": 630}, + {"name": "fish tank", "id": 951, "trainId": 631}, + {"name": "towel dispenser", "id": 2822, "trainId": 632}, + {"name": "carriage", "id": 415, "trainId": 633}, + {"name": "brochure", "id": 297, "trainId": 634}, + {"name": "plaque", "id": 1914, "trainId": 635}, + {"name": "stringer", "id": 2619, "trainId": 636}, + {"name": "iron", "id": 1338, "trainId": 637}, + {"name": "spoon", "id": 2505, "trainId": 638}, + {"name": "flag pole", "id": 955, "trainId": 639}, + {"name": "toilet brush", "id": 2786, "trainId": 640}, + {"name": "book stand", "id": 238, "trainId": 641}, + {"name": "water faucet, water tap, tap, hydrant", "id": 3000, "trainId": 642}, + {"name": "ticket office", "id": 2763, "trainId": 643}, + {"name": "broom", "id": 299, "trainId": 644}, + {"name": "dvd", "id": 822, "trainId": 645}, + {"name": "ice bucket", "id": 1288, "trainId": 646}, + {"name": "carapace, shell, cuticle, shield", "id": 3101, "trainId": 647}, + {"name": "tureen", "id": 2894, "trainId": 648}, + {"name": "folders", "id": 992, "trainId": 649}, + {"name": "chess", "id": 489, "trainId": 650}, + {"name": "root", "id": 2157, "trainId": 651}, + {"name": "sewing machine", "id": 2309, "trainId": 652}, + {"name": "model", "id": 1576, "trainId": 653}, + {"name": "pen", "id": 1810, "trainId": 654}, + {"name": "violin", "id": 2964, "trainId": 655}, + {"name": "sweatshirt", "id": 2662, "trainId": 656}, + {"name": "recycling materials", "id": 2087, "trainId": 657}, + {"name": "mitten", "id": 1569, "trainId": 658}, + {"name": "chopping board, cutting board", "id": 503, "trainId": 659}, + {"name": "mask", "id": 1505, "trainId": 660}, + {"name": "log", "id": 1468, "trainId": 661}, + {"name": "mouse, computer mouse", "id": 1613, "trainId": 662}, + {"name": "grill", "id": 1138, "trainId": 663}, + {"name": "hole", "id": 1256, "trainId": 664}, + {"name": "target", "id": 2715, "trainId": 665}, + {"name": "trash bag", "id": 2846, "trainId": 666}, + {"name": "chalk", "id": 477, "trainId": 667}, + {"name": "sticks", "id": 2576, "trainId": 668}, + {"name": "balloon", "id": 108, "trainId": 669}, + {"name": "score", "id": 2245, "trainId": 670}, + {"name": "hair spray", "id": 1162, "trainId": 671}, + {"name": "roll", "id": 2149, "trainId": 672}, + {"name": "runner", "id": 2183, "trainId": 673}, + {"name": "engine", "id": 858, "trainId": 674}, + {"name": "inflatable glove", "id": 1324, "trainId": 675}, + {"name": "games", "id": 1055, "trainId": 676}, + {"name": "pallets", "id": 1741, "trainId": 677}, + {"name": "baskets", "id": 149, "trainId": 678}, + {"name": "coop", "id": 615, "trainId": 679}, + {"name": "dvd player", "id": 825, "trainId": 680}, + {"name": "rocking horse", "id": 2143, "trainId": 681}, + {"name": "buckets", "id": 304, "trainId": 682}, + {"name": "bread rolls", "id": 283, "trainId": 683}, + {"name": "shawl", "id": 2322, "trainId": 684}, + {"name": "watering can", "id": 3017, "trainId": 685}, + {"name": "spotlights", "id": 2510, "trainId": 686}, + {"name": "post-it", "id": 1960, "trainId": 687}, + {"name": "bowls", "id": 265, "trainId": 688}, + {"name": "security camera", "id": 2282, "trainId": 689}, + {"name": "runner cloth", "id": 2184, "trainId": 690}, + {"name": "lock", "id": 1461, "trainId": 691}, + {"name": "alarm, warning device, alarm system", "id": 3113, "trainId": 692}, + {"name": "side", "id": 2372, "trainId": 693}, + {"name": "roulette", "id": 2166, "trainId": 694}, + {"name": "bone", "id": 232, "trainId": 695}, + {"name": "cutlery", "id": 693, "trainId": 696}, + {"name": "pool balls", "id": 1945, "trainId": 697}, + {"name": "wheels", "id": 3039, "trainId": 698}, + {"name": "spice rack", "id": 2494, "trainId": 699}, + {"name": "plant pots", "id": 1908, "trainId": 700}, + {"name": "towel ring", "id": 2827, "trainId": 701}, + {"name": "bread box", "id": 280, "trainId": 702}, + {"name": "video", "id": 2950, "trainId": 703}, + {"name": "funfair", "id": 1044, "trainId": 704}, + {"name": "breads", "id": 288, "trainId": 705}, + {"name": "tripod", "id": 2863, "trainId": 706}, + {"name": "ironing board", "id": 1342, "trainId": 707}, + {"name": "skimmer", "id": 2409, "trainId": 708}, + {"name": "hollow", "id": 1258, "trainId": 709}, + {"name": "scratching post", "id": 2249, "trainId": 710}, + {"name": "tricycle", "id": 2862, "trainId": 711}, + {"name": "file box", "id": 920, "trainId": 712}, + {"name": "mountain pass", "id": 1607, "trainId": 713}, + {"name": "tombstones", "id": 2802, "trainId": 714}, + {"name": "cooker", "id": 610, "trainId": 715}, + {"name": "card game, cards", "id": 3129, "trainId": 716}, + {"name": "golf bag", "id": 1108, "trainId": 717}, + {"name": "towel paper", "id": 2823, "trainId": 718}, + {"name": "chaise lounge", "id": 476, "trainId": 719}, + {"name": "sun", "id": 2641, "trainId": 720}, + {"name": "toilet paper holder", "id": 2788, "trainId": 721}, + {"name": "rake", "id": 2070, "trainId": 722}, + {"name": "key", "id": 1368, "trainId": 723}, + {"name": "umbrella stand", "id": 2903, "trainId": 724}, + {"name": "dartboard", "id": 699, "trainId": 725}, + {"name": "transformer", "id": 2844, "trainId": 726}, + {"name": "fireplace utensils", "id": 942, "trainId": 727}, + {"name": "sweatshirts", "id": 2663, "trainId": 728}, + { + "name": "cellular telephone, cellular phone, cellphone, cell, mobile phone", + "id": 457, + "trainId": 729, + }, + {"name": "tallboy", "id": 2701, "trainId": 730}, + {"name": "stapler", "id": 2540, "trainId": 731}, + {"name": "sauna", "id": 2231, "trainId": 732}, + {"name": "test tube", "id": 2746, "trainId": 733}, + {"name": "palette", "id": 1738, "trainId": 734}, + {"name": "shopping carts", "id": 2350, "trainId": 735}, + {"name": "tools", "id": 2808, "trainId": 736}, + {"name": "push button, push, button", "id": 2025, "trainId": 737}, + {"name": "star", "id": 2541, "trainId": 738}, + {"name": "roof rack", "id": 2156, "trainId": 739}, + {"name": "barbed wire", "id": 126, "trainId": 740}, + {"name": "spray", "id": 2512, "trainId": 741}, + {"name": "ear", "id": 831, "trainId": 742}, + {"name": "sponge", "id": 2503, "trainId": 743}, + {"name": "racket", "id": 2039, "trainId": 744}, + {"name": "tins", "id": 2774, "trainId": 745}, + {"name": "eyeglasses", "id": 886, "trainId": 746}, + {"name": "file", "id": 919, "trainId": 747}, + {"name": "scarfs", "id": 2240, "trainId": 748}, + {"name": "sugar bowl", "id": 2636, "trainId": 749}, + {"name": "flip flop", "id": 963, "trainId": 750}, + {"name": "headstones", "id": 1218, "trainId": 751}, + {"name": "laptop bag", "id": 1406, "trainId": 752}, + {"name": "leash", "id": 1420, "trainId": 753}, + {"name": "climbing frame", "id": 526, "trainId": 754}, + {"name": "suit hanger", "id": 2639, "trainId": 755}, + {"name": "floor spotlight", "id": 975, "trainId": 756}, + {"name": "plate rack", "id": 1921, "trainId": 757}, + {"name": "sewer", "id": 2305, "trainId": 758}, + {"name": "hard drive", "id": 1193, "trainId": 759}, + {"name": "sprinkler", "id": 2517, "trainId": 760}, + {"name": "tools box", "id": 2809, "trainId": 761}, + {"name": "necklace", "id": 1647, "trainId": 762}, + {"name": "bulbs", "id": 314, "trainId": 763}, + {"name": "steel industry", "id": 2560, "trainId": 764}, + {"name": "club", "id": 545, "trainId": 765}, + {"name": "jack", "id": 1345, "trainId": 766}, + {"name": "door bars", "id": 775, "trainId": 767}, + { + "name": "control panel, instrument panel, control board, board, panel", + "id": 603, + "trainId": 768, + }, + {"name": "hairbrush", "id": 1163, "trainId": 769}, + {"name": "napkin holder", "id": 1641, "trainId": 770}, + {"name": "office", "id": 1678, "trainId": 771}, + {"name": "smoke detector", "id": 2450, "trainId": 772}, + {"name": "utensils", "id": 2915, "trainId": 773}, + {"name": "apron", "id": 42, "trainId": 774}, + {"name": "scissors", "id": 2242, "trainId": 775}, + {"name": "terminal", "id": 2741, "trainId": 776}, + {"name": "grinder", "id": 1143, "trainId": 777}, + {"name": "entry phone", "id": 862, "trainId": 778}, + {"name": "newspaper stand", "id": 1654, "trainId": 779}, + {"name": "pepper shaker", "id": 1826, "trainId": 780}, + {"name": "onions", "id": 1689, "trainId": 781}, + { + "name": "central processing unit, cpu, c p u , central processor, processor, mainframe", + "id": 3124, + "trainId": 782, + }, + {"name": "tape", "id": 2710, "trainId": 783}, + {"name": "bat", "id": 152, "trainId": 784}, + {"name": "coaster", "id": 549, "trainId": 785}, + {"name": "calculator", "id": 360, "trainId": 786}, + {"name": "potatoes", "id": 1982, "trainId": 787}, + {"name": "luggage rack", "id": 1478, "trainId": 788}, + {"name": "salt", "id": 2203, "trainId": 789}, + {"name": "street number", "id": 2612, "trainId": 790}, + {"name": "viewpoint", "id": 2956, "trainId": 791}, + {"name": "sword", "id": 2681, "trainId": 792}, + {"name": "cd", "id": 437, "trainId": 793}, + {"name": "rowing machine", "id": 2171, "trainId": 794}, + {"name": "plug", "id": 1933, "trainId": 795}, + {"name": "andiron, firedog, dog, dog-iron", "id": 3110, "trainId": 796}, + {"name": "pepper", "id": 1824, "trainId": 797}, + {"name": "tongs", "id": 2803, "trainId": 798}, + {"name": "bonfire", "id": 234, "trainId": 799}, + {"name": "dog dish", "id": 764, "trainId": 800}, + {"name": "belt", "id": 177, "trainId": 801}, + {"name": "dumbbells", "id": 817, "trainId": 802}, + {"name": "videocassette recorder, vcr", "id": 3145, "trainId": 803}, + {"name": "hook", "id": 1262, "trainId": 804}, + {"name": "envelopes", "id": 864, "trainId": 805}, + {"name": "shower faucet", "id": 2359, "trainId": 806}, + {"name": "watch", "id": 2992, "trainId": 807}, + {"name": "padlock", "id": 1725, "trainId": 808}, + {"name": "swimming pool ladder", "id": 2667, "trainId": 809}, + {"name": "spanners", "id": 2484, "trainId": 810}, + {"name": "gravy boat", "id": 1133, "trainId": 811}, + {"name": "notice board", "id": 1667, "trainId": 812}, + {"name": "trash bags", "id": 2847, "trainId": 813}, + {"name": "fire alarm", "id": 932, "trainId": 814}, + {"name": "ladle", "id": 1392, "trainId": 815}, + {"name": "stethoscope", "id": 2573, "trainId": 816}, + {"name": "rocket", "id": 2140, "trainId": 817}, + {"name": "funnel", "id": 1046, "trainId": 818}, + {"name": "bowling pins", "id": 264, "trainId": 819}, + {"name": "valve", "id": 2927, "trainId": 820}, + {"name": "thermometer", "id": 2752, "trainId": 821}, + {"name": "cups", "id": 679, "trainId": 822}, + {"name": "spice jar", "id": 2493, "trainId": 823}, + {"name": "night light", "id": 1658, "trainId": 824}, + {"name": "soaps", "id": 2466, "trainId": 825}, + {"name": "games table", "id": 1057, "trainId": 826}, + {"name": "slotted spoon", "id": 2444, "trainId": 827}, + {"name": "reel", "id": 2093, "trainId": 828}, + {"name": "scourer", "id": 2248, "trainId": 829}, + {"name": "sleeping robe", "id": 2432, "trainId": 830}, + {"name": "desk mat", "id": 726, "trainId": 831}, + {"name": "dumbbell", "id": 816, "trainId": 832}, + {"name": "hammer", "id": 1171, "trainId": 833}, + {"name": "tie", "id": 2766, "trainId": 834}, + {"name": "typewriter", "id": 2900, "trainId": 835}, + {"name": "shaker", "id": 2313, "trainId": 836}, + {"name": "cheese dish", "id": 488, "trainId": 837}, + {"name": "sea star", "id": 2265, "trainId": 838}, + {"name": "racquet", "id": 2043, "trainId": 839}, + {"name": "butane gas cylinder", "id": 332, "trainId": 840}, + {"name": "paper weight", "id": 1771, "trainId": 841}, + {"name": "shaving brush", "id": 2320, "trainId": 842}, + {"name": "sunglasses", "id": 2646, "trainId": 843}, + {"name": "gear shift", "id": 1089, "trainId": 844}, + {"name": "towel rail", "id": 2826, "trainId": 845}, + {"name": "adding machine, totalizer, totaliser", "id": 3148, "trainId": 846}, +] + + +def _get_ade20k_full_meta(): + stuff_ids = [k["id"] for k in ADE20K_SEM_SEG_FULL_CATEGORIES] + assert len(stuff_ids) == 847, len(stuff_ids) + + stuff_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(stuff_ids)} + stuff_classes = [k["name"] for k in ADE20K_SEM_SEG_FULL_CATEGORIES] + + ret = { + "stuff_dataset_id_to_contiguous_id": stuff_dataset_id_to_contiguous_id, + "stuff_classes": stuff_classes, + } + return ret + + +def register_all_ade20k_full(root): + meta = _get_ade20k_full_meta() + for name, dirname in [("val", "validation")]: + image_dir = os.path.join(root, "ADE20K_2021_17_01/images_detectron2", dirname) + gt_dir = os.path.join(root, "ADE20K_2021_17_01/annotations_detectron2", dirname) + name = f"ade20k_full_sem_seg_{name}" + DatasetCatalog.register( + name, + lambda x=image_dir, y=gt_dir: load_sem_seg( + y, x, gt_ext="tif", image_ext="jpg" + ), + ) + MetadataCatalog.get(name).set( + stuff_classes=meta["stuff_classes"][:], + thing_classes=meta["stuff_classes"][:], # the same as stuff_classes + image_root=image_dir, + sem_seg_root=gt_dir, + evaluator_type="sem_seg", + ignore_label=65535, # NOTE: gt is saved in 16-bit TIFF images + ) + + +_root = os.getenv("DETECTRON2_DATASETS", "datasets") +register_all_ade20k_full(_root) diff --git a/open_vocab_seg/data/datasets/register_cc3m.py b/open_vocab_seg/data/datasets/register_cc3m.py new file mode 100644 index 0000000000000000000000000000000000000000..8aa5cb07bc99b574505b6319835750789bb3ee26 --- /dev/null +++ b/open_vocab_seg/data/datasets/register_cc3m.py @@ -0,0 +1,457 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import os + +import pandas as pd +from detectron2.data import DatasetCatalog, MetadataCatalog +from detectron2.data.datasets import load_sem_seg +from detectron2.utils.file_io import PathManager + + +COCO_CATEGORIES = [ + {"color": [220, 20, 60], "isthing": 1, "id": 1, "name": "person"}, + {"color": [119, 11, 32], "isthing": 1, "id": 2, "name": "bicycle"}, + {"color": [0, 0, 142], "isthing": 1, "id": 3, "name": "car"}, + {"color": [0, 0, 230], "isthing": 1, "id": 4, "name": "motorcycle"}, + {"color": [106, 0, 228], "isthing": 1, "id": 5, "name": "airplane"}, + {"color": [0, 60, 100], "isthing": 1, "id": 6, "name": "bus"}, + {"color": [0, 80, 100], "isthing": 1, "id": 7, "name": "train"}, + {"color": [0, 0, 70], "isthing": 1, "id": 8, "name": "truck"}, + {"color": [0, 0, 192], "isthing": 1, "id": 9, "name": "boat"}, + {"color": [250, 170, 30], "isthing": 1, "id": 10, "name": "traffic light"}, + {"color": [100, 170, 30], "isthing": 1, "id": 11, "name": "fire hydrant"}, + {"color": [220, 220, 0], "isthing": 1, "id": 13, "name": "stop sign"}, + {"color": [175, 116, 175], "isthing": 1, "id": 14, "name": "parking meter"}, + {"color": [250, 0, 30], "isthing": 1, "id": 15, "name": "bench"}, + {"color": [165, 42, 42], "isthing": 1, "id": 16, "name": "bird"}, + {"color": [255, 77, 255], "isthing": 1, "id": 17, "name": "cat"}, + {"color": [0, 226, 252], "isthing": 1, "id": 18, "name": "dog"}, + {"color": [182, 182, 255], "isthing": 1, "id": 19, "name": "horse"}, + {"color": [0, 82, 0], "isthing": 1, "id": 20, "name": "sheep"}, + {"color": [120, 166, 157], "isthing": 1, "id": 21, "name": "cow"}, + {"color": [110, 76, 0], "isthing": 1, "id": 22, "name": "elephant"}, + {"color": [174, 57, 255], "isthing": 1, "id": 23, "name": "bear"}, + {"color": [199, 100, 0], "isthing": 1, "id": 24, "name": "zebra"}, + {"color": [72, 0, 118], "isthing": 1, "id": 25, "name": "giraffe"}, + {"color": [255, 179, 240], "isthing": 1, "id": 27, "name": "backpack"}, + {"color": [0, 125, 92], "isthing": 1, "id": 28, "name": "umbrella"}, + {"color": [209, 0, 151], "isthing": 1, "id": 31, "name": "handbag"}, + {"color": [188, 208, 182], "isthing": 1, "id": 32, "name": "tie"}, + {"color": [0, 220, 176], "isthing": 1, "id": 33, "name": "suitcase"}, + {"color": [255, 99, 164], "isthing": 1, "id": 34, "name": "frisbee"}, + {"color": [92, 0, 73], "isthing": 1, "id": 35, "name": "skis"}, + {"color": [133, 129, 255], "isthing": 1, "id": 36, "name": "snowboard"}, + {"color": [78, 180, 255], "isthing": 1, "id": 37, "name": "sports ball"}, + {"color": [0, 228, 0], "isthing": 1, "id": 38, "name": "kite"}, + {"color": [174, 255, 243], "isthing": 1, "id": 39, "name": "baseball bat"}, + {"color": [45, 89, 255], "isthing": 1, "id": 40, "name": "baseball glove"}, + {"color": [134, 134, 103], "isthing": 1, "id": 41, "name": "skateboard"}, + {"color": [145, 148, 174], "isthing": 1, "id": 42, "name": "surfboard"}, + {"color": [255, 208, 186], "isthing": 1, "id": 43, "name": "tennis racket"}, + {"color": [197, 226, 255], "isthing": 1, "id": 44, "name": "bottle"}, + {"color": [171, 134, 1], "isthing": 1, "id": 46, "name": "wine glass"}, + {"color": [109, 63, 54], "isthing": 1, "id": 47, "name": "cup"}, + {"color": [207, 138, 255], "isthing": 1, "id": 48, "name": "fork"}, + {"color": [151, 0, 95], "isthing": 1, "id": 49, "name": "knife"}, + {"color": [9, 80, 61], "isthing": 1, "id": 50, "name": "spoon"}, + {"color": [84, 105, 51], "isthing": 1, "id": 51, "name": "bowl"}, + {"color": [74, 65, 105], "isthing": 1, "id": 52, "name": "banana"}, + {"color": [166, 196, 102], "isthing": 1, "id": 53, "name": "apple"}, + {"color": [208, 195, 210], "isthing": 1, "id": 54, "name": "sandwich"}, + {"color": [255, 109, 65], "isthing": 1, "id": 55, "name": "orange"}, + {"color": [0, 143, 149], "isthing": 1, "id": 56, "name": "broccoli"}, + {"color": [179, 0, 194], "isthing": 1, "id": 57, "name": "carrot"}, + {"color": [209, 99, 106], "isthing": 1, "id": 58, "name": "hot dog"}, + {"color": [5, 121, 0], "isthing": 1, "id": 59, "name": "pizza"}, + {"color": [227, 255, 205], "isthing": 1, "id": 60, "name": "donut"}, + {"color": [147, 186, 208], "isthing": 1, "id": 61, "name": "cake"}, + {"color": [153, 69, 1], "isthing": 1, "id": 62, "name": "chair"}, + {"color": [3, 95, 161], "isthing": 1, "id": 63, "name": "couch"}, + {"color": [163, 255, 0], "isthing": 1, "id": 64, "name": "potted plant"}, + {"color": [119, 0, 170], "isthing": 1, "id": 65, "name": "bed"}, + {"color": [0, 182, 199], "isthing": 1, "id": 67, "name": "dining table"}, + {"color": [0, 165, 120], "isthing": 1, "id": 70, "name": "toilet"}, + {"color": [183, 130, 88], "isthing": 1, "id": 72, "name": "tv"}, + {"color": [95, 32, 0], "isthing": 1, "id": 73, "name": "laptop"}, + {"color": [130, 114, 135], "isthing": 1, "id": 74, "name": "mouse"}, + {"color": [110, 129, 133], "isthing": 1, "id": 75, "name": "remote"}, + {"color": [166, 74, 118], "isthing": 1, "id": 76, "name": "keyboard"}, + {"color": [219, 142, 185], "isthing": 1, "id": 77, "name": "cell phone"}, + {"color": [79, 210, 114], "isthing": 1, "id": 78, "name": "microwave"}, + {"color": [178, 90, 62], "isthing": 1, "id": 79, "name": "oven"}, + {"color": [65, 70, 15], "isthing": 1, "id": 80, "name": "toaster"}, + {"color": [127, 167, 115], "isthing": 1, "id": 81, "name": "sink"}, + {"color": [59, 105, 106], "isthing": 1, "id": 82, "name": "refrigerator"}, + {"color": [142, 108, 45], "isthing": 1, "id": 84, "name": "book"}, + {"color": [196, 172, 0], "isthing": 1, "id": 85, "name": "clock"}, + {"color": [95, 54, 80], "isthing": 1, "id": 86, "name": "vase"}, + {"color": [128, 76, 255], "isthing": 1, "id": 87, "name": "scissors"}, + {"color": [201, 57, 1], "isthing": 1, "id": 88, "name": "teddy bear"}, + {"color": [246, 0, 122], "isthing": 1, "id": 89, "name": "hair drier"}, + {"color": [191, 162, 208], "isthing": 1, "id": 90, "name": "toothbrush"}, + {"id": 92, "name": "banner", "supercategory": "textile"}, + {"id": 93, "name": "blanket", "supercategory": "textile"}, + {"id": 94, "name": "branch", "supercategory": "plant"}, + {"id": 95, "name": "bridge", "supercategory": "building"}, + {"id": 96, "name": "building-other", "supercategory": "building"}, + {"id": 97, "name": "bush", "supercategory": "plant"}, + {"id": 98, "name": "cabinet", "supercategory": "furniture-stuff"}, + {"id": 99, "name": "cage", "supercategory": "structural"}, + {"id": 100, "name": "cardboard", "supercategory": "raw-material"}, + {"id": 101, "name": "carpet", "supercategory": "floor"}, + {"id": 102, "name": "ceiling-other", "supercategory": "ceiling"}, + {"id": 103, "name": "ceiling-tile", "supercategory": "ceiling"}, + {"id": 104, "name": "cloth", "supercategory": "textile"}, + {"id": 105, "name": "clothes", "supercategory": "textile"}, + {"id": 106, "name": "clouds", "supercategory": "sky"}, + {"id": 107, "name": "counter", "supercategory": "furniture-stuff"}, + {"id": 108, "name": "cupboard", "supercategory": "furniture-stuff"}, + {"id": 109, "name": "curtain", "supercategory": "textile"}, + {"id": 110, "name": "desk-stuff", "supercategory": "furniture-stuff"}, + {"id": 111, "name": "dirt", "supercategory": "ground"}, + {"id": 112, "name": "door-stuff", "supercategory": "furniture-stuff"}, + {"id": 113, "name": "fence", "supercategory": "structural"}, + {"id": 114, "name": "floor-marble", "supercategory": "floor"}, + {"id": 115, "name": "floor-other", "supercategory": "floor"}, + {"id": 116, "name": "floor-stone", "supercategory": "floor"}, + {"id": 117, "name": "floor-tile", "supercategory": "floor"}, + {"id": 118, "name": "floor-wood", "supercategory": "floor"}, + {"id": 119, "name": "flower", "supercategory": "plant"}, + {"id": 120, "name": "fog", "supercategory": "water"}, + {"id": 121, "name": "food-other", "supercategory": "food-stuff"}, + {"id": 122, "name": "fruit", "supercategory": "food-stuff"}, + {"id": 123, "name": "furniture-other", "supercategory": "furniture-stuff"}, + {"id": 124, "name": "grass", "supercategory": "plant"}, + {"id": 125, "name": "gravel", "supercategory": "ground"}, + {"id": 126, "name": "ground-other", "supercategory": "ground"}, + {"id": 127, "name": "hill", "supercategory": "solid"}, + {"id": 128, "name": "house", "supercategory": "building"}, + {"id": 129, "name": "leaves", "supercategory": "plant"}, + {"id": 130, "name": "light", "supercategory": "furniture-stuff"}, + {"id": 131, "name": "mat", "supercategory": "textile"}, + {"id": 132, "name": "metal", "supercategory": "raw-material"}, + {"id": 133, "name": "mirror-stuff", "supercategory": "furniture-stuff"}, + {"id": 134, "name": "moss", "supercategory": "plant"}, + {"id": 135, "name": "mountain", "supercategory": "solid"}, + {"id": 136, "name": "mud", "supercategory": "ground"}, + {"id": 137, "name": "napkin", "supercategory": "textile"}, + {"id": 138, "name": "net", "supercategory": "structural"}, + {"id": 139, "name": "paper", "supercategory": "raw-material"}, + {"id": 140, "name": "pavement", "supercategory": "ground"}, + {"id": 141, "name": "pillow", "supercategory": "textile"}, + {"id": 142, "name": "plant-other", "supercategory": "plant"}, + {"id": 143, "name": "plastic", "supercategory": "raw-material"}, + {"id": 144, "name": "platform", "supercategory": "ground"}, + {"id": 145, "name": "playingfield", "supercategory": "ground"}, + {"id": 146, "name": "railing", "supercategory": "structural"}, + {"id": 147, "name": "railroad", "supercategory": "ground"}, + {"id": 148, "name": "river", "supercategory": "water"}, + {"id": 149, "name": "road", "supercategory": "ground"}, + {"id": 150, "name": "rock", "supercategory": "solid"}, + {"id": 151, "name": "roof", "supercategory": "building"}, + {"id": 152, "name": "rug", "supercategory": "textile"}, + {"id": 153, "name": "salad", "supercategory": "food-stuff"}, + {"id": 154, "name": "sand", "supercategory": "ground"}, + {"id": 155, "name": "sea", "supercategory": "water"}, + {"id": 156, "name": "shelf", "supercategory": "furniture-stuff"}, + {"id": 157, "name": "sky-other", "supercategory": "sky"}, + {"id": 158, "name": "skyscraper", "supercategory": "building"}, + {"id": 159, "name": "snow", "supercategory": "ground"}, + {"id": 160, "name": "solid-other", "supercategory": "solid"}, + {"id": 161, "name": "stairs", "supercategory": "furniture-stuff"}, + {"id": 162, "name": "stone", "supercategory": "solid"}, + {"id": 163, "name": "straw", "supercategory": "plant"}, + {"id": 164, "name": "structural-other", "supercategory": "structural"}, + {"id": 165, "name": "table", "supercategory": "furniture-stuff"}, + {"id": 166, "name": "tent", "supercategory": "building"}, + {"id": 167, "name": "textile-other", "supercategory": "textile"}, + {"id": 168, "name": "towel", "supercategory": "textile"}, + {"id": 169, "name": "tree", "supercategory": "plant"}, + {"id": 170, "name": "vegetable", "supercategory": "food-stuff"}, + {"id": 171, "name": "wall-brick", "supercategory": "wall"}, + {"id": 172, "name": "wall-concrete", "supercategory": "wall"}, + {"id": 173, "name": "wall-other", "supercategory": "wall"}, + {"id": 174, "name": "wall-panel", "supercategory": "wall"}, + {"id": 175, "name": "wall-stone", "supercategory": "wall"}, + {"id": 176, "name": "wall-tile", "supercategory": "wall"}, + {"id": 177, "name": "wall-wood", "supercategory": "wall"}, + {"id": 178, "name": "water-other", "supercategory": "water"}, + {"id": 179, "name": "waterdrops", "supercategory": "water"}, + {"id": 180, "name": "window-blind", "supercategory": "window"}, + {"id": 181, "name": "window-other", "supercategory": "window"}, + {"id": 182, "name": "wood", "supercategory": "solid"}, +] + + +ADE20K_150_CATEGORIES = [ + {"color": [120, 120, 120], "id": 0, "isthing": 0, "name": "wall"}, + {"color": [180, 120, 120], "id": 1, "isthing": 0, "name": "building"}, + {"color": [6, 230, 230], "id": 2, "isthing": 0, "name": "sky"}, + {"color": [80, 50, 50], "id": 3, "isthing": 0, "name": "floor"}, + {"color": [4, 200, 3], "id": 4, "isthing": 0, "name": "tree"}, + {"color": [120, 120, 80], "id": 5, "isthing": 0, "name": "ceiling"}, + {"color": [140, 140, 140], "id": 6, "isthing": 0, "name": "road, route"}, + {"color": [204, 5, 255], "id": 7, "isthing": 1, "name": "bed"}, + {"color": [230, 230, 230], "id": 8, "isthing": 1, "name": "window "}, + {"color": [4, 250, 7], "id": 9, "isthing": 0, "name": "grass"}, + {"color": [224, 5, 255], "id": 10, "isthing": 1, "name": "cabinet"}, + {"color": [235, 255, 7], "id": 11, "isthing": 0, "name": "sidewalk, pavement"}, + {"color": [150, 5, 61], "id": 12, "isthing": 1, "name": "person"}, + {"color": [120, 120, 70], "id": 13, "isthing": 0, "name": "earth, ground"}, + {"color": [8, 255, 51], "id": 14, "isthing": 1, "name": "door"}, + {"color": [255, 6, 82], "id": 15, "isthing": 1, "name": "table"}, + {"color": [143, 255, 140], "id": 16, "isthing": 0, "name": "mountain, mount"}, + {"color": [204, 255, 4], "id": 17, "isthing": 0, "name": "plant"}, + {"color": [255, 51, 7], "id": 18, "isthing": 1, "name": "curtain"}, + {"color": [204, 70, 3], "id": 19, "isthing": 1, "name": "chair"}, + {"color": [0, 102, 200], "id": 20, "isthing": 1, "name": "car"}, + {"color": [61, 230, 250], "id": 21, "isthing": 0, "name": "water"}, + {"color": [255, 6, 51], "id": 22, "isthing": 1, "name": "painting, picture"}, + {"color": [11, 102, 255], "id": 23, "isthing": 1, "name": "sofa"}, + {"color": [255, 7, 71], "id": 24, "isthing": 1, "name": "shelf"}, + {"color": [255, 9, 224], "id": 25, "isthing": 0, "name": "house"}, + {"color": [9, 7, 230], "id": 26, "isthing": 0, "name": "sea"}, + {"color": [220, 220, 220], "id": 27, "isthing": 1, "name": "mirror"}, + {"color": [255, 9, 92], "id": 28, "isthing": 0, "name": "rug"}, + {"color": [112, 9, 255], "id": 29, "isthing": 0, "name": "field"}, + {"color": [8, 255, 214], "id": 30, "isthing": 1, "name": "armchair"}, + {"color": [7, 255, 224], "id": 31, "isthing": 1, "name": "seat"}, + {"color": [255, 184, 6], "id": 32, "isthing": 1, "name": "fence"}, + {"color": [10, 255, 71], "id": 33, "isthing": 1, "name": "desk"}, + {"color": [255, 41, 10], "id": 34, "isthing": 0, "name": "rock, stone"}, + {"color": [7, 255, 255], "id": 35, "isthing": 1, "name": "wardrobe, closet, press"}, + {"color": [224, 255, 8], "id": 36, "isthing": 1, "name": "lamp"}, + {"color": [102, 8, 255], "id": 37, "isthing": 1, "name": "tub"}, + {"color": [255, 61, 6], "id": 38, "isthing": 1, "name": "rail"}, + {"color": [255, 194, 7], "id": 39, "isthing": 1, "name": "cushion"}, + {"color": [255, 122, 8], "id": 40, "isthing": 0, "name": "base, pedestal, stand"}, + {"color": [0, 255, 20], "id": 41, "isthing": 1, "name": "box"}, + {"color": [255, 8, 41], "id": 42, "isthing": 1, "name": "column, pillar"}, + {"color": [255, 5, 153], "id": 43, "isthing": 1, "name": "signboard, sign"}, + { + "color": [6, 51, 255], + "id": 44, + "isthing": 1, + "name": "chest of drawers, chest, bureau, dresser", + }, + {"color": [235, 12, 255], "id": 45, "isthing": 1, "name": "counter"}, + {"color": [160, 150, 20], "id": 46, "isthing": 0, "name": "sand"}, + {"color": [0, 163, 255], "id": 47, "isthing": 1, "name": "sink"}, + {"color": [140, 140, 140], "id": 48, "isthing": 0, "name": "skyscraper"}, + {"color": [250, 10, 15], "id": 49, "isthing": 1, "name": "fireplace"}, + {"color": [20, 255, 0], "id": 50, "isthing": 1, "name": "refrigerator, icebox"}, + {"color": [31, 255, 0], "id": 51, "isthing": 0, "name": "grandstand, covered stand"}, + {"color": [255, 31, 0], "id": 52, "isthing": 0, "name": "path"}, + {"color": [255, 224, 0], "id": 53, "isthing": 1, "name": "stairs"}, + {"color": [153, 255, 0], "id": 54, "isthing": 0, "name": "runway"}, + {"color": [0, 0, 255], "id": 55, "isthing": 1, "name": "case, display case, showcase, vitrine"}, + { + "color": [255, 71, 0], + "id": 56, + "isthing": 1, + "name": "pool table, billiard table, snooker table", + }, + {"color": [0, 235, 255], "id": 57, "isthing": 1, "name": "pillow"}, + {"color": [0, 173, 255], "id": 58, "isthing": 1, "name": "screen door, screen"}, + {"color": [31, 0, 255], "id": 59, "isthing": 0, "name": "stairway, staircase"}, + {"color": [11, 200, 200], "id": 60, "isthing": 0, "name": "river"}, + {"color": [255, 82, 0], "id": 61, "isthing": 0, "name": "bridge, span"}, + {"color": [0, 255, 245], "id": 62, "isthing": 1, "name": "bookcase"}, + {"color": [0, 61, 255], "id": 63, "isthing": 0, "name": "blind, screen"}, + {"color": [0, 255, 112], "id": 64, "isthing": 1, "name": "coffee table"}, + { + "color": [0, 255, 133], + "id": 65, + "isthing": 1, + "name": "toilet, can, commode, crapper, pot, potty, stool, throne", + }, + {"color": [255, 0, 0], "id": 66, "isthing": 1, "name": "flower"}, + {"color": [255, 163, 0], "id": 67, "isthing": 1, "name": "book"}, + {"color": [255, 102, 0], "id": 68, "isthing": 0, "name": "hill"}, + {"color": [194, 255, 0], "id": 69, "isthing": 1, "name": "bench"}, + {"color": [0, 143, 255], "id": 70, "isthing": 1, "name": "countertop"}, + {"color": [51, 255, 0], "id": 71, "isthing": 1, "name": "stove"}, + {"color": [0, 82, 255], "id": 72, "isthing": 1, "name": "palm, palm tree"}, + {"color": [0, 255, 41], "id": 73, "isthing": 1, "name": "kitchen island"}, + {"color": [0, 255, 173], "id": 74, "isthing": 1, "name": "computer"}, + {"color": [10, 0, 255], "id": 75, "isthing": 1, "name": "swivel chair"}, + {"color": [173, 255, 0], "id": 76, "isthing": 1, "name": "boat"}, + {"color": [0, 255, 153], "id": 77, "isthing": 0, "name": "bar"}, + {"color": [255, 92, 0], "id": 78, "isthing": 1, "name": "arcade machine"}, + {"color": [255, 0, 255], "id": 79, "isthing": 0, "name": "hovel, hut, hutch, shack, shanty"}, + {"color": [255, 0, 245], "id": 80, "isthing": 1, "name": "bus"}, + {"color": [255, 0, 102], "id": 81, "isthing": 1, "name": "towel"}, + {"color": [255, 173, 0], "id": 82, "isthing": 1, "name": "light"}, + {"color": [255, 0, 20], "id": 83, "isthing": 1, "name": "truck"}, + {"color": [255, 184, 184], "id": 84, "isthing": 0, "name": "tower"}, + {"color": [0, 31, 255], "id": 85, "isthing": 1, "name": "chandelier"}, + {"color": [0, 255, 61], "id": 86, "isthing": 1, "name": "awning, sunshade, sunblind"}, + {"color": [0, 71, 255], "id": 87, "isthing": 1, "name": "street lamp"}, + {"color": [255, 0, 204], "id": 88, "isthing": 1, "name": "booth"}, + {"color": [0, 255, 194], "id": 89, "isthing": 1, "name": "tv"}, + {"color": [0, 255, 82], "id": 90, "isthing": 1, "name": "plane"}, + {"color": [0, 10, 255], "id": 91, "isthing": 0, "name": "dirt track"}, + {"color": [0, 112, 255], "id": 92, "isthing": 1, "name": "clothes"}, + {"color": [51, 0, 255], "id": 93, "isthing": 1, "name": "pole"}, + {"color": [0, 194, 255], "id": 94, "isthing": 0, "name": "land, ground, soil"}, + { + "color": [0, 122, 255], + "id": 95, + "isthing": 1, + "name": "bannister, banister, balustrade, balusters, handrail", + }, + { + "color": [0, 255, 163], + "id": 96, + "isthing": 0, + "name": "escalator, moving staircase, moving stairway", + }, + { + "color": [255, 153, 0], + "id": 97, + "isthing": 1, + "name": "ottoman, pouf, pouffe, puff, hassock", + }, + {"color": [0, 255, 10], "id": 98, "isthing": 1, "name": "bottle"}, + {"color": [255, 112, 0], "id": 99, "isthing": 0, "name": "buffet, counter, sideboard"}, + { + "color": [143, 255, 0], + "id": 100, + "isthing": 0, + "name": "poster, posting, placard, notice, bill, card", + }, + {"color": [82, 0, 255], "id": 101, "isthing": 0, "name": "stage"}, + {"color": [163, 255, 0], "id": 102, "isthing": 1, "name": "van"}, + {"color": [255, 235, 0], "id": 103, "isthing": 1, "name": "ship"}, + {"color": [8, 184, 170], "id": 104, "isthing": 1, "name": "fountain"}, + { + "color": [133, 0, 255], + "id": 105, + "isthing": 0, + "name": "conveyer belt, conveyor belt, conveyer, conveyor, transporter", + }, + {"color": [0, 255, 92], "id": 106, "isthing": 0, "name": "canopy"}, + { + "color": [184, 0, 255], + "id": 107, + "isthing": 1, + "name": "washer, automatic washer, washing machine", + }, + {"color": [255, 0, 31], "id": 108, "isthing": 1, "name": "plaything, toy"}, + {"color": [0, 184, 255], "id": 109, "isthing": 0, "name": "pool"}, + {"color": [0, 214, 255], "id": 110, "isthing": 1, "name": "stool"}, + {"color": [255, 0, 112], "id": 111, "isthing": 1, "name": "barrel, cask"}, + {"color": [92, 255, 0], "id": 112, "isthing": 1, "name": "basket, handbasket"}, + {"color": [0, 224, 255], "id": 113, "isthing": 0, "name": "falls"}, + {"color": [112, 224, 255], "id": 114, "isthing": 0, "name": "tent"}, + {"color": [70, 184, 160], "id": 115, "isthing": 1, "name": "bag"}, + {"color": [163, 0, 255], "id": 116, "isthing": 1, "name": "minibike, motorbike"}, + {"color": [153, 0, 255], "id": 117, "isthing": 0, "name": "cradle"}, + {"color": [71, 255, 0], "id": 118, "isthing": 1, "name": "oven"}, + {"color": [255, 0, 163], "id": 119, "isthing": 1, "name": "ball"}, + {"color": [255, 204, 0], "id": 120, "isthing": 1, "name": "food, solid food"}, + {"color": [255, 0, 143], "id": 121, "isthing": 1, "name": "step, stair"}, + {"color": [0, 255, 235], "id": 122, "isthing": 0, "name": "tank, storage tank"}, + {"color": [133, 255, 0], "id": 123, "isthing": 1, "name": "trade name"}, + {"color": [255, 0, 235], "id": 124, "isthing": 1, "name": "microwave"}, + {"color": [245, 0, 255], "id": 125, "isthing": 1, "name": "pot"}, + {"color": [255, 0, 122], "id": 126, "isthing": 1, "name": "animal"}, + {"color": [255, 245, 0], "id": 127, "isthing": 1, "name": "bicycle"}, + {"color": [10, 190, 212], "id": 128, "isthing": 0, "name": "lake"}, + {"color": [214, 255, 0], "id": 129, "isthing": 1, "name": "dishwasher"}, + {"color": [0, 204, 255], "id": 130, "isthing": 1, "name": "screen"}, + {"color": [20, 0, 255], "id": 131, "isthing": 0, "name": "blanket, cover"}, + {"color": [255, 255, 0], "id": 132, "isthing": 1, "name": "sculpture"}, + {"color": [0, 153, 255], "id": 133, "isthing": 1, "name": "hood, exhaust hood"}, + {"color": [0, 41, 255], "id": 134, "isthing": 1, "name": "sconce"}, + {"color": [0, 255, 204], "id": 135, "isthing": 1, "name": "vase"}, + {"color": [41, 0, 255], "id": 136, "isthing": 1, "name": "traffic light"}, + {"color": [41, 255, 0], "id": 137, "isthing": 1, "name": "tray"}, + {"color": [173, 0, 255], "id": 138, "isthing": 1, "name": "trash can"}, + {"color": [0, 245, 255], "id": 139, "isthing": 1, "name": "fan"}, + {"color": [71, 0, 255], "id": 140, "isthing": 0, "name": "pier"}, + {"color": [122, 0, 255], "id": 141, "isthing": 0, "name": "crt screen"}, + {"color": [0, 255, 184], "id": 142, "isthing": 1, "name": "plate"}, + {"color": [0, 92, 255], "id": 143, "isthing": 1, "name": "monitor"}, + {"color": [184, 255, 0], "id": 144, "isthing": 1, "name": "bulletin board"}, + {"color": [0, 133, 255], "id": 145, "isthing": 0, "name": "shower"}, + {"color": [255, 214, 0], "id": 146, "isthing": 1, "name": "radiator"}, + {"color": [25, 194, 194], "id": 147, "isthing": 1, "name": "glass, drinking glass"}, + {"color": [102, 255, 0], "id": 148, "isthing": 1, "name": "clock"}, + {"color": [92, 0, 255], "id": 149, "isthing": 1, "name": "flag"}, +] + +TEST_CATEGORIES = [ + {"color": [143, 255, 140], "id": 16, "isthing": 0, "name": "Oculus"}, + {"color": [204, 255, 4], "id": 17, "isthing": 0, "name": "Ukulele"}, +] + +COCO_BASE_CATEGORIES = [ + c + for i, c in enumerate(COCO_CATEGORIES) + if c["id"] - 1 + not in [20, 24, 32, 33, 40, 56, 86, 99, 105, 123, 144, 147, 148, 168, 171] +] +COCO_NOVEL_CATEGORIES = [ + c + for i, c in enumerate(COCO_CATEGORIES) + if c["id"] - 1 + in [20, 24, 32, 33, 40, 56, 86, 99, 105, 123, 144, 147, 148, 168, 171] +] + + +def load_cc_image(csv_file, img_key='filepath', caption_key='title', sep="\t"): + print(f'Loading csv data from {csv_file}.') + df = pd.read_csv(csv_file, sep=sep) + + input_files = df[img_key].tolist() + captions = df[caption_key].tolist() + + print("Loaded {} images".format(len(input_files))) + + dataset_dicts = [] + for (img_path, text) in zip(input_files, captions): + record = {} + record["file_name"] = img_path + record["caption"] = text + dataset_dicts.append(record) + + return dataset_dicts + + +def _get_coco_stuff_meta(cat_list): + # Id 0 is reserved for ignore_label, we change ignore_label for 0 + # to 255 in our pre-processing. + stuff_ids = [k["id"] for k in cat_list] + + # For semantic segmentation, this mapping maps from contiguous stuff id + # (in [0, 91], used in models) to ids in the dataset (used for processing results) + stuff_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(stuff_ids)} + stuff_classes = [k["name"] for k in cat_list] + + ret = { + "stuff_dataset_id_to_contiguous_id": stuff_dataset_id_to_contiguous_id, + "stuff_classes": stuff_classes, + } + return ret + + +def register_cc_3m(csv_file): + + meta = _get_coco_stuff_meta(TEST_CATEGORIES) + name = "cc_3m_train" + + DatasetCatalog.register( + name, + lambda x=csv_file: load_cc_image(x), + ) + MetadataCatalog.get(name).set( + csv_file=csv_file, + evaluator_type="dummy", + ignore_label=255, + **meta, + ) + + +# _csv_file = "/home/jeffliang/zsseg/datasets/coco/coco_train_merge_captions.csv" +_csv_file = "/home/jeffliang/zsseg/configs/masked_images/pred/samples.csv" +register_cc_3m(_csv_file) diff --git a/open_vocab_seg/data/datasets/register_coco_stuff.py b/open_vocab_seg/data/datasets/register_coco_stuff.py new file mode 100644 index 0000000000000000000000000000000000000000..d1a0f5b571a971fe20ebc8932d27499de856a565 --- /dev/null +++ b/open_vocab_seg/data/datasets/register_coco_stuff.py @@ -0,0 +1,250 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import os + +from detectron2.data import DatasetCatalog, MetadataCatalog +from detectron2.data.datasets import load_sem_seg + + +COCO_CATEGORIES = [ + {"color": [220, 20, 60], "isthing": 1, "id": 1, "name": "person"}, + {"color": [119, 11, 32], "isthing": 1, "id": 2, "name": "bicycle"}, + {"color": [0, 0, 142], "isthing": 1, "id": 3, "name": "car"}, + {"color": [0, 0, 230], "isthing": 1, "id": 4, "name": "motorcycle"}, + {"color": [106, 0, 228], "isthing": 1, "id": 5, "name": "airplane"}, + {"color": [0, 60, 100], "isthing": 1, "id": 6, "name": "bus"}, + {"color": [0, 80, 100], "isthing": 1, "id": 7, "name": "train"}, + {"color": [0, 0, 70], "isthing": 1, "id": 8, "name": "truck"}, + {"color": [0, 0, 192], "isthing": 1, "id": 9, "name": "boat"}, + {"color": [250, 170, 30], "isthing": 1, "id": 10, "name": "traffic light"}, + {"color": [100, 170, 30], "isthing": 1, "id": 11, "name": "fire hydrant"}, + {"color": [220, 220, 0], "isthing": 1, "id": 13, "name": "stop sign"}, + {"color": [175, 116, 175], "isthing": 1, "id": 14, "name": "parking meter"}, + {"color": [250, 0, 30], "isthing": 1, "id": 15, "name": "bench"}, + {"color": [165, 42, 42], "isthing": 1, "id": 16, "name": "bird"}, + {"color": [255, 77, 255], "isthing": 1, "id": 17, "name": "cat"}, + {"color": [0, 226, 252], "isthing": 1, "id": 18, "name": "dog"}, + {"color": [182, 182, 255], "isthing": 1, "id": 19, "name": "horse"}, + {"color": [0, 82, 0], "isthing": 1, "id": 20, "name": "sheep"}, + {"color": [120, 166, 157], "isthing": 1, "id": 21, "name": "cow"}, + {"color": [110, 76, 0], "isthing": 1, "id": 22, "name": "elephant"}, + {"color": [174, 57, 255], "isthing": 1, "id": 23, "name": "bear"}, + {"color": [199, 100, 0], "isthing": 1, "id": 24, "name": "zebra"}, + {"color": [72, 0, 118], "isthing": 1, "id": 25, "name": "giraffe"}, + {"color": [255, 179, 240], "isthing": 1, "id": 27, "name": "backpack"}, + {"color": [0, 125, 92], "isthing": 1, "id": 28, "name": "umbrella"}, + {"color": [209, 0, 151], "isthing": 1, "id": 31, "name": "handbag"}, + {"color": [188, 208, 182], "isthing": 1, "id": 32, "name": "tie"}, + {"color": [0, 220, 176], "isthing": 1, "id": 33, "name": "suitcase"}, + {"color": [255, 99, 164], "isthing": 1, "id": 34, "name": "frisbee"}, + {"color": [92, 0, 73], "isthing": 1, "id": 35, "name": "skis"}, + {"color": [133, 129, 255], "isthing": 1, "id": 36, "name": "snowboard"}, + {"color": [78, 180, 255], "isthing": 1, "id": 37, "name": "sports ball"}, + {"color": [0, 228, 0], "isthing": 1, "id": 38, "name": "kite"}, + {"color": [174, 255, 243], "isthing": 1, "id": 39, "name": "baseball bat"}, + {"color": [45, 89, 255], "isthing": 1, "id": 40, "name": "baseball glove"}, + {"color": [134, 134, 103], "isthing": 1, "id": 41, "name": "skateboard"}, + {"color": [145, 148, 174], "isthing": 1, "id": 42, "name": "surfboard"}, + {"color": [255, 208, 186], "isthing": 1, "id": 43, "name": "tennis racket"}, + {"color": [197, 226, 255], "isthing": 1, "id": 44, "name": "bottle"}, + {"color": [171, 134, 1], "isthing": 1, "id": 46, "name": "wine glass"}, + {"color": [109, 63, 54], "isthing": 1, "id": 47, "name": "cup"}, + {"color": [207, 138, 255], "isthing": 1, "id": 48, "name": "fork"}, + {"color": [151, 0, 95], "isthing": 1, "id": 49, "name": "knife"}, + {"color": [9, 80, 61], "isthing": 1, "id": 50, "name": "spoon"}, + {"color": [84, 105, 51], "isthing": 1, "id": 51, "name": "bowl"}, + {"color": [74, 65, 105], "isthing": 1, "id": 52, "name": "banana"}, + {"color": [166, 196, 102], "isthing": 1, "id": 53, "name": "apple"}, + {"color": [208, 195, 210], "isthing": 1, "id": 54, "name": "sandwich"}, + {"color": [255, 109, 65], "isthing": 1, "id": 55, "name": "orange"}, + {"color": [0, 143, 149], "isthing": 1, "id": 56, "name": "broccoli"}, + {"color": [179, 0, 194], "isthing": 1, "id": 57, "name": "carrot"}, + {"color": [209, 99, 106], "isthing": 1, "id": 58, "name": "hot dog"}, + {"color": [5, 121, 0], "isthing": 1, "id": 59, "name": "pizza"}, + {"color": [227, 255, 205], "isthing": 1, "id": 60, "name": "donut"}, + {"color": [147, 186, 208], "isthing": 1, "id": 61, "name": "cake"}, + {"color": [153, 69, 1], "isthing": 1, "id": 62, "name": "chair"}, + {"color": [3, 95, 161], "isthing": 1, "id": 63, "name": "couch"}, + {"color": [163, 255, 0], "isthing": 1, "id": 64, "name": "potted plant"}, + {"color": [119, 0, 170], "isthing": 1, "id": 65, "name": "bed"}, + {"color": [0, 182, 199], "isthing": 1, "id": 67, "name": "dining table"}, + {"color": [0, 165, 120], "isthing": 1, "id": 70, "name": "toilet"}, + {"color": [183, 130, 88], "isthing": 1, "id": 72, "name": "tv"}, + {"color": [95, 32, 0], "isthing": 1, "id": 73, "name": "laptop"}, + {"color": [130, 114, 135], "isthing": 1, "id": 74, "name": "mouse"}, + {"color": [110, 129, 133], "isthing": 1, "id": 75, "name": "remote"}, + {"color": [166, 74, 118], "isthing": 1, "id": 76, "name": "keyboard"}, + {"color": [219, 142, 185], "isthing": 1, "id": 77, "name": "cell phone"}, + {"color": [79, 210, 114], "isthing": 1, "id": 78, "name": "microwave"}, + {"color": [178, 90, 62], "isthing": 1, "id": 79, "name": "oven"}, + {"color": [65, 70, 15], "isthing": 1, "id": 80, "name": "toaster"}, + {"color": [127, 167, 115], "isthing": 1, "id": 81, "name": "sink"}, + {"color": [59, 105, 106], "isthing": 1, "id": 82, "name": "refrigerator"}, + {"color": [142, 108, 45], "isthing": 1, "id": 84, "name": "book"}, + {"color": [196, 172, 0], "isthing": 1, "id": 85, "name": "clock"}, + {"color": [95, 54, 80], "isthing": 1, "id": 86, "name": "vase"}, + {"color": [128, 76, 255], "isthing": 1, "id": 87, "name": "scissors"}, + {"color": [201, 57, 1], "isthing": 1, "id": 88, "name": "teddy bear"}, + {"color": [246, 0, 122], "isthing": 1, "id": 89, "name": "hair drier"}, + {"color": [191, 162, 208], "isthing": 1, "id": 90, "name": "toothbrush"}, + {"id": 92, "name": "banner", "supercategory": "textile"}, + {"id": 93, "name": "blanket", "supercategory": "textile"}, + {"id": 94, "name": "branch", "supercategory": "plant"}, + {"id": 95, "name": "bridge", "supercategory": "building"}, + {"id": 96, "name": "building-other", "supercategory": "building"}, + {"id": 97, "name": "bush", "supercategory": "plant"}, + {"id": 98, "name": "cabinet", "supercategory": "furniture-stuff"}, + {"id": 99, "name": "cage", "supercategory": "structural"}, + {"id": 100, "name": "cardboard", "supercategory": "raw-material"}, + {"id": 101, "name": "carpet", "supercategory": "floor"}, + {"id": 102, "name": "ceiling-other", "supercategory": "ceiling"}, + {"id": 103, "name": "ceiling-tile", "supercategory": "ceiling"}, + {"id": 104, "name": "cloth", "supercategory": "textile"}, + {"id": 105, "name": "clothes", "supercategory": "textile"}, + {"id": 106, "name": "clouds", "supercategory": "sky"}, + {"id": 107, "name": "counter", "supercategory": "furniture-stuff"}, + {"id": 108, "name": "cupboard", "supercategory": "furniture-stuff"}, + {"id": 109, "name": "curtain", "supercategory": "textile"}, + {"id": 110, "name": "desk-stuff", "supercategory": "furniture-stuff"}, + {"id": 111, "name": "dirt", "supercategory": "ground"}, + {"id": 112, "name": "door-stuff", "supercategory": "furniture-stuff"}, + {"id": 113, "name": "fence", "supercategory": "structural"}, + {"id": 114, "name": "floor-marble", "supercategory": "floor"}, + {"id": 115, "name": "floor-other", "supercategory": "floor"}, + {"id": 116, "name": "floor-stone", "supercategory": "floor"}, + {"id": 117, "name": "floor-tile", "supercategory": "floor"}, + {"id": 118, "name": "floor-wood", "supercategory": "floor"}, + {"id": 119, "name": "flower", "supercategory": "plant"}, + {"id": 120, "name": "fog", "supercategory": "water"}, + {"id": 121, "name": "food-other", "supercategory": "food-stuff"}, + {"id": 122, "name": "fruit", "supercategory": "food-stuff"}, + {"id": 123, "name": "furniture-other", "supercategory": "furniture-stuff"}, + {"id": 124, "name": "grass", "supercategory": "plant"}, + {"id": 125, "name": "gravel", "supercategory": "ground"}, + {"id": 126, "name": "ground-other", "supercategory": "ground"}, + {"id": 127, "name": "hill", "supercategory": "solid"}, + {"id": 128, "name": "house", "supercategory": "building"}, + {"id": 129, "name": "leaves", "supercategory": "plant"}, + {"id": 130, "name": "light", "supercategory": "furniture-stuff"}, + {"id": 131, "name": "mat", "supercategory": "textile"}, + {"id": 132, "name": "metal", "supercategory": "raw-material"}, + {"id": 133, "name": "mirror-stuff", "supercategory": "furniture-stuff"}, + {"id": 134, "name": "moss", "supercategory": "plant"}, + {"id": 135, "name": "mountain", "supercategory": "solid"}, + {"id": 136, "name": "mud", "supercategory": "ground"}, + {"id": 137, "name": "napkin", "supercategory": "textile"}, + {"id": 138, "name": "net", "supercategory": "structural"}, + {"id": 139, "name": "paper", "supercategory": "raw-material"}, + {"id": 140, "name": "pavement", "supercategory": "ground"}, + {"id": 141, "name": "pillow", "supercategory": "textile"}, + {"id": 142, "name": "plant-other", "supercategory": "plant"}, + {"id": 143, "name": "plastic", "supercategory": "raw-material"}, + {"id": 144, "name": "platform", "supercategory": "ground"}, + {"id": 145, "name": "playingfield", "supercategory": "ground"}, + {"id": 146, "name": "railing", "supercategory": "structural"}, + {"id": 147, "name": "railroad", "supercategory": "ground"}, + {"id": 148, "name": "river", "supercategory": "water"}, + {"id": 149, "name": "road", "supercategory": "ground"}, + {"id": 150, "name": "rock", "supercategory": "solid"}, + {"id": 151, "name": "roof", "supercategory": "building"}, + {"id": 152, "name": "rug", "supercategory": "textile"}, + {"id": 153, "name": "salad", "supercategory": "food-stuff"}, + {"id": 154, "name": "sand", "supercategory": "ground"}, + {"id": 155, "name": "sea", "supercategory": "water"}, + {"id": 156, "name": "shelf", "supercategory": "furniture-stuff"}, + {"id": 157, "name": "sky-other", "supercategory": "sky"}, + {"id": 158, "name": "skyscraper", "supercategory": "building"}, + {"id": 159, "name": "snow", "supercategory": "ground"}, + {"id": 160, "name": "solid-other", "supercategory": "solid"}, + {"id": 161, "name": "stairs", "supercategory": "furniture-stuff"}, + {"id": 162, "name": "stone", "supercategory": "solid"}, + {"id": 163, "name": "straw", "supercategory": "plant"}, + {"id": 164, "name": "structural-other", "supercategory": "structural"}, + {"id": 165, "name": "table", "supercategory": "furniture-stuff"}, + {"id": 166, "name": "tent", "supercategory": "building"}, + {"id": 167, "name": "textile-other", "supercategory": "textile"}, + {"id": 168, "name": "towel", "supercategory": "textile"}, + {"id": 169, "name": "tree", "supercategory": "plant"}, + {"id": 170, "name": "vegetable", "supercategory": "food-stuff"}, + {"id": 171, "name": "wall-brick", "supercategory": "wall"}, + {"id": 172, "name": "wall-concrete", "supercategory": "wall"}, + {"id": 173, "name": "wall-other", "supercategory": "wall"}, + {"id": 174, "name": "wall-panel", "supercategory": "wall"}, + {"id": 175, "name": "wall-stone", "supercategory": "wall"}, + {"id": 176, "name": "wall-tile", "supercategory": "wall"}, + {"id": 177, "name": "wall-wood", "supercategory": "wall"}, + {"id": 178, "name": "water-other", "supercategory": "water"}, + {"id": 179, "name": "waterdrops", "supercategory": "water"}, + {"id": 180, "name": "window-blind", "supercategory": "window"}, + {"id": 181, "name": "window-other", "supercategory": "window"}, + {"id": 182, "name": "wood", "supercategory": "solid"}, +] + +def _get_coco_stuff_meta(cat_list): + # Id 0 is reserved for ignore_label, we change ignore_label for 0 + # to 255 in our pre-processing. + stuff_ids = [k["id"] for k in cat_list] + + # For semantic segmentation, this mapping maps from contiguous stuff id + # (in [0, 91], used in models) to ids in the dataset (used for processing results) + stuff_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(stuff_ids)} + stuff_classes = [k["name"] for k in cat_list] + + ret = { + "stuff_dataset_id_to_contiguous_id": stuff_dataset_id_to_contiguous_id, + "stuff_classes": stuff_classes, + } + return ret + + +def register_all_coco_stuff_10k(root): + root = os.path.join(root, "coco", "coco_stuff_10k") + meta = _get_coco_stuff_meta(COCO_CATEGORIES) + for name, image_dirname, sem_seg_dirname in [ + ("train", "images_detectron2/train", "annotations_detectron2/train"), + ]: + image_dir = os.path.join(root, image_dirname) + gt_dir = os.path.join(root, sem_seg_dirname) + name = f"coco_2017_{name}_stuff_10k_sem_seg" + DatasetCatalog.register( + name, + lambda x=image_dir, y=gt_dir: load_sem_seg( + y, x, gt_ext="png", image_ext="jpg" + ), + ) + MetadataCatalog.get(name).set( + image_root=image_dir, + sem_seg_root=gt_dir, + evaluator_type="sem_seg", + ignore_label=255, + **meta, + ) + + +def register_all_coco_stuff(root): + root = os.path.join(root, "coco") + meta = _get_coco_stuff_meta(COCO_CATEGORIES) + + for name, image_dirname, sem_seg_dirname in [ + ("train", "train2017", "stuffthingmaps_detectron2/train2017"), + ]: + image_dir = os.path.join(root, image_dirname) + gt_dir = os.path.join(root, sem_seg_dirname) + all_name = f"coco_2017_{name}_stuff_sem_seg" + DatasetCatalog.register( + all_name, + lambda x=image_dir, y=gt_dir: load_sem_seg( + y, x, gt_ext="png", image_ext="jpg" + ), + ) + MetadataCatalog.get(all_name).set( + image_root=image_dir, + sem_seg_root=gt_dir, + evaluator_type="sem_seg", + ignore_label=255, + **meta, + ) + + +_root = os.getenv("DETECTRON2_DATASETS", "datasets") +register_all_coco_stuff_10k(_root) +register_all_coco_stuff(_root) diff --git a/open_vocab_seg/data/datasets/register_pascal_context.py b/open_vocab_seg/data/datasets/register_pascal_context.py new file mode 100644 index 0000000000000000000000000000000000000000..e40f87c945da20e78c0a3ea230bc9f36d1800071 --- /dev/null +++ b/open_vocab_seg/data/datasets/register_pascal_context.py @@ -0,0 +1,588 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import os + +from detectron2.data import DatasetCatalog, MetadataCatalog +from detectron2.data.datasets import load_sem_seg + +PASCALCONTEX59_NAMES = ( + "aeroplane", + "bicycle", + "bird", + "boat", + "bottle", + "bus", + "car", + "cat", + "chair", + "cow", + "table", + "dog", + "horse", + "motorbike", + "person", + "pottedplant", + "sheep", + "sofa", + "train", + "tvmonitor", + "bag", + "bed", + "bench", + "book", + "building", + "cabinet", + "ceiling", + "cloth", + "computer", + "cup", + "door", + "fence", + "floor", + "flower", + "food", + "grass", + "ground", + "keyboard", + "light", + "mountain", + "mouse", + "curtain", + "platform", + "sign", + "plate", + "road", + "rock", + "shelves", + "sidewalk", + "sky", + "snow", + "bedclothes", + "track", + "tree", + "truck", + "wall", + "water", + "window", + "wood", +) + +PASCALCONTEX459_NAMES = ( + "accordion", + "aeroplane", + "air conditioner", + "antenna", + "artillery", + "ashtray", + "atrium", + "baby carriage", + "bag", + "ball", + "balloon", + "bamboo weaving", + "barrel", + "baseball bat", + "basket", + "basketball backboard", + "bathtub", + "bed", + "bedclothes", + "beer", + "bell", + "bench", + "bicycle", + "binoculars", + "bird", + "bird cage", + "bird feeder", + "bird nest", + "blackboard", + "board", + "boat", + "bone", + "book", + "bottle", + "bottle opener", + "bowl", + "box", + "bracelet", + "brick", + "bridge", + "broom", + "brush", + "bucket", + "building", + "bus", + "cabinet", + "cabinet door", + "cage", + "cake", + "calculator", + "calendar", + "camel", + "camera", + "camera lens", + "can", + "candle", + "candle holder", + "cap", + "car", + "card", + "cart", + "case", + "casette recorder", + "cash register", + "cat", + "cd", + "cd player", + "ceiling", + "cell phone", + "cello", + "chain", + "chair", + "chessboard", + "chicken", + "chopstick", + "clip", + "clippers", + "clock", + "closet", + "cloth", + "clothes tree", + "coffee", + "coffee machine", + "comb", + "computer", + "concrete", + "cone", + "container", + "control booth", + "controller", + "cooker", + "copying machine", + "coral", + "cork", + "corkscrew", + "counter", + "court", + "cow", + "crabstick", + "crane", + "crate", + "cross", + "crutch", + "cup", + "curtain", + "cushion", + "cutting board", + "dais", + "disc", + "disc case", + "dishwasher", + "dock", + "dog", + "dolphin", + "door", + "drainer", + "dray", + "drink dispenser", + "drinking machine", + "drop", + "drug", + "drum", + "drum kit", + "duck", + "dumbbell", + "earphone", + "earrings", + "egg", + "electric fan", + "electric iron", + "electric pot", + "electric saw", + "electronic keyboard", + "engine", + "envelope", + "equipment", + "escalator", + "exhibition booth", + "extinguisher", + "eyeglass", + "fan", + "faucet", + "fax machine", + "fence", + "ferris wheel", + "fire extinguisher", + "fire hydrant", + "fire place", + "fish", + "fish tank", + "fishbowl", + "fishing net", + "fishing pole", + "flag", + "flagstaff", + "flame", + "flashlight", + "floor", + "flower", + "fly", + "foam", + "food", + "footbridge", + "forceps", + "fork", + "forklift", + "fountain", + "fox", + "frame", + "fridge", + "frog", + "fruit", + "funnel", + "furnace", + "game controller", + "game machine", + "gas cylinder", + "gas hood", + "gas stove", + "gift box", + "glass", + "glass marble", + "globe", + "glove", + "goal", + "grandstand", + "grass", + "gravestone", + "ground", + "guardrail", + "guitar", + "gun", + "hammer", + "hand cart", + "handle", + "handrail", + "hanger", + "hard disk drive", + "hat", + "hay", + "headphone", + "heater", + "helicopter", + "helmet", + "holder", + "hook", + "horse", + "horse-drawn carriage", + "hot-air balloon", + "hydrovalve", + "ice", + "inflator pump", + "ipod", + "iron", + "ironing board", + "jar", + "kart", + "kettle", + "key", + "keyboard", + "kitchen range", + "kite", + "knife", + "knife block", + "ladder", + "ladder truck", + "ladle", + "laptop", + "leaves", + "lid", + "life buoy", + "light", + "light bulb", + "lighter", + "line", + "lion", + "lobster", + "lock", + "machine", + "mailbox", + "mannequin", + "map", + "mask", + "mat", + "match book", + "mattress", + "menu", + "metal", + "meter box", + "microphone", + "microwave", + "mirror", + "missile", + "model", + "money", + "monkey", + "mop", + "motorbike", + "mountain", + "mouse", + "mouse pad", + "musical instrument", + "napkin", + "net", + "newspaper", + "oar", + "ornament", + "outlet", + "oven", + "oxygen bottle", + "pack", + "pan", + "paper", + "paper box", + "paper cutter", + "parachute", + "parasol", + "parterre", + "patio", + "pelage", + "pen", + "pen container", + "pencil", + "person", + "photo", + "piano", + "picture", + "pig", + "pillar", + "pillow", + "pipe", + "pitcher", + "plant", + "plastic", + "plate", + "platform", + "player", + "playground", + "pliers", + "plume", + "poker", + "poker chip", + "pole", + "pool table", + "postcard", + "poster", + "pot", + "pottedplant", + "printer", + "projector", + "pumpkin", + "rabbit", + "racket", + "radiator", + "radio", + "rail", + "rake", + "ramp", + "range hood", + "receiver", + "recorder", + "recreational machines", + "remote control", + "road", + "robot", + "rock", + "rocket", + "rocking horse", + "rope", + "rug", + "ruler", + "runway", + "saddle", + "sand", + "saw", + "scale", + "scanner", + "scissors", + "scoop", + "screen", + "screwdriver", + "sculpture", + "scythe", + "sewer", + "sewing machine", + "shed", + "sheep", + "shell", + "shelves", + "shoe", + "shopping cart", + "shovel", + "sidecar", + "sidewalk", + "sign", + "signal light", + "sink", + "skateboard", + "ski", + "sky", + "sled", + "slippers", + "smoke", + "snail", + "snake", + "snow", + "snowmobiles", + "sofa", + "spanner", + "spatula", + "speaker", + "speed bump", + "spice container", + "spoon", + "sprayer", + "squirrel", + "stage", + "stair", + "stapler", + "stick", + "sticky note", + "stone", + "stool", + "stove", + "straw", + "stretcher", + "sun", + "sunglass", + "sunshade", + "surveillance camera", + "swan", + "sweeper", + "swim ring", + "swimming pool", + "swing", + "switch", + "table", + "tableware", + "tank", + "tap", + "tape", + "tarp", + "telephone", + "telephone booth", + "tent", + "tire", + "toaster", + "toilet", + "tong", + "tool", + "toothbrush", + "towel", + "toy", + "toy car", + "track", + "train", + "trampoline", + "trash bin", + "tray", + "tree", + "tricycle", + "tripod", + "trophy", + "truck", + "tube", + "turtle", + "tvmonitor", + "tweezers", + "typewriter", + "umbrella", + "unknown", + "vacuum cleaner", + "vending machine", + "video camera", + "video game console", + "video player", + "video tape", + "violin", + "wakeboard", + "wall", + "wallet", + "wardrobe", + "washing machine", + "watch", + "water", + "water dispenser", + "water pipe", + "water skate board", + "watermelon", + "whale", + "wharf", + "wheel", + "wheelchair", + "window", + "window blinds", + "wineglass", + "wire", + "wood", + "wool", + +) + + +def _get_voc_meta(cat_list): + ret = { + "stuff_classes": cat_list, + } + return ret + + +def register_pascal_context_59(root): + root = os.path.join(root, "VOCdevkit/VOC2010") + meta = _get_voc_meta(PASCALCONTEX59_NAMES) + for name, image_dirname, sem_seg_dirname in [ + ("val", "JPEGImages", "annotations_detectron2/pc59_val"), + ]: + image_dir = os.path.join(root, image_dirname) + gt_dir = os.path.join(root, sem_seg_dirname) + all_name = f"pascal_context_59_sem_seg_{name}" + DatasetCatalog.register( + all_name, + lambda x=image_dir, y=gt_dir: load_sem_seg( + y, x, gt_ext="png", image_ext="jpg" + ), + ) + MetadataCatalog.get(all_name).set( + image_root=image_dir, + sem_seg_root=gt_dir, + evaluator_type="sem_seg", + ignore_label=255, + **meta, + ) + +def register_pascal_context_459(root): + root = os.path.join(root, "VOCdevkit/VOC2010") + meta = _get_voc_meta(PASCALCONTEX459_NAMES) + for name, image_dirname, sem_seg_dirname in [ + ("val", "JPEGImages", "annotations_detectron2/pc459_val"), + ]: + image_dir = os.path.join(root, image_dirname) + gt_dir = os.path.join(root, sem_seg_dirname) + all_name = f"pascal_context_459_sem_seg_{name}" + DatasetCatalog.register( + all_name, + lambda x=image_dir, y=gt_dir: load_sem_seg( + y, x, gt_ext="tif", image_ext="jpg" + ), + ) + MetadataCatalog.get(all_name).set( + image_root=image_dir, + sem_seg_root=gt_dir, + evaluator_type="sem_seg", + ignore_label=65535, # NOTE: gt is saved in 16-bit TIFF images + **meta, + ) + +_root = os.getenv("DETECTRON2_DATASETS", "datasets") +register_pascal_context_59(_root) +register_pascal_context_459(_root) diff --git a/open_vocab_seg/data/datasets/register_voc_seg.py b/open_vocab_seg/data/datasets/register_voc_seg.py new file mode 100644 index 0000000000000000000000000000000000000000..b8c2be16f4bb5348de8f1051f3579e02e362488f --- /dev/null +++ b/open_vocab_seg/data/datasets/register_voc_seg.py @@ -0,0 +1,62 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import os + +from detectron2.data import DatasetCatalog, MetadataCatalog +from detectron2.data.datasets import load_sem_seg + +PASCALVOC20_NAMES = ( + "aeroplane", + "bicycle", + "bird", + "boat", + "bottle", + "bus", + "car", + "cat", + "chair", + "cow", + "diningtable", + "dog", + "horse", + "motorbike", + "person", + "pottedplant", + "sheep", + "sofa", + "train", + "tvmonitor", +) + +def _get_voc_meta(cat_list): + ret = { + "stuff_classes": cat_list, + } + return ret + + +def register_pascalvoc(root): + root = os.path.join(root, "VOCdevkit/VOC2012") + meta = _get_voc_meta(PASCALVOC20_NAMES) + + for name, image_dirname, sem_seg_dirname in [ + ("val", "JPEGImages", "annotations_detectron2/val"), + ]: + image_dir = os.path.join(root, image_dirname) + gt_dir = os.path.join(root, sem_seg_dirname) + all_name = f"pascalvoc20_sem_seg_{name}" + DatasetCatalog.register( + all_name, + lambda x=image_dir, y=gt_dir: load_sem_seg( + y, x, gt_ext="png", image_ext="jpg" + ), + ) + MetadataCatalog.get(all_name).set( + image_root=image_dir, + sem_seg_root=gt_dir, + evaluator_type="sem_seg", + ignore_label=255, + **meta, + ) + +_root = os.getenv("DETECTRON2_DATASETS", "datasets") +register_pascalvoc(_root) diff --git a/open_vocab_seg/evaluation/__init__.py b/open_vocab_seg/evaluation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b9d36d8e9659a1d31471273a6a0f82c2642ea982 --- /dev/null +++ b/open_vocab_seg/evaluation/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +from .generalized_sem_seg_evaluation import GeneralizedSemSegEvaluator diff --git a/open_vocab_seg/evaluation/generalized_sem_seg_evaluation.py b/open_vocab_seg/evaluation/generalized_sem_seg_evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..ce960ae7cbffde4a981be941ed03a8fc7025ed80 --- /dev/null +++ b/open_vocab_seg/evaluation/generalized_sem_seg_evaluation.py @@ -0,0 +1,159 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +import itertools +import json +import numpy as np +import os +from collections import OrderedDict +import PIL.Image as Image +import torch + +from detectron2.data import DatasetCatalog, MetadataCatalog +from detectron2.utils.comm import all_gather, is_main_process, synchronize +from detectron2.utils.file_io import PathManager + +from detectron2.evaluation import SemSegEvaluator + + +class GeneralizedSemSegEvaluator(SemSegEvaluator): + """ + Evaluate semantic segmentation metrics. + """ + + def __init__( + self, + dataset_name, + distributed=True, + output_dir=None, + *, + num_classes=None, + ignore_label=None, + post_process_func=None, + ): + super().__init__( + dataset_name, + distributed=distributed, + output_dir=output_dir, + num_classes=num_classes, + ignore_label=ignore_label, + ) + meta = MetadataCatalog.get(dataset_name) + try: + self._evaluation_set = meta.evaluation_set + except AttributeError: + self._evaluation_set = None + self.post_process_func = ( + post_process_func + if post_process_func is not None + else lambda x, **kwargs: x + ) + + def process(self, inputs, outputs): + """ + Args: + inputs: the inputs to a model. + It is a list of dicts. Each dict corresponds to an image and + contains keys like "height", "width", "file_name". + outputs: the outputs of a model. It is either list of semantic segmentation predictions + (Tensor [H, W]) or list of dicts with key "sem_seg" that contains semantic + segmentation prediction in the same format. + """ + for input, output in zip(inputs, outputs): + output = self.post_process_func( + output["sem_seg"], image=np.array(Image.open(input["file_name"])) + ) + output = output.argmax(dim=0).to(self._cpu_device) + pred = np.array(output, dtype=np.int) + with PathManager.open( + self.input_file_to_gt_file[input["file_name"]], "rb" + ) as f: + gt = np.array(Image.open(f), dtype=np.int) + + gt[gt == self._ignore_label] = self._num_classes + + self._conf_matrix += np.bincount( + (self._num_classes + 1) * pred.reshape(-1) + gt.reshape(-1), + minlength=self._conf_matrix.size, + ).reshape(self._conf_matrix.shape) + + self._predictions.extend(self.encode_json_sem_seg(pred, input["file_name"])) + + def evaluate(self): + """ + Evaluates standard semantic segmentation metrics (http://cocodataset.org/#stuff-eval): + + * Mean intersection-over-union averaged across classes (mIoU) + * Frequency Weighted IoU (fwIoU) + * Mean pixel accuracy averaged across classes (mACC) + * Pixel Accuracy (pACC) + """ + if self._distributed: + synchronize() + conf_matrix_list = all_gather(self._conf_matrix) + self._predictions = all_gather(self._predictions) + self._predictions = list(itertools.chain(*self._predictions)) + if not is_main_process(): + return + + self._conf_matrix = np.zeros_like(self._conf_matrix) + for conf_matrix in conf_matrix_list: + self._conf_matrix += conf_matrix + + if self._output_dir: + PathManager.mkdirs(self._output_dir) + file_path = os.path.join(self._output_dir, "sem_seg_predictions.json") + with PathManager.open(file_path, "w") as f: + f.write(json.dumps(self._predictions)) + + acc = np.full(self._num_classes, np.nan, dtype=np.float) + iou = np.full(self._num_classes, np.nan, dtype=np.float) + tp = self._conf_matrix.diagonal()[:-1].astype(np.float) + pos_gt = np.sum(self._conf_matrix[:-1, :-1], axis=0).astype(np.float) + class_weights = pos_gt / np.sum(pos_gt) + pos_pred = np.sum(self._conf_matrix[:-1, :-1], axis=1).astype(np.float) + acc_valid = pos_gt > 0 + acc[acc_valid] = tp[acc_valid] / pos_gt[acc_valid] + iou_valid = (pos_gt + pos_pred) > 0 + union = pos_gt + pos_pred - tp + iou[acc_valid] = tp[acc_valid] / union[acc_valid] + macc = np.sum(acc[acc_valid]) / np.sum(acc_valid) + miou = np.sum(iou[acc_valid]) / np.sum(iou_valid) + fiou = np.sum(iou[acc_valid] * class_weights[acc_valid]) + pacc = np.sum(tp) / np.sum(pos_gt) + + res = {} + res["mIoU"] = 100 * miou + res["fwIoU"] = 100 * fiou + for i, name in enumerate(self._class_names): + res["IoU-{}".format(name)] = 100 * iou[i] + res["mACC"] = 100 * macc + res["pACC"] = 100 * pacc + for i, name in enumerate(self._class_names): + res["ACC-{}".format(name)] = 100 * acc[i] + if self._evaluation_set is not None: + for set_name, set_inds in self._evaluation_set.items(): + iou_list = [] + set_inds = np.array(set_inds, np.int) + mask = np.zeros((len(iou),)).astype(np.bool) + mask[set_inds] = 1 + miou = np.sum(iou[mask][acc_valid[mask]]) / np.sum(iou_valid[mask]) + pacc = np.sum(tp[mask]) / np.sum(pos_gt[mask]) + res["mIoU-{}".format(set_name)] = 100 * miou + res["pAcc-{}".format(set_name)] = 100 * pacc + iou_list.append(miou) + miou = np.sum(iou[~mask][acc_valid[~mask]]) / np.sum(iou_valid[~mask]) + pacc = np.sum(tp[~mask]) / np.sum(pos_gt[~mask]) + res["mIoU-un{}".format(set_name)] = 100 * miou + res["pAcc-un{}".format(set_name)] = 100 * pacc + iou_list.append(miou) + res["hIoU-{}".format(set_name)] = ( + 100 * len(iou_list) / sum([1 / iou for iou in iou_list]) + ) + if self._output_dir: + file_path = os.path.join(self._output_dir, "sem_seg_evaluation.pth") + with PathManager.open(file_path, "wb") as f: + torch.save(res, f) + results = OrderedDict({"sem_seg": res}) + self._logger.info(results) + return results diff --git a/open_vocab_seg/mask_former_model.py b/open_vocab_seg/mask_former_model.py new file mode 100644 index 0000000000000000000000000000000000000000..3708d65de4695368b1d088abde4bdf4a9fa39b2b --- /dev/null +++ b/open_vocab_seg/mask_former_model.py @@ -0,0 +1,254 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +from typing import Tuple + +import torch +from torch import nn +from torch.nn import functional as F + +from detectron2.config import configurable +from detectron2.data import MetadataCatalog +from detectron2.modeling import META_ARCH_REGISTRY, build_backbone, build_sem_seg_head +from detectron2.modeling.backbone import Backbone +from detectron2.modeling.postprocessing import sem_seg_postprocess +from detectron2.structures import ImageList + +from .modeling.criterion import SetCriterion +from .modeling.matcher import HungarianMatcher + + +@META_ARCH_REGISTRY.register() +class MaskFormer(nn.Module): + """ + Main class for mask classification semantic segmentation architectures. + """ + + @configurable + def __init__( + self, + *, + backbone: Backbone, + sem_seg_head: nn.Module, + criterion: nn.Module, + num_queries: int, + panoptic_on: bool, + object_mask_threshold: float, + overlap_threshold: float, + metadata, + size_divisibility: int, + sem_seg_postprocess_before_inference: bool, + pixel_mean: Tuple[float], + pixel_std: Tuple[float], + ): + """ + Args: + backbone: a backbone module, must follow detectron2's backbone interface + sem_seg_head: a module that predicts semantic segmentation from backbone features + criterion: a module that defines the loss + num_queries: int, number of queries + panoptic_on: bool, whether to output panoptic segmentation prediction + object_mask_threshold: float, threshold to filter query based on classification score + for panoptic segmentation inference + overlap_threshold: overlap threshold used in general inference for panoptic segmentation + metadata: dataset meta, get `thing` and `stuff` category names for panoptic + segmentation inference + size_divisibility: Some backbones require the input height and width to be divisible by a + specific integer. We can use this to override such requirement. + sem_seg_postprocess_before_inference: whether to resize the prediction back + to original input size before semantic segmentation inference or after. + For high-resolution dataset like Mapillary, resizing predictions before + inference will cause OOM error. + pixel_mean, pixel_std: list or tuple with #channels element, representing + the per-channel mean and std to be used to normalize the input image + """ + super().__init__() + self.backbone = backbone + self.sem_seg_head = sem_seg_head + self.criterion = criterion + self.num_queries = num_queries + self.overlap_threshold = overlap_threshold + self.panoptic_on = panoptic_on + self.object_mask_threshold = object_mask_threshold + self.metadata = metadata + if size_divisibility < 0: + # use backbone size_divisibility if not set + size_divisibility = self.backbone.size_divisibility + self.size_divisibility = size_divisibility + self.sem_seg_postprocess_before_inference = sem_seg_postprocess_before_inference + self.register_buffer( + "pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False + ) + self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) + + @classmethod + def from_config(cls, cfg): + backbone = build_backbone(cfg) + sem_seg_head = build_sem_seg_head(cfg, backbone.output_shape()) + + # Loss parameters: + deep_supervision = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION + no_object_weight = cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT + dice_weight = cfg.MODEL.MASK_FORMER.DICE_WEIGHT + mask_weight = cfg.MODEL.MASK_FORMER.MASK_WEIGHT + + # building criterion + matcher = HungarianMatcher( + cost_class=1, + cost_mask=mask_weight, + cost_dice=dice_weight, + ) + + weight_dict = {"loss_ce": 1, "loss_mask": mask_weight, "loss_dice": dice_weight} + if deep_supervision: + dec_layers = cfg.MODEL.MASK_FORMER.DEC_LAYERS + aux_weight_dict = {} + for i in range(dec_layers - 1): + aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()}) + weight_dict.update(aux_weight_dict) + + losses = ["labels", "masks"] + + criterion = SetCriterion( + sem_seg_head.num_classes, + matcher=matcher, + weight_dict=weight_dict, + eos_coef=no_object_weight, + losses=losses, + ) + + return { + "backbone": backbone, + "sem_seg_head": sem_seg_head, + "criterion": criterion, + "num_queries": cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES, + "panoptic_on": cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON, + "object_mask_threshold": cfg.MODEL.MASK_FORMER.TEST.OBJECT_MASK_THRESHOLD, + "overlap_threshold": cfg.MODEL.MASK_FORMER.TEST.OVERLAP_THRESHOLD, + "metadata": MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), + "size_divisibility": cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY, + "sem_seg_postprocess_before_inference": ( + cfg.MODEL.MASK_FORMER.TEST.SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE + or cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON + ), + "pixel_mean": cfg.MODEL.PIXEL_MEAN, + "pixel_std": cfg.MODEL.PIXEL_STD, + } + + @property + def device(self): + return self.pixel_mean.device + + def forward(self, batched_inputs): + """ + Args: + batched_inputs: a list, batched outputs of :class:`DatasetMapper`. + Each item in the list contains the inputs for one image. + For now, each item in the list is a dict that contains: + * "image": Tensor, image in (C, H, W) format. + * "instances": per-region ground truth + * Other information that's included in the original dicts, such as: + "height", "width" (int): the output resolution of the model (may be different + from input resolution), used in inference. + Returns: + list[dict]: + each dict has the results for one image. The dict contains the following keys: + + * "sem_seg": + A Tensor that represents the + per-pixel segmentation prediced by the head. + The prediction has shape KxHxW that represents the logits of + each class for each pixel. + * "panoptic_seg": + A tuple that represent panoptic output + panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment. + segments_info (list[dict]): Describe each segment in `panoptic_seg`. + Each dict contains keys "id", "category_id", "isthing". + """ + images = [x["image"].to(self.device) for x in batched_inputs] + images = [(x - self.pixel_mean) / self.pixel_std for x in images] + images = ImageList.from_tensors(images, self.size_divisibility) + + features = self.backbone(images.tensor) + outputs = self.sem_seg_head(features) + + if self.training: + # mask classification target + if "instances" in batched_inputs[0]: + gt_instances = [x["instances"].to(self.device) for x in batched_inputs] + targets = self.prepare_targets(gt_instances, images) + else: + targets = None + + # bipartite matching-based loss + losses = self.criterion(outputs, targets) + + for k in list(losses.keys()): + if k in self.criterion.weight_dict: + losses[k] *= self.criterion.weight_dict[k] + else: + # remove this loss if not specified in `weight_dict` + losses.pop(k) + + return losses + else: + mask_cls_results = outputs["pred_logits"] + mask_pred_results = outputs["pred_masks"] + # upsample masks + mask_pred_results = F.interpolate( + mask_pred_results, + size=(images.tensor.shape[-2], images.tensor.shape[-1]), + mode="bilinear", + align_corners=False, + ) + + processed_results = [] + for mask_cls_result, mask_pred_result, input_per_image, image_size in zip( + mask_cls_results, mask_pred_results, batched_inputs, images.image_sizes + ): + height = input_per_image.get("height", image_size[0]) + width = input_per_image.get("width", image_size[1]) + + if self.sem_seg_postprocess_before_inference: + mask_pred_result = sem_seg_postprocess( + mask_pred_result, image_size, height, width + ) + + # semantic segmentation inference + r = self.semantic_inference(mask_cls_result, mask_pred_result) + if not self.sem_seg_postprocess_before_inference: + r = sem_seg_postprocess(r, image_size, height, width) + processed_results.append({"sem_seg": r}) + + # panoptic segmentation inference + if self.panoptic_on: + panoptic_r = self.panoptic_inference( + mask_cls_result, mask_pred_result + ) + processed_results[-1]["panoptic_seg"] = panoptic_r + + return processed_results + + def prepare_targets(self, targets, images): + h, w = images.tensor.shape[-2:] + new_targets = [] + for targets_per_image in targets: + # pad gt + gt_masks = targets_per_image.gt_masks + padded_masks = torch.zeros( + (gt_masks.shape[0], h, w), dtype=gt_masks.dtype, device=gt_masks.device + ) + padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks + new_targets.append( + { + "labels": targets_per_image.gt_classes, + "masks": padded_masks, + } + ) + return new_targets + + def semantic_inference(self, mask_cls, mask_pred): + mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1] + mask_pred = mask_pred.sigmoid() + semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred) + return semseg diff --git a/open_vocab_seg/modeling/.DS_Store b/open_vocab_seg/modeling/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..22e04d81d5a0756554382506a89d270c4397faa6 Binary files /dev/null and b/open_vocab_seg/modeling/.DS_Store differ diff --git a/open_vocab_seg/modeling/__init__.py b/open_vocab_seg/modeling/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9b4dd2628880e93338b39b0b6562b2a5838692b5 --- /dev/null +++ b/open_vocab_seg/modeling/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +from .backbone.swin import D2SwinTransformer +from .backbone.clip_resnet import D2ModifiedResNet +from .heads.mask_former_head import MaskFormerHead +from .heads.open_vocab_mask_former_head import OpenVocabMaskFormerHead +from .heads.pixel_decoder import BasePixelDecoder diff --git a/open_vocab_seg/modeling/backbone/__init__.py b/open_vocab_seg/modeling/backbone/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..49f9003b7a688f5396170dd89c26ef335a2c201f --- /dev/null +++ b/open_vocab_seg/modeling/backbone/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Copyright (c) Meta Platforms, Inc. All Rights Reserved diff --git a/open_vocab_seg/modeling/backbone/clip_resnet.py b/open_vocab_seg/modeling/backbone/clip_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..7d40d88c1eac79a873a1396f7203b3555c68a364 --- /dev/null +++ b/open_vocab_seg/modeling/backbone/clip_resnet.py @@ -0,0 +1,206 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +from collections import OrderedDict +import torch +import torch.nn as nn +from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, dilation=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + + self.conv2 = nn.Conv2d( + planes, planes, 3, padding=1 * dilation, bias=False, dilation=dilation + ) + self.bn2 = nn.BatchNorm2d(planes) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + + self.relu = nn.ReLU(inplace=True) + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential( + OrderedDict( + [ + ("-1", nn.AvgPool2d(stride)), + ( + "0", + nn.Conv2d( + inplanes, + planes * self.expansion, + 1, + stride=1, + bias=False, + ), + ), + ("1", nn.BatchNorm2d(planes * self.expansion)), + ] + ) + ) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu(self.bn1(self.conv1(x))) + out = self.relu(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + return out + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, width=64, strides=[2, 1, 2, 2, 2], multi_grid=[1, 1, 1]): + super().__init__() + + # the 3-layer stem + self.conv1 = nn.Conv2d( + 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False + ) + self.bn1 = nn.BatchNorm2d(width // 2) + self.conv2 = nn.Conv2d( + width // 2, width // 2, kernel_size=3, padding=1, bias=False + ) + self.bn2 = nn.BatchNorm2d(width // 2) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.avgpool = nn.AvgPool2d(strides[0]) if strides[0] > 1 else nn.Identity() + self.relu = nn.ReLU(inplace=True) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0], stride=strides[1]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=strides[2]) + self.layer3 = self._make_layer(width * 4, layers[2], stride=strides[3]) + self.layer4 = self._make_layer( + width * 8, layers[3], stride=strides[4], dilations=multi_grid + ) + self.num_features = [width * 4, width * 8, width * 16, width * 32] + + def _make_layer(self, planes, blocks, stride=1, dilations=None): + if dilations is None: + dilations = [1] * blocks + layers = [Bottleneck(self._inplanes, planes, stride, dilation=dilations[0])] + self._inplanes = planes * Bottleneck.expansion + + for i in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes, dilation=dilations[i])) + + return nn.Sequential(*layers) + + def forward(self, x): + def stem(x): + for conv, bn in [ + (self.conv1, self.bn1), + (self.conv2, self.bn2), + (self.conv3, self.bn3), + ]: + x = self.relu(bn(conv(x))) + x = self.avgpool(x) + return x + + output = {} + x = x.type(self.conv1.weight.dtype) + x = stem(x) # 1/4,1/4 + x = self.layer1(x) + output["res2"] = x + x = self.layer2(x) # 1/8,1/8 + output["res3"] = x + x = self.layer3(x) # 1/16,1/16 + output["res4"] = x + x = self.layer4(x) # 1/32,1/32 + output["res5"] = x + return output + + +@BACKBONE_REGISTRY.register() +class D2ModifiedResNet(ModifiedResNet, Backbone): + def __init__(self, cfg, input_shape): + depth = cfg.MODEL.RESNETS.DEPTH + num_groups = cfg.MODEL.RESNETS.NUM_GROUPS + width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP + bottleneck_channels = num_groups * width_per_group + num_blocks_per_stage = { + 18: [2, 2, 2, 2], + 34: [3, 4, 6, 3], + 50: [3, 4, 6, 3], + 101: [3, 4, 23, 3], + 152: [3, 8, 36, 3], + }[depth] + strides = [2, 1, 2, 2, 2] + multi_grid = cfg.MODEL.RESNETS.RES5_MULTI_GRID + if cfg.MODEL.RESNETS.STEM_TYPE == "deeplab": + strides = [1, 1, 2, 2, 2] + super().__init__( + num_blocks_per_stage, + bottleneck_channels, + strides=strides, + multi_grid=multi_grid, + ) + self._out_features = cfg.MODEL.RESNETS.OUT_FEATURES + + self._out_feature_strides = { + "res2": 4, + "res3": 8, + "res4": 16, + "res5": 32, + } + self._out_feature_channels = { + "res2": self.num_features[0], + "res3": self.num_features[1], + "res4": self.num_features[2], + "res5": self.num_features[3], + } + + def forward(self, x): + """ + Args: + x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``. + Returns: + dict[str->Tensor]: names and the corresponding features + """ + outputs = {} + y = super().forward(x) + for k in y.keys(): + if k in self._out_features: + outputs[k] = y[k] + return outputs + + def output_shape(self): + return { + name: ShapeSpec( + channels=self._out_feature_channels[name], + stride=self._out_feature_strides[name], + ) + for name in self._out_features + } + + @property + def size_divisibility(self): + return 32 diff --git a/open_vocab_seg/modeling/backbone/swin.py b/open_vocab_seg/modeling/backbone/swin.py new file mode 100644 index 0000000000000000000000000000000000000000..aa651bdab51bb353e3be4b5554f41e251803d5cb --- /dev/null +++ b/open_vocab_seg/modeling/backbone/swin.py @@ -0,0 +1,832 @@ +# -------------------------------------------------------- +# Swin Transformer +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ze Liu, Yutong Lin, Yixuan Wei +# -------------------------------------------------------- + +# Copyright (c) Facebook, Inc. and its affiliates. +# Modified by Bowen Cheng from https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation/blob/main/mmseg/models/backbones/swin_transformer.py +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + +from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec + + +class Mlp(nn.Module): + """Multilayer perceptron.""" + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = ( + x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + ) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view( + B, H // window_size, W // window_size, window_size, window_size, -1 + ) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + """Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__( + self, + dim, + window_size, + num_heads, + qkv_bias=True, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + ): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) + ) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = ( + coords_flatten[:, :, None] - coords_flatten[:, None, :] + ) # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute( + 1, 2, 0 + ).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=0.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """Forward function. + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B_, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = ( + qkv[0], + qkv[1], + qkv[2], + ) # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1) + ].view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], + -1, + ) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1 + ).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze( + 1 + ).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class SwinTransformerBlock(nn.Module): + """Swin Transformer Block. + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__( + self, + dim, + num_heads, + window_size=7, + shift_size=0, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + ): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + assert ( + 0 <= self.shift_size < self.window_size + ), "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + + self.H = None + self.W = None + + def forward(self, x, mask_matrix): + """Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + mask_matrix: Attention mask for cyclic shift. + """ + B, L, C = x.shape + H, W = self.H, self.W + assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # pad feature maps to multiples of window size + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll( + x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2) + ) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + + # partition windows + x_windows = window_partition( + shifted_x, self.window_size + ) # nW*B, window_size, window_size, C + x_windows = x_windows.view( + -1, self.window_size * self.window_size, C + ) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn( + x_windows, mask=attn_mask + ) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll( + shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2) + ) + else: + x = shifted_x + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + +class PatchMerging(nn.Module): + """Patch Merging Layer + Args: + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x, H, W): + """Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C) + + # padding + pad_input = (H % 2 == 1) or (W % 2 == 1) + if pad_input: + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + +class BasicLayer(nn.Module): + """A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of feature channels + depth (int): Depths of this stage. + num_heads (int): Number of attention head. + window_size (int): Local window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__( + self, + dim, + depth, + num_heads, + window_size=7, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False, + ): + super().__init__() + self.window_size = window_size + self.shift_size = window_size // 2 + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList( + [ + SwinTransformerBlock( + dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] + if isinstance(drop_path, list) + else drop_path, + norm_layer=norm_layer, + ) + for i in range(depth) + ] + ) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, H, W): + """Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + + # calculate attention mask for SW-MSA + Hp = int(np.ceil(H / self.window_size)) * self.window_size + Wp = int(np.ceil(W / self.window_size)) * self.window_size + img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 + h_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + w_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition( + img_mask, self.window_size + ) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill( + attn_mask == 0, float(0.0) + ) + + for blk in self.blocks: + blk.H, blk.W = H, W + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, attn_mask) + else: + x = blk(x, attn_mask) + if self.downsample is not None: + x_down = self.downsample(x, H, W) + Wh, Ww = (H + 1) // 2, (W + 1) // 2 + return x, H, W, x_down, Wh, Ww + else: + return x, H, W, x, H, W + + +class PatchEmbed(nn.Module): + """Image to Patch Embedding + Args: + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + patch_size = to_2tuple(patch_size) + self.patch_size = patch_size + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size + ) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + """Forward function.""" + # padding + _, _, H, W = x.size() + if W % self.patch_size[1] != 0: + x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) + if H % self.patch_size[0] != 0: + x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) + + x = self.proj(x) # B C Wh Ww + if self.norm is not None: + Wh, Ww = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) + + return x + + +class SwinTransformer(nn.Module): + """Swin Transformer backbone. + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + pretrain_img_size (int): Input image size for training the pretrained model, + used in absolute postion embedding. Default 224. + patch_size (int | tuple(int)): Patch size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + depths (tuple[int]): Depths of each Swin Transformer stage. + num_heads (tuple[int]): Number of attention head of each stage. + window_size (int): Window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. + drop_rate (float): Dropout rate. + attn_drop_rate (float): Attention dropout rate. Default: 0. + drop_path_rate (float): Stochastic depth rate. Default: 0.2. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False. + patch_norm (bool): If True, add normalization after patch embedding. Default: True. + out_indices (Sequence[int]): Output from which stages. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__( + self, + pretrain_img_size=224, + patch_size=4, + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + out_indices=(0, 1, 2, 3), + norm_indices=None, + frozen_stages=-1, + use_checkpoint=False, + projection=False, + project_dim=256, + ): + super().__init__() + + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.out_indices = out_indices + self.norm_indices = norm_indices if norm_indices is not None else out_indices + self.frozen_stages = frozen_stages + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None, + ) + + # absolute position embedding + if self.ape: + pretrain_img_size = to_2tuple(pretrain_img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [ + pretrain_img_size[0] // patch_size[0], + pretrain_img_size[1] // patch_size[1], + ] + + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]) + ) + trunc_normal_(self.absolute_pos_embed, std=0.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim=int(embed_dim * 2 ** i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint, + ) + self.layers.append(layer) + + num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] + self.num_features = num_features + + # add a norm layer for each output + for i_layer in self.norm_indices: + if i_layer >= len(self.num_features): + continue + layer = norm_layer(num_features[i_layer]) + layer_name = f"norm{i_layer}" + self.add_module(layer_name, layer) + # add projector head + self.projection = projection + if projection: + self.project_dim = project_dim + self.norm = norm_layer(self.num_features[-1]) + self.projector = nn.Linear(self.num_features[-1], project_dim, bias=False) + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, x): + """Forward function.""" + x = self.patch_embed(x) + + Wh, Ww = x.size(2), x.size(3) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed = F.interpolate( + self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic" + ) + x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C + else: + x = x.flatten(2).transpose(1, 2) + x = self.pos_drop(x) + + outs = {} + for i in range(self.num_layers): + layer = self.layers[i] + x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) + + if i in self.out_indices: + if i in self.norm_indices: + norm_layer = getattr(self, f"norm{i}") + x_out = norm_layer(x_out) + out = ( + x_out.view(-1, H, W, self.num_features[i]) + .permute(0, 3, 1, 2) + .contiguous() + ) + outs["res{}".format(i + 2)] = out + if self.projection: + x_out = self.norm(x_out) + x_out = x_out.view(-1, H, W, self.num_features[-1]).contiguous() + outs["fc"] = self.projector(x_out).permute(0, 3, 1, 2) + + return outs + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer, self).train(mode) + self._freeze_stages() + + +@BACKBONE_REGISTRY.register() +class D2SwinTransformer(SwinTransformer, Backbone): + def __init__(self, cfg, input_shape): + + pretrain_img_size = cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE + patch_size = cfg.MODEL.SWIN.PATCH_SIZE + in_chans = 3 + embed_dim = cfg.MODEL.SWIN.EMBED_DIM + depths = cfg.MODEL.SWIN.DEPTHS + num_heads = cfg.MODEL.SWIN.NUM_HEADS + window_size = cfg.MODEL.SWIN.WINDOW_SIZE + mlp_ratio = cfg.MODEL.SWIN.MLP_RATIO + qkv_bias = cfg.MODEL.SWIN.QKV_BIAS + qk_scale = cfg.MODEL.SWIN.QK_SCALE + drop_rate = cfg.MODEL.SWIN.DROP_RATE + attn_drop_rate = cfg.MODEL.SWIN.ATTN_DROP_RATE + drop_path_rate = cfg.MODEL.SWIN.DROP_PATH_RATE + norm_layer = nn.LayerNorm + ape = cfg.MODEL.SWIN.APE + patch_norm = cfg.MODEL.SWIN.PATCH_NORM + norm_indices = cfg.MODEL.SWIN.NORM_INDICES + projection = cfg.MODEL.SWIN.PROJECTION + project_dim = cfg.MODEL.SWIN.PROJECT_DIM + super().__init__( + pretrain_img_size, + patch_size, + in_chans, + embed_dim, + depths, + num_heads, + window_size, + mlp_ratio, + qkv_bias, + qk_scale, + drop_rate, + attn_drop_rate, + drop_path_rate, + norm_layer, + ape, + patch_norm, + norm_indices=norm_indices, + projection=projection, + project_dim=project_dim, + ) + + self._out_features = cfg.MODEL.SWIN.OUT_FEATURES + + self._out_feature_strides = { + "res2": 4, + "res3": 8, + "res4": 16, + "res5": 32, + "fc": 32, + } + self._out_feature_channels = { + "res2": self.num_features[0], + "res3": self.num_features[1], + "res4": self.num_features[2], + "res5": self.num_features[3], + "fc": self.num_features[3], + } + + def forward(self, x): + """ + Args: + x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``. + Returns: + dict[str->Tensor]: names and the corresponding features + """ + assert ( + x.dim() == 4 + ), f"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!" + outputs = {} + y = super().forward(x) + for k in y.keys(): + if k in self._out_features: + outputs[k] = y[k] + return outputs + + def output_shape(self): + return { + name: ShapeSpec( + channels=self._out_feature_channels[name], + stride=self._out_feature_strides[name], + ) + for name in self._out_features + } + + @property + def size_divisibility(self): + return 32 diff --git a/open_vocab_seg/modeling/clip_adapter/__init__.py b/open_vocab_seg/modeling/clip_adapter/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6d925b068328373b26352f9a82895d197b47455c --- /dev/null +++ b/open_vocab_seg/modeling/clip_adapter/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +from .text_template import ( + PredefinedPromptExtractor, + ImageNetPromptExtractor, + VILDPromptExtractor, +) +from .adapter import ClipAdapter, MaskFormerClipAdapter + + +def build_text_prompt(cfg): + if cfg.TEXT_TEMPLATES == "predefined": + text_templates = PredefinedPromptExtractor(cfg.PREDEFINED_PROMPT_TEMPLATES) + elif cfg.TEXT_TEMPLATES == "imagenet": + text_templates = ImageNetPromptExtractor() + elif cfg.TEXT_TEMPLATES == "vild": + text_templates = VILDPromptExtractor() + else: + raise NotImplementedError( + "Prompt learner {} is not supported".format(cfg.TEXT_TEMPLATES) + ) + return text_templates + +from .clip import tokenize \ No newline at end of file diff --git a/open_vocab_seg/modeling/clip_adapter/adapter.py b/open_vocab_seg/modeling/clip_adapter/adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..864d20b160714865b4130fab8714f323aaae2572 --- /dev/null +++ b/open_vocab_seg/modeling/clip_adapter/adapter.py @@ -0,0 +1,206 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Copyright (c) Meta Platforms, Inc. All Rights Reserved +# Modified by Feng Liang from +# https://github.com/MendelXu/zsseg.baseline/blob/master/mask_former/modeling/clip_adapter/adapter.py + +from typing import List +import torch +from torch import nn +from torch.nn import functional as F +from detectron2.structures import BitMasks +from .utils import build_clip_model, crop_with_mask +from .text_template import PromptExtractor + + +PIXEL_MEAN = (0.48145466, 0.4578275, 0.40821073) +PIXEL_STD = (0.26862954, 0.26130258, 0.27577711) + + +class ClipAdapter(nn.Module): + def __init__(self, clip_model_name: str, mask_prompt_depth: int, text_templates: PromptExtractor): + super().__init__() + self.clip_model = build_clip_model(clip_model_name, mask_prompt_depth) + self.text_templates = text_templates + self.text_templates.init_buffer(self.clip_model) + self.text_feature_buffer = {} + + def forward(self, image: torch.Tensor, text: List[str], **kwargs): + image = self._preprocess_image(image, **kwargs) + text_feature = self.get_text_features(text) # k,feat_dim + image_features = self.get_image_features(image) + return self.get_sim_logits(text_feature, image_features) + + def _preprocess_image(self, image: torch.Tensor): + return image + + def _get_text_features(self, noun_list: List[str]): + left_noun_list = [ + noun for noun in noun_list if noun not in self.text_feature_buffer + ] + if len(left_noun_list) > 0: + left_text_features = self.text_templates( + left_noun_list, self.clip_model + ) + self.text_feature_buffer.update( + { + noun: text_feature + for noun, text_feature in zip( + left_noun_list, left_text_features + ) + } + ) + return torch.stack([self.text_feature_buffer[noun] for noun in noun_list]) + + + def get_text_features(self, noun_list: List[str]): + return self._get_text_features(noun_list) + + def get_image_features(self, image: torch.Tensor): + image_features = self.clip_model.visual(image) + image_features = image_features / image_features.norm(dim=-1, keepdim=True) + return image_features + + def get_sim_logits( + self, + text_features: torch.Tensor, + image_features: torch.Tensor, + temperature: float = 100, + ): + return temperature * image_features @ text_features.T + + def normalize_feature(self, feat: torch.Tensor): + return feat / feat.norm(dim=-1, keepdim=True) + + +class MaskFormerClipAdapter(ClipAdapter): + def __init__( + self, + clip_model_name: str, + text_templates: PromptExtractor, + mask_fill: str = "mean", + mask_expand_ratio: float = 1.0, + mask_thr: float = 0.5, + mask_matting: bool = False, + region_resized: bool = True, + mask_prompt_depth: int = 0, + mask_prompt_fwd: bool = False, + ): + super().__init__(clip_model_name, mask_prompt_depth, text_templates) + self.non_object_embedding = nn.Parameter( + torch.empty(1, self.clip_model.text_projection.shape[-1]) + ) + nn.init.normal_( + self.non_object_embedding.data, + std=self.clip_model.transformer.width ** -0.5, + ) + # for test + self.mask_fill = mask_fill + if self.mask_fill == "zero": + self.mask_fill = (0.0, 0.0, 0.0) + elif self.mask_fill == "mean": + self.mask_fill = [255.0 * c for c in PIXEL_MEAN] + else: + raise NotImplementedError( + "Unknown mask_fill method: {}".format(self.mask_fill) + ) + self.mask_expand_ratio = mask_expand_ratio + self.mask_thr = mask_thr + self.mask_matting = mask_matting + self.region_resized = region_resized + self.mask_prompt_fwd = mask_prompt_fwd + self.register_buffer( + "pixel_mean", torch.Tensor(PIXEL_MEAN).reshape(1, 3, 1, 1) * 255.0 + ) + self.register_buffer( + "pixel_std", torch.Tensor(PIXEL_STD).reshape(1, 3, 1, 1) * 255.0 + ) + + def forward( + self, + image: torch.Tensor, + text: List[str], + mask: torch.Tensor, + normalize: bool = True, + fwd_w_region_mask: bool = False, + ): + (regions, unnorm_regions), region_masks, valid_flag = self._preprocess_image(image, mask, normalize=normalize) + if regions is None: + return None, valid_flag + if isinstance(regions, list): + assert NotImplementedError + image_features = torch.cat( + [self.get_image_features(image_i) for image_i in regions], dim=0 + ) + else: + if self.mask_prompt_fwd: + image_features = self.get_image_features(regions, region_masks) + else: + image_features = self.get_image_features(regions) + text_feature = self.get_text_features(text) # k,feat_dim + return self.get_sim_logits(text_feature, image_features), unnorm_regions, valid_flag + + def get_image_features(self, image, region_masks=None): + image_features = self.clip_model.visual(image, region_masks) + image_features = image_features / image_features.norm(dim=-1, keepdim=True) + return image_features + + def _preprocess_image( + self, image: torch.Tensor, mask: torch.Tensor, normalize: bool = True + ): + """crop, mask and normalize the image + + Args: + image ([type]): [C,H,W] + mask ([type]): [K,H,W + normalize (bool, optional): [description]. Defaults to True. + """ + dtype = mask.dtype + bin_mask = mask > self.mask_thr + valid = bin_mask.sum(dim=(-1, -2)) > 0 + bin_mask = bin_mask[valid] + mask = mask[valid] + if not self.mask_matting: + mask = bin_mask + bin_mask = BitMasks(bin_mask) + bboxes = bin_mask.get_bounding_boxes() + # crop,mask + regions = [] + region_masks = [] + for bbox, single_mask in zip(bboxes, mask): + region, region_mask = crop_with_mask( + image.type(dtype), + single_mask.type(dtype), + bbox, + fill=self.mask_fill, + expand_ratio=self.mask_expand_ratio, + ) + regions.append(region.unsqueeze(0)) + region_masks.append(region_mask.unsqueeze(0)) + if len(regions) == 0: + return None, valid + unnorm_regions = regions + if normalize: + regions = [(r - self.pixel_mean) / self.pixel_std for r in regions] + # resize + if self.region_resized: + regions = [ + F.interpolate(r, size=(224, 224), mode="bicubic") for r in regions + ] + regions = torch.cat(regions) + region_masks = [ + F.interpolate(r, size=(224, 224), mode="nearest") for r in region_masks + ] + region_masks = torch.cat(region_masks) + unnorm_regions = [ + F.interpolate(r, size=(224, 224), mode="bicubic") for r in unnorm_regions + ] + unnorm_regions = torch.cat(unnorm_regions) + return (regions, unnorm_regions), region_masks, valid + + def get_text_features(self, noun_list: List[str]): + object_text_features = self._get_text_features(noun_list) + non_object_text_features = ( + self.non_object_embedding + / self.non_object_embedding.norm(dim=-1, keepdim=True) + ) + return torch.cat([object_text_features, non_object_text_features], dim=0) diff --git a/open_vocab_seg/modeling/clip_adapter/clip/__init__.py b/open_vocab_seg/modeling/clip_adapter/clip/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dcc5619538c0f7c782508bdbd9587259d805e0d9 --- /dev/null +++ b/open_vocab_seg/modeling/clip_adapter/clip/__init__.py @@ -0,0 +1 @@ +from .clip import * diff --git a/open_vocab_seg/modeling/clip_adapter/clip/bpe_simple_vocab_16e6.txt.gz b/open_vocab_seg/modeling/clip_adapter/clip/bpe_simple_vocab_16e6.txt.gz new file mode 100644 index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113 --- /dev/null +++ b/open_vocab_seg/modeling/clip_adapter/clip/bpe_simple_vocab_16e6.txt.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a +size 1356917 diff --git a/open_vocab_seg/modeling/clip_adapter/clip/clip.py b/open_vocab_seg/modeling/clip_adapter/clip/clip.py new file mode 100644 index 0000000000000000000000000000000000000000..6d733edfac02d81ba3e402eb7e702764728bdaa2 --- /dev/null +++ b/open_vocab_seg/modeling/clip_adapter/clip/clip.py @@ -0,0 +1,285 @@ +import hashlib +import os +import urllib +import warnings +from collections import OrderedDict +from typing import Union, List + +import torch +from PIL import Image +from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize +from tqdm import tqdm + +from .model import build_model +from .simple_tokenizer import SimpleTokenizer as _Tokenizer + +try: + from torchvision.transforms import InterpolationMode + + BICUBIC = InterpolationMode.BICUBIC +except ImportError: + BICUBIC = Image.BICUBIC + + +if torch.__version__.split(".") < ["1", "7", "1"]: + warnings.warn("PyTorch version 1.7.1 or higher is recommended") + + +__all__ = ["available_models", "load", "tokenize"] +_tokenizer = _Tokenizer() + +_MODELS = { + "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", + "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", + "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", + "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", + "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", + "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", + "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", + "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", +} + + +def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): + os.makedirs(root, exist_ok=True) + filename = os.path.basename(url) + + expected_sha256 = url.split("/")[-2] + download_target = os.path.join(root, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if ( + hashlib.sha256(open(download_target, "rb").read()).hexdigest() + == expected_sha256 + ): + return download_target + else: + warnings.warn( + f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file" + ) + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm( + total=int(source.info().get("Content-Length")), + ncols=80, + unit="iB", + unit_scale=True, + ) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if ( + hashlib.sha256(open(download_target, "rb").read()).hexdigest() + != expected_sha256 + ): + raise RuntimeError( + f"Model has been downloaded but the SHA256 checksum does not not match" + ) + + return download_target + + +def _transform(n_px): + return Compose( + [ + Resize(n_px, interpolation=BICUBIC), + CenterCrop(n_px), + lambda image: image.convert("RGB"), + ToTensor(), + Normalize( + (0.48145466, 0.4578275, 0.40821073), + (0.26862954, 0.26130258, 0.27577711), + ), + ] + ) + + +def available_models() -> List[str]: + """Returns the names of available CLIP models""" + return list(_MODELS.keys()) + + +def load( + name: str, + mask_prompt_depth: int = 0, + device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", + jit=False, +): + """Load a CLIP model + + Parameters + ---------- + name : str + A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict + + device : Union[str, torch.device] + The device to put the loaded model + + jit : bool + Whether to load the optimized JIT model or more hackable non-JIT model (default). + + Returns + ------- + model : torch.nn.Module + The CLIP model + + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + if name in _MODELS: + model_path = _download(_MODELS[name]) + elif os.path.isfile(name): + model_path = name + else: + raise RuntimeError( + f"Model {name} not found; available models = {available_models()}" + ) + + try: + # loading JIT archive + model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() + state_dict = None + except RuntimeError: + # loading saved state dict + if jit: + warnings.warn( + f"File {model_path} is not a JIT archive. Loading as a state dict instead" + ) + jit = False + state_dict = torch.load(model_path, map_location="cpu") + if 'state_dict' in state_dict: + new_state_dict = OrderedDict() + for k, v in state_dict['state_dict'].items(): + if k.startswith('module.'): + name = k[7:] # remove `module.` + new_state_dict[name] = v + state_dict = new_state_dict + + if not jit: + model = build_model(state_dict or model.state_dict(), mask_prompt_depth).to(device) + if str(device) == "cpu": + model.float() + return model, _transform(model.visual.input_resolution) + + # patch the device names + device_holder = torch.jit.trace( + lambda: torch.ones([]).to(torch.device(device)), example_inputs=[] + ) + device_node = [ + n + for n in device_holder.graph.findAllNodes("prim::Constant") + if "Device" in repr(n) + ][-1] + + def patch_device(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("prim::Constant"): + if "value" in node.attributeNames() and str(node["value"]).startswith( + "cuda" + ): + node.copyAttributes(device_node) + + model.apply(patch_device) + patch_device(model.encode_image) + patch_device(model.encode_text) + + # patch dtype to float32 on CPU + if str(device) == "cpu": + float_holder = torch.jit.trace( + lambda: torch.ones([]).float(), example_inputs=[] + ) + float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] + float_node = float_input.node() + + def patch_float(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("aten::to"): + inputs = list(node.inputs()) + for i in [ + 1, + 2, + ]: # dtype can be the second or third argument to aten::to() + if inputs[i].node()["value"] == 5: + inputs[i].node().copyAttributes(float_node) + + model.apply(patch_float) + patch_float(model.encode_image) + patch_float(model.encode_text) + + model.float() + + return model, _transform(model.input_resolution.item()) + + +def tokenize( + texts: Union[str, List[str]], + context_length: int = 77, + truncate: bool = False, + return_length: bool = False, +) -> torch.LongTensor: + """ + Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + + context_length : int + The context length to use; all CLIP models use 77 as the context length + + truncate: bool + Whether to truncate the text in case its encoding is longer than the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder["<|startoftext|>"] + eot_token = _tokenizer.encoder["<|endoftext|>"] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + length = [] + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + if truncate: + tokens = tokens[:context_length] + tokens[-1] = eot_token + length.append(context_length) + else: + raise RuntimeError( + f"Input {texts[i]} is too long for context length {context_length}" + ) + else: + length.append(len(tokens)) + result[i, : len(tokens)] = torch.tensor(tokens) + if return_length: + return result, length + return result diff --git a/open_vocab_seg/modeling/clip_adapter/clip/model.py b/open_vocab_seg/modeling/clip_adapter/clip/model.py new file mode 100644 index 0000000000000000000000000000000000000000..8ea730a2cc8a992f9180428bd1fec7fc96aa89dd --- /dev/null +++ b/open_vocab_seg/modeling/clip_adapter/clip/model.py @@ -0,0 +1,613 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Copyright (c) Meta Platforms, Inc. All Rights Reserved +# Modified by Feng Liang from https://github.com/openai/CLIP/blob/main/clip/model.py + +from collections import OrderedDict +from typing import Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + + self.relu = nn.ReLU(inplace=True) + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential( + OrderedDict( + [ + ("-1", nn.AvgPool2d(stride)), + ( + "0", + nn.Conv2d( + inplanes, + planes * self.expansion, + 1, + stride=1, + bias=False, + ), + ), + ("1", nn.BatchNorm2d(planes * self.expansion)), + ] + ) + ) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu(self.bn1(self.conv1(x))) + out = self.relu(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__( + self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None + ): + super().__init__() + self.positional_embedding = nn.Parameter( + torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5 + ) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + self.grid_size = spacial_dim + + def forward(self, x, mask=None, return_cls=True): + b, c, gh, gw = x.shape + # remove irrelated feature + if mask is not None: + mask = F.interpolate(mask[:, None, ...], size=(gh, gw)).squeeze( + 1 + ) # [N,H,W] -> [N,grid,grid] + mask = (mask > 0.5).reshape(mask.shape[0], -1) + mask = torch.cat([mask, mask.new_ones(mask.shape[0], 1)], dim=1) + if x.size()[0] == 1: + x = x.expand(mask.shape[0], c, gh, gw) + + x = x.reshape(x.shape[0], c, gh * gw).permute(2, 0, 1) # NCHW -> (HW)NC + + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + positional_embedding = self.positional_embedding + if not (self.positional_embedding.shape[0] == x.shape[0]): + cls_pos = positional_embedding[0:1, :] + per_pos_embedding = ( + F.interpolate( + positional_embedding[1:, :] + .permute(1, 0) + .view(1, -1, self.grid_size, self.grid_size), + size=(gh, gw), + mode="bicubic", + ) + .reshape(-1, gh * gw) + .permute(1, 0) + ) + positional_embedding = torch.cat([cls_pos, per_pos_embedding]) + + x = x + positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x, + key=x, + value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat( + [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias] + ), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False, + key_padding_mask=mask, + ) + + if return_cls: + return x[0] + else: + return x + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): + super().__init__() + self.output_dim = output_dim + self.input_resolution = input_resolution + + # the 3-layer stem + self.conv1 = nn.Conv2d( + 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False + ) + self.bn1 = nn.BatchNorm2d(width // 2) + self.conv2 = nn.Conv2d( + width // 2, width // 2, kernel_size=3, padding=1, bias=False + ) + self.bn2 = nn.BatchNorm2d(width // 2) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.avgpool = nn.AvgPool2d(2) + self.relu = nn.ReLU(inplace=True) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d( + input_resolution // 32, embed_dim, heads, output_dim + ) + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x, mask: torch.Tensor = None, return_cls=True): + def stem(x): + for conv, bn in [ + (self.conv1, self.bn1), + (self.conv2, self.bn2), + (self.conv3, self.bn3), + ]: + x = self.relu(bn(conv(x))) + x = self.avgpool(x) + return x + + x = x.type(self.conv1.weight.dtype) + x = stem(x) # 1/4,1/4 + x = self.layer1(x) + x = self.layer2(x) # 1/8,1/8 + x = self.layer3(x) # 1/16,1/16 + x = self.layer4(x) # 1/32,1/32 + b, c, gh, gw = x.shape + x = self.attnpool(x, mask, return_cls) + if not return_cls: + return x[1:].permute(1, 0, 2).reshape(b, gh, gw, x.shape[-1]) # N,L,C + return x + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential( + OrderedDict( + [ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)), + ] + ) + ) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor, **kwargs): + self.attn_mask = ( + self.attn_mask.to(dtype=x.dtype, device=x.device) + if self.attn_mask is not None + else None + ) + return self.attn( + x, x, x, need_weights=False, attn_mask=self.attn_mask, **kwargs + )[0] + + def forward(self, x: torch.Tensor, **kwargs): + x = x + self.attention(self.ln_1(x), **kwargs) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__( + self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None + ): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential( + *[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)] + ) + + def forward(self, x: torch.Tensor, **kwargs): + for block in self.resblocks: + x = block(x, **kwargs) + return x + + +class VisionTransformer(nn.Module): + def __init__( + self, + input_resolution: int, + patch_size: int, + mask_prompt_depth: int, + width: int, + layers: int, + heads: int, + output_dim: int, + ): + super().__init__() + self.input_resolution = input_resolution + self.output_dim = output_dim + self.conv1 = nn.Conv2d( + in_channels=3, + out_channels=width, + kernel_size=patch_size, + stride=patch_size, + bias=False, + ) + + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter( + scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width) + ) + self.grid_size = input_resolution // patch_size + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, layers, heads) + + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + self.mask_pool = nn.AvgPool2d(patch_size, stride=patch_size) + self.mask_prompt_depth = mask_prompt_depth + self.mask_embedding = nn.Parameter(torch.zeros(self.mask_prompt_depth, self.grid_size * self.grid_size, width)) + + def forward(self, x: torch.Tensor, m: torch.Tensor = None): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + if m is not None: + m = self.mask_pool(m.to(torch.float).squeeze()).reshape(m.shape[0], -1).unsqueeze(-1) + m = torch.ceil(m) + if self.mask_embedding.shape[1] == 1: + mask_embedding = self.mask_embedding.to(x.dtype).repeat(1, x.shape[1], 1) + else: + mask_embedding = self.mask_embedding.to(x.dtype) + x = x * m + mask_embedding[0].unsqueeze(0) * (1 - m) + + x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + if m is not None: + for i, blk in enumerate(self.transformer.resblocks): + d = i + 1 + x = blk(x) + if d < self.mask_prompt_depth: + masked_x = x[1:, :, :] * m.permute(1, 0, 2) + \ + mask_embedding[d].unsqueeze(0).permute(1, 0, 2) * (1 - m.permute(1, 0, 2)) + x = torch.cat([x[:1, :, :], masked_x], dim=0) + else: + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_post(x[:, 0, :]) + + if self.proj is not None: + x = x @ self.proj + + return x + + + +class CLIP(nn.Module): + def __init__( + self, + embed_dim: int, + # vision + image_resolution: int, + vision_layers: Union[Tuple[int, int, int, int], int], + vision_width: int, + vision_patch_size: int, + mask_prompt_depth: int, + # text + context_length: int, + vocab_size: int, + transformer_width: int, + transformer_heads: int, + transformer_layers: int, + ): + super().__init__() + + self.context_length = context_length + + if isinstance(vision_layers, (tuple, list)): + vision_heads = vision_width * 32 // 64 + self.visual = ModifiedResNet( + layers=vision_layers, + output_dim=embed_dim, + heads=vision_heads, + input_resolution=image_resolution, + width=vision_width, + ) + else: + vision_heads = vision_width // 64 + self.visual = VisionTransformer( + input_resolution=image_resolution, + patch_size=vision_patch_size, + mask_prompt_depth=mask_prompt_depth, + width=vision_width, + layers=vision_layers, + heads=vision_heads, + output_dim=embed_dim, + ) + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask(), + ) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter( + torch.empty(self.context_length, transformer_width) + ) + self.ln_final = LayerNorm(transformer_width) + + self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + self.initialize_parameters() + + def initialize_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + if isinstance(self.visual, ModifiedResNet): + if self.visual.attnpool is not None: + std = self.visual.attnpool.c_proj.in_features ** -0.5 + nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) + + for resnet_block in [ + self.visual.layer1, + self.visual.layer2, + self.visual.layer3, + self.visual.layer4, + ]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + proj_std = (self.transformer.width ** -0.5) * ( + (2 * self.transformer.layers) ** -0.5 + ) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + @property + def dtype(self): + return self.visual.conv1.weight.dtype + + def encode_image(self, image, **kwargs): + return self.visual(image.type(self.dtype), **kwargs) + + def encode_text(self, text): + x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.type(self.dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x).type(self.dtype) + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + + return x + + def forward(self, image, text): + image_features = self.encode_image(image) + text_features = self.encode_text(text) + + # normalized features + image_features = image_features / image_features.norm(dim=-1, keepdim=True) + text_features = text_features / text_features.norm(dim=-1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logit_scale * text_features @ image_features.t() + + # shape = [global_batch_size, global_batch_size] + return logits_per_image, logits_per_text + + +def convert_weights(model: nn.Module): + """Convert applicable model parameters to fp16""" + + def _convert_weights_to_fp16(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + + if isinstance(l, nn.MultiheadAttention): + for attr in [ + *[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], + "in_proj_bias", + "bias_k", + "bias_v", + ]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.half() + + for name in ["text_projection", "proj"]: + if hasattr(l, name): + attr = getattr(l, name) + if attr is not None: + attr.data = attr.data.half() + + model.apply(_convert_weights_to_fp16) + + +def build_model(state_dict: dict, mask_prompt_depth: int = 0): + vit = "visual.proj" in state_dict + + if vit: + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len( + [ + k + for k in state_dict.keys() + if k.startswith("visual.") and k.endswith(".attn.in_proj_weight") + ] + ) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round( + (state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5 + ) + image_resolution = vision_patch_size * grid_size + else: + assert mask_prompt_depth == 0, 'ResNets do not support mask prompt tuning' + counts: list = [ + len( + set( + k.split(".")[2] + for k in state_dict + if k.startswith(f"visual.layer{b}") + ) + ) + for b in [1, 2, 3, 4] + ] + vision_layers = tuple(counts) + vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] + output_width = round( + (state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5 + ) + vision_patch_size = None + assert ( + output_width ** 2 + 1 + == state_dict["visual.attnpool.positional_embedding"].shape[0] + ) + image_resolution = output_width * 32 + + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len( + set( + k.split(".")[2] + for k in state_dict + if k.startswith(f"transformer.resblocks") + ) + ) + + model = CLIP( + embed_dim, + image_resolution, + vision_layers, + vision_width, + vision_patch_size, + mask_prompt_depth, + context_length, + vocab_size, + transformer_width, + transformer_heads, + transformer_layers, + ) + + for key in ["input_resolution", "context_length", "vocab_size"]: + if key in state_dict: + del state_dict[key] + + convert_weights(model) + model.load_state_dict(state_dict, strict=False) + return model.eval() diff --git a/open_vocab_seg/modeling/clip_adapter/clip/simple_tokenizer.py b/open_vocab_seg/modeling/clip_adapter/clip/simple_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..56d17512b06afb700e7834e4f3f6515c315ebb74 --- /dev/null +++ b/open_vocab_seg/modeling/clip_adapter/clip/simple_tokenizer.py @@ -0,0 +1,150 @@ +import gzip +import html +import os +from functools import lru_cache + +import ftfy +import regex as re + + +@lru_cache() +def default_bpe(): + return os.path.join( + os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz" + ) + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + + list(range(ord("¡"), ord("¬") + 1)) + + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2 ** 8): + if b not in bs: + bs.append(b) + cs.append(2 ** 8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe()): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split("\n") + merges = merges[1 : 49152 - 256 - 2 + 1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v + "" for v in vocab] + for merge in merges: + vocab.append("".join(merge)) + vocab.extend(["<|startoftext|>", "<|endoftext|>"]) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = { + "<|startoftext|>": "<|startoftext|>", + "<|endoftext|>": "<|endoftext|>", + } + self.pat = re.compile( + r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", + re.IGNORECASE, + ) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + (token[-1] + "",) + pairs = get_pairs(word) + + if not pairs: + return token + "" + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) + bpe_tokens.extend( + self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ") + ) + return bpe_tokens + + def decode(self, tokens): + text = "".join([self.decoder[token] for token in tokens]) + text = ( + bytearray([self.byte_decoder[c] for c in text]) + .decode("utf-8", errors="replace") + .replace("", " ") + ) + return text diff --git a/open_vocab_seg/modeling/clip_adapter/text_template.py b/open_vocab_seg/modeling/clip_adapter/text_template.py new file mode 100644 index 0000000000000000000000000000000000000000..724bbef34c6bd74b0d7ead336d6b06d145bbee2d --- /dev/null +++ b/open_vocab_seg/modeling/clip_adapter/text_template.py @@ -0,0 +1,156 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Copyright (c) Meta Platforms, Inc. All Rights Reserved +# Modified by Feng Liang from +# https://github.com/MendelXu/zsseg.baseline/blob/master/mask_former/modeling/clip_adapter/text_prompt.py +# https://github.com/MendelXu/zsseg.baseline/blob/master/mask_former/modeling/clip_adapter/utils.py + +from typing import List + +# import clip +from .clip import tokenize +import torch +from torch import nn + +IMAGENET_PROMPT = [ + "a bad photo of a {}.", + "a photo of many {}.", + "a sculpture of a {}.", + "a photo of the hard to see {}.", + "a low resolution photo of the {}.", + "a rendering of a {}.", + "graffiti of a {}.", + "a bad photo of the {}.", + "a cropped photo of the {}.", + "a tattoo of a {}.", + "the embroidered {}.", + "a photo of a hard to see {}.", + "a bright photo of a {}.", + "a photo of a clean {}.", + "a photo of a dirty {}.", + "a dark photo of the {}.", + "a drawing of a {}.", + "a photo of my {}.", + "the plastic {}.", + "a photo of the cool {}.", + "a close-up photo of a {}.", + "a black and white photo of the {}.", + "a painting of the {}.", + "a painting of a {}.", + "a pixelated photo of the {}.", + "a sculpture of the {}.", + "a bright photo of the {}.", + "a cropped photo of a {}.", + "a plastic {}.", + "a photo of the dirty {}.", + "a jpeg corrupted photo of a {}.", + "a blurry photo of the {}.", + "a photo of the {}.", + "a good photo of the {}.", + "a rendering of the {}.", + "a {} in a video game.", + "a photo of one {}.", + "a doodle of a {}.", + "a close-up photo of the {}.", + "a photo of a {}.", + "the origami {}.", + "the {} in a video game.", + "a sketch of a {}.", + "a doodle of the {}.", + "a origami {}.", + "a low resolution photo of a {}.", + "the toy {}.", + "a rendition of the {}.", + "a photo of the clean {}.", + "a photo of a large {}.", + "a rendition of a {}.", + "a photo of a nice {}.", + "a photo of a weird {}.", + "a blurry photo of a {}.", + "a cartoon {}.", + "art of a {}.", + "a sketch of the {}.", + "a embroidered {}.", + "a pixelated photo of a {}.", + "itap of the {}.", + "a jpeg corrupted photo of the {}.", + "a good photo of a {}.", + "a plushie {}.", + "a photo of the nice {}.", + "a photo of the small {}.", + "a photo of the weird {}.", + "the cartoon {}.", + "art of the {}.", + "a drawing of the {}.", + "a photo of the large {}.", + "a black and white photo of a {}.", + "the plushie {}.", + "a dark photo of a {}.", + "itap of a {}.", + "graffiti of the {}.", + "a toy {}.", + "itap of my {}.", + "a photo of a cool {}.", + "a photo of a small {}.", + "a tattoo of the {}.", +] + +VILD_PROMPT = [ + "a photo of a {}.", + "This is a photo of a {}", + "There is a {} in the scene", + "There is the {} in the scene", + "a photo of a {} in the scene", + "a photo of a small {}.", + "a photo of a medium {}.", + "a photo of a large {}.", + "This is a photo of a small {}.", + "This is a photo of a medium {}.", + "This is a photo of a large {}.", + "There is a small {} in the scene.", + "There is a medium {} in the scene.", + "There is a large {} in the scene.", +] + +class PromptExtractor(nn.Module): + def __init__(self): + super().__init__() + self._buffer_init = False + + def init_buffer(self, clip_model): + self._buffer_init = True + + def forward(self, noun_list: List[str], clip_model: nn.Module): + raise NotImplementedError() + + +class PredefinedPromptExtractor(PromptExtractor): + def __init__(self, templates: List[str]): + super().__init__() + self.templates = templates + + def forward(self, noun_list: List[str], clip_model: nn.Module): + text_features_bucket = [] + for template in self.templates: + noun_tokens = [tokenize(template.format(noun)) for noun in noun_list] + text_inputs = torch.cat(noun_tokens).to( + clip_model.text_projection.data.device + ) + text_features = clip_model.encode_text(text_inputs) + text_features /= text_features.norm(dim=-1, keepdim=True) + text_features_bucket.append(text_features) + del text_inputs + # ensemble by averaging + text_features = torch.stack(text_features_bucket).mean(dim=0) + text_features = text_features / text_features.norm(dim=-1, keepdim=True) + + return text_features + + +class ImageNetPromptExtractor(PredefinedPromptExtractor): + def __init__(self): + super().__init__(IMAGENET_PROMPT) + + +class VILDPromptExtractor(PredefinedPromptExtractor): + def __init__(self): + super().__init__(VILD_PROMPT) diff --git a/open_vocab_seg/modeling/clip_adapter/utils.py b/open_vocab_seg/modeling/clip_adapter/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..dbe5d9d5284597cca444287f6bae38e37549bde0 --- /dev/null +++ b/open_vocab_seg/modeling/clip_adapter/utils.py @@ -0,0 +1,81 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +from typing import Tuple +import numpy as np +import torch +from .clip import load as clip_load +from detectron2.utils.comm import get_local_rank, synchronize + + +def expand_box( + x1: float, + y1: float, + x2: float, + y2: float, + expand_ratio: float = 1.0, + max_h: int = None, + max_w: int = None, +): + cx = 0.5 * (x1 + x2) + cy = 0.5 * (y1 + y2) + w = x2 - x1 + h = y2 - y1 + w = w * expand_ratio + h = h * expand_ratio + box = [cx - 0.5 * w, cy - 0.5 * h, cx + 0.5 * w, cy + 0.5 * h] + if max_h is not None: + box[1] = max(0, box[1]) + box[3] = min(max_h - 1, box[3]) + if max_w is not None: + box[0] = max(0, box[0]) + box[2] = min(max_w - 1, box[2]) + return [int(b) for b in box] + + +def mask2box(mask: torch.Tensor): + # use naive way + row = torch.nonzero(mask.sum(dim=0))[:, 0] + if len(row) == 0: + return None + x1 = row.min() + x2 = row.max() + col = np.nonzero(mask.sum(dim=1))[:, 0] + y1 = col.min() + y2 = col.max() + return x1, y1, x2 + 1, y2 + 1 + + +def crop_with_mask( + image: torch.Tensor, + mask: torch.Tensor, + bbox: torch.Tensor, + fill: Tuple[float, float, float] = (0, 0, 0), + expand_ratio: float = 1.0, +): + l, t, r, b = expand_box(*bbox, expand_ratio) + _, h, w = image.shape + l = max(l, 0) + t = max(t, 0) + r = min(r, w) + b = min(b, h) + new_image = torch.cat( + [image.new_full((1, b - t, r - l), fill_value=val) for val in fill] + ) + mask_bool = mask.bool() + return image[:, t:b, l:r] * mask[None, t:b, l:r] + (~ mask_bool[None, t:b, l:r]) * new_image, mask[None, t:b, l:r] + + +def build_clip_model(model: str, mask_prompt_depth: int = 0, frozen: bool = True): + rank = get_local_rank() + if rank == 0: + # download on rank 0 only + model, _ = clip_load(model, mask_prompt_depth=mask_prompt_depth, device="cpu") + synchronize() + if rank != 0: + model, _ = clip_load(model, mask_prompt_depth=mask_prompt_depth, device="cpu") + synchronize() + if frozen: + for param in model.parameters(): + param.requires_grad = False + return model diff --git a/open_vocab_seg/modeling/criterion.py b/open_vocab_seg/modeling/criterion.py new file mode 100644 index 0000000000000000000000000000000000000000..f4d5b71242f87c6f67463f9c31f873a742f3e5c7 --- /dev/null +++ b/open_vocab_seg/modeling/criterion.py @@ -0,0 +1,229 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/models/detr.py +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +""" +MaskFormer criterion. +""" +import torch +import torch.nn.functional as F +from torch import nn + +from detectron2.utils.comm import get_world_size + +from ..utils.misc import is_dist_avail_and_initialized, nested_tensor_from_tensor_list + + +def dice_loss(inputs, targets, num_masks): + """ + Compute the DICE loss, similar to generalized IOU for masks + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + """ + inputs = inputs.sigmoid() + inputs = inputs.flatten(1) + numerator = 2 * (inputs * targets).sum(-1) + denominator = inputs.sum(-1) + targets.sum(-1) + loss = 1 - (numerator + 1) / (denominator + 1) + return loss.sum() / num_masks + + +def sigmoid_focal_loss( + inputs, targets, num_masks, alpha: float = 0.25, gamma: float = 2 +): + """ + Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + alpha: (optional) Weighting factor in range (0,1) to balance + positive vs negative examples. Default = -1 (no weighting). + gamma: Exponent of the modulating factor (1 - p_t) to + balance easy vs hard examples. + Returns: + Loss tensor + """ + prob = inputs.sigmoid() + ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") + p_t = prob * targets + (1 - prob) * (1 - targets) + loss = ce_loss * ((1 - p_t) ** gamma) + + if alpha >= 0: + alpha_t = alpha * targets + (1 - alpha) * (1 - targets) + loss = alpha_t * loss + + return loss.mean(1).sum() / num_masks + + +class SetCriterion(nn.Module): + """This class computes the loss for DETR. + The process happens in two steps: + 1) we compute hungarian assignment between ground truth boxes and the outputs of the model + 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) + """ + + def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses): + """Create the criterion. + Parameters: + num_classes: number of object categories, omitting the special no-object category + matcher: module able to compute a matching between targets and proposals + weight_dict: dict containing as key the names of the losses and as values their relative weight. + eos_coef: relative classification weight applied to the no-object category + losses: list of all the losses to be applied. See get_loss for list of available losses. + """ + super().__init__() + self.num_classes = num_classes + self.matcher = matcher + self.weight_dict = weight_dict + self.eos_coef = eos_coef + self.losses = losses + if eos_coef > 0: + + empty_weight = torch.ones(self.num_classes + 1) + + empty_weight[-1] = self.eos_coef + self.register_buffer("empty_weight", empty_weight) + self.use_ignore_idx = False + else: + self.use_ignore_idx = True + self.cur_target = [] + + def loss_labels(self, outputs, targets, indices, num_masks): + """Classification loss (NLL) + targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] + """ + assert "pred_logits" in outputs + src_logits = outputs["pred_logits"] + + idx = self._get_src_permutation_idx(indices) + target_classes_o = torch.cat( + [t["labels"][J] for t, (_, J) in zip(targets, indices)] + ) + target_classes = torch.full( + src_logits.shape[:2], + self.num_classes, + dtype=torch.int64, + device=src_logits.device, + ) + target_classes[idx] = target_classes_o + if self.use_ignore_idx: + loss_ce = F.cross_entropy( + src_logits.transpose(1, 2), + target_classes, + ignore_index=self.num_classes, + ) + else: + if "empty_weight" in outputs: + empty_weight = torch.cat( + [outputs["empty_weight"], self.empty_weight[-1:]] + ).detach() + else: + empty_weight = self.empty_weight + loss_ce = F.cross_entropy( + src_logits.transpose(1, 2), target_classes, empty_weight + ) + losses = {"loss_ce": loss_ce} + return losses + + def loss_masks(self, outputs, targets, indices, num_masks): + """Compute the losses related to the masks: the focal loss and the dice loss. + targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w] + """ + assert "pred_masks" in outputs + + src_idx = self._get_src_permutation_idx(indices) + tgt_idx = self._get_tgt_permutation_idx(indices) + src_masks = outputs["pred_masks"] + src_masks = src_masks[src_idx] + masks = [t["masks"] for t in targets] + # TODO use valid to mask invalid areas due to padding in loss + target_masks, valid = nested_tensor_from_tensor_list(masks).decompose() + target_masks = target_masks.to(src_masks) + target_masks = target_masks[tgt_idx] + + # upsample predictions to the target size + src_masks = F.interpolate( + src_masks[:, None], + size=target_masks.shape[-2:], + mode="bilinear", + align_corners=False, + ) + src_masks = src_masks[:, 0].flatten(1) + + target_masks = target_masks.flatten(1) + target_masks = target_masks.view(src_masks.shape) + losses = { + "loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_masks), + "loss_dice": dice_loss(src_masks, target_masks, num_masks), + } + return losses + + def _get_src_permutation_idx(self, indices): + # permute predictions following indices + batch_idx = torch.cat( + [torch.full_like(src, i) for i, (src, _) in enumerate(indices)] + ) + src_idx = torch.cat([src for (src, _) in indices]) + return batch_idx, src_idx + + def _get_tgt_permutation_idx(self, indices): + # permute targets following indices + batch_idx = torch.cat( + [torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)] + ) + tgt_idx = torch.cat([tgt for (_, tgt) in indices]) + return batch_idx, tgt_idx + + def get_loss(self, loss, outputs, targets, indices, num_masks): + loss_map = {"labels": self.loss_labels, "masks": self.loss_masks} + assert loss in loss_map, f"do you really want to compute {loss} loss?" + return loss_map[loss](outputs, targets, indices, num_masks) + + def forward(self, outputs, targets): + """This performs the loss computation. + Parameters: + outputs: dict of tensors, see the output specification of the model for the format + targets: list of dicts, such that len(targets) == batch_size. + The expected keys in each dict depends on the losses applied, see each loss' doc + """ + outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"} + + # Retrieve the matching between the outputs of the last layer and the targets + indices = self.matcher(outputs_without_aux, targets) + + # Compute the average number of target boxes accross all nodes, for normalization purposes + num_masks = sum(len(t["labels"]) for t in targets) + num_masks = torch.as_tensor( + [num_masks], dtype=torch.float, device=next(iter(outputs.values())).device + ) + if is_dist_avail_and_initialized(): + torch.distributed.all_reduce(num_masks) + num_masks = torch.clamp(num_masks / get_world_size(), min=1).item() + + # Compute all the requested losses + losses = {} + for loss in self.losses: + losses.update(self.get_loss(loss, outputs, targets, indices, num_masks)) + + # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. + if "aux_outputs" in outputs: + for i, aux_outputs in enumerate(outputs["aux_outputs"]): + indices = self.matcher(aux_outputs, targets) + for loss in self.losses: + l_dict = self.get_loss( + loss, aux_outputs, targets, indices, num_masks + ) + l_dict = {k + f"_{i}": v for k, v in l_dict.items()} + losses.update(l_dict) + + return losses + + def clean_buffer(self): + self.cur_target = [] diff --git a/open_vocab_seg/modeling/heads/__init__.py b/open_vocab_seg/modeling/heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..52db7cce67b1686f7cab3698f15b8f309c897918 --- /dev/null +++ b/open_vocab_seg/modeling/heads/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Copyright (c) Meta Platforms, Inc. All Rights Reserved \ No newline at end of file diff --git a/open_vocab_seg/modeling/heads/mask_former_head.py b/open_vocab_seg/modeling/heads/mask_former_head.py new file mode 100644 index 0000000000000000000000000000000000000000..5f592662f92d1b0862a3ef76304e7b28b46ecf80 --- /dev/null +++ b/open_vocab_seg/modeling/heads/mask_former_head.py @@ -0,0 +1,135 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +import logging +from copy import deepcopy +from typing import Callable, Dict, List, Optional, Tuple, Union + +import fvcore.nn.weight_init as weight_init +from torch import nn +from torch.nn import functional as F + +from detectron2.config import configurable +from detectron2.layers import Conv2d, ShapeSpec, get_norm +from detectron2.modeling import SEM_SEG_HEADS_REGISTRY + +from ..transformer.transformer_predictor import TransformerPredictor +from .pixel_decoder import build_pixel_decoder + + +@SEM_SEG_HEADS_REGISTRY.register() +class MaskFormerHead(nn.Module): + + _version = 2 + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + version = local_metadata.get("version", None) + if version is None or version < 2: + # Do not warn if train from scratch + scratch = True + logger = logging.getLogger(__name__) + for k in list(state_dict.keys()): + newk = k + if "sem_seg_head" in k and not k.startswith(prefix + "predictor"): + newk = k.replace(prefix, prefix + "pixel_decoder.") + # logger.debug(f"{k} ==> {newk}") + if newk != k: + state_dict[newk] = state_dict[k] + del state_dict[k] + scratch = False + + if not scratch: + logger.warning( + f"Weight format of {self.__class__.__name__} have changed! " + "Please upgrade your models. Applying automatic conversion now ..." + ) + + @configurable + def __init__( + self, + input_shape: Dict[str, ShapeSpec], + *, + num_classes: int, + pixel_decoder: nn.Module, + loss_weight: float = 1.0, + ignore_value: int = -1, + # extra parameters + transformer_predictor: nn.Module, + transformer_in_feature: str, + ): + """ + NOTE: this interface is experimental. + Args: + input_shape: shapes (channels and stride) of the input features + num_classes: number of classes to predict + pixel_decoder: the pixel decoder module + loss_weight: loss weight + ignore_value: category id to be ignored during training. + transformer_predictor: the transformer decoder that makes prediction + transformer_in_feature: input feature name to the transformer_predictor + """ + super().__init__() + input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) + self.in_features = [k for k, v in input_shape] + feature_strides = [v.stride for k, v in input_shape] + feature_channels = [v.channels for k, v in input_shape] + + self.ignore_value = ignore_value + self.common_stride = 4 + self.loss_weight = loss_weight + + self.pixel_decoder = pixel_decoder + self.predictor = transformer_predictor + self.transformer_in_feature = transformer_in_feature + + self.num_classes = num_classes + + @classmethod + def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): + return { + "input_shape": { + k: v + for k, v in input_shape.items() + if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES + }, + "ignore_value": cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE, + "num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES, + "pixel_decoder": build_pixel_decoder(cfg, input_shape), + "loss_weight": cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT, + "transformer_in_feature": cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE, + "transformer_predictor": TransformerPredictor( + cfg, + cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM + if cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE == "transformer_encoder" + else input_shape[cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE].channels, + mask_classification=True, + ), + } + + def forward(self, features): + return self.layers(features) + + def layers(self, features): + ( + mask_features, + transformer_encoder_features, + ) = self.pixel_decoder.forward_features(features) + if self.transformer_in_feature == "transformer_encoder": + assert ( + transformer_encoder_features is not None + ), "Please use the TransformerEncoderPixelDecoder." + predictions = self.predictor(transformer_encoder_features, mask_features) + else: + predictions = self.predictor( + features[self.transformer_in_feature], mask_features + ) + return predictions diff --git a/open_vocab_seg/modeling/heads/open_vocab_mask_former_head.py b/open_vocab_seg/modeling/heads/open_vocab_mask_former_head.py new file mode 100644 index 0000000000000000000000000000000000000000..8ed84f9a44d24415b3334fdf2ea8e1188de32de6 --- /dev/null +++ b/open_vocab_seg/modeling/heads/open_vocab_mask_former_head.py @@ -0,0 +1,145 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Copyright (c) Meta Platforms, Inc. All Rights Reserved +# Modified by Feng Liang from +# https://github.com/MendelXu/zsseg.baseline/blob/master/mask_former/modeling/heads/zero_shot_mask_former_head.py + +import logging +from copy import deepcopy +from typing import Callable, Dict, List, Optional, Tuple, Union + +import fvcore.nn.weight_init as weight_init +from torch import nn +from torch.nn import functional as F + +from detectron2.config import configurable +from detectron2.layers import Conv2d, ShapeSpec, get_norm +from detectron2.modeling import SEM_SEG_HEADS_REGISTRY + +from ..transformer.open_vocab_transformer_predictor import OpenVocabTransformerPredictor +from .pixel_decoder import build_pixel_decoder + + +@SEM_SEG_HEADS_REGISTRY.register() +class OpenVocabMaskFormerHead(nn.Module): + + _version = 2 + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + version = local_metadata.get("version", None) + if version is None or version < 2: + # Do not warn if train from scratch + scratch = True + logger = logging.getLogger(__name__) + for k in list(state_dict.keys()): + newk = k + if "sem_seg_head" in k and not k.startswith(prefix + "predictor"): + newk = k.replace(prefix, prefix + "pixel_decoder.") + # logger.debug(f"{k} ==> {newk}") + if newk != k: + state_dict[newk] = state_dict[k] + del state_dict[k] + scratch = False + + if not scratch: + logger.warning( + f"Weight format of {self.__class__.__name__} have changed! " + "Please upgrade your models. Applying automatic conversion now ..." + ) + + @configurable + def __init__( + self, + input_shape: Dict[str, ShapeSpec], + *, + num_classes: int, + pixel_decoder: nn.Module, + loss_weight: float = 1.0, + ignore_value: int = -1, + # extra parameters + transformer_predictor: nn.Module, + transformer_in_feature: str, + ): + """ + NOTE: this interface is experimental. + Args: + input_shape: shapes (channels and stride) of the input features + num_classes: number of classes to predict + pixel_decoder: the pixel decoder module + loss_weight: loss weight + ignore_value: category id to be ignored during training. + transformer_predictor: the transformer decoder that makes prediction + transformer_in_feature: input feature name to the transformer_predictor + """ + super().__init__() + input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) + self.in_features = [k for k, v in input_shape] + feature_strides = [v.stride for k, v in input_shape] + feature_channels = [v.channels for k, v in input_shape] + + self.ignore_value = ignore_value + self.common_stride = 4 + self.loss_weight = loss_weight + + self.pixel_decoder = pixel_decoder + self.predictor = transformer_predictor + self.transformer_in_feature = transformer_in_feature + + self.num_classes = num_classes + + @classmethod + def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): + return { + "input_shape": { + k: v + for k, v in input_shape.items() + if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES + }, + "ignore_value": cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE, + "num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES, + "pixel_decoder": build_pixel_decoder(cfg, input_shape), + "loss_weight": cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT, + "transformer_in_feature": cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE, + "transformer_predictor": OpenVocabTransformerPredictor( + cfg, + cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM + if cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE == "transformer_encoder" + else input_shape[cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE].channels, + mask_classification=True, + ), + } + + def forward(self, features): + return self.layers(features) + + def layers(self, features): + ( + mask_features, + transformer_encoder_features, + ) = self.pixel_decoder.forward_features(features) + if self.transformer_in_feature == "transformer_encoder": + assert ( + transformer_encoder_features is not None + ), "Please use the TransformerEncoderPixelDecoder." + predictions = self.predictor(transformer_encoder_features, mask_features) + else: + predictions = self.predictor( + features[self.transformer_in_feature], mask_features + ) + return predictions + + def freeze_pretrained(self): + for name, module in self.named_children(): + if name not in ["predictor"]: + for param in module.parameters(): + param.requires_grad = False + else: + module.freeze_pretrained() diff --git a/open_vocab_seg/modeling/heads/pixel_decoder.py b/open_vocab_seg/modeling/heads/pixel_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..6b10089331785e937b79cf82af6d8fba55519082 --- /dev/null +++ b/open_vocab_seg/modeling/heads/pixel_decoder.py @@ -0,0 +1,308 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +import logging +from typing import Callable, Dict, List, Optional, Tuple, Union + +import fvcore.nn.weight_init as weight_init +from torch import nn +from torch.nn import functional as F + +from detectron2.config import configurable +from detectron2.layers import Conv2d, ShapeSpec, get_norm +from detectron2.modeling import SEM_SEG_HEADS_REGISTRY + +from ..transformer.position_encoding import PositionEmbeddingSine +from ..transformer.transformer import TransformerEncoder, TransformerEncoderLayer + + +def build_pixel_decoder(cfg, input_shape): + """ + Build a pixel decoder from `cfg.MODEL.MASK_FORMER.PIXEL_DECODER_NAME`. + """ + name = cfg.MODEL.SEM_SEG_HEAD.PIXEL_DECODER_NAME + model = SEM_SEG_HEADS_REGISTRY.get(name)(cfg, input_shape) + forward_features = getattr(model, "forward_features", None) + if not callable(forward_features): + raise ValueError( + "Only SEM_SEG_HEADS with forward_features method can be used as pixel decoder. " + f"Please implement forward_features for {name} to only return mask features." + ) + return model + + +@SEM_SEG_HEADS_REGISTRY.register() +class BasePixelDecoder(nn.Module): + @configurable + def __init__( + self, + input_shape: Dict[str, ShapeSpec], + *, + conv_dim: int, + mask_dim: int, + norm: Optional[Union[str, Callable]] = None, + ): + """ + NOTE: this interface is experimental. + Args: + input_shape: shapes (channels and stride) of the input features + conv_dims: number of output channels for the intermediate conv layers. + mask_dim: number of output channels for the final conv layer. + norm (str or callable): normalization for all conv layers + """ + super().__init__() + + input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) + self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5" + feature_channels = [v.channels for k, v in input_shape] + + lateral_convs = [] + output_convs = [] + + use_bias = norm == "" + for idx, in_channels in enumerate(feature_channels): + if idx == len(self.in_features) - 1: + output_norm = get_norm(norm, conv_dim) + output_conv = Conv2d( + in_channels, + conv_dim, + kernel_size=3, + stride=1, + padding=1, + bias=use_bias, + norm=output_norm, + activation=F.relu, + ) + weight_init.c2_xavier_fill(output_conv) + self.add_module("layer_{}".format(idx + 1), output_conv) + + lateral_convs.append(None) + output_convs.append(output_conv) + else: + lateral_norm = get_norm(norm, conv_dim) + output_norm = get_norm(norm, conv_dim) + + lateral_conv = Conv2d( + in_channels, + conv_dim, + kernel_size=1, + bias=use_bias, + norm=lateral_norm, + ) + output_conv = Conv2d( + conv_dim, + conv_dim, + kernel_size=3, + stride=1, + padding=1, + bias=use_bias, + norm=output_norm, + activation=F.relu, + ) + weight_init.c2_xavier_fill(lateral_conv) + weight_init.c2_xavier_fill(output_conv) + self.add_module("adapter_{}".format(idx + 1), lateral_conv) + self.add_module("layer_{}".format(idx + 1), output_conv) + + lateral_convs.append(lateral_conv) + output_convs.append(output_conv) + # Place convs into top-down order (from low to high resolution) + # to make the top-down computation in forward clearer. + self.lateral_convs = lateral_convs[::-1] + self.output_convs = output_convs[::-1] + + self.mask_dim = mask_dim + self.mask_features = Conv2d( + conv_dim, + mask_dim, + kernel_size=3, + stride=1, + padding=1, + ) + weight_init.c2_xavier_fill(self.mask_features) + + @classmethod + def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): + ret = {} + ret["input_shape"] = { + k: v + for k, v in input_shape.items() + if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES + } + ret["conv_dim"] = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM + ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM + ret["norm"] = cfg.MODEL.SEM_SEG_HEAD.NORM + return ret + + def forward_features(self, features): + # Reverse feature maps into top-down order (from low to high resolution) + for idx, f in enumerate(self.in_features[::-1]): + x = features[f] + lateral_conv = self.lateral_convs[idx] + output_conv = self.output_convs[idx] + if lateral_conv is None: + y = output_conv(x) + else: + cur_fpn = lateral_conv(x) + # Following FPN implementation, we use nearest upsampling here + y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest") + y = output_conv(y) + return self.mask_features(y), None + + def forward(self, features, targets=None): + logger = logging.getLogger(__name__) + logger.warning( + "Calling forward() may cause unpredicted behavior of PixelDecoder module." + ) + return self.forward_features(features) + + +class TransformerEncoderOnly(nn.Module): + def __init__( + self, + d_model=512, + nhead=8, + num_encoder_layers=6, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=False, + ): + super().__init__() + + encoder_layer = TransformerEncoderLayer( + d_model, nhead, dim_feedforward, dropout, activation, normalize_before + ) + encoder_norm = nn.LayerNorm(d_model) if normalize_before else None + self.encoder = TransformerEncoder( + encoder_layer, num_encoder_layers, encoder_norm + ) + + self._reset_parameters() + + self.d_model = d_model + self.nhead = nhead + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, src, mask, pos_embed): + # flatten NxCxHxW to HWxNxC + bs, c, h, w = src.shape + src = src.flatten(2).permute(2, 0, 1) + pos_embed = pos_embed.flatten(2).permute(2, 0, 1) + if mask is not None: + mask = mask.flatten(1) + + memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) + return memory.permute(1, 2, 0).view(bs, c, h, w) + + +@SEM_SEG_HEADS_REGISTRY.register() +class TransformerEncoderPixelDecoder(BasePixelDecoder): + @configurable + def __init__( + self, + input_shape: Dict[str, ShapeSpec], + *, + transformer_dropout: float, + transformer_nheads: int, + transformer_dim_feedforward: int, + transformer_enc_layers: int, + transformer_pre_norm: bool, + conv_dim: int, + mask_dim: int, + norm: Optional[Union[str, Callable]] = None, + ): + """ + NOTE: this interface is experimental. + Args: + input_shape: shapes (channels and stride) of the input features + transformer_dropout: dropout probability in transformer + transformer_nheads: number of heads in transformer + transformer_dim_feedforward: dimension of feedforward network + transformer_enc_layers: number of transformer encoder layers + transformer_pre_norm: whether to use pre-layernorm or not + conv_dims: number of output channels for the intermediate conv layers. + mask_dim: number of output channels for the final conv layer. + norm (str or callable): normalization for all conv layers + """ + super().__init__(input_shape, conv_dim=conv_dim, mask_dim=mask_dim, norm=norm) + + input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) + self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5" + feature_strides = [v.stride for k, v in input_shape] + feature_channels = [v.channels for k, v in input_shape] + + in_channels = feature_channels[len(self.in_features) - 1] + self.input_proj = Conv2d(in_channels, conv_dim, kernel_size=1) + weight_init.c2_xavier_fill(self.input_proj) + self.transformer = TransformerEncoderOnly( + d_model=conv_dim, + dropout=transformer_dropout, + nhead=transformer_nheads, + dim_feedforward=transformer_dim_feedforward, + num_encoder_layers=transformer_enc_layers, + normalize_before=transformer_pre_norm, + ) + N_steps = conv_dim // 2 + self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True) + + # update layer + use_bias = norm == "" + output_norm = get_norm(norm, conv_dim) + output_conv = Conv2d( + conv_dim, + conv_dim, + kernel_size=3, + stride=1, + padding=1, + bias=use_bias, + norm=output_norm, + activation=F.relu, + ) + weight_init.c2_xavier_fill(output_conv) + delattr(self, "layer_{}".format(len(self.in_features))) + self.add_module("layer_{}".format(len(self.in_features)), output_conv) + self.output_convs[0] = output_conv + + @classmethod + def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): + ret = super().from_config(cfg, input_shape) + ret["transformer_dropout"] = cfg.MODEL.MASK_FORMER.DROPOUT + ret["transformer_nheads"] = cfg.MODEL.MASK_FORMER.NHEADS + ret["transformer_dim_feedforward"] = cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD + ret[ + "transformer_enc_layers" + ] = cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS # a separate config + ret["transformer_pre_norm"] = cfg.MODEL.MASK_FORMER.PRE_NORM + return ret + + def forward_features(self, features): + # Reverse feature maps into top-down order (from low to high resolution) + for idx, f in enumerate(self.in_features[::-1]): + x = features[f] + lateral_conv = self.lateral_convs[idx] + output_conv = self.output_convs[idx] + if lateral_conv is None: + transformer = self.input_proj(x) + pos = self.pe_layer(x) + transformer = self.transformer(transformer, None, pos) + y = output_conv(transformer) + # save intermediate feature as input to Transformer decoder + transformer_encoder_features = transformer + else: + cur_fpn = lateral_conv(x) + # Following FPN implementation, we use nearest upsampling here + y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest") + y = output_conv(y) + return self.mask_features(y), transformer_encoder_features + + def forward(self, features, targets=None): + logger = logging.getLogger(__name__) + logger.warning( + "Calling forward() may cause unpredicted behavior of PixelDecoder module." + ) + return self.forward_features(features) diff --git a/open_vocab_seg/modeling/matcher.py b/open_vocab_seg/modeling/matcher.py new file mode 100644 index 0000000000000000000000000000000000000000..a72ba671ad60db078e08046357a6aa0e5e9bd5dc --- /dev/null +++ b/open_vocab_seg/modeling/matcher.py @@ -0,0 +1,187 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/models/matcher.py +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +""" +Modules to compute the matching cost and solve the corresponding LSAP. +""" +import torch +import torch.nn.functional as F +from scipy.optimize import linear_sum_assignment +from torch import nn + + +def batch_dice_loss(inputs, targets): + """ + Compute the DICE loss, similar to generalized IOU for masks + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + """ + inputs = inputs.sigmoid() + inputs = inputs.flatten(1) + numerator = 2 * torch.einsum("nc,mc->nm", inputs, targets) + denominator = inputs.sum(-1)[:, None] + targets.sum(-1)[None, :] + loss = 1 - (numerator + 1) / (denominator + 1) + return loss + + +def batch_sigmoid_focal_loss(inputs, targets, alpha: float = 0.25, gamma: float = 2): + """ + Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + alpha: (optional) Weighting factor in range (0,1) to balance + positive vs negative examples. Default = -1 (no weighting). + gamma: Exponent of the modulating factor (1 - p_t) to + balance easy vs hard examples. + Returns: + Loss tensor + """ + hw = inputs.shape[1] + + prob = inputs.sigmoid() + focal_pos = ((1 - prob) ** gamma) * F.binary_cross_entropy_with_logits( + inputs, torch.ones_like(inputs), reduction="none" + ) + focal_neg = (prob ** gamma) * F.binary_cross_entropy_with_logits( + inputs, torch.zeros_like(inputs), reduction="none" + ) + if alpha >= 0: + focal_pos = focal_pos * alpha + focal_neg = focal_neg * (1 - alpha) + + loss = torch.einsum("nc,mc->nm", focal_pos, targets) + torch.einsum( + "nc,mc->nm", focal_neg, (1 - targets) + ) + + return loss / hw + + +class HungarianMatcher(nn.Module): + """This class computes an assignment between the targets and the predictions of the network + + For efficiency reasons, the targets don't include the no_object. Because of this, in general, + there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, + while the others are un-matched (and thus treated as non-objects). + """ + + def __init__( + self, cost_class: float = 1, cost_mask: float = 1, cost_dice: float = 1 + ): + """Creates the matcher + + Params: + cost_class: This is the relative weight of the classification error in the matching cost + cost_mask: This is the relative weight of the focal loss of the binary mask in the matching cost + cost_dice: This is the relative weight of the dice loss of the binary mask in the matching cost + """ + super().__init__() + self.cost_class = cost_class + self.cost_mask = cost_mask + self.cost_dice = cost_dice + assert ( + cost_class != 0 or cost_mask != 0 or cost_dice != 0 + ), "all costs cant be 0" + + @torch.no_grad() + def memory_efficient_forward(self, outputs, targets): + """More memory-friendly matching""" + bs, num_queries = outputs["pred_logits"].shape[:2] + + # Work out the mask padding size + masks = [v["masks"] for v in targets] + h_max = max([m.shape[1] for m in masks]) + w_max = max([m.shape[2] for m in masks]) + + indices = [] + + # Iterate through batch size + for b in range(bs): + + out_prob = outputs["pred_logits"][b].softmax( + -1 + ) # [num_queries, num_classes] + out_mask = outputs["pred_masks"][b] # [num_queries, H_pred, W_pred] + + tgt_ids = targets[b]["labels"] + # gt masks are already padded when preparing target + tgt_mask = targets[b]["masks"].to(out_mask) + + # Compute the classification cost. Contrary to the loss, we don't use the NLL, + # but approximate it in 1 - proba[target class]. + # The 1 is a constant that doesn't change the matching, it can be ommitted. + cost_class = -out_prob[:, tgt_ids] + + # Downsample gt masks to save memory + tgt_mask = F.interpolate( + tgt_mask[:, None], size=out_mask.shape[-2:], mode="nearest" + ) + + # Flatten spatial dimension + out_mask = out_mask.flatten(1) # [batch_size * num_queries, H*W] + tgt_mask = tgt_mask[:, 0].flatten(1) # [num_total_targets, H*W] + + # Compute the focal loss between masks + cost_mask = batch_sigmoid_focal_loss(out_mask, tgt_mask) + + # Compute the dice loss betwen masks + cost_dice = batch_dice_loss(out_mask, tgt_mask) + + # Final cost matrix + C = ( + self.cost_mask * cost_mask + + self.cost_class * cost_class + + self.cost_dice * cost_dice + ) + C = C.reshape(num_queries, -1).cpu() + + indices.append(linear_sum_assignment(C)) + return [ + ( + torch.as_tensor(i, dtype=torch.int64), + torch.as_tensor(j, dtype=torch.int64), + ) + for i, j in indices + ] + + @torch.no_grad() + def forward(self, outputs, targets): + """Performs the matching + + Params: + outputs: This is a dict that contains at least these entries: + "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits + "pred_masks": Tensor of dim [batch_size, num_queries, H_pred, W_pred] with the predicted masks + + targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: + "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth + objects in the target) containing the class labels + "masks": Tensor of dim [num_target_boxes, H_gt, W_gt] containing the target masks + + Returns: + A list of size batch_size, containing tuples of (index_i, index_j) where: + - index_i is the indices of the selected predictions (in order) + - index_j is the indices of the corresponding selected targets (in order) + For each batch element, it holds: + len(index_i) = len(index_j) = min(num_queries, num_target_boxes) + """ + return self.memory_efficient_forward(outputs, targets) + + def __repr__(self): + head = "Matcher " + self.__class__.__name__ + body = [ + "cost_class: {}".format(self.cost_class), + "cost_mask: {}".format(self.cost_mask), + "cost_dice: {}".format(self.cost_dice), + ] + _repr_indent = 4 + lines = [head] + [" " * _repr_indent + line for line in body] + return "\n".join(lines) diff --git a/open_vocab_seg/modeling/transformer/__init__.py b/open_vocab_seg/modeling/transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..49f9003b7a688f5396170dd89c26ef335a2c201f --- /dev/null +++ b/open_vocab_seg/modeling/transformer/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Copyright (c) Meta Platforms, Inc. All Rights Reserved diff --git a/open_vocab_seg/modeling/transformer/open_vocab_transformer_predictor.py b/open_vocab_seg/modeling/transformer/open_vocab_transformer_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..0efee3e14c71400a1cc5a55ea6c21b6876189aaa --- /dev/null +++ b/open_vocab_seg/modeling/transformer/open_vocab_transformer_predictor.py @@ -0,0 +1,84 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/detr.py +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +from torch import nn +from detectron2.config import configurable +from .transformer_predictor import TransformerPredictor, MLP + + +class OpenVocabTransformerPredictor(TransformerPredictor): + @configurable + def __init__( + self, + in_channels, + mask_classification=True, + *, + embedding_dim: int, + embed_hidden_dim: int, + embed_layers: int, + hidden_dim: int, + num_queries: int, + nheads: int, + dropout: float, + dim_feedforward: int, + enc_layers: int, + dec_layers: int, + pre_norm: bool, + deep_supervision: bool, + mask_dim: int, + enforce_input_project: bool, + ): + super().__init__( + in_channels, + False, + num_classes=embedding_dim, + hidden_dim=hidden_dim, + num_queries=num_queries, + nheads=nheads, + dropout=dropout, + dim_feedforward=dim_feedforward, + enc_layers=enc_layers, + dec_layers=dec_layers, + pre_norm=pre_norm, + deep_supervision=deep_supervision, + mask_dim=mask_dim, + enforce_input_project=enforce_input_project, + ) + self.mask_classification = mask_classification + # output FFNs + if self.mask_classification: + self.class_embed = MLP( + hidden_dim, embed_hidden_dim, embedding_dim, embed_layers + ) + + def freeze_pretrained(self): + for name, module in self.named_children(): + if name not in ["class_embed"]: + for param in module.parameters(): + param.requires_grad = False + + @classmethod + def from_config(cls, cfg, in_channels, mask_classification): + ret = {} + ret["in_channels"] = in_channels + ret["mask_classification"] = mask_classification + + ret["embedding_dim"] = cfg.MODEL.SEM_SEG_HEAD.EMBEDDING_DIM + ret["embed_hidden_dim"] = cfg.MODEL.SEM_SEG_HEAD.EMBED_HIDDEN_DIM + ret["embed_layers"] = cfg.MODEL.SEM_SEG_HEAD.EMBED_LAYERS + ret["hidden_dim"] = cfg.MODEL.MASK_FORMER.HIDDEN_DIM + ret["num_queries"] = cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES + # Transformer parameters: + ret["nheads"] = cfg.MODEL.MASK_FORMER.NHEADS + ret["dropout"] = cfg.MODEL.MASK_FORMER.DROPOUT + ret["dim_feedforward"] = cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD + ret["enc_layers"] = cfg.MODEL.MASK_FORMER.ENC_LAYERS + ret["dec_layers"] = cfg.MODEL.MASK_FORMER.DEC_LAYERS + ret["pre_norm"] = cfg.MODEL.MASK_FORMER.PRE_NORM + ret["deep_supervision"] = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION + ret["enforce_input_project"] = cfg.MODEL.MASK_FORMER.ENFORCE_INPUT_PROJ + + ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM + + return ret diff --git a/open_vocab_seg/modeling/transformer/position_encoding.py b/open_vocab_seg/modeling/transformer/position_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..db236c5b36cbc4f4435a83b542bdc242cbb441c3 --- /dev/null +++ b/open_vocab_seg/modeling/transformer/position_encoding.py @@ -0,0 +1,58 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +""" +Various positional encodings for the transformer. +""" +import math + +import torch +from torch import nn + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + + def __init__( + self, num_pos_feats=64, temperature=10000, normalize=False, scale=None + ): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, x, mask=None): + if mask is None: + mask = torch.zeros( + (x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool + ) + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos diff --git a/open_vocab_seg/modeling/transformer/transformer.py b/open_vocab_seg/modeling/transformer/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..76d1003b3852ce72c6ad5c3c23705f380197362f --- /dev/null +++ b/open_vocab_seg/modeling/transformer/transformer.py @@ -0,0 +1,380 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/transformer.py +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +""" +Transformer class. + +Copy-paste from torch.nn.Transformer with modifications: + * positional encodings are passed in MHattention + * extra LN at the end of encoder is removed + * decoder returns a stack of activations from all decoding layers +""" +import copy +from typing import List, Optional + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + + +class Transformer(nn.Module): + def __init__( + self, + d_model=512, + nhead=8, + num_encoder_layers=6, + num_decoder_layers=6, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=False, + return_intermediate_dec=False, + ): + super().__init__() + + encoder_layer = TransformerEncoderLayer( + d_model, nhead, dim_feedforward, dropout, activation, normalize_before + ) + encoder_norm = nn.LayerNorm(d_model) if normalize_before else None + self.encoder = TransformerEncoder( + encoder_layer, num_encoder_layers, encoder_norm + ) + + decoder_layer = TransformerDecoderLayer( + d_model, nhead, dim_feedforward, dropout, activation, normalize_before + ) + decoder_norm = nn.LayerNorm(d_model) + self.decoder = TransformerDecoder( + decoder_layer, + num_decoder_layers, + decoder_norm, + return_intermediate=return_intermediate_dec, + ) + + self._reset_parameters() + + self.d_model = d_model + self.nhead = nhead + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, src, mask, query_embed, pos_embed): + # flatten NxCxHxW to HWxNxC + bs, c, h, w = src.shape + src = src.flatten(2).permute(2, 0, 1) + pos_embed = pos_embed.flatten(2).permute(2, 0, 1) + query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) + if mask is not None: + mask = mask.flatten(1) + + tgt = torch.zeros_like(query_embed) + memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) + hs = self.decoder( + tgt, + memory, + memory_key_padding_mask=mask, + pos=pos_embed, + query_pos=query_embed, + ) + return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w) + + +class TransformerEncoder(nn.Module): + def __init__(self, encoder_layer, num_layers, norm=None): + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward( + self, + src, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + output = src + + for layer in self.layers: + output = layer( + output, + src_mask=mask, + src_key_padding_mask=src_key_padding_mask, + pos=pos, + ) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class TransformerDecoder(nn.Module): + def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + self.return_intermediate = return_intermediate + + def forward( + self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + output = tgt + + intermediate = [] + + for layer in self.layers: + output = layer( + output, + memory, + tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + pos=pos, + query_pos=query_pos, + ) + if self.return_intermediate: + intermediate.append(self.norm(output)) + + if self.norm is not None: + output = self.norm(output) + if self.return_intermediate: + intermediate.pop() + intermediate.append(output) + + if self.return_intermediate: + return torch.stack(intermediate) + + return output.unsqueeze(0) + + +class TransformerEncoderLayer(nn.Module): + def __init__( + self, + d_model, + nhead, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=False, + ): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post( + self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + q = k = self.with_pos_embed(src, pos) + src2 = self.self_attn( + q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask + )[0] + src = src + self.dropout1(src2) + src = self.norm1(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + return src + + def forward_pre( + self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + src2 = self.norm1(src) + q = k = self.with_pos_embed(src2, pos) + src2 = self.self_attn( + q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask + )[0] + src = src + self.dropout1(src2) + src2 = self.norm2(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) + src = src + self.dropout2(src2) + return src + + def forward( + self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + if self.normalize_before: + return self.forward_pre(src, src_mask, src_key_padding_mask, pos) + return self.forward_post(src, src_mask, src_key_padding_mask, pos) + + +class TransformerDecoderLayer(nn.Module): + def __init__( + self, + d_model, + nhead, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=False, + ): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post( + self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + q = k = self.with_pos_embed(tgt, query_pos) + tgt2 = self.self_attn( + q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask + )[0] + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + tgt2 = self.multihead_attn( + query=self.with_pos_embed(tgt, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask, + )[0] + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + return tgt + + def forward_pre( + self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + tgt2 = self.norm1(tgt) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2 = self.self_attn( + q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask + )[0] + tgt = tgt + self.dropout1(tgt2) + tgt2 = self.norm2(tgt) + tgt2 = self.multihead_attn( + query=self.with_pos_embed(tgt2, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask, + )[0] + tgt = tgt + self.dropout2(tgt2) + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + return tgt + + def forward( + self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + if self.normalize_before: + return self.forward_pre( + tgt, + memory, + tgt_mask, + memory_mask, + tgt_key_padding_mask, + memory_key_padding_mask, + pos, + query_pos, + ) + return self.forward_post( + tgt, + memory, + tgt_mask, + memory_mask, + tgt_key_padding_mask, + memory_key_padding_mask, + pos, + query_pos, + ) + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(f"activation should be relu/gelu, not {activation}.") diff --git a/open_vocab_seg/modeling/transformer/transformer_predictor.py b/open_vocab_seg/modeling/transformer/transformer_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..72378abe29c01809a00fa1b87d275258ee9c91fa --- /dev/null +++ b/open_vocab_seg/modeling/transformer/transformer_predictor.py @@ -0,0 +1,179 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/detr.py +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +import fvcore.nn.weight_init as weight_init +import torch +from torch import nn +from torch.nn import functional as F + +from detectron2.config import configurable +from detectron2.layers import Conv2d + +from .position_encoding import PositionEmbeddingSine +from .transformer import Transformer + + +class TransformerPredictor(nn.Module): + @configurable + def __init__( + self, + in_channels, + mask_classification=True, + *, + num_classes: int, + hidden_dim: int, + num_queries: int, + nheads: int, + dropout: float, + dim_feedforward: int, + enc_layers: int, + dec_layers: int, + pre_norm: bool, + deep_supervision: bool, + mask_dim: int, + enforce_input_project: bool, + ): + """ + NOTE: this interface is experimental. + Args: + in_channels: channels of the input features + mask_classification: whether to add mask classifier or not + num_classes: number of classes + hidden_dim: Transformer feature dimension + num_queries: number of queries + nheads: number of heads + dropout: dropout in Transformer + dim_feedforward: feature dimension in feedforward network + enc_layers: number of Transformer encoder layers + dec_layers: number of Transformer decoder layers + pre_norm: whether to use pre-LayerNorm or not + deep_supervision: whether to add supervision to every decoder layers + mask_dim: mask feature dimension + enforce_input_project: add input project 1x1 conv even if input + channels and hidden dim is identical + """ + super().__init__() + + self.mask_classification = mask_classification + + # positional encoding + N_steps = hidden_dim // 2 + self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True) + + transformer = Transformer( + d_model=hidden_dim, + dropout=dropout, + nhead=nheads, + dim_feedforward=dim_feedforward, + num_encoder_layers=enc_layers, + num_decoder_layers=dec_layers, + normalize_before=pre_norm, + return_intermediate_dec=deep_supervision, + ) + + self.num_queries = num_queries + self.transformer = transformer + hidden_dim = transformer.d_model + + self.query_embed = nn.Embedding(num_queries, hidden_dim) + + if in_channels != hidden_dim or enforce_input_project: + self.input_proj = Conv2d(in_channels, hidden_dim, kernel_size=1) + weight_init.c2_xavier_fill(self.input_proj) + else: + self.input_proj = nn.Sequential() + self.aux_loss = deep_supervision + + # output FFNs + if self.mask_classification: + self.class_embed = nn.Linear(hidden_dim, num_classes + 1) + self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3) + + @classmethod + def from_config(cls, cfg, in_channels, mask_classification): + ret = {} + ret["in_channels"] = in_channels + ret["mask_classification"] = mask_classification + + ret["num_classes"] = cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES + ret["hidden_dim"] = cfg.MODEL.MASK_FORMER.HIDDEN_DIM + ret["num_queries"] = cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES + # Transformer parameters: + ret["nheads"] = cfg.MODEL.MASK_FORMER.NHEADS + ret["dropout"] = cfg.MODEL.MASK_FORMER.DROPOUT + ret["dim_feedforward"] = cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD + ret["enc_layers"] = cfg.MODEL.MASK_FORMER.ENC_LAYERS + ret["dec_layers"] = cfg.MODEL.MASK_FORMER.DEC_LAYERS + ret["pre_norm"] = cfg.MODEL.MASK_FORMER.PRE_NORM + ret["deep_supervision"] = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION + ret["enforce_input_project"] = cfg.MODEL.MASK_FORMER.ENFORCE_INPUT_PROJ + + ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM + + return ret + + def forward(self, x, mask_features): + pos = self.pe_layer(x) + + src = x + mask = None + hs, memory = self.transformer( + self.input_proj(src), mask, self.query_embed.weight, pos + ) + + if self.mask_classification: + outputs_class = self.class_embed(hs) + out = {"pred_logits": outputs_class[-1]} + else: + out = {} + + if self.aux_loss: + # [l, bs, queries, embed] + mask_embed = self.mask_embed(hs) + outputs_seg_masks = torch.einsum( + "lbqc,bchw->lbqhw", mask_embed, mask_features + ) + out["pred_masks"] = outputs_seg_masks[-1] + out["aux_outputs"] = self._set_aux_loss( + outputs_class if self.mask_classification else None, outputs_seg_masks + ) + else: + # FIXME h_boxes takes the last one computed, keep this in mind + # [bs, queries, embed] + mask_embed = self.mask_embed(hs[-1]) + outputs_seg_masks = torch.einsum( + "bqc,bchw->bqhw", mask_embed, mask_features + ) + out["pred_masks"] = outputs_seg_masks + return out + + @torch.jit.unused + def _set_aux_loss(self, outputs_class, outputs_seg_masks): + # this is a workaround to make torchscript happy, as torchscript + # doesn't support dictionary with non-homogeneous values, such + # as a dict having both a Tensor and a list. + if self.mask_classification: + return [ + {"pred_logits": a, "pred_masks": b} + for a, b in zip(outputs_class[:-1], outputs_seg_masks[:-1]) + ] + else: + return [{"pred_masks": b} for b in outputs_seg_masks[:-1]] + + +class MLP(nn.Module): + """Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x diff --git a/open_vocab_seg/ovseg_model.py b/open_vocab_seg/ovseg_model.py new file mode 100644 index 0000000000000000000000000000000000000000..48df93168a5bdfb831715f7b8c008d7b7a5d3814 --- /dev/null +++ b/open_vocab_seg/ovseg_model.py @@ -0,0 +1,460 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Copyright (c) Meta Platforms, Inc. All Rights Reserved +# Modified by Feng Liang from +# https://github.com/MendelXu/zsseg.baseline/blob/master/mask_former/zero_shot_mask_former_model.py + +import logging +from typing import Tuple + +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + +from detectron2.config import configurable +from detectron2.data import MetadataCatalog +from detectron2.modeling import META_ARCH_REGISTRY +from detectron2.modeling.backbone import Backbone +from detectron2.modeling.postprocessing import sem_seg_postprocess +from detectron2.structures import ImageList +from detectron2.utils.logger import log_first_n +from .modeling.clip_adapter import ( + ClipAdapter, + MaskFormerClipAdapter, + build_text_prompt, +) +from .mask_former_model import MaskFormer +from .utils.misc import get_gt_binary_masks + +@META_ARCH_REGISTRY.register() +class OVSeg(MaskFormer): + """ + Main class for zero shot mask classification semantic segmentation architectures. + """ + + @configurable + def __init__( + self, + *, + backbone: Backbone, + sem_seg_head: nn.Module, + clip_adapter: nn.Module, + criterion: nn.Module, + num_queries: int, + panoptic_on: bool, + object_mask_threshold: float, + overlap_threshold: float, + metadata, + size_divisibility: int, + sem_seg_postprocess_before_inference: bool, + clip_ensemble: bool, + clip_ensemble_weight: float, + pixel_mean: Tuple[float], + pixel_std: Tuple[float], + ): + """ + Args: + backbone: a backbone module, must follow detectron2's backbone interface + sem_seg_head: a module that predicts semantic segmentation from backbone features + criterion: a module that defines the loss + clip_adapter: adapter for clip-based mask classification + num_queries: int, number of queries + panoptic_on: bool, whether to output panoptic segmentation prediction + object_mask_threshold: float, threshold to filter query based on classification score + for panoptic segmentation inference + overlap_threshold: overlap threshold used in general inference for panoptic segmentation + metadata: dataset meta, get `thing` and `stuff` category names for panoptic + segmentation inference + size_divisibility: Some backbones require the input height and width to be divisible by a + specific integer. We can use this to override such requirement. + sem_seg_postprocess_before_inference: whether to resize the prediction back + to original input size before semantic segmentation inference or after. + For high-resolution dataset like Mapillary, resizing predictions before + inference will cause OOM error. + pixel_mean, pixel_std: list or tuple with #channels element, representing + the per-channel mean and std to be used to normalize the input image + """ + super().__init__( + backbone=backbone, + sem_seg_head=sem_seg_head, + criterion=criterion, + num_queries=num_queries, + panoptic_on=panoptic_on, + object_mask_threshold=object_mask_threshold, + overlap_threshold=overlap_threshold, + metadata=metadata, + size_divisibility=size_divisibility, + sem_seg_postprocess_before_inference=sem_seg_postprocess_before_inference, + pixel_mean=pixel_mean, + pixel_std=pixel_std, + ) + self.clip_adapter: ClipAdapter = clip_adapter + + self.clip_ensemble: bool = clip_ensemble + self.clip_ensemble_weight: float = clip_ensemble_weight + + @classmethod + def from_config(cls, cfg): + init_kwargs = MaskFormer.from_config(cfg) + text_templates = build_text_prompt(cfg.MODEL.CLIP_ADAPTER) + + clip_adapter = MaskFormerClipAdapter( + cfg.MODEL.CLIP_ADAPTER.CLIP_MODEL_NAME, + text_templates, + mask_fill=cfg.MODEL.CLIP_ADAPTER.MASK_FILL, + mask_expand_ratio=cfg.MODEL.CLIP_ADAPTER.MASK_EXPAND_RATIO, + mask_thr=cfg.MODEL.CLIP_ADAPTER.MASK_THR, + mask_matting=cfg.MODEL.CLIP_ADAPTER.MASK_MATTING, + region_resized=cfg.MODEL.CLIP_ADAPTER.REGION_RESIZED, + mask_prompt_depth=cfg.MODEL.CLIP_ADAPTER.MASK_PROMPT_DEPTH, + mask_prompt_fwd=cfg.MODEL.CLIP_ADAPTER.MASK_PROMPT_FWD, + ) + init_kwargs["clip_adapter"] = clip_adapter + init_kwargs["clip_ensemble"] = cfg.MODEL.CLIP_ADAPTER.CLIP_ENSEMBLE + init_kwargs[ + "clip_ensemble_weight" + ] = cfg.MODEL.CLIP_ADAPTER.CLIP_ENSEMBLE_WEIGHT + + return init_kwargs + + def forward(self, batched_inputs): + """ + Args: + batched_inputs: a list, batched outputs of :class:`DatasetMapper`. + Each item in the list contains the inputs for one image. + For now, each item in the list is a dict that contains: + * "image": Tensor, image in (C, H, W) format. + * "instances": per-region ground truth + * Other information that's included in the original dicts, such as: + "height", "width" (int): the output resolution of the model (may be different + from input resolution), used in inference. + Returns: + list[dict]: + each dict has the results for one image. The dict contains the following keys: + + * "sem_seg": + A Tensor that represents the + per-pixel segmentation prediced by the head. + The prediction has shape KxHxW that represents the logits of + each class for each pixel. + * "panoptic_seg": + A tuple that represent panoptic output + panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment. + segments_info (list[dict]): Describe each segment in `panoptic_seg`. + Each dict contains keys "id", "category_id", "isthing". + """ + dataset_name = [x["meta"]["dataset_name"] for x in batched_inputs] + assert len(set(dataset_name)) == 1 + dataset_name = dataset_name[0] + + images = [x["image"].to(self.device) for x in batched_inputs] + images = [(x - self.pixel_mean) / self.pixel_std for x in images] + images = ImageList.from_tensors(images, self.size_divisibility) + + features = self.backbone(images.tensor) + outputs = self.sem_seg_head(features) + class_names = self.get_class_name_list(dataset_name) + text_features = self.clip_adapter.get_text_features(class_names) + outputs["pred_logits"] = self.clip_adapter.get_sim_logits( + text_features, self.clip_adapter.normalize_feature(outputs["pred_logits"]) + ) + if self.training: + if "aux_outputs" in outputs.keys(): + for i in range(len(outputs["aux_outputs"])): + outputs["aux_outputs"][i][ + "pred_logits" + ] = self.clip_adapter.get_sim_logits( + text_features, + self.clip_adapter.normalize_feature( + outputs["aux_outputs"][i]["pred_logits"] + ), + ) + # mask classification target + if "instances" in batched_inputs[0]: + gt_instances = [x["instances"].to(self.device) for x in batched_inputs] + targets = self.prepare_targets(gt_instances, images) + else: + targets = None + + # bipartite matching-based loss + losses = self.criterion(outputs, targets) + + for k in list(losses.keys()): + if k in self.criterion.weight_dict: + losses[k] *= self.criterion.weight_dict[k] + else: + # remove this loss if not specified in `weight_dict` + losses.pop(k) + + return losses + else: + mask_cls_results = outputs["pred_logits"] + mask_pred_results = outputs["pred_masks"] + # upsample masks + mask_pred_results = F.interpolate( + mask_pred_results, + size=(images.tensor.shape[-2], images.tensor.shape[-1]), + mode="bilinear", + align_corners=False, + ) + + processed_results = [] + for mask_cls_result, mask_pred_result, input_per_image, image_size in zip( + mask_cls_results, mask_pred_results, batched_inputs, images.image_sizes + ): + height = image_size[0] + width = image_size[1] + mask_pred_result = sem_seg_postprocess( + mask_pred_result, image_size, height, width + ) + image = input_per_image["image"].to(self.device) + + r, regions = self.semantic_inference( + mask_cls_result, mask_pred_result, image, class_names + ) + + height = input_per_image.get("height", image_size[0]) + width = input_per_image.get("width", image_size[1]) + r = sem_seg_postprocess(r, image_size, height, width) + processed_results.append({"sem_seg": r}) + + # panoptic segmentation inference + if self.panoptic_on: + panoptic_r = self.panoptic_inference( + mask_cls_result, mask_pred_result + ) + processed_results[-1]["panoptic_seg"] = panoptic_r + + return processed_results + + + def semantic_inference(self, mask_cls, mask_pred, image, class_names): + mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1] + mask_pred = mask_pred.sigmoid() + + regions = None + if self.clip_ensemble: + clip_cls, regions, valid_flag = self.clip_adapter( + image, class_names, mask_pred, normalize=True + ) + if clip_cls is None: + clip_cls = torch.empty(0, mask_cls.shape[-1] + 1, device=self.device) + # softmax before index or after? + clip_cls = F.softmax(clip_cls[:, :-1], dim=-1) + if self.clip_ensemble_weight > 0: + map_back_clip_cls = mask_cls.new_ones(mask_cls.shape) + map_back_clip_cls[valid_flag] = clip_cls + mask_cls = torch.pow(mask_cls, 1 - self.clip_ensemble_weight) * \ + torch.pow(map_back_clip_cls, self.clip_ensemble_weight) + + + else: + # only clip model predictions are used + mask_cls = clip_cls + mask_pred = mask_pred[valid_flag] + semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred) + return semseg, regions + + def get_class_name_list(self, dataset_name): + class_names = [ + c.strip() for c in MetadataCatalog.get(dataset_name).stuff_classes + ] + return class_names + + +@META_ARCH_REGISTRY.register() +class OVSegDEMO(MaskFormer): + """ + Main class for zero shot mask classification semantic segmentation architectures. + """ + + @configurable + def __init__( + self, + *, + backbone: Backbone, + sem_seg_head: nn.Module, + clip_adapter: nn.Module, + criterion: nn.Module, + num_queries: int, + panoptic_on: bool, + object_mask_threshold: float, + overlap_threshold: float, + metadata, + size_divisibility: int, + sem_seg_postprocess_before_inference: bool, + clip_ensemble: bool, + clip_ensemble_weight: float, + pixel_mean: Tuple[float], + pixel_std: Tuple[float], + ): + """ + Args: + backbone: a backbone module, must follow detectron2's backbone interface + sem_seg_head: a module that predicts semantic segmentation from backbone features + criterion: a module that defines the loss + clip_adapter: adapter for clip-based mask classification + num_queries: int, number of queries + panoptic_on: bool, whether to output panoptic segmentation prediction + object_mask_threshold: float, threshold to filter query based on classification score + for panoptic segmentation inference + overlap_threshold: overlap threshold used in general inference for panoptic segmentation + metadata: dataset meta, get `thing` and `stuff` category names for panoptic + segmentation inference + size_divisibility: Some backbones require the input height and width to be divisible by a + specific integer. We can use this to override such requirement. + sem_seg_postprocess_before_inference: whether to resize the prediction back + to original input size before semantic segmentation inference or after. + For high-resolution dataset like Mapillary, resizing predictions before + inference will cause OOM error. + pixel_mean, pixel_std: list or tuple with #channels element, representing + the per-channel mean and std to be used to normalize the input image + """ + super().__init__( + backbone=backbone, + sem_seg_head=sem_seg_head, + criterion=criterion, + num_queries=num_queries, + panoptic_on=panoptic_on, + object_mask_threshold=object_mask_threshold, + overlap_threshold=overlap_threshold, + metadata=metadata, + size_divisibility=size_divisibility, + sem_seg_postprocess_before_inference=sem_seg_postprocess_before_inference, + pixel_mean=pixel_mean, + pixel_std=pixel_std, + ) + self.clip_adapter: ClipAdapter = clip_adapter + + self.clip_ensemble: bool = clip_ensemble + self.clip_ensemble_weight: float = clip_ensemble_weight + + @classmethod + def from_config(cls, cfg): + init_kwargs = MaskFormer.from_config(cfg) + text_templates = build_text_prompt(cfg.MODEL.CLIP_ADAPTER) + + clip_adapter = MaskFormerClipAdapter( + cfg.MODEL.CLIP_ADAPTER.CLIP_MODEL_NAME, + text_templates, + mask_fill=cfg.MODEL.CLIP_ADAPTER.MASK_FILL, + mask_expand_ratio=cfg.MODEL.CLIP_ADAPTER.MASK_EXPAND_RATIO, + mask_thr=cfg.MODEL.CLIP_ADAPTER.MASK_THR, + mask_matting=cfg.MODEL.CLIP_ADAPTER.MASK_MATTING, + region_resized=cfg.MODEL.CLIP_ADAPTER.REGION_RESIZED, + mask_prompt_depth=cfg.MODEL.CLIP_ADAPTER.MASK_PROMPT_DEPTH, + mask_prompt_fwd=cfg.MODEL.CLIP_ADAPTER.MASK_PROMPT_FWD, + ) + init_kwargs["clip_adapter"] = clip_adapter + init_kwargs["clip_ensemble"] = cfg.MODEL.CLIP_ADAPTER.CLIP_ENSEMBLE + init_kwargs[ + "clip_ensemble_weight" + ] = cfg.MODEL.CLIP_ADAPTER.CLIP_ENSEMBLE_WEIGHT + + return init_kwargs + + def forward(self, batched_inputs): + """ + Args: + batched_inputs: a list, batched outputs of :class:`DatasetMapper`. + Each item in the list contains the inputs for one image. + For now, each item in the list is a dict that contains: + * "image": Tensor, image in (C, H, W) format. + * "instances": per-region ground truth + * Other information that's included in the original dicts, such as: + "height", "width" (int): the output resolution of the model (may be different + from input resolution), used in inference. + Returns: + list[dict]: + each dict has the results for one image. The dict contains the following keys: + + * "sem_seg": + A Tensor that represents the + per-pixel segmentation prediced by the head. + The prediction has shape KxHxW that represents the logits of + each class for each pixel. + * "panoptic_seg": + A tuple that represent panoptic output + panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment. + segments_info (list[dict]): Describe each segment in `panoptic_seg`. + Each dict contains keys "id", "category_id", "isthing". + """ + images = [x["image"].to(self.device) for x in batched_inputs] + images = [(x - self.pixel_mean) / self.pixel_std for x in images] + images = ImageList.from_tensors(images, self.size_divisibility) + + features = self.backbone(images.tensor) + outputs = self.sem_seg_head(features) + class_names = batched_inputs[0]["class_names"] + if len(class_names) == 1: + # Because classification is performed in a 'contrastive' manner, adding others to represent other concepts + class_names.append('others') + text_features = self.clip_adapter.get_text_features(class_names) + outputs["pred_logits"] = self.clip_adapter.get_sim_logits( + text_features, self.clip_adapter.normalize_feature(outputs["pred_logits"]) + ) + mask_cls_results = outputs["pred_logits"] + mask_pred_results = outputs["pred_masks"] + # upsample masks + mask_pred_results = F.interpolate( + mask_pred_results, + size=(images.tensor.shape[-2], images.tensor.shape[-1]), + mode="bilinear", + align_corners=False, + ) + + processed_results = [] + for mask_cls_result, mask_pred_result, input_per_image, image_size in zip( + mask_cls_results, mask_pred_results, batched_inputs, images.image_sizes + ): + height = image_size[0] + width = image_size[1] + mask_pred_result = sem_seg_postprocess( + mask_pred_result, image_size, height, width + ) + image = input_per_image["image"].to(self.device) + + r, regions = self.demo_inference(mask_cls_result, mask_pred_result, image, class_names) + + height = input_per_image.get("height", image_size[0]) + width = input_per_image.get("width", image_size[1]) + r = sem_seg_postprocess(r, image_size, height, width) + processed_results.append({"sem_seg": r}) + + return processed_results + + + + + def demo_inference(self, mask_cls, mask_pred, image, class_names): + mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1] + mask_pred = mask_pred.sigmoid() + + regions = None + if self.clip_ensemble: + clip_cls, regions, valid_flag = self.clip_adapter( + image, class_names, mask_pred, normalize=True + ) + if clip_cls is None: + clip_cls = torch.empty(0, mask_cls.shape[-1] + 1, device=self.device) + # softmax before index or after? + clip_cls = F.softmax(clip_cls[:, :-1], dim=-1) + if self.clip_ensemble_weight > 0: + map_back_clip_cls = mask_cls.new_ones(mask_cls.shape) + map_back_clip_cls[valid_flag] = clip_cls + mask_cls = torch.pow(mask_cls, 1 - self.clip_ensemble_weight) * \ + torch.pow(map_back_clip_cls, self.clip_ensemble_weight) + + else: + # only clip model predictions are used + mask_cls = clip_cls + mask_pred = mask_pred[valid_flag] + bin_mask = mask_pred > self.clip_adapter.mask_thr + select_cls = torch.zeros(sum(valid_flag), mask_cls.shape[-1], device=self.device) + select_mask = torch.argmax(mask_cls, dim=0) + if len(class_names) == 2 and class_names[-1] == 'others': + select_mask = select_mask[:-1] + for idx in select_mask: + select_cls[idx] = mask_cls[idx] + semseg = torch.einsum("qc,qhw->chw", select_cls, bin_mask.float()) + return semseg, regions diff --git a/open_vocab_seg/test_time_augmentation.py b/open_vocab_seg/test_time_augmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..bb7a51f28419c59775013c74fdee49e5166bde51 --- /dev/null +++ b/open_vocab_seg/test_time_augmentation.py @@ -0,0 +1,217 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +import copy +from itertools import count +import math +import numpy as np +import torch +from fvcore.transforms import HFlipTransform +from torch import nn +from torch.nn.parallel import DistributedDataParallel + +from detectron2.data.detection_utils import read_image +from detectron2.modeling import DatasetMapperTTA +from detectron2.modeling.postprocessing import sem_seg_postprocess +import logging +from detectron2.utils.logger import log_every_n, log_first_n + +__all__ = [ + "SemanticSegmentorWithTTA", +] + + +class SemanticSegmentorWithTTA(nn.Module): + """ + A SemanticSegmentor with test-time augmentation enabled. + Its :meth:`__call__` method has the same interface as :meth:`SemanticSegmentor.forward`. + """ + + def __init__(self, cfg, model, tta_mapper=None, batch_size=1): + """ + Args: + cfg (CfgNode): + model (SemanticSegmentor): a SemanticSegmentor to apply TTA on. + tta_mapper (callable): takes a dataset dict and returns a list of + augmented versions of the dataset dict. Defaults to + `DatasetMapperTTA(cfg)`. + batch_size (int): batch the augmented images into this batch size for inference. + """ + super().__init__() + if isinstance(model, DistributedDataParallel): + model = model.module + self.cfg = cfg.clone() + + self.model = model + + if tta_mapper is None: + tta_mapper = DatasetMapperTTA(cfg) + self.tta_mapper = tta_mapper + self.batch_size = batch_size + + def _inference_with_model(self, inputs): + if self.cfg.TEST.SLIDING_WINDOW: + log_first_n(logging.INFO, "Using sliding window to test") + + outputs = [] + + for input in inputs: + image_size = input["image"].shape[1:] # h,w + if self.cfg.TEST.SLIDING_TILE_SIZE > 0: + tile_size = ( + self.cfg.TEST.SLIDING_TILE_SIZE, + self.cfg.TEST.SLIDING_TILE_SIZE, + ) + else: + selected_mapping = {256: 224, 512: 256, 768: 512, 896: 512} + tile_size = min(image_size) + tile_size = selected_mapping[tile_size] + tile_size = (tile_size, tile_size) + extra_info = { + k: v + for k, v in input.items() + if k not in ["image", "height", "width"] + } + log_every_n( + logging.INFO, "split {} to {}".format(image_size, tile_size) + ) + overlap = self.cfg.TEST.SLIDING_OVERLAP + stride = math.ceil(tile_size[0] * (1 - overlap)) + tile_rows = int( + math.ceil((image_size[0] - tile_size[0]) / stride) + 1 + ) # strided convolution formula + tile_cols = int(math.ceil((image_size[1] - tile_size[1]) / stride) + 1) + full_probs = None + count_predictions = None + tile_counter = 0 + + for row in range(tile_rows): + for col in range(tile_cols): + x1 = int(col * stride) + y1 = int(row * stride) + x2 = min(x1 + tile_size[1], image_size[1]) + y2 = min(y1 + tile_size[0], image_size[0]) + x1 = max( + int(x2 - tile_size[1]), 0 + ) # for portrait images the x1 underflows sometimes + y1 = max( + int(y2 - tile_size[0]), 0 + ) # for very few rows y1 underflows + + img = input["image"][:, y1:y2, x1:x2] + padded_img = nn.functional.pad( + img, + ( + 0, + tile_size[1] - img.shape[-1], + 0, + tile_size[0] - img.shape[-2], + ), + ) + tile_counter += 1 + padded_input = {"image": padded_img} + padded_input.update(extra_info) + padded_prediction = self.model([padded_input])[0]["sem_seg"] + prediction = padded_prediction[ + :, 0 : img.shape[1], 0 : img.shape[2] + ] + if full_probs is None: + full_probs = prediction.new_zeros( + prediction.shape[0], image_size[0], image_size[1] + ) + if count_predictions is None: + count_predictions = prediction.new_zeros( + prediction.shape[0], image_size[0], image_size[1] + ) + count_predictions[:, y1:y2, x1:x2] += 1 + full_probs[ + :, y1:y2, x1:x2 + ] += prediction # accumulate the predictions also in the overlapping regions + + full_probs /= count_predictions + full_probs = sem_seg_postprocess( + full_probs, + image_size, + input.get("height", image_size[0]), + input.get("width", image_size[1]), + ) + outputs.append({"sem_seg": full_probs}) + + return outputs + else: + log_first_n(logging.INFO, "Using whole image to test") + return self.model(inputs) + + def _batch_inference(self, batched_inputs): + """ + Execute inference on a list of inputs, + using batch size = self.batch_size, instead of the length of the list. + Inputs & outputs have the same format as :meth:`SemanticSegmentor.forward` + """ + outputs = [] + inputs = [] + for idx, input in zip(count(), batched_inputs): + inputs.append(input) + if len(inputs) == self.batch_size or idx == len(batched_inputs) - 1: + with torch.no_grad(): + outputs.extend(self._inference_with_model(inputs)) + inputs = [] + return outputs + + def __call__(self, batched_inputs): + """ + Same input/output format as :meth:`SemanticSegmentor.forward` + """ + + def _maybe_read_image(dataset_dict): + ret = copy.copy(dataset_dict) + if "image" not in ret: + image = read_image(ret.pop("file_name"), self.model.input_format) + image = torch.from_numpy( + np.ascontiguousarray(image.transpose(2, 0, 1)) + ) # CHW + ret["image"] = image + if "height" not in ret and "width" not in ret: + ret["height"] = image.shape[1] + ret["width"] = image.shape[2] + return ret + + return [self._inference_one_image(_maybe_read_image(x)) for x in batched_inputs] + + def _inference_one_image(self, input): + """ + Args: + input (dict): one dataset dict with "image" field being a CHW tensor + Returns: + dict: one output dict + """ + augmented_inputs, tfms = self._get_augmented_inputs(input) + # 1: forward with all augmented images + outputs = self._batch_inference(augmented_inputs) + # Delete now useless variables to avoid being out of memory + del augmented_inputs + # 2: merge the results + # handle flip specially + # outputs = [output.detach() for output in outputs] + return self._merge_auged_output(outputs, tfms) + + def _merge_auged_output(self, outputs, tfms): + new_outputs = [] + for output, tfm in zip(outputs, tfms): + if any(isinstance(t, HFlipTransform) for t in tfm.transforms): + new_outputs.append(output["sem_seg"].flip(dims=[2])) + else: + new_outputs.append(output["sem_seg"]) + del outputs + # to avoid OOM with torch.stack + final_predictions = new_outputs[0] + for i in range(1, len(new_outputs)): + final_predictions += new_outputs[i] + final_predictions = final_predictions / len(new_outputs) + del new_outputs + return {"sem_seg": final_predictions} + + def _get_augmented_inputs(self, input): + augmented_inputs = self.tta_mapper(input) + tfms = [x.pop("transforms") for x in augmented_inputs] + return augmented_inputs, tfms diff --git a/open_vocab_seg/utils/__init__.py b/open_vocab_seg/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dcf832dce405bbdcf45f2534a782494b37760cd9 --- /dev/null +++ b/open_vocab_seg/utils/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +from .events import setup_wandb, WandbWriter +from .predictor import VisualizationDemo, SAMVisualizationDemo \ No newline at end of file diff --git a/open_vocab_seg/utils/events.py b/open_vocab_seg/utils/events.py new file mode 100644 index 0000000000000000000000000000000000000000..cbe82ce80a7110a1018167763ba3adc90f58faa0 --- /dev/null +++ b/open_vocab_seg/utils/events.py @@ -0,0 +1,121 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +import os +import wandb +from detectron2.utils import comm +from detectron2.utils.events import EventWriter, get_event_storage + + +def setup_wandb(cfg, args): + if comm.is_main_process(): + init_args = { + k.lower(): v + for k, v in cfg.WANDB.items() + if isinstance(k, str) and k not in ["config", "name"] + } + # only include most related part to avoid too big table + # TODO: add configurable params to select which part of `cfg` should be saved in config + if "config_exclude_keys" in init_args: + init_args["config"] = cfg + init_args["config"]["cfg_file"] = args.config_file + else: + init_args["config"] = { + "model": cfg.MODEL, + "solver": cfg.SOLVER, + "cfg_file": args.config_file, + } + if ("name" not in init_args) or (init_args["name"] is None): + init_args["name"] = os.path.basename(args.config_file) + wandb.init(**init_args) + + +class BaseRule(object): + def __call__(self, target): + return target + + +class IsIn(BaseRule): + def __init__(self, keyword: str): + self.keyword = keyword + + def __call__(self, target): + return self.keyword in target + + +class Prefix(BaseRule): + def __init__(self, keyword: str): + self.keyword = keyword + + def __call__(self, target): + return "/".join([self.keyword, target]) + + +class WandbWriter(EventWriter): + """ + Write all scalars to a tensorboard file. + """ + + def __init__(self): + """ + Args: + log_dir (str): the directory to save the output events + kwargs: other arguments passed to `torch.utils.tensorboard.SummaryWriter(...)` + """ + self._last_write = -1 + self._group_rules = [ + (IsIn("/"), BaseRule()), + (IsIn("loss"), Prefix("train")), + ] + + def write(self): + + storage = get_event_storage() + + def _group_name(scalar_name): + for (rule, op) in self._group_rules: + if rule(scalar_name): + return op(scalar_name) + return scalar_name + + stats = { + _group_name(name): scalars[0] + for name, scalars in storage.latest().items() + if scalars[1] > self._last_write + } + if len(stats) > 0: + self._last_write = max([v[1] for k, v in storage.latest().items()]) + + # storage.put_{image,histogram} is only meant to be used by + # tensorboard writer. So we access its internal fields directly from here. + if len(storage._vis_data) >= 1: + stats["image"] = [ + wandb.Image(img, caption=img_name) + for img_name, img, step_num in storage._vis_data + ] + # Storage stores all image data and rely on this writer to clear them. + # As a result it assumes only one writer will use its image data. + # An alternative design is to let storage store limited recent + # data (e.g. only the most recent image) that all writers can access. + # In that case a writer may not see all image data if its period is long. + storage.clear_images() + + if len(storage._histograms) >= 1: + + def create_bar(tag, bucket_limits, bucket_counts, **kwargs): + data = [ + [label, val] for (label, val) in zip(bucket_limits, bucket_counts) + ] + table = wandb.Table(data=data, columns=["label", "value"]) + return wandb.plot.bar(table, "label", "value", title=tag) + + stats["hist"] = [create_bar(**params) for params in storage._histograms] + + storage.clear_histograms() + + if len(stats) == 0: + return + wandb.log(stats, step=storage.iter) + + def close(self): + wandb.finish() diff --git a/open_vocab_seg/utils/misc.py b/open_vocab_seg/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..a22d0a978c9cd89595c6e7c900885e1c148844b1 --- /dev/null +++ b/open_vocab_seg/utils/misc.py @@ -0,0 +1,126 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/util/misc.py +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +""" +Misc functions, including distributed helpers. + +Mostly copy-paste from torchvision references. +""" +from typing import List, Optional + +import torch +import torch.distributed as dist +import torchvision +from torch import Tensor + + +def _max_by_axis(the_list): + # type: (List[List[int]]) -> List[int] + maxes = the_list[0] + for sublist in the_list[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + + +class NestedTensor(object): + def __init__(self, tensors, mask: Optional[Tensor]): + self.tensors = tensors + self.mask = mask + + def to(self, device): + # type: (Device) -> NestedTensor # noqa + cast_tensor = self.tensors.to(device) + mask = self.mask + if mask is not None: + assert mask is not None + cast_mask = mask.to(device) + else: + cast_mask = None + return NestedTensor(cast_tensor, cast_mask) + + def decompose(self): + return self.tensors, self.mask + + def __repr__(self): + return str(self.tensors) + + +def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): + # TODO make this more general + if tensor_list[0].ndim == 3: + if torchvision._is_tracing(): + # nested_tensor_from_tensor_list() does not export well to ONNX + # call _onnx_nested_tensor_from_tensor_list() instead + return _onnx_nested_tensor_from_tensor_list(tensor_list) + + # TODO make it support different-sized images + max_size = _max_by_axis([list(img.shape) for img in tensor_list]) + # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) + batch_shape = [len(tensor_list)] + max_size + b, c, h, w = batch_shape + dtype = tensor_list[0].dtype + device = tensor_list[0].device + tensor = torch.zeros(batch_shape, dtype=dtype, device=device) + mask = torch.ones((b, h, w), dtype=torch.bool, device=device) + for img, pad_img, m in zip(tensor_list, tensor, mask): + pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + m[: img.shape[1], : img.shape[2]] = False + else: + raise ValueError("not supported") + return NestedTensor(tensor, mask) + + +# _onnx_nested_tensor_from_tensor_list() is an implementation of +# nested_tensor_from_tensor_list() that is supported by ONNX tracing. +@torch.jit.unused +def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: + max_size = [] + for i in range(tensor_list[0].dim()): + max_size_i = torch.max( + torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32) + ).to(torch.int64) + max_size.append(max_size_i) + max_size = tuple(max_size) + + # work around for + # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + # m[: img.shape[1], :img.shape[2]] = False + # which is not yet supported in onnx + padded_imgs = [] + padded_masks = [] + for img in tensor_list: + padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] + padded_img = torch.nn.functional.pad( + img, (0, padding[2], 0, padding[1], 0, padding[0]) + ) + padded_imgs.append(padded_img) + + m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) + padded_mask = torch.nn.functional.pad( + m, (0, padding[2], 0, padding[1]), "constant", 1 + ) + padded_masks.append(padded_mask.to(torch.bool)) + + tensor = torch.stack(padded_imgs) + mask = torch.stack(padded_masks) + + return NestedTensor(tensor, mask=mask) + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + +def get_gt_binary_masks(gt_semseg): + mask_ids = torch.unique(gt_semseg) + gt_masks = [] + for id in mask_ids: + if id != 255: + gt_masks.append(gt_semseg == id) + gt_masks = torch.stack(gt_masks).float() + return gt_masks diff --git a/open_vocab_seg/utils/post_process_utils.py b/open_vocab_seg/utils/post_process_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ed214319d90ceba0b47ef835072102b9ffec5179 --- /dev/null +++ b/open_vocab_seg/utils/post_process_utils.py @@ -0,0 +1,74 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +import torch +from torch.nn import functional as F +import numpy as np + +try: + import pydensecrf.densecrf as dcrf + from pydensecrf.utils import ( + unary_from_softmax, + unary_from_labels, + create_pairwise_bilateral, + create_pairwise_gaussian, + ) +except: + dcrf = None + + +def dense_crf_post_process( + logits, + image, + n_labels=None, + max_iters=5, + pos_xy_std=(3, 3), + pos_w=3, + bi_xy_std=(80, 80), + bi_rgb_std=(13, 13, 13), + bi_w=10, +): + """ + logits : [C,H,W] + image : [3,H,W] + """ + if dcrf is None: + raise FileNotFoundError( + "pydensecrf is required to perform dense crf inference." + ) + if isinstance(logits, torch.Tensor): + logits = F.softmax(logits, dim=0).detach().cpu().numpy() + U = unary_from_softmax(logits) + n_labels = logits.shape[0] + elif logits.ndim == 3: + U = unary_from_softmax(logits) + n_labels = logits.shape[0] + else: + assert n_labels is not None + U = unary_from_labels(logits, n_labels, zero_unsure=False) + + d = dcrf.DenseCRF2D(image.shape[1], image.shape[0], n_labels) + + d.setUnaryEnergy(U) + + # This adds the color-independent term, features are the locations only. + d.addPairwiseGaussian( + sxy=pos_xy_std, + compat=pos_w, + kernel=dcrf.DIAG_KERNEL, + normalization=dcrf.NORMALIZE_SYMMETRIC, + ) + + # This adds the color-dependent term, i.e. features are (x,y,r,g,b). + d.addPairwiseBilateral( + sxy=bi_xy_std, + srgb=bi_rgb_std, + rgbim=image, + compat=bi_w, + kernel=dcrf.DIAG_KERNEL, + normalization=dcrf.NORMALIZE_SYMMETRIC, + ) + # Run five inference steps. + logits = d.inference(max_iters) + logits = np.asarray(logits).reshape((n_labels, image.shape[0], image.shape[1])) + return torch.from_numpy(logits) diff --git a/open_vocab_seg/utils/predictor.py b/open_vocab_seg/utils/predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..8f6cd8f86af6995768a27849f37baab82d6486cf --- /dev/null +++ b/open_vocab_seg/utils/predictor.py @@ -0,0 +1,232 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +import numpy as np +import torch +from torch.nn import functional as F +import cv2 + +from detectron2.data import MetadataCatalog +from detectron2.structures import BitMasks +from detectron2.engine.defaults import DefaultPredictor +from detectron2.utils.visualizer import ColorMode, Visualizer +from detectron2.modeling.postprocessing import sem_seg_postprocess + +import open_clip +from segment_anything import SamAutomaticMaskGenerator, sam_model_registry +from open_vocab_seg.modeling.clip_adapter.adapter import PIXEL_MEAN, PIXEL_STD +from open_vocab_seg.modeling.clip_adapter.utils import crop_with_mask + +class OVSegPredictor(DefaultPredictor): + def __init__(self, cfg): + super().__init__(cfg) + + def __call__(self, original_image, class_names): + """ + Args: + original_image (np.ndarray): an image of shape (H, W, C) (in BGR order). + + Returns: + predictions (dict): + the output of the model for one image only. + See :doc:`/tutorials/models` for details about the format. + """ + with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258 + # Apply pre-processing to image. + if self.input_format == "RGB": + # whether the model expects BGR inputs or RGB + original_image = original_image[:, :, ::-1] + height, width = original_image.shape[:2] + image = self.aug.get_transform(original_image).apply_image(original_image) + image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1)) + + inputs = {"image": image, "height": height, "width": width, "class_names": class_names} + predictions = self.model([inputs])[0] + return predictions + +class OVSegVisualizer(Visualizer): + def __init__(self, img_rgb, metadata=None, scale=1.0, instance_mode=ColorMode.IMAGE, class_names=None): + super().__init__(img_rgb, metadata, scale, instance_mode) + self.class_names = class_names + + def draw_sem_seg(self, sem_seg, area_threshold=None, alpha=0.8): + """ + Draw semantic segmentation predictions/labels. + + Args: + sem_seg (Tensor or ndarray): the segmentation of shape (H, W). + Each value is the integer label of the pixel. + area_threshold (int): segments with less than `area_threshold` are not drawn. + alpha (float): the larger it is, the more opaque the segmentations are. + + Returns: + output (VisImage): image object with visualizations. + """ + if isinstance(sem_seg, torch.Tensor): + sem_seg = sem_seg.numpy() + labels, areas = np.unique(sem_seg, return_counts=True) + sorted_idxs = np.argsort(-areas).tolist() + labels = labels[sorted_idxs] + class_names = self.class_names if self.class_names is not None else self.metadata.stuff_classes + + for label in filter(lambda l: l < len(class_names), labels): + try: + mask_color = [x / 255 for x in self.metadata.stuff_colors[label]] + except (AttributeError, IndexError): + mask_color = None + + binary_mask = (sem_seg == label).astype(np.uint8) + text = class_names[label] + self.draw_binary_mask( + binary_mask, + color=mask_color, + edge_color=(1.0, 1.0, 240.0 / 255), + text=text, + alpha=alpha, + area_threshold=area_threshold, + ) + return self.output + + + +class VisualizationDemo(object): + def __init__(self, cfg, instance_mode=ColorMode.IMAGE, parallel=False): + """ + Args: + cfg (CfgNode): + instance_mode (ColorMode): + parallel (bool): whether to run the model in different processes from visualization. + Useful since the visualization logic can be slow. + """ + self.metadata = MetadataCatalog.get( + cfg.DATASETS.TEST[0] if len(cfg.DATASETS.TEST) else "__unused" + ) + + self.cpu_device = torch.device("cpu") + self.instance_mode = instance_mode + + self.parallel = parallel + if parallel: + raise NotImplementedError + else: + self.predictor = OVSegPredictor(cfg) + + def run_on_image(self, image, class_names): + """ + Args: + image (np.ndarray): an image of shape (H, W, C) (in BGR order). + This is the format used by OpenCV. + Returns: + predictions (dict): the output of the model. + vis_output (VisImage): the visualized image output. + """ + predictions = self.predictor(image, class_names) + # Convert image from OpenCV BGR format to Matplotlib RGB format. + image = image[:, :, ::-1] + visualizer = OVSegVisualizer(image, self.metadata, instance_mode=self.instance_mode, class_names=class_names) + if "sem_seg" in predictions: + r = predictions["sem_seg"] + blank_area = (r[0] == 0) + pred_mask = r.argmax(dim=0).to('cpu') + pred_mask[blank_area] = 255 + pred_mask = np.array(pred_mask, dtype=np.int) + + vis_output = visualizer.draw_sem_seg( + pred_mask + ) + else: + raise NotImplementedError + + return predictions, vis_output + +class SAMVisualizationDemo(object): + def __init__(self, cfg, granularity, sam_path, ovsegclip_path, instance_mode=ColorMode.IMAGE, parallel=False): + self.metadata = MetadataCatalog.get( + cfg.DATASETS.TEST[0] if len(cfg.DATASETS.TEST) else "__unused" + ) + + self.cpu_device = torch.device("cpu") + self.instance_mode = instance_mode + + self.parallel = parallel + self.granularity = granularity + sam = sam_model_registry["vit_l"](checkpoint=sam_path).cuda() + self.predictor = SamAutomaticMaskGenerator(sam, points_per_batch=16) + self.clip_model, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained=ovsegclip_path) + self.clip_model.cuda() + + def run_on_image(self, image, class_names): + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + visualizer = OVSegVisualizer(image, self.metadata, instance_mode=self.instance_mode, class_names=class_names) + with torch.no_grad(), torch.cuda.amp.autocast(): + masks = self.predictor.generate(image) + pred_masks = [masks[i]['segmentation'][None,:,:] for i in range(len(masks))] + pred_masks = np.row_stack(pred_masks) + pred_masks = BitMasks(pred_masks) + bboxes = pred_masks.get_bounding_boxes() + + mask_fill = [255.0 * c for c in PIXEL_MEAN] + + image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1)) + + regions = [] + for bbox, mask in zip(bboxes, pred_masks): + region, _ = crop_with_mask( + image, + mask, + bbox, + fill=mask_fill, + ) + regions.append(region.unsqueeze(0)) + regions = [F.interpolate(r.to(torch.float), size=(224, 224), mode="bicubic") for r in regions] + + pixel_mean = torch.tensor(PIXEL_MEAN).reshape(1, -1, 1, 1) + pixel_std = torch.tensor(PIXEL_STD).reshape(1, -1, 1, 1) + imgs = [(r/255.0 - pixel_mean) / pixel_std for r in regions] + imgs = torch.cat(imgs) + if len(class_names) == 1: + class_names.append('others') + txts = [f'a photo of {cls_name}' for cls_name in class_names] + text = open_clip.tokenize(txts) + + img_batches = torch.split(imgs, 32, dim=0) + + with torch.no_grad(), torch.cuda.amp.autocast(): + text_features = self.clip_model.encode_text(text.cuda()) + text_features /= text_features.norm(dim=-1, keepdim=True) + image_features = [] + for img_batch in img_batches: + image_feat = self.clip_model.encode_image(img_batch.cuda().half()) + image_feat /= image_feat.norm(dim=-1, keepdim=True) + image_features.append(image_feat.detach()) + image_features = torch.cat(image_features, dim=0) + class_preds = (100.0 * image_features @ text_features.T).softmax(dim=-1) + select_cls = torch.zeros_like(class_preds) + + max_scores, select_mask = torch.max(class_preds, dim=0) + if len(class_names) == 2 and class_names[-1] == 'others': + select_mask = select_mask[:-1] + if self.granularity < 1: + thr_scores = max_scores * self.granularity + select_mask = [] + if len(class_names) == 2 and class_names[-1] == 'others': + thr_scores = thr_scores[:-1] + for i, thr in enumerate(thr_scores): + cls_pred = class_preds[:,i] + locs = torch.where(cls_pred > thr) + select_mask.extend(locs[0].tolist()) + for idx in select_mask: + select_cls[idx] = class_preds[idx] + semseg = torch.einsum("qc,qhw->chw", select_cls.float(), pred_masks.tensor.float().cuda()) + + r = semseg + blank_area = (r[0] == 0) + pred_mask = r.argmax(dim=0).to('cpu') + pred_mask[blank_area] = 255 + pred_mask = np.array(pred_mask, dtype=np.int) + + vis_output = visualizer.draw_sem_seg( + pred_mask + ) + + return None, vis_output \ No newline at end of file diff --git a/ovseg_clip_l_9a1909.pth b/ovseg_clip_l_9a1909.pth new file mode 100644 index 0000000000000000000000000000000000000000..f88dad1269412ee0c449597330d31dbb6d3e1042 --- /dev/null +++ b/ovseg_clip_l_9a1909.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eb5cbc83b922e18241654a19ad4cb836cf4f00169cd5684a4932d8a6f825dd36 +size 1710616901 diff --git a/ovseg_swinB_vitL_demo.yaml b/ovseg_swinB_vitL_demo.yaml new file mode 100644 index 0000000000000000000000000000000000000000..aaeb0e2cc02a9a7bee7c02f6cfb313add8da9794 --- /dev/null +++ b/ovseg_swinB_vitL_demo.yaml @@ -0,0 +1,99 @@ +MODEL: + META_ARCHITECTURE: "OVSegDEMO" + BACKBONE: + FREEZE_AT: 0 + NAME: "D2SwinTransformer" + SWIN: + EMBED_DIM: 128 + DEPTHS: [2, 2, 18, 2] + NUM_HEADS: [4, 8, 16, 32] + WINDOW_SIZE: 12 + APE: False + DROP_PATH_RATE: 0.3 + PATCH_NORM: True + PRETRAIN_IMG_SIZE: 384 + WEIGHTS: "./ovseg_swinbase_vitL14_ft_mpt.pth" + PIXEL_MEAN: [123.675, 116.280, 103.530] + PIXEL_STD: [58.395, 57.120, 57.375] + SEM_SEG_HEAD: + NAME: "OpenVocabMaskFormerHead" + IN_FEATURES: ["res2", "res3", "res4", "res5"] + IGNORE_VALUE: 255 + NUM_CLASSES: 171 # number of categories in training set + EMBEDDING_DIM: 768 + EMBED_LAYERS: 2 + COMMON_STRIDE: 4 # not used, hard-coded + LOSS_WEIGHT: 1.0 + CONVS_DIM: 256 + MASK_DIM: 256 + NORM: "GN" + MASK_FORMER: + TRANSFORMER_IN_FEATURE: "res5" + DEEP_SUPERVISION: True + NO_OBJECT_WEIGHT: 0.1 + DICE_WEIGHT: 1.0 + MASK_WEIGHT: 20.0 + HIDDEN_DIM: 256 + NUM_OBJECT_QUERIES: 100 + NHEADS: 8 + DROPOUT: 0.1 + DIM_FEEDFORWARD: 2048 + ENC_LAYERS: 0 + DEC_LAYERS: 6 + PRE_NORM: False + CLIP_ADAPTER: + TEXT_TEMPLATES: "vild" + CLIP_MODEL_NAME: "ViT-L/14" + MASK_FILL: "mean" + MASK_EXPAND_RATIO: 1.0 + MASK_THR: 0.35 # choose the foreground objects + MASK_MATTING: False # use soft background, default not used + MASK_PROMPT_DEPTH: 3 + MASK_PROMPT_FWD: True # use mask prompt during forward + REGION_RESIZED: True # resize to the input of clip, e.g., 224 + CLIP_ENSEMBLE: True # use ensemble of two classification branches + CLIP_ENSEMBLE_WEIGHT: 0.0 +DATASETS: + TRAIN: ("coco_2017_train_stuff_sem_seg",) + TEST: ("ade20k_sem_seg_val",) +SOLVER: + IMS_PER_BATCH: 32 + BASE_LR: 0.00006 + MAX_ITER: 120000 + WARMUP_FACTOR: 1e-6 + WARMUP_ITERS: 1500 + WEIGHT_DECAY: 0.01 + WEIGHT_DECAY_NORM: 0.0 + WEIGHT_DECAY_EMBED: 0.0 + BACKBONE_MULTIPLIER: 1.0 + TEST_IMS_PER_BATCH: 1 + CLIP_GRADIENTS: + ENABLED: True + CLIP_TYPE: "full_model" + CLIP_VALUE: 0.01 + NORM_TYPE: 2.0 +INPUT: + MIN_SIZE_TRAIN: !!python/object/apply:eval ["[int(x * 0.1 * 640) for x in range(5, 21)]"] + MIN_SIZE_TRAIN_SAMPLING: "choice" + MIN_SIZE_TEST: 640 + MAX_SIZE_TRAIN: 2560 + MAX_SIZE_TEST: 2560 + CROP: + ENABLED: True + TYPE: "absolute" + SIZE: (640, 640) + SINGLE_CATEGORY_MAX_AREA: 1.0 + COLOR_AUG_SSD: True + SIZE_DIVISIBILITY: 640 # used in dataset mapper + FORMAT: "RGB" +TEST: + EVAL_PERIOD: 5000 + AUG: + ENABLED: False + MIN_SIZES: [256, 384, 512, 640, 768, 896] + MAX_SIZE: 3584 + FLIP: True +DATALOADER: + FILTER_EMPTY_ANNOTATIONS: True + NUM_WORKERS: 4 +VERSION: 2 \ No newline at end of file diff --git a/ovseg_swinbase_vitL14_ft_mpt.pth b/ovseg_swinbase_vitL14_ft_mpt.pth new file mode 100644 index 0000000000000000000000000000000000000000..0d2dcc4c4e721b187574f4c3829c58236713037a --- /dev/null +++ b/ovseg_swinbase_vitL14_ft_mpt.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dd3731dde48d96654aba63e5a93753dc837d6889162a18ddf0877f5463d94c90 +size 2129343629 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..e60263206e63eced27619991b26cd842f6bc5649 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,28 @@ +numpy>=1.18.5,<1.24.0 +cython +scipy +shapely +timm +h5py +wandb +fire +opencv-python +pandas +ftfy +regex +tqdm +gdown +# Torch +--find-links https://download.pytorch.org/whl/cu113/torch_stable.html +torch==1.10.1+cu113 +torchvision==0.11.2+cu113 + +# Detectron +--find-links https://dl.fbaipublicfiles.com/detectron2/wheels/cu113/torch1.10/index.html +detectron2 + +# Segment-anything +git+https://github.com/facebookresearch/segment-anything.git + +# open_clip +open_clip_torch==1.3.0 diff --git a/resources/demo_samples/sample_01.jpeg b/resources/demo_samples/sample_01.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..b3e0245ab25117d340bf23059af6b5dcbfc8a811 --- /dev/null +++ b/resources/demo_samples/sample_01.jpeg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:154943906b5ed394b620da62124c4421dfa96f858f014839eb346678aaa71fc3 +size 4323630 diff --git a/resources/demo_samples/sample_02.jpeg b/resources/demo_samples/sample_02.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..e5489289e1128c33060f1d30a353452907f0a1d8 --- /dev/null +++ b/resources/demo_samples/sample_02.jpeg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:591c2bf26a843a62881d89dbd7f4e9a6f90dda9fb8786c9b6e5172a28623d1b0 +size 1840881 diff --git a/resources/demo_samples/sample_03.jpeg b/resources/demo_samples/sample_03.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..d34147db0c2888b6c77e92202efd8acb1e8d0f36 --- /dev/null +++ b/resources/demo_samples/sample_03.jpeg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:33e5c7054300d5cf1871a33972416504c05dbddff238b32ff884525bcbfca695 +size 7324740 diff --git a/resources/demo_samples/sample_04.png b/resources/demo_samples/sample_04.png new file mode 100644 index 0000000000000000000000000000000000000000..fda5012edd98754db1e8a3cfbfbea287a41a9299 --- /dev/null +++ b/resources/demo_samples/sample_04.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:072e7e119d437c0c1fa731c40c737780b7725bd49041ef7466153e1ee7045920 +size 7903759 diff --git a/resources/demo_samples/sample_05.png b/resources/demo_samples/sample_05.png new file mode 100644 index 0000000000000000000000000000000000000000..18bd12d1862f474cf0c9fc26edd11f55e83f34d8 --- /dev/null +++ b/resources/demo_samples/sample_05.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1331dfcef69066c225d34c659f756a92ce3dc71965978db67814eda36b1cdc5f +size 2645089 diff --git a/resources/ovseg.gif b/resources/ovseg.gif new file mode 100644 index 0000000000000000000000000000000000000000..9d77dbd81f06dc65346cfd2d7a1b4742ff0597f8 --- /dev/null +++ b/resources/ovseg.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:828e23424f7b494c6dad079d52551a4bb9a4cfb292dcec0acd376b89c5944128 +size 3789733 diff --git a/resources/proposal.png b/resources/proposal.png new file mode 100644 index 0000000000000000000000000000000000000000..4ebf9a6ae0163ad1b733fe9cd15537ea7a016c72 --- /dev/null +++ b/resources/proposal.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b8cacaabbc9dd7b7b5a7f974975128a2ea604759606dacc45b90a3d67b18d8e8 +size 194338 diff --git a/resources/pytorch-logo-dark.png b/resources/pytorch-logo-dark.png new file mode 100644 index 0000000000000000000000000000000000000000..8cef4518ca3ea7cda3b046082a6035541e0f07fd --- /dev/null +++ b/resources/pytorch-logo-dark.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8211f1b771de99ae379db83350327139a597c0f99d6b6312e81e977d4d413c44 +size 15625 diff --git a/sam_vit_h_4b8939.pth b/sam_vit_h_4b8939.pth new file mode 100644 index 0000000000000000000000000000000000000000..8523acce9ddab1cf7e355628a08b1aab8ce08a72 --- /dev/null +++ b/sam_vit_h_4b8939.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e +size 2564550879 diff --git a/sam_vit_l_0b3195.pth b/sam_vit_l_0b3195.pth new file mode 100644 index 0000000000000000000000000000000000000000..87a638d6b789dd2b10fc7414a88dacc34a50769a --- /dev/null +++ b/sam_vit_l_0b3195.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3adcc4315b642a4d2101128f611684e8734c41232a17c648ed1693702a49a622 +size 1249524607