diff --git a/INSTALL.md b/INSTALL.md
new file mode 100644
index 0000000000000000000000000000000000000000..684c21171f6fc40b5febd995d45604643374c540
--- /dev/null
+++ b/INSTALL.md
@@ -0,0 +1,20 @@
+## Installation
+
+### Requirements
+- Linux or macOS with Python ≥ 3.6
+- PyTorch ≥ 1.7 and [torchvision](https://github.com/pytorch/vision/) that matches the PyTorch installation.
+ Install them together at [pytorch.org](https://pytorch.org) to make sure of this. Note, please check
+ PyTorch version matches that is required by Detectron2.
+- Detectron2: follow [Detectron2 installation instructions](https://detectron2.readthedocs.io/tutorials/install.html).
+- OpenCV is optional but needed by demo and visualization
+- `pip install -r requirements.txt`
+
+An example of installation is shown below:
+
+```
+git clone https://github.com/~~~/CAT-Seg.git
+cd CAT-Seg
+conda create -n catseg python=3.8
+conda activate catseg
+pip install -r requirements.txt
+```
\ No newline at end of file
diff --git a/R-101.pkl b/R-101.pkl
new file mode 100755
index 0000000000000000000000000000000000000000..0df951db75f8313ee09a474d13bda301b3d8d25c
--- /dev/null
+++ b/R-101.pkl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1156c77bff95ecb027060b5c83391b45bf159acd7f5bf7eacb656be0c1f0ab55
+size 178666803
diff --git a/README.md b/README.md
index 77feaeb4c36fe78354e1df55a2f113a829424985..d66bb3afddea8db8c0a527d06a3fac9e60f19f5f 100644
--- a/README.md
+++ b/README.md
@@ -1,12 +1,48 @@
----
-title: SAM CAT Seg
-emoji: 📚
-colorFrom: purple
-colorTo: blue
-sdk: gradio
-sdk_version: 3.29.0
-app_file: app.py
-pinned: false
----
-
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+# CAT-Seg🐱: Cost Aggregation for Open-Vocabulary Semantic Segmentation
+
+This is our official implementation of CAT-Seg🐱!
+
+[[arXiv](#)] [[Project](#)]
+by [Seokju Cho](https://seokju-cho.github.io/)\*, [Heeseong Shin](https://github.com/hsshin98)\*, [Sunghwan Hong](https://sunghwanhong.github.io), Seungjun An, Seungjun Lee, [Anurag Arnab](https://anuragarnab.github.io), [Paul Hongsuck Seo](https://phseo.github.io), [Seungryong Kim](https://cvlab.korea.ac.kr)
+
+
+## Introduction
+![](assets/fig1.png)
+We introduce cost aggregation to open-vocabulary semantic segmentation, which jointly aggregates both image and text modalities within the matching cost.
+
+## Installation
+Install required packages.
+
+```bash
+conda create --name catseg python=3.8
+conda activate catseg
+conda install pytorch==1.10.1 torchvision==0.11.2 torchaudio==0.10.1 cudatoolkit=11.3 -c pytorch -c conda-forge
+pip install -r requirements.txt
+```
+
+## Data Preparation
+
+
+## Training
+### Preparation
+you have to blah
+### Training script
+```bash
+python train.py --config configs/eval/{a847 | pc459 | a150 | pc59 | pas20 | pas20b}.yaml
+```
+
+## Evaluation
+```bash
+python eval.py --config configs/eval/{a847 | pc459 | a150 | pc59 | pas20 | pas20b}.yaml
+```
+
+## Citing CAT-Seg🐱 :pray:
+
+```BibTeX
+@article{liang2022open,
+ title={Open-Vocabulary Semantic Segmentation with Mask-adapted CLIP},
+ author={Liang, Feng and Wu, Bichen and Dai, Xiaoliang and Li, Kunpeng and Zhao, Yinan and Zhang, Hang and Zhang, Peizhao and Vajda, Peter and Marculescu, Diana},
+ journal={arXiv preprint arXiv:2210.04150},
+ year={2022}
+}
+```
\ No newline at end of file
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b6b5db9996bcd9c4b6d0ac502082b6786cc9cc2
--- /dev/null
+++ b/app.py
@@ -0,0 +1,130 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Modified by Bowen Cheng from: https://github.com/facebookresearch/detectron2/blob/master/demo/demo.py
+import argparse
+import glob
+import multiprocessing as mp
+import os
+#os.environ["CUDA_VISIBLE_DEVICES"] = ""
+try:
+ import detectron2
+except ModuleNotFoundError:
+ os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
+
+try:
+ import segment_anything
+except ModuleNotFoundError:
+ os.system('pip install git+https://github.com/facebookresearch/segment-anything.git')
+
+# fmt: off
+import sys
+sys.path.insert(1, os.path.join(sys.path[0], '..'))
+# fmt: on
+
+import tempfile
+import time
+import warnings
+
+import cv2
+import numpy as np
+import tqdm
+
+from detectron2.config import get_cfg
+from detectron2.data.detection_utils import read_image
+from detectron2.projects.deeplab import add_deeplab_config
+from detectron2.utils.logger import setup_logger
+
+from cat_seg import add_cat_seg_config
+from demo.predictor import VisualizationDemo
+import gradio as gr
+import torch
+from matplotlib.backends.backend_agg import FigureCanvasAgg as fc
+
+# constants
+WINDOW_NAME = "MaskFormer demo"
+
+
+def setup_cfg(args):
+ # load config from file and command-line arguments
+ cfg = get_cfg()
+ add_deeplab_config(cfg)
+ add_cat_seg_config(cfg)
+ cfg.merge_from_file(args.config_file)
+ cfg.merge_from_list(args.opts)
+ if torch.cuda.is_available():
+ cfg.MODEL.DEVICE = "cuda"
+ cfg.freeze()
+ return cfg
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(description="Detectron2 demo for builtin configs")
+ parser.add_argument(
+ "--config-file",
+ default="configs/vitl_swinb_384.yaml",
+ metavar="FILE",
+ help="path to config file",
+ )
+ parser.add_argument(
+ "--input",
+ nargs="+",
+ help="A list of space separated input images; "
+ "or a single glob pattern such as 'directory/*.jpg'",
+ )
+ parser.add_argument(
+ "--opts",
+ help="Modify config options using the command-line 'KEY VALUE' pairs",
+ default=(
+ [
+ "MODEL.WEIGHTS", "model_final_cls.pth",
+ "MODEL.SEM_SEG_HEAD.TRAIN_CLASS_JSON", "datasets/voc20.json",
+ "MODEL.SEM_SEG_HEAD.TEST_CLASS_JSON", "datasets/voc20.json",
+ "TEST.SLIDING_WINDOW", "True",
+ "MODEL.SEM_SEG_HEAD.POOLING_SIZES", "[1,1]",
+ "MODEL.PROMPT_ENSEMBLE_TYPE", "single",
+ "MODEL.DEVICE", "cpu",
+ ]),
+ nargs=argparse.REMAINDER,
+ )
+ return parser
+
+def save_masks(preds, text):
+ preds = preds['sem_seg'].argmax(dim=0).cpu().numpy() # C H W
+ for i, t in enumerate(text):
+ dir = f"mask_{t}.png"
+ mask = preds == i
+ cv2.imwrite(dir, mask * 255)
+
+def predict(image, text, model_type):
+ #import pdb; pdb.set_trace()
+ #use_sam = True #
+ use_sam = model_type != "CAT-Seg"
+
+ predictions, visualized_output = demo.run_on_image(image, text, use_sam)
+ #save_masks(predictions, text.split(','))
+ canvas = fc(visualized_output.fig)
+ canvas.draw()
+ out = np.frombuffer(canvas.tostring_rgb(), dtype='uint8').reshape(canvas.get_width_height()[::-1] + (3,))
+
+ return out[..., ::-1]
+
+if __name__ == "__main__":
+ args = get_parser().parse_args()
+ cfg = setup_cfg(args)
+ global demo
+ demo = VisualizationDemo(cfg)
+
+ iface = gr.Interface(
+ fn=predict,
+ inputs=[gr.Image(), gr.Textbox(placeholder='background, cat, person'), ], #gr.Radio(["CAT-Seg", "Segment Anycat"], value="CAT-Seg")],
+ outputs="image",
+ description="""## Segment Anything with CAT-Seg!
+Welcome to the Segment Anything with CAT-Seg!
+
+In this demo, we combine state-of-the-art open-vocabulary semantic segmentation model, CAT-Seg with SAM(Segment Anything) for semantically labelling mask predictions from SAM.
+
+Please note that this is an optimized version of the full model, and as such, its performance may be limited compared to the full model.
+
+Also, the demo might run on a CPU depending on the demand, so it may take a little time to process your image.
+
+To get started, simply upload an image and a comma-separated list of categories, and let the model work its magic!""")
+ iface.launch()
diff --git a/assets/fig1.png b/assets/fig1.png
new file mode 100644
index 0000000000000000000000000000000000000000..0b0be9bf81880bf8404a6ebd017058c0de58636c
Binary files /dev/null and b/assets/fig1.png differ
diff --git a/cat_seg/__init__.py b/cat_seg/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e095a29ff5b655d58af6ac7ef920d4089f465f6
--- /dev/null
+++ b/cat_seg/__init__.py
@@ -0,0 +1,19 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+from . import data # register all new datasets
+from . import modeling
+
+# config
+from .config import add_cat_seg_config
+
+# dataset loading
+from .data.dataset_mappers.detr_panoptic_dataset_mapper import DETRPanopticDatasetMapper
+from .data.dataset_mappers.mask_former_panoptic_dataset_mapper import (
+ MaskFormerPanopticDatasetMapper,
+)
+from .data.dataset_mappers.mask_former_semantic_dataset_mapper import (
+ MaskFormerSemanticDatasetMapper,
+)
+
+# models
+from .cat_seg_model import CATSeg
+from .test_time_augmentation import SemanticSegmentorWithTTA
\ No newline at end of file
diff --git a/cat_seg/__pycache__/__init__.cpython-38.pyc b/cat_seg/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f0512f914d218655af766ee8a3d2e987c3a9fd0f
Binary files /dev/null and b/cat_seg/__pycache__/__init__.cpython-38.pyc differ
diff --git a/cat_seg/__pycache__/cat_sam_model.cpython-38.pyc b/cat_seg/__pycache__/cat_sam_model.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7597263e4a0206e7661c98608b131c870e6eff29
Binary files /dev/null and b/cat_seg/__pycache__/cat_sam_model.cpython-38.pyc differ
diff --git a/cat_seg/__pycache__/cat_seg_model.cpython-38.pyc b/cat_seg/__pycache__/cat_seg_model.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..15658e9ec1c5cb56fbd5de9697171c6d121c698e
Binary files /dev/null and b/cat_seg/__pycache__/cat_seg_model.cpython-38.pyc differ
diff --git a/cat_seg/__pycache__/cat_seg_panoptic.cpython-38.pyc b/cat_seg/__pycache__/cat_seg_panoptic.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d377feeda92cc8513e7b7d83cbdfe55f044e4779
Binary files /dev/null and b/cat_seg/__pycache__/cat_seg_panoptic.cpython-38.pyc differ
diff --git a/cat_seg/__pycache__/config.cpython-38.pyc b/cat_seg/__pycache__/config.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1834fef7d3edc0d927ddd6c9cb5beb7fa77dd931
Binary files /dev/null and b/cat_seg/__pycache__/config.cpython-38.pyc differ
diff --git a/cat_seg/__pycache__/pancat_model.cpython-38.pyc b/cat_seg/__pycache__/pancat_model.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..68a8ff2ed5be529d516e3505351773178baf9f60
Binary files /dev/null and b/cat_seg/__pycache__/pancat_model.cpython-38.pyc differ
diff --git a/cat_seg/__pycache__/test_time_augmentation.cpython-38.pyc b/cat_seg/__pycache__/test_time_augmentation.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e576a3fae904d38494a8919a0d47ee94151310aa
Binary files /dev/null and b/cat_seg/__pycache__/test_time_augmentation.cpython-38.pyc differ
diff --git a/cat_seg/cat_seg_model.py b/cat_seg/cat_seg_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..38bb78723c97dc38b3a597e3b344b37b3a36753f
--- /dev/null
+++ b/cat_seg/cat_seg_model.py
@@ -0,0 +1,386 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+from typing import Tuple
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from detectron2.config import configurable
+from detectron2.data import MetadataCatalog
+from detectron2.modeling import META_ARCH_REGISTRY, build_backbone, build_sem_seg_head
+from detectron2.modeling.backbone import Backbone
+from detectron2.modeling.postprocessing import sem_seg_postprocess
+from detectron2.structures import ImageList
+from detectron2.utils.memory import _ignore_torch_cuda_oom
+
+import numpy as np
+from einops import rearrange
+from segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator
+
+@META_ARCH_REGISTRY.register()
+class CATSeg(nn.Module):
+ @configurable
+ def __init__(
+ self,
+ *,
+ backbone: Backbone,
+ sem_seg_head: nn.Module,
+ size_divisibility: int,
+ pixel_mean: Tuple[float],
+ pixel_std: Tuple[float],
+ clip_pixel_mean: Tuple[float],
+ clip_pixel_std: Tuple[float],
+ train_class_json: str,
+ test_class_json: str,
+ sliding_window: bool,
+ clip_finetune: str,
+ backbone_multiplier: float,
+ clip_pretrained: str,
+ ):
+ """
+ Args:
+ backbone: a backbone module, must follow detectron2's backbone interface
+ sem_seg_head: a module that predicts semantic segmentation from backbone features
+ """
+ super().__init__()
+ self.backbone = backbone
+ self.sem_seg_head = sem_seg_head
+ if size_divisibility < 0:
+ size_divisibility = self.backbone.size_divisibility
+ self.size_divisibility = size_divisibility
+
+ self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
+ self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
+ self.register_buffer("clip_pixel_mean", torch.Tensor(clip_pixel_mean).view(-1, 1, 1), False)
+ self.register_buffer("clip_pixel_std", torch.Tensor(clip_pixel_std).view(-1, 1, 1), False)
+
+ self.train_class_json = train_class_json
+ self.test_class_json = test_class_json
+
+ self.clip_finetune = clip_finetune
+ for name, params in self.sem_seg_head.predictor.clip_model.named_parameters():
+ if "visual" in name:
+ if clip_finetune == "prompt":
+ params.requires_grad = True if "prompt" in name else False
+ elif clip_finetune == "attention":
+ params.requires_grad = True if "attn" in name or "position" in name else False
+ elif clip_finetune == "full":
+ params.requires_grad = True
+ else:
+ params.requires_grad = False
+ else:
+ params.requires_grad = False
+
+ finetune_backbone = backbone_multiplier > 0.
+ for name, params in self.backbone.named_parameters():
+ if "norm0" in name:
+ params.requires_grad = False
+ else:
+ params.requires_grad = finetune_backbone
+
+ self.sliding_window = sliding_window
+ self.clip_resolution = (384, 384) if clip_pretrained == "ViT-B/16" else (336, 336)
+ self.sequential = False
+
+ self.use_sam = False
+ self.sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth").to(self.device)
+
+ amg_kwargs = {
+ "points_per_side": 32,
+ "points_per_batch": None,
+ #"pred_iou_thresh": 0.0,
+ #"stability_score_thresh": 0.0,
+ "stability_score_offset": None,
+ "box_nms_thresh": None,
+ "crop_n_layers": None,
+ "crop_nms_thresh": None,
+ "crop_overlap_ratio": None,
+ "crop_n_points_downscale_factor": None,
+ "min_mask_region_area": None,
+ }
+ amg_kwargs = {k: v for k, v in amg_kwargs.items() if v is not None}
+ self.mask = SamAutomaticMaskGenerator(self.sam, output_mode="binary_mask", **amg_kwargs)
+ self.overlap_threshold = 0.8
+ self.panoptic_on = False
+
+ @classmethod
+ def from_config(cls, cfg):
+ backbone = build_backbone(cfg)
+ sem_seg_head = build_sem_seg_head(cfg, backbone.output_shape())
+
+ return {
+ "backbone": backbone,
+ "sem_seg_head": sem_seg_head,
+ "size_divisibility": cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY,
+ "pixel_mean": cfg.MODEL.PIXEL_MEAN,
+ "pixel_std": cfg.MODEL.PIXEL_STD,
+ "clip_pixel_mean": cfg.MODEL.CLIP_PIXEL_MEAN,
+ "clip_pixel_std": cfg.MODEL.CLIP_PIXEL_STD,
+ "train_class_json": cfg.MODEL.SEM_SEG_HEAD.TRAIN_CLASS_JSON,
+ "test_class_json": cfg.MODEL.SEM_SEG_HEAD.TEST_CLASS_JSON,
+ "sliding_window": cfg.TEST.SLIDING_WINDOW,
+ "clip_finetune": cfg.MODEL.SEM_SEG_HEAD.CLIP_FINETUNE,
+ "backbone_multiplier": cfg.SOLVER.BACKBONE_MULTIPLIER,
+ "clip_pretrained": cfg.MODEL.SEM_SEG_HEAD.CLIP_PRETRAINED,
+ }
+
+ @property
+ def device(self):
+ return self.pixel_mean.device
+
+ def forward(self, batched_inputs):
+ """
+ Args:
+ batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
+ Each item in the list contains the inputs for one image.
+ For now, each item in the list is a dict that contains:
+ * "image": Tensor, image in (C, H, W) format.
+ * "instances": per-region ground truth
+ * Other information that's included in the original dicts, such as:
+ "height", "width" (int): the output resolution of the model (may be different
+ from input resolution), used in inference.
+ Returns:
+ list[dict]:
+ each dict has the results for one image. The dict contains the following keys:
+
+ * "sem_seg":
+ A Tensor that represents the
+ per-pixel segmentation prediced by the head.
+ The prediction has shape KxHxW that represents the logits of
+ each class for each pixel.
+ """
+ images = [x["image"].to(self.device) for x in batched_inputs]
+ sam_images = images
+ if not self.training and self.sliding_window:
+ if not self.sequential:
+ with _ignore_torch_cuda_oom():
+ return self.inference_sliding_window(batched_inputs)
+ self.sequential = True
+ return self.inference_sliding_window(batched_inputs)
+
+ clip_images = [(x - self.clip_pixel_mean) / self.clip_pixel_std for x in images]
+ clip_images = ImageList.from_tensors(clip_images, self.size_divisibility)
+
+ images = [(x - self.pixel_mean) / self.pixel_std for x in images]
+ images = ImageList.from_tensors(images, self.size_divisibility)
+
+ clip_images = F.interpolate(clip_images.tensor, size=self.clip_resolution, mode='bilinear', align_corners=False, )
+ clip_features = self.sem_seg_head.predictor.clip_model.encode_image(clip_images, dense=True)
+
+ images_resized = F.interpolate(images.tensor, size=(384, 384), mode='bilinear', align_corners=False,)
+ features = self.backbone(images_resized)
+
+ outputs = self.sem_seg_head(clip_features, features)
+
+ if self.training:
+ targets = torch.stack([x["sem_seg"].to(self.device) for x in batched_inputs], dim=0)
+ outputs = F.interpolate(outputs, size=(targets.shape[-2], targets.shape[-1]), mode="bilinear", align_corners=False)
+
+ num_classes = outputs.shape[1]
+ mask = targets != self.sem_seg_head.ignore_value
+
+ outputs = outputs.permute(0,2,3,1)
+ _targets = torch.zeros(outputs.shape, device=self.device)
+ _onehot = F.one_hot(targets[mask], num_classes=num_classes).float()
+ _targets[mask] = _onehot
+
+ loss = F.binary_cross_entropy_with_logits(outputs, _targets)
+ losses = {"loss_sem_seg" : loss}
+ return losses
+ else:
+ #outputs = outputs.sigmoid()
+ image_size = images.image_sizes[0]
+ if self.use_sam:
+ masks = self.mask.generate(np.uint8(sam_images[0].permute(1, 2, 0).cpu().numpy()))
+ outputs, sam_cls = self.discrete_semantic_inference(outputs, masks, image_size)
+ #outputs, sam_cls = self.continuous_semantic_inference(outputs, masks, image_size)
+ #outputs, sam_cls = self.continuous_semantic_inference2(outputs, masks, image_size, img=img, text=text)
+ height = batched_inputs[0].get("height", image_size[0])
+ width = batched_inputs[0].get("width", image_size[1])
+
+ output = sem_seg_postprocess(outputs[0], image_size, height, width)
+ processed_results = [{'sem_seg': output}]
+ return processed_results
+
+
+ @torch.no_grad()
+ def inference_sliding_window(self, batched_inputs, kernel=384, overlap=0.333, out_res=[640, 640]):
+
+ images = [x["image"].to(self.device, dtype=torch.float32) for x in batched_inputs]
+ stride = int(kernel * (1 - overlap))
+ unfold = nn.Unfold(kernel_size=kernel, stride=stride)
+ fold = nn.Fold(out_res, kernel_size=kernel, stride=stride)
+
+ image = F.interpolate(images[0].unsqueeze(0), size=out_res, mode='bilinear', align_corners=False).squeeze()
+ sam_images = [image]
+ image = rearrange(unfold(image), "(C H W) L-> L C H W", C=3, H=kernel)
+ global_image = F.interpolate(images[0].unsqueeze(0), size=(kernel, kernel), mode='bilinear', align_corners=False)
+ image = torch.cat((image, global_image), dim=0)
+
+ images = (image - self.pixel_mean) / self.pixel_std
+ clip_images = (image - self.clip_pixel_mean) / self.clip_pixel_std
+ clip_images = F.interpolate(clip_images, size=self.clip_resolution, mode='bilinear', align_corners=False, )
+ clip_features = self.sem_seg_head.predictor.clip_model.encode_image(clip_images, dense=True)
+
+ if self.sequential:
+ outputs = []
+ for clip_feat, image in zip(clip_features, images):
+ feature = self.backbone(image.unsqueeze(0))
+ output = self.sem_seg_head(clip_feat.unsqueeze(0), feature)
+ outputs.append(output[0])
+ outputs = torch.stack(outputs, dim=0)
+ else:
+ features = self.backbone(images)
+ outputs = self.sem_seg_head(clip_features, features)
+
+ outputs = F.interpolate(outputs, size=kernel, mode="bilinear", align_corners=False)
+ outputs = outputs.sigmoid()
+
+ global_output = outputs[-1:]
+ global_output = F.interpolate(global_output, size=out_res, mode='bilinear', align_corners=False,)
+ outputs = outputs[:-1]
+ outputs = fold(outputs.flatten(1).T) / fold(unfold(torch.ones([1] + out_res, device=self.device)))
+ outputs = (outputs + global_output) / 2.
+
+ height = batched_inputs[0].get("height", out_res[0])
+ width = batched_inputs[0].get("width", out_res[1])
+ catseg_outputs = sem_seg_postprocess(outputs[0], out_res, height, width)
+ #catseg_outputs = catseg_outputs.argmax(dim=1)[0].cpu()
+
+ masks = self.mask.generate(np.uint8(sam_images[0].permute(1, 2, 0).cpu().numpy()))
+ if self.use_sam:
+ outputs, sam_cls = self.discrete_semantic_inference(outputs, masks, out_res)
+ #outputs, sam_cls = self.continuous_semantic_inference(outputs, masks, out_res)
+
+ output = sem_seg_postprocess(outputs[0], out_res, height, width)
+
+ ret = [{'sem_seg': output}]
+ if self.panoptic_on:
+ panoptic_r = self.panoptic_inference(catseg_outputs, masks, sam_cls, size=output.shape[-2:])
+ ret[0]['panoptic_seg'] = panoptic_r
+
+ return ret
+
+ def discrete_semantic_inference(self, outputs, masks, image_size):
+ catseg_outputs = F.interpolate(outputs, size=image_size, mode="bilinear", align_corners=True) #.argmax(dim=1)[0].cpu()
+ sam_outputs = torch.zeros_like(catseg_outputs).cpu()
+ catseg_outputs = catseg_outputs.argmax(dim=1)[0].cpu()
+ sam_classes = torch.zeros(len(masks))
+ for i in range(len(masks)):
+ m = masks[i]['segmentation']
+ s = masks[i]['stability_score']
+ idx = catseg_outputs[m].bincount().argmax()
+ sam_outputs[0, idx][m] = s
+ sam_classes[i] = idx
+
+ return sam_outputs, sam_classes
+
+ def continuous_semantic_inference(self, outputs, masks, image_size, scale=100/7.):
+ #import pdb; pdb.set_trace()
+ catseg_outputs = F.interpolate(outputs, size=image_size, mode="bilinear", align_corners=True)[0].cpu()
+ sam_outputs = torch.zeros_like(catseg_outputs)
+ #catseg_outputs = catseg_outputs.argmax(dim=1)[0].cpu()
+ sam_classes = torch.zeros(len(masks))
+ #import pdb; pdb.set_trace()
+ mask_pred = torch.tensor(np.asarray([x['segmentation'] for x in masks]), dtype=torch.float32) # N H W
+ mask_score = torch.tensor(np.asarray([x['predicted_iou'] for x in masks]), dtype=torch.float32) # N
+
+ mask_cls = torch.einsum("nhw, chw -> nc", mask_pred, catseg_outputs)
+ mask_norm = mask_pred.sum(-1).sum(-1)
+ mask_cls = mask_cls / mask_norm[:, None]
+ mask_cls = mask_cls / mask_cls.norm(p=1, dim=1)[:, None]
+
+ mask_logits = mask_pred * mask_score[:, None, None]
+ output = torch.einsum("nhw, nc -> chw", mask_logits, mask_cls)
+
+ return output.unsqueeze(0), mask_cls
+
+ def continuous_semantic_inference2(self, outputs, masks, image_size, scale=100/7., img=None, text=None):
+ assert img is not None and text is not None
+ import pdb; pdb.set_trace()
+ #catseg_outputs = F.interpolate(outputs, size=image_size, mode="bilinear", align_corners=True)[0].cpu()
+ img = F.interpolate(img, size=image_size, mode="bilinear", align_corners=True)[0].cpu()
+ img = img.permute(1, 2, 0)
+
+ #sam_outputs = torch.zeros_like(catseg_outputs)
+ #catseg_outputs = catseg_outputs.argmax(dim=1)[0].cpu()
+ sam_classes = torch.zeros(len(masks))
+ #import pdb; pdb.set_trace()
+ mask_pred = torch.tensor(np.asarray([x['segmentation'] for x in masks]), dtype=torch.float32) # N H W
+ mask_score = torch.tensor(np.asarray([x['predicted_iou'] for x in masks]), dtype=torch.float32) # N
+
+ mask_pool = torch.einsum("nhw, hwd -> nd ", mask_pred, img)
+ mask_pool = mask_pool / mask_pool.norm(dim=1, keepdim=True)
+ mask_cls = torch.einsum("nd, cd -> nc", 100 * mask_pool, text.cpu())
+ mask_cls = mask_cls.softmax(dim=1)
+
+ #mask_cls = torch.einsum("nhw, chw -> nc", mask_pred, catseg_outputs)
+ mask_norm = mask_pred.sum(-1).sum(-1)
+ mask_cls = mask_cls / mask_norm[:, None]
+ mask_cls = mask_cls / mask_cls.norm(p=1, dim=1)[:, None]
+
+ mask_logits = mask_pred * mask_score[:, None, None]
+ output = torch.einsum("nhw, nc -> chw", mask_logits, mask_cls)
+
+ return output.unsqueeze(0), sam_classes
+
+ def panoptic_inference(self, outputs, masks, sam_classes, size=None):
+ #import pdb; pdb.set_trace()
+ scores = np.asarray([x['predicted_iou'] for x in masks])
+ mask_pred = np.asarray([x['segmentation'] for x in masks])
+
+ #keep = labels.ne(self.sem_seg_head.num_classes) & (scores > self.object_mask_threshold)
+ cur_scores = torch.tensor(scores)
+ cur_masks = torch.tensor(mask_pred)
+ cur_masks = F.interpolate(cur_masks.unsqueeze(0).float(), size=outputs.shape[-2:], mode="nearest")[0]
+ cur_classes = sam_classes.argmax(dim=-1)
+ #cur_mask_cls = mask_cls#[keep]
+ #cur_mask_cls = cur_mask_cls[:, :-1]
+
+ #import pdb; pdb.set_trace()
+ cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks
+
+ h, w = cur_masks.shape[-2:]
+ panoptic_seg = torch.zeros((h, w), dtype=torch.int32, device=cur_masks.device)
+ segments_info = []
+
+ current_segment_id = 0
+ if cur_masks.shape[0] == 0:
+ # We didn't detect any mask :(
+ return panoptic_seg, segments_info
+ else:
+ # take argmax
+ cur_mask_ids = cur_prob_masks.argmax(0)
+ stuff_memory_list = {}
+ for k in range(cur_classes.shape[0]):
+ pred_class = cur_classes[k].item()
+ #isthing = pred_class in self.metadata.thing_dataset_id_to_contiguous_id.values()
+ isthing = pred_class in [3, 6] #[i for i in range(10)]#self.metadata.thing_dataset_id_to_contiguous_id.values()
+ mask = cur_mask_ids == k
+ mask_area = mask.sum().item()
+ original_area = (cur_masks[k] >= 0.5).sum().item()
+
+ if mask_area > 0 and original_area > 0:
+ if mask_area / original_area < self.overlap_threshold:
+ continue
+
+ # merge stuff regions
+ if not isthing:
+ if int(pred_class) in stuff_memory_list.keys():
+ panoptic_seg[mask] = stuff_memory_list[int(pred_class)]
+ continue
+ else:
+ stuff_memory_list[int(pred_class)] = current_segment_id + 1
+
+ current_segment_id += 1
+ panoptic_seg[mask] = current_segment_id
+
+ segments_info.append(
+ {
+ "id": current_segment_id,
+ "isthing": bool(isthing),
+ "category_id": int(pred_class),
+ }
+ )
+
+ return panoptic_seg, segments_info
\ No newline at end of file
diff --git a/cat_seg/config.py b/cat_seg/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..41720bb3971e1070ae9a0e9fc0a2026d8d19126f
--- /dev/null
+++ b/cat_seg/config.py
@@ -0,0 +1,93 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) Facebook, Inc. and its affiliates.
+from detectron2.config import CfgNode as CN
+
+
+def add_cat_seg_config(cfg):
+ """
+ Add config for MASK_FORMER.
+ """
+ # data config
+ # select the dataset mapper
+ cfg.INPUT.DATASET_MAPPER_NAME = "mask_former_semantic"
+
+ cfg.DATASETS.VAL_ALL = ("coco_2017_val_all_stuff_sem_seg",)
+
+ # Color augmentation
+ cfg.INPUT.COLOR_AUG_SSD = False
+ # We retry random cropping until no single category in semantic segmentation GT occupies more
+ # than `SINGLE_CATEGORY_MAX_AREA` part of the crop.
+ cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA = 1.0
+ # Pad image and segmentation GT in dataset mapper.
+ cfg.INPUT.SIZE_DIVISIBILITY = -1
+
+ # solver config
+ # weight decay on embedding
+ cfg.SOLVER.WEIGHT_DECAY_EMBED = 0.0
+ # optimizer
+ cfg.SOLVER.OPTIMIZER = "ADAMW"
+ cfg.SOLVER.BACKBONE_MULTIPLIER = 0.1
+
+ # mask_former model config
+ cfg.MODEL.MASK_FORMER = CN()
+
+ # Sometimes `backbone.size_divisibility` is set to 0 for some backbone (e.g. ResNet)
+ # you can use this config to override
+ cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY = 32
+
+ # swin transformer backbone
+ cfg.MODEL.SWIN = CN()
+ cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE = 224
+ cfg.MODEL.SWIN.PATCH_SIZE = 4
+ cfg.MODEL.SWIN.EMBED_DIM = 96
+ cfg.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]
+ cfg.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]
+ cfg.MODEL.SWIN.WINDOW_SIZE = 7
+ cfg.MODEL.SWIN.MLP_RATIO = 4.0
+ cfg.MODEL.SWIN.QKV_BIAS = True
+ cfg.MODEL.SWIN.QK_SCALE = None
+ cfg.MODEL.SWIN.DROP_RATE = 0.0
+ cfg.MODEL.SWIN.ATTN_DROP_RATE = 0.0
+ cfg.MODEL.SWIN.DROP_PATH_RATE = 0.3
+ cfg.MODEL.SWIN.APE = False
+ cfg.MODEL.SWIN.PATCH_NORM = True
+ cfg.MODEL.SWIN.OUT_FEATURES = ["res2", "res3", "res4", "res5"]
+
+ # zero shot config
+ cfg.MODEL.SEM_SEG_HEAD.TRAIN_CLASS_JSON = "datasets/ADE20K_2021_17_01/ADE20K_847.json"
+ cfg.MODEL.SEM_SEG_HEAD.TEST_CLASS_JSON = "datasets/ADE20K_2021_17_01/ADE20K_847.json"
+ cfg.MODEL.SEM_SEG_HEAD.TRAIN_CLASS_INDEXES = "datasets/coco/coco_stuff/split/seen_indexes.json"
+ cfg.MODEL.SEM_SEG_HEAD.TEST_CLASS_INDEXES = "datasets/coco/coco_stuff/split/unseen_indexes.json"
+
+ cfg.MODEL.SEM_SEG_HEAD.CLIP_PRETRAINED = "ViT-B/16"
+
+ cfg.MODEL.PROMPT_ENSEMBLE = False
+ cfg.MODEL.PROMPT_ENSEMBLE_TYPE = "single"
+
+ cfg.MODEL.CLIP_PIXEL_MEAN = [122.7709383, 116.7460125, 104.09373615]
+ cfg.MODEL.CLIP_PIXEL_STD = [68.5005327, 66.6321579, 70.3231630]
+ # three styles for clip classification, crop, mask, cropmask
+
+ cfg.MODEL.SEM_SEG_HEAD.TEXT_AFFINITY_DIM = 512
+ cfg.MODEL.SEM_SEG_HEAD.TEXT_AFFINITY_PROJ_DIM = 128
+ cfg.MODEL.SEM_SEG_HEAD.APPEARANCE_AFFINITY_DIM = 512
+ cfg.MODEL.SEM_SEG_HEAD.APPEARANCE_AFFINITY_PROJ_DIM = 128
+
+ cfg.MODEL.SEM_SEG_HEAD.DECODER_DIMS = [64, 32]
+ cfg.MODEL.SEM_SEG_HEAD.DECODER_AFFINITY_DIMS = [256, 128]
+ cfg.MODEL.SEM_SEG_HEAD.DECODER_AFFINITY_PROJ_DIMS = [32, 16]
+
+ cfg.MODEL.SEM_SEG_HEAD.NUM_LAYERS = 4
+ cfg.MODEL.SEM_SEG_HEAD.NUM_HEADS = 4
+ cfg.MODEL.SEM_SEG_HEAD.HIDDEN_DIMS = 128
+ cfg.MODEL.SEM_SEG_HEAD.POOLING_SIZES = [6, 6]
+ cfg.MODEL.SEM_SEG_HEAD.FEATURE_RESOLUTION = [24, 24]
+ cfg.MODEL.SEM_SEG_HEAD.WINDOW_SIZES = 12
+ cfg.MODEL.SEM_SEG_HEAD.ATTENTION_TYPE = "linear"
+
+ cfg.MODEL.SEM_SEG_HEAD.PROMPT_DEPTH = 0
+ cfg.MODEL.SEM_SEG_HEAD.PROMPT_LENGTH = 0
+ cfg.SOLVER.CLIP_MULTIPLIER = 0.01
+
+ cfg.MODEL.SEM_SEG_HEAD.CLIP_FINETUNE = "attention"
+ cfg.TEST.SLIDING_WINDOW = False
\ No newline at end of file
diff --git a/cat_seg/data/__init__.py b/cat_seg/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..63ba265b1effc69f1eef16e57a04db8902ee347e
--- /dev/null
+++ b/cat_seg/data/__init__.py
@@ -0,0 +1,2 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+from . import datasets
diff --git a/cat_seg/data/__pycache__/__init__.cpython-38.pyc b/cat_seg/data/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9dd0f90fab9c6891906bd034dfb992c6e44b13ed
Binary files /dev/null and b/cat_seg/data/__pycache__/__init__.cpython-38.pyc differ
diff --git a/cat_seg/data/dataset_mappers/__init__.py b/cat_seg/data/dataset_mappers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9020c2df23e2af280b7bb168b996ae9eaf312eb8
--- /dev/null
+++ b/cat_seg/data/dataset_mappers/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
diff --git a/cat_seg/data/dataset_mappers/__pycache__/__init__.cpython-38.pyc b/cat_seg/data/dataset_mappers/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cd18469e8728a1bfadc6b83ba40a6581f9487604
Binary files /dev/null and b/cat_seg/data/dataset_mappers/__pycache__/__init__.cpython-38.pyc differ
diff --git a/cat_seg/data/dataset_mappers/__pycache__/detr_panoptic_dataset_mapper.cpython-38.pyc b/cat_seg/data/dataset_mappers/__pycache__/detr_panoptic_dataset_mapper.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..15d4ff32d999726ff661a6477e7c5389709bb55f
Binary files /dev/null and b/cat_seg/data/dataset_mappers/__pycache__/detr_panoptic_dataset_mapper.cpython-38.pyc differ
diff --git a/cat_seg/data/dataset_mappers/__pycache__/mask_former_panoptic_dataset_mapper.cpython-38.pyc b/cat_seg/data/dataset_mappers/__pycache__/mask_former_panoptic_dataset_mapper.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..59850ea023365536b7e311bb84db09de93876939
Binary files /dev/null and b/cat_seg/data/dataset_mappers/__pycache__/mask_former_panoptic_dataset_mapper.cpython-38.pyc differ
diff --git a/cat_seg/data/dataset_mappers/__pycache__/mask_former_semantic_dataset_mapper.cpython-38.pyc b/cat_seg/data/dataset_mappers/__pycache__/mask_former_semantic_dataset_mapper.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7d31161e12090a8a3b763096caa94a4e64524b95
Binary files /dev/null and b/cat_seg/data/dataset_mappers/__pycache__/mask_former_semantic_dataset_mapper.cpython-38.pyc differ
diff --git a/cat_seg/data/dataset_mappers/detr_panoptic_dataset_mapper.py b/cat_seg/data/dataset_mappers/detr_panoptic_dataset_mapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a296f2fbbd24b190b312b464ce2d4c1957b221c
--- /dev/null
+++ b/cat_seg/data/dataset_mappers/detr_panoptic_dataset_mapper.py
@@ -0,0 +1,180 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/d2/detr/dataset_mapper.py
+import copy
+import logging
+
+import numpy as np
+import torch
+
+from detectron2.config import configurable
+from detectron2.data import detection_utils as utils
+from detectron2.data import transforms as T
+from detectron2.data.transforms import TransformGen
+from detectron2.structures import BitMasks, Instances
+
+__all__ = ["DETRPanopticDatasetMapper"]
+
+
+def build_transform_gen(cfg, is_train):
+ """
+ Create a list of :class:`TransformGen` from config.
+ Returns:
+ list[TransformGen]
+ """
+ if is_train:
+ min_size = cfg.INPUT.MIN_SIZE_TRAIN
+ max_size = cfg.INPUT.MAX_SIZE_TRAIN
+ sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING
+ else:
+ min_size = cfg.INPUT.MIN_SIZE_TEST
+ max_size = cfg.INPUT.MAX_SIZE_TEST
+ sample_style = "choice"
+ if sample_style == "range":
+ assert len(min_size) == 2, "more than 2 ({}) min_size(s) are provided for ranges".format(
+ len(min_size)
+ )
+
+ logger = logging.getLogger(__name__)
+ tfm_gens = []
+ if is_train:
+ tfm_gens.append(T.RandomFlip())
+ tfm_gens.append(T.ResizeShortestEdge(min_size, max_size, sample_style))
+ if is_train:
+ logger.info("TransformGens used in training: " + str(tfm_gens))
+ return tfm_gens
+
+
+# This is specifically designed for the COCO dataset.
+class DETRPanopticDatasetMapper:
+ """
+ A callable which takes a dataset dict in Detectron2 Dataset format,
+ and map it into a format used by MaskFormer.
+
+ This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation.
+
+ The callable currently does the following:
+
+ 1. Read the image from "file_name"
+ 2. Applies geometric transforms to the image and annotation
+ 3. Find and applies suitable cropping to the image and annotation
+ 4. Prepare image and annotation to Tensors
+ """
+
+ @configurable
+ def __init__(
+ self,
+ is_train=True,
+ *,
+ crop_gen,
+ tfm_gens,
+ image_format,
+ ):
+ """
+ NOTE: this interface is experimental.
+ Args:
+ is_train: for training or inference
+ augmentations: a list of augmentations or deterministic transforms to apply
+ crop_gen: crop augmentation
+ tfm_gens: data augmentation
+ image_format: an image format supported by :func:`detection_utils.read_image`.
+ """
+ self.crop_gen = crop_gen
+ self.tfm_gens = tfm_gens
+ logging.getLogger(__name__).info(
+ "[DETRPanopticDatasetMapper] Full TransformGens used in training: {}, crop: {}".format(
+ str(self.tfm_gens), str(self.crop_gen)
+ )
+ )
+
+ self.img_format = image_format
+ self.is_train = is_train
+
+ @classmethod
+ def from_config(cls, cfg, is_train=True):
+ # Build augmentation
+ if cfg.INPUT.CROP.ENABLED and is_train:
+ crop_gen = [
+ T.ResizeShortestEdge([400, 500, 600], sample_style="choice"),
+ T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE),
+ ]
+ else:
+ crop_gen = None
+
+ tfm_gens = build_transform_gen(cfg, is_train)
+
+ ret = {
+ "is_train": is_train,
+ "crop_gen": crop_gen,
+ "tfm_gens": tfm_gens,
+ "image_format": cfg.INPUT.FORMAT,
+ }
+ return ret
+
+ def __call__(self, dataset_dict):
+ """
+ Args:
+ dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
+
+ Returns:
+ dict: a format that builtin models in detectron2 accept
+ """
+ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
+ image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
+ utils.check_image_size(dataset_dict, image)
+
+ if self.crop_gen is None:
+ image, transforms = T.apply_transform_gens(self.tfm_gens, image)
+ else:
+ if np.random.rand() > 0.5:
+ image, transforms = T.apply_transform_gens(self.tfm_gens, image)
+ else:
+ image, transforms = T.apply_transform_gens(
+ self.tfm_gens[:-1] + self.crop_gen + self.tfm_gens[-1:], image
+ )
+
+ image_shape = image.shape[:2] # h, w
+
+ # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
+ # but not efficient on large generic data structures due to the use of pickle & mp.Queue.
+ # Therefore it's important to use torch.Tensor.
+ dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
+
+ if not self.is_train:
+ # USER: Modify this if you want to keep them for some reason.
+ dataset_dict.pop("annotations", None)
+ return dataset_dict
+
+ if "pan_seg_file_name" in dataset_dict:
+ pan_seg_gt = utils.read_image(dataset_dict.pop("pan_seg_file_name"), "RGB")
+ segments_info = dataset_dict["segments_info"]
+
+ # apply the same transformation to panoptic segmentation
+ pan_seg_gt = transforms.apply_segmentation(pan_seg_gt)
+
+ from panopticapi.utils import rgb2id
+
+ pan_seg_gt = rgb2id(pan_seg_gt)
+
+ instances = Instances(image_shape)
+ classes = []
+ masks = []
+ for segment_info in segments_info:
+ class_id = segment_info["category_id"]
+ if not segment_info["iscrowd"]:
+ classes.append(class_id)
+ masks.append(pan_seg_gt == segment_info["id"])
+
+ classes = np.array(classes)
+ instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
+ if len(masks) == 0:
+ # Some image does not have annotation (all ignored)
+ instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1]))
+ else:
+ masks = BitMasks(
+ torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])
+ )
+ instances.gt_masks = masks.tensor
+
+ dataset_dict["instances"] = instances
+
+ return dataset_dict
diff --git a/cat_seg/data/dataset_mappers/mask_former_panoptic_dataset_mapper.py b/cat_seg/data/dataset_mappers/mask_former_panoptic_dataset_mapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..ddbc2bd77fb1b17540dd5272cfc6534ee2b6e2df
--- /dev/null
+++ b/cat_seg/data/dataset_mappers/mask_former_panoptic_dataset_mapper.py
@@ -0,0 +1,165 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import copy
+import logging
+
+import numpy as np
+import torch
+from torch.nn import functional as F
+
+from detectron2.config import configurable
+from detectron2.data import detection_utils as utils
+from detectron2.data import transforms as T
+from detectron2.structures import BitMasks, Instances
+
+from .mask_former_semantic_dataset_mapper import MaskFormerSemanticDatasetMapper
+
+__all__ = ["MaskFormerPanopticDatasetMapper"]
+
+
+class MaskFormerPanopticDatasetMapper(MaskFormerSemanticDatasetMapper):
+ """
+ A callable which takes a dataset dict in Detectron2 Dataset format,
+ and map it into a format used by MaskFormer for panoptic segmentation.
+
+ The callable currently does the following:
+
+ 1. Read the image from "file_name"
+ 2. Applies geometric transforms to the image and annotation
+ 3. Find and applies suitable cropping to the image and annotation
+ 4. Prepare image and annotation to Tensors
+ """
+
+ @configurable
+ def __init__(
+ self,
+ is_train=True,
+ *,
+ augmentations,
+ image_format,
+ ignore_label,
+ size_divisibility,
+ ):
+ """
+ NOTE: this interface is experimental.
+ Args:
+ is_train: for training or inference
+ augmentations: a list of augmentations or deterministic transforms to apply
+ image_format: an image format supported by :func:`detection_utils.read_image`.
+ ignore_label: the label that is ignored to evaluation
+ size_divisibility: pad image size to be divisible by this value
+ """
+ super().__init__(
+ is_train,
+ augmentations=augmentations,
+ image_format=image_format,
+ ignore_label=ignore_label,
+ size_divisibility=size_divisibility,
+ )
+
+ def __call__(self, dataset_dict):
+ """
+ Args:
+ dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
+
+ Returns:
+ dict: a format that builtin models in detectron2 accept
+ """
+ assert self.is_train, "MaskFormerPanopticDatasetMapper should only be used for training!"
+
+ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
+ image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
+ utils.check_image_size(dataset_dict, image)
+
+ # semantic segmentation
+ if "sem_seg_file_name" in dataset_dict:
+ # PyTorch transformation not implemented for uint16, so converting it to double first
+ sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name")).astype("double")
+ else:
+ sem_seg_gt = None
+
+ # panoptic segmentation
+ if "pan_seg_file_name" in dataset_dict:
+ pan_seg_gt = utils.read_image(dataset_dict.pop("pan_seg_file_name"), "RGB")
+ segments_info = dataset_dict["segments_info"]
+ else:
+ pan_seg_gt = None
+ segments_info = None
+
+ if pan_seg_gt is None:
+ raise ValueError(
+ "Cannot find 'pan_seg_file_name' for panoptic segmentation dataset {}.".format(
+ dataset_dict["file_name"]
+ )
+ )
+
+ aug_input = T.AugInput(image, sem_seg=sem_seg_gt)
+ aug_input, transforms = T.apply_transform_gens(self.tfm_gens, aug_input)
+ image = aug_input.image
+ if sem_seg_gt is not None:
+ sem_seg_gt = aug_input.sem_seg
+
+ # apply the same transformation to panoptic segmentation
+ pan_seg_gt = transforms.apply_segmentation(pan_seg_gt)
+
+ from panopticapi.utils import rgb2id
+
+ pan_seg_gt = rgb2id(pan_seg_gt)
+
+ # Pad image and segmentation label here!
+ image = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
+ if sem_seg_gt is not None:
+ sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long"))
+ pan_seg_gt = torch.as_tensor(pan_seg_gt.astype("long"))
+
+ if self.size_divisibility > 0:
+ image_size = (image.shape[-2], image.shape[-1])
+ padding_size = [
+ 0,
+ self.size_divisibility - image_size[1],
+ 0,
+ self.size_divisibility - image_size[0],
+ ]
+ image = F.pad(image, padding_size, value=128).contiguous()
+ if sem_seg_gt is not None:
+ sem_seg_gt = F.pad(sem_seg_gt, padding_size, value=self.ignore_label).contiguous()
+ pan_seg_gt = F.pad(
+ pan_seg_gt, padding_size, value=0
+ ).contiguous() # 0 is the VOID panoptic label
+
+ image_shape = (image.shape[-2], image.shape[-1]) # h, w
+
+ # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
+ # but not efficient on large generic data structures due to the use of pickle & mp.Queue.
+ # Therefore it's important to use torch.Tensor.
+ dataset_dict["image"] = image
+ if sem_seg_gt is not None:
+ dataset_dict["sem_seg"] = sem_seg_gt.long()
+
+ if "annotations" in dataset_dict:
+ raise ValueError("Pemantic segmentation dataset should not have 'annotations'.")
+
+ # Prepare per-category binary masks
+ pan_seg_gt = pan_seg_gt.numpy()
+ instances = Instances(image_shape)
+ classes = []
+ masks = []
+ for segment_info in segments_info:
+ class_id = segment_info["category_id"]
+ if not segment_info["iscrowd"]:
+ classes.append(class_id)
+ masks.append(pan_seg_gt == segment_info["id"])
+
+ classes = np.array(classes)
+ instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
+ if len(masks) == 0:
+ # Some image does not have annotation (all ignored)
+ instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1]))
+ else:
+ masks = BitMasks(
+ torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])
+ )
+ instances.gt_masks = masks.tensor
+
+ dataset_dict["instances"] = instances
+
+ return dataset_dict
diff --git a/cat_seg/data/dataset_mappers/mask_former_semantic_dataset_mapper.py b/cat_seg/data/dataset_mappers/mask_former_semantic_dataset_mapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..41c82f2c76cb6d74020ae0a6a3ba045469755f01
--- /dev/null
+++ b/cat_seg/data/dataset_mappers/mask_former_semantic_dataset_mapper.py
@@ -0,0 +1,186 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import copy
+import logging
+
+import numpy as np
+import torch
+from torch.nn import functional as F
+
+from detectron2.config import configurable
+from detectron2.data import MetadataCatalog
+from detectron2.data import detection_utils as utils
+from detectron2.data import transforms as T
+from detectron2.projects.point_rend import ColorAugSSDTransform
+from detectron2.structures import BitMasks, Instances
+
+__all__ = ["MaskFormerSemanticDatasetMapper"]
+
+
+class MaskFormerSemanticDatasetMapper:
+ """
+ A callable which takes a dataset dict in Detectron2 Dataset format,
+ and map it into a format used by MaskFormer for semantic segmentation.
+
+ The callable currently does the following:
+
+ 1. Read the image from "file_name"
+ 2. Applies geometric transforms to the image and annotation
+ 3. Find and applies suitable cropping to the image and annotation
+ 4. Prepare image and annotation to Tensors
+ """
+
+ @configurable
+ def __init__(
+ self,
+ is_train=True,
+ *,
+ augmentations,
+ image_format,
+ ignore_label,
+ size_divisibility,
+ ):
+ """
+ NOTE: this interface is experimental.
+ Args:
+ is_train: for training or inference
+ augmentations: a list of augmentations or deterministic transforms to apply
+ image_format: an image format supported by :func:`detection_utils.read_image`.
+ ignore_label: the label that is ignored to evaluation
+ size_divisibility: pad image size to be divisible by this value
+ """
+ self.is_train = is_train
+ self.tfm_gens = augmentations
+ self.img_format = image_format
+ self.ignore_label = ignore_label
+ self.size_divisibility = size_divisibility
+
+ logger = logging.getLogger(__name__)
+ mode = "training" if is_train else "inference"
+ logger.info(f"[{self.__class__.__name__}] Augmentations used in {mode}: {augmentations}")
+
+ @classmethod
+ def from_config(cls, cfg, is_train=True):
+ # Build augmentation
+ augs = [
+ T.ResizeShortestEdge(
+ cfg.INPUT.MIN_SIZE_TRAIN,
+ cfg.INPUT.MAX_SIZE_TRAIN,
+ cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING,
+ )
+ ]
+ if cfg.INPUT.CROP.ENABLED:
+ augs.append(
+ T.RandomCrop_CategoryAreaConstraint(
+ cfg.INPUT.CROP.TYPE,
+ cfg.INPUT.CROP.SIZE,
+ cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA,
+ cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
+ )
+ )
+ if cfg.INPUT.COLOR_AUG_SSD:
+ augs.append(ColorAugSSDTransform(img_format=cfg.INPUT.FORMAT))
+ augs.append(T.RandomFlip())
+
+ # Assume always applies to the training set.
+ dataset_names = cfg.DATASETS.TRAIN
+ meta = MetadataCatalog.get(dataset_names[0])
+ ignore_label = meta.ignore_label
+
+ ret = {
+ "is_train": is_train,
+ "augmentations": augs,
+ "image_format": cfg.INPUT.FORMAT,
+ "ignore_label": ignore_label,
+ "size_divisibility": cfg.INPUT.SIZE_DIVISIBILITY,
+ }
+ return ret
+
+ def __call__(self, dataset_dict):
+ """
+ Args:
+ dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
+
+ Returns:
+ dict: a format that builtin models in detectron2 accept
+ """
+ assert self.is_train, "MaskFormerSemanticDatasetMapper should only be used for training!"
+
+ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
+ image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
+ utils.check_image_size(dataset_dict, image)
+
+ if "sem_seg_file_name" in dataset_dict:
+ # PyTorch transformation not implemented for uint16, so converting it to double first
+ sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name")).astype("double")
+ else:
+ sem_seg_gt = None
+
+ if sem_seg_gt is None:
+ raise ValueError(
+ "Cannot find 'sem_seg_file_name' for semantic segmentation dataset {}.".format(
+ dataset_dict["file_name"]
+ )
+ )
+
+ aug_input = T.AugInput(image, sem_seg=sem_seg_gt)
+ aug_input, transforms = T.apply_transform_gens(self.tfm_gens, aug_input)
+ image = aug_input.image
+ sem_seg_gt = aug_input.sem_seg
+
+ # Pad image and segmentation label here!
+ image = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
+ if sem_seg_gt is not None:
+ sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long"))
+ # import ipdb; ipdb.set_trace()
+ if self.size_divisibility > 0:
+ image_size = (image.shape[-2], image.shape[-1])
+ # The ori_size is not the real original size, but size before padding
+ dataset_dict['ori_size'] = image_size
+ padding_size = [
+ 0,
+ self.size_divisibility - image_size[1], # w: (left, right)
+ 0,
+ self.size_divisibility - image_size[0], # h: 0,(top, bottom)
+ ]
+ image = F.pad(image, padding_size, value=128).contiguous()
+ if sem_seg_gt is not None:
+ sem_seg_gt = F.pad(sem_seg_gt, padding_size, value=self.ignore_label).contiguous()
+
+ image_shape = (image.shape[-2], image.shape[-1]) # h, w
+
+ # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
+ # but not efficient on large generic data structures due to the use of pickle & mp.Queue.
+ # Therefore it's important to use torch.Tensor.
+ dataset_dict["image"] = image
+ # print('#########################################################################################')
+ if sem_seg_gt is not None:
+ dataset_dict["sem_seg"] = sem_seg_gt.long()
+
+ if "annotations" in dataset_dict:
+ raise ValueError("Semantic segmentation dataset should not have 'annotations'.")
+
+ # Prepare per-category binary masks
+ if sem_seg_gt is not None:
+ sem_seg_gt = sem_seg_gt.numpy()
+ instances = Instances(image_shape)
+ classes = np.unique(sem_seg_gt)
+ # remove ignored region
+ classes = classes[classes != self.ignore_label]
+ instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
+
+ masks = []
+ for class_id in classes:
+ masks.append(sem_seg_gt == class_id)
+
+ if len(masks) == 0:
+ # Some image does not have annotation (all ignored)
+ instances.gt_masks = torch.zeros((0, sem_seg_gt.shape[-2], sem_seg_gt.shape[-1]))
+ else:
+ masks = BitMasks(
+ torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])
+ )
+ instances.gt_masks = masks.tensor
+
+ dataset_dict["instances"] = instances
+
+ return dataset_dict
diff --git a/cat_seg/data/datasets/__init__.py b/cat_seg/data/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..90d0d07e352ea3952b34176383d89d02456f76d1
--- /dev/null
+++ b/cat_seg/data/datasets/__init__.py
@@ -0,0 +1,8 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+from . import (
+ register_coco_stuff,
+ register_ade20k_150,
+ register_ade20k_847,
+ register_pascal_20,
+ register_pascal_59,
+)
diff --git a/cat_seg/data/datasets/__pycache__/__init__.cpython-38.pyc b/cat_seg/data/datasets/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cdd78ed6c133f3af057f1d9d02a18e12e057f4d3
Binary files /dev/null and b/cat_seg/data/datasets/__pycache__/__init__.cpython-38.pyc differ
diff --git a/cat_seg/data/datasets/__pycache__/register_ade20k_150.cpython-38.pyc b/cat_seg/data/datasets/__pycache__/register_ade20k_150.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2f62b3d8f51ae1c0614e68be95e48ada6169eb28
Binary files /dev/null and b/cat_seg/data/datasets/__pycache__/register_ade20k_150.cpython-38.pyc differ
diff --git a/cat_seg/data/datasets/__pycache__/register_ade20k_847.cpython-38.pyc b/cat_seg/data/datasets/__pycache__/register_ade20k_847.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1ec729447ec61997499ad27bb2d32df5e0bd315d
Binary files /dev/null and b/cat_seg/data/datasets/__pycache__/register_ade20k_847.cpython-38.pyc differ
diff --git a/cat_seg/data/datasets/__pycache__/register_ade_panoptic.cpython-38.pyc b/cat_seg/data/datasets/__pycache__/register_ade_panoptic.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6d4b46ac638e75a9f960fb9c06719628540aa1c6
Binary files /dev/null and b/cat_seg/data/datasets/__pycache__/register_ade_panoptic.cpython-38.pyc differ
diff --git a/cat_seg/data/datasets/__pycache__/register_coco_panoptic.cpython-38.pyc b/cat_seg/data/datasets/__pycache__/register_coco_panoptic.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8d90852e677285f6d24580b149b63d071ad521da
Binary files /dev/null and b/cat_seg/data/datasets/__pycache__/register_coco_panoptic.cpython-38.pyc differ
diff --git a/cat_seg/data/datasets/__pycache__/register_coco_stuff.cpython-38.pyc b/cat_seg/data/datasets/__pycache__/register_coco_stuff.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9dc6d9fea8d4a7385d823dd940eb246ee49ac9e3
Binary files /dev/null and b/cat_seg/data/datasets/__pycache__/register_coco_stuff.cpython-38.pyc differ
diff --git a/cat_seg/data/datasets/__pycache__/register_pascal_20.cpython-38.pyc b/cat_seg/data/datasets/__pycache__/register_pascal_20.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b55462af6e2e94b84d0cd1a2973208cc1c4754c4
Binary files /dev/null and b/cat_seg/data/datasets/__pycache__/register_pascal_20.cpython-38.pyc differ
diff --git a/cat_seg/data/datasets/__pycache__/register_pascal_59.cpython-38.pyc b/cat_seg/data/datasets/__pycache__/register_pascal_59.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9e04b5e47af69f8d87dffdabff3c626aa5322bf1
Binary files /dev/null and b/cat_seg/data/datasets/__pycache__/register_pascal_59.cpython-38.pyc differ
diff --git a/cat_seg/data/datasets/__pycache__/register_pascal_context.cpython-38.pyc b/cat_seg/data/datasets/__pycache__/register_pascal_context.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..77b9b6a84ec40f1540398d6021b495a18815e769
Binary files /dev/null and b/cat_seg/data/datasets/__pycache__/register_pascal_context.cpython-38.pyc differ
diff --git a/cat_seg/data/datasets/register_ade20k_150.py b/cat_seg/data/datasets/register_ade20k_150.py
new file mode 100755
index 0000000000000000000000000000000000000000..fa3cb77077513df451620dac095cf625fc40e0e7
--- /dev/null
+++ b/cat_seg/data/datasets/register_ade20k_150.py
@@ -0,0 +1,28 @@
+import os
+
+from detectron2.data import DatasetCatalog, MetadataCatalog
+from detectron2.data.datasets import load_sem_seg
+import copy
+
+def _get_ade20k_150_meta():
+ ade20k_150_classes = ["wall", "building", "sky", "floor", "tree", "ceiling", "road", "bed ", "windowpane", "grass", "cabinet", "sidewalk", "person", "earth", "door", "table", "mountain", "plant", "curtain", "chair", "car", "water", "painting", "sofa", "shelf", "house", "sea", "mirror", "rug", "field", "armchair", "seat", "fence", "desk", "rock", "wardrobe", "lamp", "bathtub", "railing", "cushion", "base", "box", "column", "signboard", "chest of drawers", "counter", "sand", "sink", "skyscraper", "fireplace", "refrigerator", "grandstand", "path", "stairs", "runway", "case", "pool table", "pillow", "screen door", "stairway", "river", "bridge", "bookcase", "blind", "coffee table", "toilet", "flower", "book", "hill", "bench", "countertop", "stove", "palm", "kitchen island", "computer", "swivel chair", "boat", "bar", "arcade machine", "hovel", "bus", "towel", "light", "truck", "tower", "chandelier", "awning", "streetlight", "booth", "television receiver", "airplane", "dirt track", "apparel", "pole", "land", "bannister", "escalator", "ottoman", "bottle", "buffet", "poster", "stage", "van", "ship", "fountain", "conveyer belt", "canopy", "washer", "plaything", "swimming pool", "stool", "barrel", "basket", "waterfall", "tent", "bag", "minibike", "cradle", "oven", "ball", "food", "step", "tank", "trade name", "microwave", "pot", "animal", "bicycle", "lake", "dishwasher", "screen", "blanket", "sculpture", "hood", "sconce", "vase", "traffic light", "tray", "ashcan", "fan", "pier", "crt screen", "plate", "monitor", "bulletin board", "shower", "radiator", "glass", "clock", "flag"]
+
+ ret = {
+ "stuff_classes" : ade20k_150_classes,
+ }
+ return ret
+
+def register_ade20k_150(root):
+ root = os.path.join(root, "ADEChallengeData2016")
+ meta = _get_ade20k_150_meta()
+ for name, image_dirname, sem_seg_dirname in [
+ ("test", "images/validation", "annotations_detectron2/validation"),
+ ]:
+ image_dir = os.path.join(root, image_dirname)
+ gt_dir = os.path.join(root, sem_seg_dirname)
+ name = f"ade20k_150_{name}_sem_seg"
+ DatasetCatalog.register(name, lambda x=image_dir, y=gt_dir: load_sem_seg(y, x, gt_ext='png', image_ext='jpg'))
+ MetadataCatalog.get(name).set(image_root=image_dir, seg_seg_root=gt_dir, evaluator_type="sem_seg", ignore_label=255, **meta,)
+
+_root = os.getenv("DETECTRON2_DATASETS", "datasets")
+register_ade20k_150(_root)
diff --git a/cat_seg/data/datasets/register_ade20k_847.py b/cat_seg/data/datasets/register_ade20k_847.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d781f3334a2efcd935d1f5cadf76af67080a8da
--- /dev/null
+++ b/cat_seg/data/datasets/register_ade20k_847.py
@@ -0,0 +1,50 @@
+import os
+
+from detectron2.data import DatasetCatalog, MetadataCatalog
+from detectron2.data.datasets import load_sem_seg
+import copy
+# TODO: check this code
+import random
+
+ADE20K_SEM_SEG_FULL_CATEGORIES = [{'name': 'wall', 'id': 2978, 'trainId': 0, 'image_count': 13447, 'frequency': 'frequent', 'color': [197, 196, 23]},
+ {'name': 'building, edifice', 'id': 312, 'trainId': 1, 'image_count': 7243, 'frequency': 'frequent', 'color': [181, 87, 228]}, {'name': 'sky', 'id': 2420, 'trainId': 2, 'image_count': 11005, 'frequency': 'frequent', 'color': [45, 37, 176]}, {'name': 'tree', 'id': 2855, 'trainId': 3, 'image_count': 8428, 'frequency': 'frequent', 'color': [114, 237, 23]}, {'name': 'road, route', 'id': 2131, 'trainId': 4, 'image_count': 4403, 'frequency': 'frequent', 'color': [107, 11, 32]}, {'name': 'floor, flooring', 'id': 976, 'trainId': 5, 'image_count': 10757, 'frequency': 'frequent', 'color': [21, 88, 64]}, {'name': 'ceiling', 'id': 447, 'trainId': 6, 'image_count': 7576, 'frequency': 'frequent', 'color': [222, 174, 10]}, {'name': 'bed', 'id': 165, 'trainId': 7, 'image_count': 2041, 'frequency': 'frequent', 'color': [89, 173, 246]}, {'name': 'sidewalk, pavement', 'id': 2377, 'trainId': 8, 'image_count': 3367, 'frequency': 'frequent', 'color': [163, 140, 64]}, {'name': 'earth, ground', 'id': 838, 'trainId': 9, 'image_count': 2993, 'frequency': 'frequent', 'color': [92, 6, 252]}, {'name': 'cabinet', 'id': 350, 'trainId': 10, 'image_count': 3440, 'frequency': 'frequent', 'color': [11, 247, 152]}, {'name': 'person, individual, someone, somebody, mortal, soul', 'id': 1831, 'trainId': 11, 'image_count': 6426, 'frequency': 'frequent', 'color': [118, 24, 255]}, {'name': 'grass', 'id': 1125, 'trainId': 12, 'image_count': 2712, 'frequency': 'frequent', 'color': [253, 77, 23]}, {'name': 'windowpane, window', 'id': 3055, 'trainId': 13, 'image_count': 5332, 'frequency': 'frequent', 'color': [155, 245, 212]}, {'name': 'car, auto, automobile, machine, motorcar', 'id': 401, 'trainId': 14, 'image_count': 3525, 'frequency': 'frequent', 'color': [104, 80, 118]}, {'name': 'mountain, mount', 'id': 1610, 'trainId': 15, 'image_count': 2598, 'frequency': 'frequent', 'color': [249, 238, 178]}, {'name': 'plant, flora, plant life', 'id': 1910, 'trainId': 16, 'image_count': 5120, 'frequency': 'frequent', 'color': [236, 93, 193]}, {'name': 'table', 'id': 2684, 'trainId': 17, 'image_count': 5116, 'frequency': 'frequent', 'color': [241, 183, 84]}, {'name': 'chair', 'id': 471, 'trainId': 18, 'image_count': 3939, 'frequency': 'frequent', 'color': [168, 120, 192]}, {'name': 'curtain, drape, drapery, mantle, pall', 'id': 687, 'trainId': 19, 'image_count': 2513, 'frequency': 'frequent', 'color': [43, 21, 0]}, {'name': 'door', 'id': 774, 'trainId': 20, 'image_count': 3532, 'frequency': 'frequent', 'color': [12, 96, 230]}, {'name': 'sofa, couch, lounge', 'id': 2473, 'trainId': 21, 'image_count': 1572, 'frequency': 'frequent', 'color': [89, 56, 229]}, {'name': 'sea', 'id': 2264, 'trainId': 22, 'image_count': 1180, 'frequency': 'frequent', 'color': [218, 210, 26]}, {'name': 'painting, picture', 'id': 1735, 'trainId': 23, 'image_count': 3874, 'frequency': 'frequent', 'color': [100, 120, 254]}, {'name': 'water', 'id': 2994, 'trainId': 24, 'image_count': 792, 'frequency': 'frequent', 'color': [173, 224, 31]}, {'name': 'mirror', 'id': 1564, 'trainId': 25, 'image_count': 1376, 'frequency': 'frequent', 'color': [63, 77, 192]}, {'name': 'house', 'id': 1276, 'trainId': 26, 'image_count': 828, 'frequency': 'frequent', 'color': [81, 72, 245]}, {'name': 'rug, carpet, carpeting', 'id': 2178, 'trainId': 27, 'image_count': 1448, 'frequency': 'frequent', 'color': [20, 174, 65]}, {'name': 'shelf', 'id': 2329, 'trainId': 28, 'image_count': 1639, 'frequency': 'frequent', 'color': [60, 251, 135]}, {'name': 'armchair', 'id': 57, 'trainId': 29, 'image_count': 1449, 'frequency': 'frequent', 'color': [250, 251, 64]}, {'name': 'fence, fencing', 'id': 907, 'trainId': 30, 'image_count': 1750, 'frequency': 'frequent', 'color': [27, 165, 104]}, {'name': 'field', 'id': 913, 'trainId': 31, 'image_count': 618, 'frequency': 'frequent', 'color': [182, 218, 110]}, {'name': 'lamp', 'id': 1395, 'trainId': 32, 'image_count': 3672, 'frequency': 'frequent', 'color': [216, 50, 199]}, {'name': 'rock, stone', 'id': 2138, 'trainId': 33, 'image_count': 1139, 'frequency': 'frequent', 'color': [112, 62, 2]}, {'name': 'seat', 'id': 2272, 'trainId': 34, 'image_count': 551, 'frequency': 'frequent', 'color': [139, 224, 163]}, {'name': 'river', 'id': 2128, 'trainId': 35, 'image_count': 601, 'frequency': 'frequent', 'color': [28, 45, 29]}, {'name': 'desk', 'id': 724, 'trainId': 36, 'image_count': 745, 'frequency': 'frequent', 'color': [212, 57, 188]}, {'name': 'bathtub, bathing tub, bath, tub', 'id': 155, 'trainId': 37, 'image_count': 428, 'frequency': 'frequent', 'color': [60, 60, 247]}, {'name': 'railing, rail', 'id': 2053, 'trainId': 38, 'image_count': 1002, 'frequency': 'frequent', 'color': [4, 255, 242]}, {'name': 'signboard, sign', 'id': 2380, 'trainId': 39, 'image_count': 2959, 'frequency': 'frequent', 'color': [140, 237, 70]}, {'name': 'cushion', 'id': 689, 'trainId': 40, 'image_count': 1788, 'frequency': 'frequent', 'color': [30, 11, 119]}, {'name': 'path', 'id': 1788, 'trainId': 41, 'image_count': 754, 'frequency': 'frequent', 'color': [251, 131, 118]}, {'name': 'work surface', 'id': 3087, 'trainId': 42, 'image_count': 846, 'frequency': 'frequent', 'color': [161, 151, 56]}, {'name': 'stairs, steps', 'id': 2530, 'trainId': 43, 'image_count': 1024, 'frequency': 'frequent', 'color': [145, 185, 70]}, {'name': 'column, pillar', 'id': 581, 'trainId': 44, 'image_count': 947, 'frequency': 'frequent', 'color': [251, 139, 64]}, {'name': 'sink', 'id': 2388, 'trainId': 45, 'image_count': 1329, 'frequency': 'frequent', 'color': [217, 106, 66]}, {'name': 'wardrobe, closet, press', 'id': 2985, 'trainId': 46, 'image_count': 361, 'frequency': 'frequent', 'color': [193, 149, 91]}, {'name': 'snow', 'id': 2454, 'trainId': 47, 'image_count': 186, 'frequency': 'frequent', 'color': [216, 32, 217]}, {'name': 'refrigerator, icebox', 'id': 2096, 'trainId': 48, 'image_count': 521, 'frequency': 'frequent', 'color': [61, 63, 192]}, {'name': 'base, pedestal, stand', 'id': 137, 'trainId': 49, 'image_count': 466, 'frequency': 'frequent', 'color': [249, 32, 139]}, {'name': 'bridge, span', 'id': 294, 'trainId': 50, 'image_count': 341, 'frequency': 'frequent', 'color': [108, 115, 187]}, {'name': 'blind, screen', 'id': 212, 'trainId': 51, 'image_count': 445, 'frequency': 'frequent', 'color': [123, 17, 204]}, {'name': 'runway', 'id': 2185, 'trainId': 52, 'image_count': 90, 'frequency': 'common', 'color': [187, 212, 44]}, {'name': 'cliff, drop, drop-off', 'id': 524, 'trainId': 53, 'image_count': 234, 'frequency': 'frequent', 'color': [108, 131, 139]}, {'name': 'sand', 'id': 2212, 'trainId': 54, 'image_count': 477, 'frequency': 'frequent', 'color': [149, 207, 251]}, {'name': 'fireplace, hearth, open fireplace', 'id': 943, 'trainId': 55, 'image_count': 619, 'frequency': 'frequent', 'color': [161, 10, 212]}, {'name': 'pillow', 'id': 1869, 'trainId': 56, 'image_count': 1008, 'frequency': 'frequent', 'color': [63, 144, 70]}, {'name': 'screen door, screen', 'id': 2251, 'trainId': 57, 'image_count': 150, 'frequency': 'frequent', 'color': [6, 149, 64]}, {'name': 'toilet, can, commode, crapper, pot, potty, stool, throne', 'id': 2793, 'trainId': 58, 'image_count': 418, 'frequency': 'frequent', 'color': [142, 241, 79]}, {'name': 'skyscraper', 'id': 2423, 'trainId': 59, 'image_count': 349, 'frequency': 'frequent', 'color': [99, 161, 20]}, {'name': 'grandstand, covered stand', 'id': 1121, 'trainId': 60, 'image_count': 145, 'frequency': 'frequent', 'color': [122, 56, 70]}, {'name': 'box', 'id': 266, 'trainId': 61, 'image_count': 1686, 'frequency': 'frequent', 'color': [159, 86, 45]}, {'name': 'pool table, billiard table, snooker table', 'id': 1948, 'trainId': 62, 'image_count': 208, 'frequency': 'frequent', 'color': [150, 208, 179]}, {'name': 'palm, palm tree', 'id': 1744, 'trainId': 63, 'image_count': 462, 'frequency': 'frequent', 'color': [79, 122, 55]}, {'name': 'double door', 'id': 783, 'trainId': 64, 'image_count': 440, 'frequency': 'frequent', 'color': [239, 66, 120]}, {'name': 'coffee table, cocktail table', 'id': 571, 'trainId': 65, 'image_count': 1054, 'frequency': 'frequent', 'color': [232, 173, 114]}, {'name': 'counter', 'id': 627, 'trainId': 66, 'image_count': 406, 'frequency': 'frequent', 'color': [83, 79, 187]}, {'name': 'countertop', 'id': 629, 'trainId': 67, 'image_count': 352, 'frequency': 'frequent', 'color': [85, 221, 24]}, {'name': 'chest of drawers, chest, bureau, dresser', 'id': 491, 'trainId': 68, 'image_count': 575, 'frequency': 'frequent', 'color': [16, 101, 29]}, {'name': 'kitchen island', 'id': 1374, 'trainId': 69, 'image_count': 198, 'frequency': 'frequent', 'color': [192, 87, 21]}, {'name': 'boat', 'id': 223, 'trainId': 70, 'image_count': 431, 'frequency': 'frequent', 'color': [154, 3, 5]}, {'name': 'waterfall, falls', 'id': 3016, 'trainId': 71, 'image_count': 361, 'frequency': 'frequent', 'color': [136, 164, 163]}, {'name': 'stove, kitchen stove, range, kitchen range, cooking stove', 'id': 2598, 'trainId': 72, 'image_count': 624, 'frequency': 'frequent', 'color': [62, 241, 149]}, {'name': 'flower', 'id': 978, 'trainId': 73, 'image_count': 1575, 'frequency': 'frequent', 'color': [237, 254, 1]}, {'name': 'bookcase', 'id': 239, 'trainId': 74, 'image_count': 303, 'frequency': 'frequent', 'color': [179, 126, 130]}, {'name': 'controls', 'id': 608, 'trainId': 75, 'image_count': 27, 'frequency': 'common', 'color': [100, 161, 7]}, {'name': 'book', 'id': 236, 'trainId': 76, 'image_count': 1339, 'frequency': 'frequent', 'color': [103, 181, 187]}, {'name': 'stairway, staircase', 'id': 2531, 'trainId': 77, 'image_count': 637, 'frequency': 'frequent', 'color': [176, 6, 205]}, {'name': 'streetlight, street lamp', 'id': 2616, 'trainId': 78, 'image_count': 2261, 'frequency': 'frequent', 'color': [232, 153, 247]}, {'name': 'computer, computing machine, computing device, data processor, electronic computer, information processing system', 'id': 591, 'trainId': 79, 'image_count': 318, 'frequency': 'frequent', 'color': [137, 67, 164]}, {'name': 'bus, autobus, coach, charabanc, double-decker, jitney, motorbus, motorcoach, omnibus, passenger vehicle', 'id': 327, 'trainId': 80, 'image_count': 278, 'frequency': 'frequent', 'color': [31, 97, 12]}, {'name': 'swivel chair', 'id': 2679, 'trainId': 81, 'image_count': 327, 'frequency': 'frequent', 'color': [53, 154, 170]}, {'name': 'light, light source', 'id': 1451, 'trainId': 82, 'image_count': 2949, 'frequency': 'frequent', 'color': [207, 32, 254]}, {'name': 'bench', 'id': 181, 'trainId': 83, 'image_count': 781, 'frequency': 'frequent', 'color': [234, 218, 233]}, {'name': 'case, display case, showcase, vitrine', 'id': 420, 'trainId': 84, 'image_count': 199, 'frequency': 'frequent', 'color': [78, 103, 138]}, {'name': 'towel', 'id': 2821, 'trainId': 85, 'image_count': 528, 'frequency': 'frequent', 'color': [212, 206, 168]}, {'name': 'fountain', 'id': 1023, 'trainId': 86, 'image_count': 143, 'frequency': 'frequent', 'color': [217, 91, 165]}, {'name': 'embankment', 'id': 855, 'trainId': 87, 'image_count': 152, 'frequency': 'frequent', 'color': [182, 229, 156]}, {'name': 'television receiver, television, television set, tv, tv set, idiot box, boob tube, telly, goggle box', 'id': 2733, 'trainId': 88, 'image_count': 783, 'frequency': 'frequent', 'color': [187, 150, 131]}, {'name': 'van', 'id': 2928, 'trainId': 89, 'image_count': 602, 'frequency': 'frequent', 'color': [40, 151, 184]}, {'name': 'hill', 'id': 1240, 'trainId': 90, 'image_count': 309, 'frequency': 'frequent', 'color': [202, 103, 204]}, {'name': 'awning, sunshade, sunblind', 'id': 77, 'trainId': 91, 'image_count': 593, 'frequency': 'frequent', 'color': [39, 177, 249]}, {'name': 'poster, posting, placard, notice, bill, card', 'id': 1969, 'trainId': 92, 'image_count': 412, 'frequency': 'frequent', 'color': [157, 199, 162]}, {'name': 'truck, motortruck', 'id': 2880, 'trainId': 93, 'image_count': 483, 'frequency': 'frequent', 'color': [123, 115, 234]}, {'name': 'airplane, aeroplane, plane', 'id': 14, 'trainId': 94, 'image_count': 186, 'frequency': 'frequent', 'color': [154, 91, 230]}, {'name': 'pole', 'id': 1936, 'trainId': 95, 'image_count': 1021, 'frequency': 'frequent', 'color': [49, 165, 158]}, {'name': 'tower', 'id': 2828, 'trainId': 96, 'image_count': 156, 'frequency': 'frequent', 'color': [59, 227, 89]}, {'name': 'court', 'id': 631, 'trainId': 97, 'image_count': 56, 'frequency': 'common', 'color': [110, 106, 119]}, {'name': 'ball', 'id': 103, 'trainId': 98, 'image_count': 212, 'frequency': 'frequent', 'color': [30, 39, 151]}, {'name': 'aircraft carrier, carrier, flattop, attack aircraft carrier', 'id': 3144, 'trainId': 99, 'image_count': 43, 'frequency': 'common', 'color': [170, 83, 193]}, {'name': 'buffet, counter, sideboard', 'id': 308, 'trainId': 100, 'image_count': 139, 'frequency': 'frequent', 'color': [63, 233, 86]}, {'name': 'hovel, hut, hutch, shack, shanty', 'id': 1282, 'trainId': 101, 'image_count': 82, 'frequency': 'common', 'color': [66, 181, 80]}, {'name': 'apparel, wearing apparel, dress, clothes', 'id': 38, 'trainId': 102, 'image_count': 199, 'frequency': 'frequent', 'color': [101, 39, 76]}, {'name': 'minibike, motorbike', 'id': 1563, 'trainId': 103, 'image_count': 307, 'frequency': 'frequent', 'color': [117, 98, 128]}, {'name': 'animal, animate being, beast, brute, creature, fauna', 'id': 29, 'trainId': 104, 'image_count': 138, 'frequency': 'frequent', 'color': [121, 49, 133]}, {'name': 'chandelier, pendant, pendent', 'id': 480, 'trainId': 105, 'image_count': 668, 'frequency': 'frequent', 'color': [15, 191, 184]}, {'name': 'step, stair', 'id': 2569, 'trainId': 106, 'image_count': 360, 'frequency': 'frequent', 'color': [247, 77, 217]}, {'name': 'booth, cubicle, stall, kiosk', 'id': 247, 'trainId': 107, 'image_count': 71, 'frequency': 'common', 'color': [217, 190, 226]}, {'name': 'bicycle, bike, wheel, cycle', 'id': 187, 'trainId': 108, 'image_count': 397, 'frequency': 'frequent', 'color': [250, 68, 41]}, {'name': 'doorframe, doorcase', 'id': 778, 'trainId': 109, 'image_count': 269, 'frequency': 'frequent', 'color': [82, 106, 174]}, {'name': 'sconce', 'id': 2243, 'trainId': 110, 'image_count': 1151, 'frequency': 'frequent', 'color': [117, 71, 202]}, {'name': 'pond', 'id': 1941, 'trainId': 111, 'image_count': 134, 'frequency': 'frequent', 'color': [111, 167, 244]}, {'name': 'trade name, brand name, brand, marque', 'id': 2833, 'trainId': 112, 'image_count': 448, 'frequency': 'frequent', 'color': [195, 64, 196]}, {'name': 'bannister, banister, balustrade, balusters, handrail', 'id': 120, 'trainId': 113, 'image_count': 512, 'frequency': 'frequent', 'color': [190, 41, 31]}, {'name': 'bag', 'id': 95, 'trainId': 114, 'image_count': 652, 'frequency': 'frequent', 'color': [218, 252, 187]}, {'name': 'traffic light, traffic signal, stoplight', 'id': 2836, 'trainId': 115, 'image_count': 565, 'frequency': 'frequent', 'color': [68, 228, 71]}, {'name': 'gazebo', 'id': 1087, 'trainId': 116, 'image_count': 18, 'frequency': 'common', 'color': [145, 119, 126]}, {'name': 'escalator, moving staircase, moving stairway', 'id': 868, 'trainId': 117, 'image_count': 46, 'frequency': 'common', 'color': [125, 150, 168]}, {'name': 'land, ground, soil', 'id': 1401, 'trainId': 118, 'image_count': 262, 'frequency': 'frequent', 'color': [167, 24, 48]}, {'name': 'board, plank', 'id': 220, 'trainId': 119, 'image_count': 283, 'frequency': 'frequent', 'color': [199, 73, 113]}, {'name': 'arcade machine', 'id': 47, 'trainId': 120, 'image_count': 72, 'frequency': 'common', 'color': [132, 114, 81]}, {'name': 'eiderdown, duvet, continental quilt', 'id': 843, 'trainId': 121, 'image_count': 34, 'frequency': 'common', 'color': [169, 161, 105]}, {'name': 'bar', 'id': 123, 'trainId': 122, 'image_count': 206, 'frequency': 'frequent', 'color': [86, 209, 125]}, {'name': 'stall, stand, sales booth', 'id': 2537, 'trainId': 123, 'image_count': 72, 'frequency': 'common', 'color': [108, 182, 0]}, {'name': 'playground', 'id': 1927, 'trainId': 124, 'image_count': 31, 'frequency': 'common', 'color': [34, 31, 155]}, {'name': 'ship', 'id': 2337, 'trainId': 125, 'image_count': 60, 'frequency': 'common', 'color': [84, 113, 69]}, {'name': 'ottoman, pouf, pouffe, puff, hassock', 'id': 1702, 'trainId': 126, 'image_count': 388, 'frequency': 'frequent', 'color': [169, 207, 246]}, {'name': 'ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin', 'id': 64, 'trainId': 127, 'image_count': 768, 'frequency': 'frequent', 'color': [219, 79, 106]}, {'name': 'bottle', 'id': 249, 'trainId': 128, 'image_count': 1179, 'frequency': 'frequent', 'color': [180, 167, 243]}, {'name': 'cradle', 'id': 642, 'trainId': 129, 'image_count': 83, 'frequency': 'common', 'color': [168, 128, 192]}, {'name': 'pot, flowerpot', 'id': 1981, 'trainId': 130, 'image_count': 1525, 'frequency': 'frequent', 'color': [222, 91, 164]}, {'name': 'conveyer belt, conveyor belt, conveyer, conveyor, transporter', 'id': 609, 'trainId': 131, 'image_count': 60, 'frequency': 'common', 'color': [89, 234, 127]}, {'name': 'train, railroad train', 'id': 2840, 'trainId': 132, 'image_count': 81, 'frequency': 'common', 'color': [252, 22, 43]}, {'name': 'stool', 'id': 2586, 'trainId': 133, 'image_count': 555, 'frequency': 'frequent', 'color': [163, 10, 17]}, {'name': 'lake', 'id': 1393, 'trainId': 134, 'image_count': 66, 'frequency': 'common', 'color': [149, 32, 109]}, {'name': 'tank, storage tank', 'id': 2704, 'trainId': 135, 'image_count': 62, 'frequency': 'common', 'color': [98, 117, 164]}, {'name': 'ice, water ice', 'id': 1304, 'trainId': 136, 'image_count': 49, 'frequency': 'common', 'color': [57, 18, 23]}, {'name': 'basket, handbasket', 'id': 146, 'trainId': 137, 'image_count': 741, 'frequency': 'frequent', 'color': [39, 145, 197]}, {'name': 'manhole', 'id': 1494, 'trainId': 138, 'image_count': 259, 'frequency': 'frequent', 'color': [8, 52, 37]}, {'name': 'tent, collapsible shelter', 'id': 2739, 'trainId': 139, 'image_count': 85, 'frequency': 'common', 'color': [112, 103, 78]}, {'name': 'canopy', 'id': 389, 'trainId': 140, 'image_count': 67, 'frequency': 'common', 'color': [135, 162, 237]}, {'name': 'microwave, microwave oven', 'id': 1551, 'trainId': 141, 'image_count': 399, 'frequency': 'frequent', 'color': [181, 174, 228]}, {'name': 'barrel, cask', 'id': 131, 'trainId': 142, 'image_count': 58, 'frequency': 'common', 'color': [152, 178, 238]}, {'name': 'dirt track', 'id': 738, 'trainId': 143, 'image_count': 87, 'frequency': 'common', 'color': [247, 84, 123]}, {'name': 'beam', 'id': 161, 'trainId': 144, 'image_count': 165, 'frequency': 'frequent', 'color': [159, 65, 70]}, {'name': 'dishwasher, dish washer, dishwashing machine', 'id': 747, 'trainId': 145, 'image_count': 310, 'frequency': 'frequent', 'color': [97, 92, 40]}, {'name': 'plate', 'id': 1919, 'trainId': 146, 'image_count': 720, 'frequency': 'frequent', 'color': [105, 149, 44]}, {'name': 'screen, crt screen', 'id': 3109, 'trainId': 147, 'image_count': 263, 'frequency': 'frequent', 'color': [24, 188, 61]}, {'name': 'ruins', 'id': 2179, 'trainId': 148, 'image_count': 40, 'frequency': 'common', 'color': [84, 140, 165]}, {'name': 'washer, automatic washer, washing machine', 'id': 2989, 'trainId': 149, 'image_count': 88, 'frequency': 'common', 'color': [213, 84, 60]}, {'name': 'blanket, cover', 'id': 206, 'trainId': 150, 'image_count': 250, 'frequency': 'frequent', 'color': [63, 217, 207]}, {'name': 'plaything, toy', 'id': 1930, 'trainId': 151, 'image_count': 363, 'frequency': 'frequent', 'color': [218, 112, 74]}, {'name': 'food, solid food', 'id': 1002, 'trainId': 152, 'image_count': 143, 'frequency': 'frequent', 'color': [174, 0, 77]}, {'name': 'screen, silver screen, projection screen', 'id': 2254, 'trainId': 153, 'image_count': 109, 'frequency': 'frequent', 'color': [15, 68, 208]}, {'name': 'oven', 'id': 1708, 'trainId': 154, 'image_count': 262, 'frequency': 'frequent', 'color': [118, 59, 12]}, {'name': 'stage', 'id': 2526, 'trainId': 155, 'image_count': 113, 'frequency': 'frequent', 'color': [16, 40, 167]}, {'name': 'beacon, lighthouse, beacon light, pharos', 'id': 160, 'trainId': 156, 'image_count': 81, 'frequency': 'common', 'color': [185, 42, 197]}, {'name': 'umbrella', 'id': 2901, 'trainId': 157, 'image_count': 299, 'frequency': 'frequent', 'color': [72, 148, 13]}, {'name': 'sculpture', 'id': 2262, 'trainId': 158, 'image_count': 316, 'frequency': 'frequent', 'color': [196, 19, 235]}, {'name': 'aqueduct', 'id': 44, 'trainId': 159, 'image_count': 25, 'frequency': 'common', 'color': [134, 178, 137]}, {'name': 'container', 'id': 597, 'trainId': 160, 'image_count': 193, 'frequency': 'frequent', 'color': [217, 187, 40]}, {'name': 'scaffolding, staging', 'id': 2235, 'trainId': 161, 'image_count': 70, 'frequency': 'common', 'color': [204, 31, 146]}, {'name': 'hood, exhaust hood', 'id': 1260, 'trainId': 162, 'image_count': 329, 'frequency': 'frequent', 'color': [141, 67, 99]}, {'name': 'curb, curbing, kerb', 'id': 682, 'trainId': 163, 'image_count': 242, 'frequency': 'frequent', 'color': [219, 239, 118]}, {'name': 'roller coaster', 'id': 2151, 'trainId': 164, 'image_count': 15, 'frequency': 'common', 'color': [151, 236, 126]}, {'name': 'horse, equus caballus', 'id': 3107, 'trainId': 165, 'image_count': 82, 'frequency': 'common', 'color': [174, 8, 141]}, {'name': 'catwalk', 'id': 432, 'trainId': 166, 'image_count': 10, 'frequency': 'rare', 'color': [223, 2, 183]}, {'name': 'glass, drinking glass', 'id': 1098, 'trainId': 167, 'image_count': 650, 'frequency': 'frequent', 'color': [196, 143, 14]}, {'name': 'vase', 'id': 2932, 'trainId': 168, 'image_count': 1559, 'frequency': 'frequent', 'color': [48, 172, 76]}, {'name': 'central reservation', 'id': 461, 'trainId': 169, 'image_count': 216, 'frequency': 'frequent', 'color': [163, 153, 15]}, {'name': 'carousel', 'id': 410, 'trainId': 170, 'image_count': 22, 'frequency': 'common', 'color': [179, 15, 102]}, {'name': 'radiator', 'id': 2046, 'trainId': 171, 'image_count': 205, 'frequency': 'frequent', 'color': [244, 95, 38]}, {'name': 'closet', 'id': 533, 'trainId': 172, 'image_count': 53, 'frequency': 'common', 'color': [205, 83, 12]}, {'name': 'machine', 'id': 1481, 'trainId': 173, 'image_count': 71, 'frequency': 'common', 'color': [204, 7, 182]}, {'name': 'pier, wharf, wharfage, dock', 'id': 1858, 'trainId': 174, 'image_count': 108, 'frequency': 'frequent', 'color': [88, 120, 120]}, {'name': 'fan', 'id': 894, 'trainId': 175, 'image_count': 485, 'frequency': 'frequent', 'color': [182, 190, 55]}, {'name': 'inflatable bounce game', 'id': 1322, 'trainId': 176, 'image_count': 10, 'frequency': 'rare', 'color': [107, 52, 111]}, {'name': 'pitch', 'id': 1891, 'trainId': 177, 'image_count': 36, 'frequency': 'common', 'color': [109, 157, 80]}, {'name': 'paper', 'id': 1756, 'trainId': 178, 'image_count': 445, 'frequency': 'frequent', 'color': [122, 164, 139]}, {'name': 'arcade, colonnade', 'id': 49, 'trainId': 179, 'image_count': 55, 'frequency': 'common', 'color': [20, 26, 240]}, {'name': 'hot tub', 'id': 1272, 'trainId': 180, 'image_count': 61, 'frequency': 'common', 'color': [20, 164, 133]}, {'name': 'helicopter', 'id': 1229, 'trainId': 181, 'image_count': 32, 'frequency': 'common', 'color': [86, 254, 105]}, {'name': 'tray', 'id': 2850, 'trainId': 182, 'image_count': 514, 'frequency': 'frequent', 'color': [175, 114, 210]}, {'name': 'partition, divider', 'id': 1784, 'trainId': 183, 'image_count': 93, 'frequency': 'common', 'color': [237, 156, 183]}, {'name': 'vineyard', 'id': 2962, 'trainId': 184, 'image_count': 22, 'frequency': 'common', 'color': [132, 235, 45]}, {'name': 'bowl', 'id': 259, 'trainId': 185, 'image_count': 745, 'frequency': 'frequent', 'color': [106, 155, 174]}, {'name': 'bullring', 'id': 319, 'trainId': 186, 'image_count': 13, 'frequency': 'common', 'color': [64, 29, 180]}, {'name': 'flag', 'id': 954, 'trainId': 187, 'image_count': 500, 'frequency': 'frequent', 'color': [44, 107, 49]}, {'name': 'pot', 'id': 1974, 'trainId': 188, 'image_count': 626, 'frequency': 'frequent', 'color': [125, 130, 81]}, {'name': 'footbridge, overcrossing, pedestrian bridge', 'id': 1013, 'trainId': 189, 'image_count': 72, 'frequency': 'common', 'color': [13, 26, 15]}, {'name': 'shower', 'id': 2356, 'trainId': 190, 'image_count': 131, 'frequency': 'frequent', 'color': [234, 77, 118]}, {'name': 'bag, traveling bag, travelling bag, grip, suitcase', 'id': 97, 'trainId': 191, 'image_count': 141, 'frequency': 'frequent', 'color': [33, 127, 217]}, {'name': 'bulletin board, notice board', 'id': 318, 'trainId': 192, 'image_count': 232, 'frequency': 'frequent', 'color': [207, 197, 181]}, {'name': 'confessional booth', 'id': 592, 'trainId': 193, 'image_count': 9, 'frequency': 'rare', 'color': [185, 152, 157]}, {'name': 'trunk, tree trunk, bole', 'id': 2885, 'trainId': 194, 'image_count': 129, 'frequency': 'frequent', 'color': [55, 117, 68]}, {'name': 'forest', 'id': 1017, 'trainId': 195, 'image_count': 12, 'frequency': 'common', 'color': [8, 241, 184]}, {'name': 'elevator door', 'id': 851, 'trainId': 196, 'image_count': 23, 'frequency': 'common', 'color': [83, 88, 193]}, {'name': 'laptop, laptop computer', 'id': 1407, 'trainId': 197, 'image_count': 187, 'frequency': 'frequent', 'color': [236, 40, 196]}, {'name': 'instrument panel', 'id': 1332, 'trainId': 198, 'image_count': 20, 'frequency': 'common', 'color': [156, 111, 230]}, {'name': 'bucket, pail', 'id': 303, 'trainId': 199, 'image_count': 378, 'frequency': 'frequent', 'color': [5, 196, 192]}, {'name': 'tapestry, tapis', 'id': 2714, 'trainId': 200, 'image_count': 70, 'frequency': 'common', 'color': [63, 164, 205]}, {'name': 'platform', 'id': 1924, 'trainId': 201, 'image_count': 87, 'frequency': 'common', 'color': [167, 42, 80]}, {'name': 'jacket', 'id': 1346, 'trainId': 202, 'image_count': 126, 'frequency': 'frequent', 'color': [211, 30, 168]}, {'name': 'gate', 'id': 1081, 'trainId': 203, 'image_count': 209, 'frequency': 'frequent', 'color': [108, 205, 218]}, {'name': 'monitor, monitoring device', 'id': 1583, 'trainId': 204, 'image_count': 234, 'frequency': 'frequent', 'color': [148, 175, 229]}, {'name': 'telephone booth, phone booth, call box, telephone box, telephone kiosk', 'id': 2727, 'trainId': 205, 'image_count': 88, 'frequency': 'common', 'color': [150, 234, 208]}, {'name': 'spotlight, spot', 'id': 2509, 'trainId': 206, 'image_count': 842, 'frequency': 'frequent', 'color': [252, 207, 221]}, {'name': 'ring', 'id': 2123, 'trainId': 207, 'image_count': 22, 'frequency': 'common', 'color': [197, 86, 54]}, {'name': 'control panel', 'id': 602, 'trainId': 208, 'image_count': 7, 'frequency': 'rare', 'color': [110, 11, 149]}, {'name': 'blackboard, chalkboard', 'id': 202, 'trainId': 209, 'image_count': 61, 'frequency': 'common', 'color': [156, 127, 24]}, {'name': 'air conditioner, air conditioning', 'id': 10, 'trainId': 210, 'image_count': 351, 'frequency': 'frequent', 'color': [160, 43, 232]}, {'name': 'chest', 'id': 490, 'trainId': 211, 'image_count': 120, 'frequency': 'frequent', 'color': [61, 150, 25]}, {'name': 'clock', 'id': 530, 'trainId': 212, 'image_count': 916, 'frequency': 'frequent', 'color': [88, 2, 234]}, {'name': 'sand dune', 'id': 2213, 'trainId': 213, 'image_count': 15, 'frequency': 'common', 'color': [58, 36, 252]}, {'name': 'pipe, pipage, piping', 'id': 1884, 'trainId': 214, 'image_count': 160, 'frequency': 'frequent', 'color': [124, 141, 67]}, {'name': 'vault', 'id': 2934, 'trainId': 215, 'image_count': 44, 'frequency': 'common', 'color': [236, 221, 117]}, {'name': 'table football', 'id': 2687, 'trainId': 216, 'image_count': 42, 'frequency': 'common', 'color': [218, 141, 130]}, {'name': 'cannon', 'id': 387, 'trainId': 217, 'image_count': 25, 'frequency': 'common', 'color': [0, 192, 246]}, {'name': 'swimming pool, swimming bath, natatorium', 'id': 2668, 'trainId': 218, 'image_count': 75, 'frequency': 'common', 'color': [210, 122, 130]}, {'name': 'fluorescent, fluorescent fixture', 'id': 982, 'trainId': 219, 'image_count': 376, 'frequency': 'frequent', 'color': [177, 241, 217]}, {'name': 'statue', 'id': 2547, 'trainId': 220, 'image_count': 140, 'frequency': 'frequent', 'color': [103, 240, 107]}, {'name': 'loudspeaker, speaker, speaker unit, loudspeaker system, speaker system', 'id': 1474, 'trainId': 221, 'image_count': 224, 'frequency': 'frequent', 'color': [2, 191, 37]}, {'name': 'exhibitor', 'id': 877, 'trainId': 222, 'image_count': 58, 'frequency': 'common', 'color': [113, 43, 97]}, {'name': 'ladder', 'id': 1391, 'trainId': 223, 'image_count': 139, 'frequency': 'frequent', 'color': [60, 133, 3]}, {'name': 'carport', 'id': 414, 'trainId': 224, 'image_count': 23, 'frequency': 'common', 'color': [225, 70, 87]}, {'name': 'dam', 'id': 698, 'trainId': 225, 'image_count': 24, 'frequency': 'common', 'color': [171, 183, 18]}, {'name': 'pulpit', 'id': 2019, 'trainId': 226, 'image_count': 27, 'frequency': 'common', 'color': [151, 168, 59]}, {'name': 'skylight, fanlight', 'id': 2422, 'trainId': 227, 'image_count': 81, 'frequency': 'common', 'color': [187, 75, 51]}, {'name': 'water tower', 'id': 3010, 'trainId': 228, 'image_count': 29, 'frequency': 'common', 'color': [21, 168, 51]}, {'name': 'grill, grille, grillwork', 'id': 1139, 'trainId': 229, 'image_count': 215, 'frequency': 'frequent', 'color': [191, 153, 92]}, {'name': 'display board', 'id': 753, 'trainId': 230, 'image_count': 31, 'frequency': 'common', 'color': [186, 90, 81]}, {'name': 'pane, pane of glass, window glass', 'id': 1747, 'trainId': 231, 'image_count': 151, 'frequency': 'frequent', 'color': [139, 184, 165]}, {'name': 'rubbish, trash, scrap', 'id': 2175, 'trainId': 232, 'image_count': 35, 'frequency': 'common', 'color': [78, 91, 185]}, {'name': 'ice rink', 'id': 1301, 'trainId': 233, 'image_count': 28, 'frequency': 'common', 'color': [22, 204, 179]}, {'name': 'fruit', 'id': 1033, 'trainId': 234, 'image_count': 226, 'frequency': 'frequent', 'color': [57, 226, 16]}, {'name': 'patio', 'id': 1789, 'trainId': 235, 'image_count': 15, 'frequency': 'common', 'color': [201, 196, 116]}, {'name': 'vending machine', 'id': 2939, 'trainId': 236, 'image_count': 79, 'frequency': 'common', 'color': [215, 14, 177]}, {'name': 'telephone, phone, telephone set', 'id': 2730, 'trainId': 237, 'image_count': 530, 'frequency': 'frequent', 'color': [236, 107, 64]}, {'name': 'net', 'id': 1652, 'trainId': 238, 'image_count': 75, 'frequency': 'common', 'color': [175, 202, 87]}, {'name': 'backpack, back pack, knapsack, packsack, rucksack, haversack', 'id': 90, 'trainId': 239, 'image_count': 228, 'frequency': 'frequent', 'color': [44, 160, 175]}, {'name': 'jar', 'id': 1349, 'trainId': 240, 'image_count': 333, 'frequency': 'frequent', 'color': [30, 253, 211]}, {'name': 'track', 'id': 2830, 'trainId': 241, 'image_count': 26, 'frequency': 'common', 'color': [68, 233, 54]}, {'name': 'magazine', 'id': 1485, 'trainId': 242, 'image_count': 251, 'frequency': 'frequent', 'color': [144, 151, 30]}, {'name': 'shutter', 'id': 2370, 'trainId': 243, 'image_count': 71, 'frequency': 'common', 'color': [223, 140, 21]}, {'name': 'roof', 'id': 2155, 'trainId': 244, 'image_count': 34, 'frequency': 'common', 'color': [245, 80, 190]}, {'name': 'banner, streamer', 'id': 118, 'trainId': 245, 'image_count': 100, 'frequency': 'common', 'color': [239, 196, 204]}, {'name': 'landfill', 'id': 1402, 'trainId': 246, 'image_count': 12, 'frequency': 'common', 'color': [231, 35, 242]}, {'name': 'post', 'id': 1957, 'trainId': 247, 'image_count': 164, 'frequency': 'frequent', 'color': [77, 242, 167]}, {'name': 'altarpiece, reredos', 'id': 3130, 'trainId': 248, 'image_count': 19, 'frequency': 'common', 'color': [247, 57, 233]}, {'name': 'hat, chapeau, lid', 'id': 1197, 'trainId': 249, 'image_count': 100, 'frequency': 'common', 'color': [211, 118, 1]}, {'name': 'arch, archway', 'id': 52, 'trainId': 250, 'image_count': 46, 'frequency': 'common', 'color': [18, 22, 164]}, {'name': 'table game', 'id': 2688, 'trainId': 251, 'image_count': 20, 'frequency': 'common', 'color': [73, 172, 236]}, {'name': 'bag, handbag, pocketbook, purse', 'id': 96, 'trainId': 252, 'image_count': 233, 'frequency': 'frequent', 'color': [169, 216, 158]}, {'name': 'document, written document, papers', 'id': 762, 'trainId': 253, 'image_count': 148, 'frequency': 'frequent', 'color': [160, 2, 251]}, {'name': 'dome', 'id': 772, 'trainId': 254, 'image_count': 27, 'frequency': 'common', 'color': [106, 69, 63]}, {'name': 'pier', 'id': 1857, 'trainId': 255, 'image_count': 37, 'frequency': 'common', 'color': [52, 144, 79]}, {'name': 'shanties', 'id': 2315, 'trainId': 256, 'image_count': 14, 'frequency': 'common', 'color': [212, 79, 29]}, {'name': 'forecourt', 'id': 1016, 'trainId': 257, 'image_count': 49, 'frequency': 'common', 'color': [85, 28, 150]}, {'name': 'crane', 'id': 643, 'trainId': 258, 'image_count': 118, 'frequency': 'frequent', 'color': [65, 134, 162]}, {'name': 'dog, domestic dog, canis familiaris', 'id': 3105, 'trainId': 259, 'image_count': 129, 'frequency': 'frequent', 'color': [160, 92, 184]}, {'name': 'piano, pianoforte, forte-piano', 'id': 1849, 'trainId': 260, 'image_count': 57, 'frequency': 'common', 'color': [96, 63, 55]}, {'name': 'drawing', 'id': 791, 'trainId': 261, 'image_count': 56, 'frequency': 'common', 'color': [175, 89, 162]}, {'name': 'cabin', 'id': 349, 'trainId': 262, 'image_count': 17, 'frequency': 'common', 'color': [43, 215, 31]}, {'name': 'ad, advertisement, advertizement, advertising, advertizing, advert', 'id': 6, 'trainId': 263, 'image_count': 33, 'frequency': 'common', 'color': [142, 132, 117]}, {'name': 'amphitheater, amphitheatre, coliseum', 'id': 3114, 'trainId': 264, 'image_count': 2, 'frequency': 'rare', 'color': [119, 215, 72]}, {'name': 'monument', 'id': 1587, 'trainId': 265, 'image_count': 31, 'frequency': 'common', 'color': [2, 65, 141]}, {'name': 'henhouse', 'id': 1233, 'trainId': 266, 'image_count': 13, 'frequency': 'common', 'color': [80, 119, 33]}, {'name': 'cockpit', 'id': 559, 'trainId': 267, 'image_count': 2, 'frequency': 'rare', 'color': [55, 198, 24]}, {'name': 'heater, warmer', 'id': 1223, 'trainId': 268, 'image_count': 143, 'frequency': 'frequent', 'color': [109, 234, 38]}, {'name': 'windmill, aerogenerator, wind generator', 'id': 3049, 'trainId': 269, 'image_count': 61, 'frequency': 'common', 'color': [104, 102, 209]}, {'name': 'pool', 'id': 1943, 'trainId': 270, 'image_count': 27, 'frequency': 'common', 'color': [76, 211, 236]}, {'name': 'elevator, lift', 'id': 853, 'trainId': 271, 'image_count': 51, 'frequency': 'common', 'color': [217, 253, 196]}, {'name': 'decoration, ornament, ornamentation', 'id': 709, 'trainId': 272, 'image_count': 88, 'frequency': 'common', 'color': [200, 7, 109]}, {'name': 'labyrinth', 'id': 1390, 'trainId': 273, 'image_count': 10, 'frequency': 'rare', 'color': [150, 74, 95]}, {'name': 'text, textual matter', 'id': 2748, 'trainId': 274, 'image_count': 241, 'frequency': 'frequent', 'color': [149, 155, 153]}, {'name': 'printer', 'id': 2007, 'trainId': 275, 'image_count': 95, 'frequency': 'common', 'color': [170, 177, 5]}, {'name': 'mezzanine, first balcony', 'id': 1546, 'trainId': 276, 'image_count': 37, 'frequency': 'common', 'color': [189, 222, 188]}, {'name': 'mattress', 'id': 1513, 'trainId': 277, 'image_count': 26, 'frequency': 'common', 'color': [175, 178, 80]}, {'name': 'straw', 'id': 2600, 'trainId': 278, 'image_count': 24, 'frequency': 'common', 'color': [220, 86, 190]}, {'name': 'stalls', 'id': 2538, 'trainId': 279, 'image_count': 11, 'frequency': 'common', 'color': [244, 71, 219]}, {'name': 'patio, terrace', 'id': 1790, 'trainId': 280, 'image_count': 12, 'frequency': 'common', 'color': [159, 204, 177]}, {'name': 'billboard, hoarding', 'id': 194, 'trainId': 281, 'image_count': 58, 'frequency': 'common', 'color': [54, 116, 78]}, {'name': 'bus stop', 'id': 326, 'trainId': 282, 'image_count': 49, 'frequency': 'common', 'color': [129, 46, 26]}, {'name': 'trouser, pant', 'id': 2877, 'trainId': 283, 'image_count': 61, 'frequency': 'common', 'color': [179, 210, 139]}, {'name': 'console table, console', 'id': 594, 'trainId': 284, 'image_count': 122, 'frequency': 'frequent', 'color': [121, 134, 92]}, {'name': 'rack', 'id': 2036, 'trainId': 285, 'image_count': 114, 'frequency': 'frequent', 'color': [82, 147, 184]}, {'name': 'notebook', 'id': 1662, 'trainId': 286, 'image_count': 200, 'frequency': 'frequent', 'color': [108, 177, 193]}, {'name': 'shrine', 'id': 2366, 'trainId': 287, 'image_count': 3, 'frequency': 'rare', 'color': [8, 81, 163]}, {'name': 'pantry', 'id': 1754, 'trainId': 288, 'image_count': 18, 'frequency': 'common', 'color': [94, 49, 17]}, {'name': 'cart', 'id': 418, 'trainId': 289, 'image_count': 59, 'frequency': 'common', 'color': [184, 127, 248]}, {'name': 'steam shovel', 'id': 2553, 'trainId': 290, 'image_count': 13, 'frequency': 'common', 'color': [229, 128, 239]}, {'name': 'porch', 'id': 1951, 'trainId': 291, 'image_count': 19, 'frequency': 'common', 'color': [85, 229, 45]}, {'name': 'postbox, mailbox, letter box', 'id': 1963, 'trainId': 292, 'image_count': 97, 'frequency': 'common', 'color': [183, 217, 200]}, {'name': 'figurine, statuette', 'id': 918, 'trainId': 293, 'image_count': 380, 'frequency': 'frequent', 'color': [147, 228, 184]}, {'name': 'recycling bin', 'id': 2086, 'trainId': 294, 'image_count': 12, 'frequency': 'common', 'color': [194, 236, 92]}, {'name': 'folding screen', 'id': 997, 'trainId': 295, 'image_count': 29, 'frequency': 'common', 'color': [55, 59, 182]}, {'name': 'telescope', 'id': 2731, 'trainId': 296, 'image_count': 14, 'frequency': 'common', 'color': [186, 159, 179]}, {'name': 'deck chair, beach chair', 'id': 704, 'trainId': 297, 'image_count': 106, 'frequency': 'frequent', 'color': [136, 22, 26]}, {'name': 'kennel', 'id': 1365, 'trainId': 298, 'image_count': 4, 'frequency': 'rare', 'color': [11, 9, 155]}, {'name': 'coffee maker', 'id': 569, 'trainId': 299, 'image_count': 244, 'frequency': 'frequent', 'color': [12, 208, 106]}, {'name': "altar, communion table, lord's table", 'id': 3108, 'trainId': 300, 'image_count': 63, 'frequency': 'common', 'color': [185, 93, 48]}, {'name': 'fish', 'id': 948, 'trainId': 301, 'image_count': 45, 'frequency': 'common', 'color': [16, 47, 229]}, {'name': 'easel', 'id': 839, 'trainId': 302, 'image_count': 78, 'frequency': 'common', 'color': [198, 168, 150]}, {'name': 'artificial golf green', 'id': 63, 'trainId': 303, 'image_count': 8, 'frequency': 'rare', 'color': [115, 245, 187]}, {'name': 'iceberg', 'id': 1305, 'trainId': 304, 'image_count': 27, 'frequency': 'common', 'color': [183, 193, 227]}, {'name': 'candlestick, candle holder', 'id': 378, 'trainId': 305, 'image_count': 475, 'frequency': 'frequent', 'color': [81, 143, 140]}, {'name': 'shower stall, shower bath', 'id': 2362, 'trainId': 306, 'image_count': 60, 'frequency': 'common', 'color': [26, 82, 227]}, {'name': 'television stand', 'id': 2734, 'trainId': 307, 'image_count': 34, 'frequency': 'common', 'color': [86, 234, 191]}, {'name': 'wall socket, wall plug, electric outlet, electrical outlet, outlet, electric receptacle', 'id': 2982, 'trainId': 308, 'image_count': 1336, 'frequency': 'frequent', 'color': [1, 169, 248]}, {'name': 'skeleton', 'id': 2398, 'trainId': 309, 'image_count': 11, 'frequency': 'common', 'color': [184, 246, 229]}, {'name': 'grand piano, grand', 'id': 1119, 'trainId': 310, 'image_count': 54, 'frequency': 'common', 'color': [221, 96, 245]}, {'name': 'candy, confect', 'id': 382, 'trainId': 311, 'image_count': 8, 'frequency': 'rare', 'color': [145, 109, 20]}, {'name': 'grille door', 'id': 1141, 'trainId': 312, 'image_count': 11, 'frequency': 'common', 'color': [85, 155, 64]}, {'name': 'pedestal, plinth, footstall', 'id': 1805, 'trainId': 313, 'image_count': 65, 'frequency': 'common', 'color': [232, 127, 202]}, {'name': 'jersey, t-shirt, tee shirt', 'id': 3102, 'trainId': 314, 'image_count': 58, 'frequency': 'common', 'color': [113, 215, 101]}, {'name': 'shoe', 'id': 2341, 'trainId': 315, 'image_count': 138, 'frequency': 'frequent', 'color': [161, 56, 87]}, {'name': 'gravestone, headstone, tombstone', 'id': 1131, 'trainId': 316, 'image_count': 27, 'frequency': 'common', 'color': [186, 175, 163]}, {'name': 'shanty', 'id': 2316, 'trainId': 317, 'image_count': 7, 'frequency': 'rare', 'color': [191, 219, 240]}, {'name': 'structure', 'id': 2626, 'trainId': 318, 'image_count': 14, 'frequency': 'common', 'color': [169, 99, 255]}, {'name': 'rocking chair, rocker', 'id': 3104, 'trainId': 319, 'image_count': 56, 'frequency': 'common', 'color': [252, 169, 32]}, {'name': 'bird', 'id': 198, 'trainId': 320, 'image_count': 107, 'frequency': 'frequent', 'color': [180, 18, 58]}, {'name': 'place mat', 'id': 1896, 'trainId': 321, 'image_count': 86, 'frequency': 'common', 'color': [120, 2, 95]}, {'name': 'tomb', 'id': 2800, 'trainId': 322, 'image_count': 29, 'frequency': 'common', 'color': [61, 40, 178]}, {'name': 'big top', 'id': 190, 'trainId': 323, 'image_count': 11, 'frequency': 'common', 'color': [230, 61, 167]}, {'name': 'gas pump, gasoline pump, petrol pump, island dispenser', 'id': 3131, 'trainId': 324, 'image_count': 31, 'frequency': 'common', 'color': [195, 219, 34]}, {'name': 'lockers', 'id': 1463, 'trainId': 325, 'image_count': 19, 'frequency': 'common', 'color': [67, 49, 142]}, {'name': 'cage', 'id': 357, 'trainId': 326, 'image_count': 29, 'frequency': 'common', 'color': [111, 142, 74]}, {'name': 'finger', 'id': 929, 'trainId': 327, 'image_count': 26, 'frequency': 'common', 'color': [112, 69, 16]}, {'name': 'bleachers', 'id': 209, 'trainId': 328, 'image_count': 4, 'frequency': 'rare', 'color': [186, 254, 171]}, {'name': 'ferris wheel', 'id': 912, 'trainId': 329, 'image_count': 10, 'frequency': 'rare', 'color': [170, 172, 210]}, {'name': 'hairdresser chair', 'id': 1164, 'trainId': 330, 'image_count': 6, 'frequency': 'rare', 'color': [15, 219, 52]}, {'name': 'mat', 'id': 1509, 'trainId': 331, 'image_count': 19, 'frequency': 'common', 'color': [153, 3, 172]}, {'name': 'stands', 'id': 2539, 'trainId': 332, 'image_count': 20, 'frequency': 'common', 'color': [145, 83, 144]}, {'name': 'aquarium, fish tank, marine museum', 'id': 3116, 'trainId': 333, 'image_count': 17, 'frequency': 'common', 'color': [222, 189, 175]}, {'name': 'streetcar, tram, tramcar, trolley, trolley car', 'id': 2615, 'trainId': 334, 'image_count': 68, 'frequency': 'common', 'color': [4, 225, 95]}, {'name': 'napkin, table napkin, serviette', 'id': 1644, 'trainId': 335, 'image_count': 191, 'frequency': 'frequent', 'color': [223, 240, 162]}, {'name': 'dummy', 'id': 818, 'trainId': 336, 'image_count': 43, 'frequency': 'common', 'color': [94, 9, 167]}, {'name': 'booklet, brochure, folder, leaflet, pamphlet', 'id': 242, 'trainId': 337, 'image_count': 140, 'frequency': 'frequent', 'color': [47, 211, 236]}, {'name': 'sand trap', 'id': 2217, 'trainId': 338, 'image_count': 18, 'frequency': 'common', 'color': [40, 166, 180]}, {'name': 'shop, store', 'id': 2347, 'trainId': 339, 'image_count': 20, 'frequency': 'common', 'color': [71, 50, 252]}, {'name': 'table cloth', 'id': 2686, 'trainId': 340, 'image_count': 10, 'frequency': 'rare', 'color': [160, 16, 17]}, {'name': 'service station', 'id': 2300, 'trainId': 341, 'image_count': 15, 'frequency': 'common', 'color': [198, 56, 20]}, {'name': 'coffin', 'id': 572, 'trainId': 342, 'image_count': 6, 'frequency': 'rare', 'color': [121, 88, 173]}, {'name': 'drawer', 'id': 789, 'trainId': 343, 'image_count': 88, 'frequency': 'common', 'color': [205, 133, 151]}, {'name': 'cages', 'id': 358, 'trainId': 344, 'image_count': 8, 'frequency': 'rare', 'color': [16, 94, 163]}, {'name': 'slot machine, coin machine', 'id': 2443, 'trainId': 345, 'image_count': 36, 'frequency': 'common', 'color': [242, 164, 197]}, {'name': 'balcony', 'id': 101, 'trainId': 346, 'image_count': 19, 'frequency': 'common', 'color': [34, 121, 179]}, {'name': 'volleyball court', 'id': 2969, 'trainId': 347, 'image_count': 3, 'frequency': 'rare', 'color': [30, 211, 92]}, {'name': 'table tennis', 'id': 2692, 'trainId': 348, 'image_count': 29, 'frequency': 'common', 'color': [173, 77, 189]}, {'name': 'control table', 'id': 606, 'trainId': 349, 'image_count': 5, 'frequency': 'rare', 'color': [77, 72, 146]}, {'name': 'shirt', 'id': 2339, 'trainId': 350, 'image_count': 71, 'frequency': 'common', 'color': [157, 83, 15]}, {'name': 'merchandise, ware, product', 'id': 1533, 'trainId': 351, 'image_count': 33, 'frequency': 'common', 'color': [33, 161, 100]}, {'name': 'railway', 'id': 2060, 'trainId': 352, 'image_count': 30, 'frequency': 'common', 'color': [58, 158, 174]}, {'name': 'parterre', 'id': 1782, 'trainId': 353, 'image_count': 17, 'frequency': 'common', 'color': [90, 30, 161]}, {'name': 'chimney', 'id': 495, 'trainId': 354, 'image_count': 48, 'frequency': 'common', 'color': [120, 85, 103]}, {'name': 'can, tin, tin can', 'id': 371, 'trainId': 355, 'image_count': 223, 'frequency': 'frequent', 'color': [246, 123, 74]}, {'name': 'tanks', 'id': 2707, 'trainId': 356, 'image_count': 5, 'frequency': 'rare', 'color': [228, 160, 154]}, {'name': 'fabric, cloth, material, textile', 'id': 889, 'trainId': 357, 'image_count': 47, 'frequency': 'common', 'color': [120, 5, 113]}, {'name': 'alga, algae', 'id': 3156, 'trainId': 358, 'image_count': 6, 'frequency': 'rare', 'color': [58, 182, 1]}, {'name': 'system', 'id': 2683, 'trainId': 359, 'image_count': 137, 'frequency': 'frequent', 'color': [55, 6, 70]}, {'name': 'map', 'id': 1499, 'trainId': 360, 'image_count': 37, 'frequency': 'common', 'color': [156, 0, 84]}, {'name': 'greenhouse', 'id': 1135, 'trainId': 361, 'image_count': 18, 'frequency': 'common', 'color': [8, 146, 167]}, {'name': 'mug', 'id': 1619, 'trainId': 362, 'image_count': 210, 'frequency': 'frequent', 'color': [244, 116, 80]}, {'name': 'barbecue', 'id': 125, 'trainId': 363, 'image_count': 19, 'frequency': 'common', 'color': [57, 20, 103]}, {'name': 'trailer', 'id': 2838, 'trainId': 364, 'image_count': 31, 'frequency': 'common', 'color': [121, 27, 22]}, {'name': 'toilet tissue, toilet paper, bathroom tissue', 'id': 2792, 'trainId': 365, 'image_count': 150, 'frequency': 'frequent', 'color': [165, 83, 26]}, {'name': 'organ', 'id': 1695, 'trainId': 366, 'image_count': 17, 'frequency': 'common', 'color': [128, 180, 107]}, {'name': 'dishrag, dishcloth', 'id': 746, 'trainId': 367, 'image_count': 128, 'frequency': 'frequent', 'color': [40, 149, 7]}, {'name': 'island', 'id': 1343, 'trainId': 368, 'image_count': 39, 'frequency': 'common', 'color': [42, 57, 147]}, {'name': 'keyboard', 'id': 1370, 'trainId': 369, 'image_count': 90, 'frequency': 'common', 'color': [93, 189, 185]}, {'name': 'trench', 'id': 2858, 'trainId': 370, 'image_count': 16, 'frequency': 'common', 'color': [40, 29, 22]}, {'name': 'basket, basketball hoop, hoop', 'id': 145, 'trainId': 371, 'image_count': 39, 'frequency': 'common', 'color': [146, 85, 2]}, {'name': 'steering wheel, wheel', 'id': 2565, 'trainId': 372, 'image_count': 21, 'frequency': 'common', 'color': [103, 197, 242]}, {'name': 'pitcher, ewer', 'id': 1892, 'trainId': 373, 'image_count': 202, 'frequency': 'frequent', 'color': [142, 192, 168]}, {'name': 'goal', 'id': 1103, 'trainId': 374, 'image_count': 20, 'frequency': 'common', 'color': [121, 45, 61]}, {'name': 'bread, breadstuff, staff of life', 'id': 286, 'trainId': 375, 'image_count': 56, 'frequency': 'common', 'color': [188, 127, 45]}, {'name': 'beds', 'id': 170, 'trainId': 376, 'image_count': 1, 'frequency': 'rare', 'color': [45, 212, 25]}, {'name': 'wood', 'id': 3073, 'trainId': 377, 'image_count': 26, 'frequency': 'common', 'color': [41, 211, 240]}, {'name': 'file cabinet', 'id': 922, 'trainId': 378, 'image_count': 16, 'frequency': 'common', 'color': [201, 160, 85]}, {'name': 'newspaper, paper', 'id': 1655, 'trainId': 379, 'image_count': 48, 'frequency': 'common', 'color': [102, 55, 91]}, {'name': 'motorboat', 'id': 1602, 'trainId': 380, 'image_count': 1, 'frequency': 'rare', 'color': [206, 30, 80]}, {'name': 'rope', 'id': 2160, 'trainId': 381, 'image_count': 40, 'frequency': 'common', 'color': [195, 241, 38]}, {'name': 'guitar', 'id': 1151, 'trainId': 382, 'image_count': 46, 'frequency': 'common', 'color': [163, 11, 16]}, {'name': 'rubble', 'id': 2176, 'trainId': 383, 'image_count': 10, 'frequency': 'rare', 'color': [203, 161, 87]}, {'name': 'scarf', 'id': 2239, 'trainId': 384, 'image_count': 18, 'frequency': 'common', 'color': [85, 16, 71]}, {'name': 'barrels', 'id': 132, 'trainId': 385, 'image_count': 17, 'frequency': 'common', 'color': [99, 158, 194]}, {'name': 'cap', 'id': 394, 'trainId': 386, 'image_count': 44, 'frequency': 'common', 'color': [142, 197, 4]}, {'name': 'leaves', 'id': 1424, 'trainId': 387, 'image_count': 19, 'frequency': 'common', 'color': [96, 202, 4]}, {'name': 'control tower', 'id': 607, 'trainId': 388, 'image_count': 23, 'frequency': 'common', 'color': [5, 130, 182]}, {'name': 'dashboard', 'id': 700, 'trainId': 389, 'image_count': 13, 'frequency': 'common', 'color': [40, 191, 28]}, {'name': 'bandstand', 'id': 116, 'trainId': 390, 'image_count': 16, 'frequency': 'common', 'color': [159, 182, 62]}, {'name': 'lectern', 'id': 1425, 'trainId': 391, 'image_count': 47, 'frequency': 'common', 'color': [4, 61, 28]}, {'name': 'switch, electric switch, electrical switch', 'id': 2676, 'trainId': 392, 'image_count': 676, 'frequency': 'frequent', 'color': [213, 208, 212]}, {'name': 'baseboard, mopboard, skirting board', 'id': 141, 'trainId': 393, 'image_count': 38, 'frequency': 'common', 'color': [247, 208, 239]}, {'name': 'shower room', 'id': 2360, 'trainId': 394, 'image_count': 8, 'frequency': 'rare', 'color': [99, 99, 127]}, {'name': 'smoke', 'id': 2449, 'trainId': 395, 'image_count': 24, 'frequency': 'common', 'color': [212, 137, 222]}, {'name': 'faucet, spigot', 'id': 897, 'trainId': 396, 'image_count': 343, 'frequency': 'frequent', 'color': [224, 229, 154]}, {'name': 'bulldozer', 'id': 317, 'trainId': 397, 'image_count': 15, 'frequency': 'common', 'color': [16, 116, 241]}, {'name': 'saucepan', 'id': 2228, 'trainId': 398, 'image_count': 64, 'frequency': 'common', 'color': [188, 235, 182]}, {'name': 'shops', 'id': 2351, 'trainId': 399, 'image_count': 11, 'frequency': 'common', 'color': [20, 213, 105]}, {'name': 'meter', 'id': 1543, 'trainId': 400, 'image_count': 10, 'frequency': 'rare', 'color': [169, 131, 181]}, {'name': 'crevasse', 'id': 656, 'trainId': 401, 'image_count': 5, 'frequency': 'rare', 'color': [183, 149, 170]}, {'name': 'gear', 'id': 1088, 'trainId': 402, 'image_count': 3, 'frequency': 'rare', 'color': [45, 203, 214]}, {'name': 'candelabrum, candelabra', 'id': 373, 'trainId': 403, 'image_count': 84, 'frequency': 'common', 'color': [231, 62, 168]}, {'name': 'sofa bed', 'id': 2472, 'trainId': 404, 'image_count': 8, 'frequency': 'rare', 'color': [44, 49, 156]}, {'name': 'tunnel', 'id': 2892, 'trainId': 405, 'image_count': 31, 'frequency': 'common', 'color': [26, 220, 213]}, {'name': 'pallet', 'id': 1740, 'trainId': 406, 'image_count': 32, 'frequency': 'common', 'color': [163, 224, 44]}, {'name': 'wire, conducting wire', 'id': 3067, 'trainId': 407, 'image_count': 57, 'frequency': 'common', 'color': [246, 54, 198]}, {'name': 'kettle, boiler', 'id': 1367, 'trainId': 408, 'image_count': 153, 'frequency': 'frequent', 'color': [131, 16, 103]}, {'name': 'bidet', 'id': 188, 'trainId': 409, 'image_count': 33, 'frequency': 'common', 'color': [82, 169, 253]}, {'name': 'baby buggy, baby carriage, carriage, perambulator, pram, stroller, go-cart, pushchair, pusher', 'id': 79, 'trainId': 410, 'image_count': 56, 'frequency': 'common', 'color': [113, 226, 1]}, {'name': 'music stand', 'id': 1633, 'trainId': 411, 'image_count': 10, 'frequency': 'rare', 'color': [36, 39, 126]}, {'name': 'pipe, tube', 'id': 1885, 'trainId': 412, 'image_count': 46, 'frequency': 'common', 'color': [103, 113, 200]}, {'name': 'cup', 'id': 677, 'trainId': 413, 'image_count': 200, 'frequency': 'frequent', 'color': [72, 98, 86]}, {'name': 'parking meter', 'id': 1779, 'trainId': 414, 'image_count': 153, 'frequency': 'frequent', 'color': [163, 185, 8]}, {'name': 'ice hockey rink', 'id': 1297, 'trainId': 415, 'image_count': 5, 'frequency': 'rare', 'color': [97, 2, 254]}, {'name': 'shelter', 'id': 2334, 'trainId': 416, 'image_count': 22, 'frequency': 'common', 'color': [244, 47, 192]}, {'name': 'weeds', 'id': 3027, 'trainId': 417, 'image_count': 6, 'frequency': 'rare', 'color': [46, 81, 19]}, {'name': 'temple', 'id': 2735, 'trainId': 418, 'image_count': 11, 'frequency': 'common', 'color': [100, 83, 16]}, {'name': 'patty, cake', 'id': 1791, 'trainId': 419, 'image_count': 67, 'frequency': 'common', 'color': [54, 73, 213]}, {'name': 'ski slope', 'id': 2405, 'trainId': 420, 'image_count': 4, 'frequency': 'rare', 'color': [88, 88, 239]}, {'name': 'panel', 'id': 1748, 'trainId': 421, 'image_count': 6, 'frequency': 'rare', 'color': [148, 205, 76]}, {'name': 'wallet', 'id': 2983, 'trainId': 422, 'image_count': 13, 'frequency': 'common', 'color': [24, 6, 50]}, {'name': 'wheel', 'id': 3035, 'trainId': 423, 'image_count': 59, 'frequency': 'common', 'color': [131, 63, 231]}, {'name': 'towel rack, towel horse', 'id': 2824, 'trainId': 424, 'image_count': 91, 'frequency': 'common', 'color': [148, 178, 77]}, {'name': 'roundabout', 'id': 2168, 'trainId': 425, 'image_count': 22, 'frequency': 'common', 'color': [188, 98, 246]}, {'name': 'canister, cannister, tin', 'id': 385, 'trainId': 426, 'image_count': 94, 'frequency': 'common', 'color': [42, 7, 111]}, {'name': 'rod', 'id': 2148, 'trainId': 427, 'image_count': 129, 'frequency': 'frequent', 'color': [176, 64, 131]}, {'name': 'soap dispenser', 'id': 2465, 'trainId': 428, 'image_count': 157, 'frequency': 'frequent', 'color': [223, 45, 213]}, {'name': 'bell', 'id': 175, 'trainId': 429, 'image_count': 39, 'frequency': 'common', 'color': [167, 193, 99]}, {'name': 'canvas', 'id': 390, 'trainId': 430, 'image_count': 9, 'frequency': 'rare', 'color': [119, 83, 69]}, {'name': 'box office, ticket office, ticket booth', 'id': 268, 'trainId': 431, 'image_count': 10, 'frequency': 'rare', 'color': [237, 49, 62]}, {'name': 'teacup', 'id': 2722, 'trainId': 432, 'image_count': 74, 'frequency': 'common', 'color': [125, 134, 90]}, {'name': 'trellis', 'id': 2857, 'trainId': 433, 'image_count': 14, 'frequency': 'common', 'color': [228, 72, 109]}, {'name': 'workbench', 'id': 3088, 'trainId': 434, 'image_count': 23, 'frequency': 'common', 'color': [24, 122, 20]}, {'name': 'valley, vale', 'id': 2926, 'trainId': 435, 'image_count': 32, 'frequency': 'common', 'color': [200, 238, 248]}, {'name': 'toaster', 'id': 2782, 'trainId': 436, 'image_count': 102, 'frequency': 'frequent', 'color': [195, 37, 196]}, {'name': 'knife', 'id': 1378, 'trainId': 437, 'image_count': 204, 'frequency': 'frequent', 'color': [229, 43, 34]}, {'name': 'podium', 'id': 1934, 'trainId': 438, 'image_count': 25, 'frequency': 'common', 'color': [171, 29, 179]}, {'name': 'ramp', 'id': 2072, 'trainId': 439, 'image_count': 24, 'frequency': 'common', 'color': [246, 185, 120]}, {'name': 'tumble dryer', 'id': 2889, 'trainId': 440, 'image_count': 22, 'frequency': 'common', 'color': [189, 255, 93]}, {'name': 'fireplug, fire hydrant, plug', 'id': 944, 'trainId': 441, 'image_count': 139, 'frequency': 'frequent', 'color': [89, 9, 211]}, {'name': 'gym shoe, sneaker, tennis shoe', 'id': 1158, 'trainId': 442, 'image_count': 42, 'frequency': 'common', 'color': [163, 107, 33]}, {'name': 'lab bench', 'id': 1383, 'trainId': 443, 'image_count': 4, 'frequency': 'rare', 'color': [156, 92, 225]}, {'name': 'equipment', 'id': 867, 'trainId': 444, 'image_count': 32, 'frequency': 'common', 'color': [46, 81, 244]}, {'name': 'rocky formation', 'id': 2145, 'trainId': 445, 'image_count': 7, 'frequency': 'rare', 'color': [41, 152, 39]}, {'name': 'plastic', 'id': 1915, 'trainId': 446, 'image_count': 8, 'frequency': 'rare', 'color': [144, 91, 194]}, {'name': 'calendar', 'id': 361, 'trainId': 447, 'image_count': 39, 'frequency': 'common', 'color': [0, 248, 191]}, {'name': 'caravan', 'id': 402, 'trainId': 448, 'image_count': 15, 'frequency': 'common', 'color': [77, 44, 143]}, {'name': 'check-in-desk', 'id': 482, 'trainId': 449, 'image_count': 11, 'frequency': 'common', 'color': [78, 70, 92]}, {'name': 'ticket counter', 'id': 2761, 'trainId': 450, 'image_count': 11, 'frequency': 'common', 'color': [177, 121, 166]}, {'name': 'brush', 'id': 300, 'trainId': 451, 'image_count': 39, 'frequency': 'common', 'color': [163, 1, 92]}, {'name': 'mill', 'id': 1554, 'trainId': 452, 'image_count': 18, 'frequency': 'common', 'color': [14, 151, 255]}, {'name': 'covered bridge', 'id': 636, 'trainId': 453, 'image_count': 9, 'frequency': 'rare', 'color': [184, 126, 141]}, {'name': 'bowling alley', 'id': 260, 'trainId': 454, 'image_count': 9, 'frequency': 'rare', 'color': [23, 240, 96]}, {'name': 'hanger', 'id': 1186, 'trainId': 455, 'image_count': 64, 'frequency': 'common', 'color': [110, 28, 145]}, {'name': 'excavator', 'id': 871, 'trainId': 456, 'image_count': 17, 'frequency': 'common', 'color': [142, 199, 139]}, {'name': 'trestle', 'id': 2859, 'trainId': 457, 'image_count': 14, 'frequency': 'common', 'color': [156, 130, 51]}, {'name': 'revolving door', 'id': 2103, 'trainId': 458, 'image_count': 6, 'frequency': 'rare', 'color': [236, 132, 230]}, {'name': 'blast furnace', 'id': 208, 'trainId': 459, 'image_count': 5, 'frequency': 'rare', 'color': [126, 197, 152]}, {'name': 'scale, weighing machine', 'id': 2236, 'trainId': 460, 'image_count': 43, 'frequency': 'common', 'color': [41, 175, 46]}, {'name': 'projector', 'id': 2012, 'trainId': 461, 'image_count': 63, 'frequency': 'common', 'color': [152, 4, 150]}, {'name': 'soap', 'id': 2462, 'trainId': 462, 'image_count': 93, 'frequency': 'common', 'color': [47, 53, 5]}, {'name': 'locker', 'id': 1462, 'trainId': 463, 'image_count': 7, 'frequency': 'rare', 'color': [243, 42, 27]}, {'name': 'tractor', 'id': 2832, 'trainId': 464, 'image_count': 22, 'frequency': 'common', 'color': [200, 174, 80]}, {'name': 'stretcher', 'id': 2617, 'trainId': 465, 'image_count': 30, 'frequency': 'common', 'color': [231, 139, 65]}, {'name': 'frame', 'id': 1024, 'trainId': 466, 'image_count': 18, 'frequency': 'common', 'color': [238, 61, 167]}, {'name': 'grating', 'id': 1129, 'trainId': 467, 'image_count': 18, 'frequency': 'common', 'color': [239, 161, 88]}, {'name': 'alembic', 'id': 18, 'trainId': 468, 'image_count': 1, 'frequency': 'rare', 'color': [44, 107, 6]}, {'name': 'candle, taper, wax light', 'id': 376, 'trainId': 469, 'image_count': 171, 'frequency': 'frequent', 'color': [130, 2, 15]}, {'name': 'barrier', 'id': 134, 'trainId': 470, 'image_count': 29, 'frequency': 'common', 'color': [110, 40, 6]}, {'name': 'cardboard', 'id': 407, 'trainId': 471, 'image_count': 5, 'frequency': 'rare', 'color': [242, 4, 202]}, {'name': 'cave', 'id': 434, 'trainId': 472, 'image_count': 10, 'frequency': 'rare', 'color': [195, 177, 147]}, {'name': 'puddle', 'id': 2017, 'trainId': 473, 'image_count': 16, 'frequency': 'common', 'color': [108, 139, 19]}, {'name': 'tarp', 'id': 2717, 'trainId': 474, 'image_count': 3, 'frequency': 'rare', 'color': [15, 63, 100]}, {'name': 'price tag', 'id': 2005, 'trainId': 475, 'image_count': 23, 'frequency': 'common', 'color': [14, 201, 113]}, {'name': 'watchtower', 'id': 2993, 'trainId': 476, 'image_count': 13, 'frequency': 'common', 'color': [87, 179, 143]}, {'name': 'meters', 'id': 1545, 'trainId': 477, 'image_count': 1, 'frequency': 'rare', 'color': [156, 114, 0]}, {'name': 'light bulb, lightbulb, bulb, incandescent lamp, electric light, electric-light bulb', 'id': 1445, 'trainId': 478, 'image_count': 54, 'frequency': 'common', 'color': [255, 139, 79]}, {'name': 'tracks', 'id': 2831, 'trainId': 479, 'image_count': 15, 'frequency': 'common', 'color': [201, 202, 95]}, {'name': 'hair dryer', 'id': 1161, 'trainId': 480, 'image_count': 22, 'frequency': 'common', 'color': [143, 252, 38]}, {'name': 'skirt', 'id': 2411, 'trainId': 481, 'image_count': 18, 'frequency': 'common', 'color': [13, 122, 175]}, {'name': 'viaduct', 'id': 2949, 'trainId': 482, 'image_count': 10, 'frequency': 'rare', 'color': [46, 5, 82]}, {'name': 'paper towel', 'id': 1769, 'trainId': 483, 'image_count': 94, 'frequency': 'common', 'color': [130, 64, 10]}, {'name': 'coat', 'id': 552, 'trainId': 484, 'image_count': 27, 'frequency': 'common', 'color': [179, 213, 158]}, {'name': 'sheet', 'id': 2327, 'trainId': 485, 'image_count': 8, 'frequency': 'rare', 'color': [194, 134, 111]}, {'name': 'fire extinguisher, extinguisher, asphyxiator', 'id': 939, 'trainId': 486, 'image_count': 65, 'frequency': 'common', 'color': [217, 51, 206]}, {'name': 'water wheel', 'id': 3013, 'trainId': 487, 'image_count': 13, 'frequency': 'common', 'color': [104, 54, 62]}, {'name': 'pottery, clayware', 'id': 1986, 'trainId': 488, 'image_count': 33, 'frequency': 'common', 'color': [1, 241, 108]}, {'name': 'magazine rack', 'id': 1486, 'trainId': 489, 'image_count': 33, 'frequency': 'common', 'color': [198, 157, 204]}, {'name': 'teapot', 'id': 2723, 'trainId': 490, 'image_count': 128, 'frequency': 'frequent', 'color': [132, 208, 95]}, {'name': 'microphone, mike', 'id': 1549, 'trainId': 491, 'image_count': 80, 'frequency': 'common', 'color': [71, 205, 70]}, {'name': 'support', 'id': 2649, 'trainId': 492, 'image_count': 24, 'frequency': 'common', 'color': [142, 176, 90]}, {'name': 'forklift', 'id': 1020, 'trainId': 493, 'image_count': 17, 'frequency': 'common', 'color': [220, 2, 198]}, {'name': 'canyon', 'id': 392, 'trainId': 494, 'image_count': 4, 'frequency': 'rare', 'color': [48, 142, 164]}, {'name': 'cash register, register', 'id': 422, 'trainId': 495, 'image_count': 47, 'frequency': 'common', 'color': [29, 10, 20]}, {'name': 'leaf, leafage, foliage', 'id': 1419, 'trainId': 496, 'image_count': 26, 'frequency': 'common', 'color': [28, 217, 73]}, {'name': 'remote control, remote', 'id': 2099, 'trainId': 497, 'image_count': 172, 'frequency': 'frequent', 'color': [124, 238, 45]}, {'name': 'soap dish', 'id': 2464, 'trainId': 498, 'image_count': 79, 'frequency': 'common', 'color': [224, 15, 253]}, {'name': 'windshield, windscreen', 'id': 3058, 'trainId': 499, 'image_count': 20, 'frequency': 'common', 'color': [92, 128, 39]}, {'name': 'cat', 'id': 430, 'trainId': 500, 'image_count': 18, 'frequency': 'common', 'color': [96, 251, 57]}, {'name': 'cue, cue stick, pool cue, pool stick', 'id': 675, 'trainId': 501, 'image_count': 82, 'frequency': 'common', 'color': [1, 102, 13]}, {'name': 'vent, venthole, vent-hole, blowhole', 'id': 2941, 'trainId': 502, 'image_count': 84, 'frequency': 'common', 'color': [59, 96, 185]}, {'name': 'videos', 'id': 2955, 'trainId': 503, 'image_count': 17, 'frequency': 'common', 'color': [155, 244, 34]}, {'name': 'shovel', 'id': 2355, 'trainId': 504, 'image_count': 25, 'frequency': 'common', 'color': [203, 11, 83]}, {'name': 'eaves', 'id': 840, 'trainId': 505, 'image_count': 8, 'frequency': 'rare', 'color': [248, 147, 136]}, {'name': 'antenna, aerial, transmitting aerial', 'id': 32, 'trainId': 506, 'image_count': 138, 'frequency': 'frequent', 'color': [210, 193, 62]}, {'name': 'shipyard', 'id': 2338, 'trainId': 507, 'image_count': 9, 'frequency': 'rare', 'color': [130, 225, 144]}, {'name': 'hen, biddy', 'id': 1232, 'trainId': 508, 'image_count': 14, 'frequency': 'common', 'color': [231, 12, 154]}, {'name': 'traffic cone', 'id': 2834, 'trainId': 509, 'image_count': 73, 'frequency': 'common', 'color': [131, 210, 149]}, {'name': 'washing machines', 'id': 2991, 'trainId': 510, 'image_count': 13, 'frequency': 'common', 'color': [141, 189, 207]}, {'name': 'truck crane', 'id': 2879, 'trainId': 511, 'image_count': 2, 'frequency': 'rare', 'color': [83, 30, 145]}, {'name': 'cds', 'id': 444, 'trainId': 512, 'image_count': 10, 'frequency': 'rare', 'color': [146, 139, 161]}, {'name': 'niche', 'id': 1657, 'trainId': 513, 'image_count': 11, 'frequency': 'common', 'color': [183, 178, 191]}, {'name': 'scoreboard', 'id': 2246, 'trainId': 514, 'image_count': 25, 'frequency': 'common', 'color': [68, 74, 242]}, {'name': 'briefcase', 'id': 296, 'trainId': 515, 'image_count': 57, 'frequency': 'common', 'color': [139, 254, 118]}, {'name': 'boot', 'id': 245, 'trainId': 516, 'image_count': 42, 'frequency': 'common', 'color': [41, 116, 238]}, {'name': 'sweater, jumper', 'id': 2661, 'trainId': 517, 'image_count': 32, 'frequency': 'common', 'color': [86, 66, 142]}, {'name': 'hay', 'id': 1202, 'trainId': 518, 'image_count': 13, 'frequency': 'common', 'color': [78, 18, 194]}, {'name': 'pack', 'id': 1714, 'trainId': 519, 'image_count': 37, 'frequency': 'common', 'color': [121, 151, 88]}, {'name': 'bottle rack', 'id': 251, 'trainId': 520, 'image_count': 23, 'frequency': 'common', 'color': [228, 195, 157]}, {'name': 'glacier', 'id': 1095, 'trainId': 521, 'image_count': 5, 'frequency': 'rare', 'color': [186, 145, 119]}, {'name': 'pergola', 'id': 1828, 'trainId': 522, 'image_count': 5, 'frequency': 'rare', 'color': [170, 142, 141]}, {'name': 'building materials', 'id': 311, 'trainId': 523, 'image_count': 3, 'frequency': 'rare', 'color': [107, 224, 124]}, {'name': 'television camera', 'id': 2732, 'trainId': 524, 'image_count': 23, 'frequency': 'common', 'color': [241, 11, 83]}, {'name': 'first floor', 'id': 947, 'trainId': 525, 'image_count': 5, 'frequency': 'rare', 'color': [20, 216, 156]}, {'name': 'rifle', 'id': 2115, 'trainId': 526, 'image_count': 19, 'frequency': 'common', 'color': [122, 113, 7]}, {'name': 'tennis table', 'id': 2738, 'trainId': 527, 'image_count': 2, 'frequency': 'rare', 'color': [54, 174, 156]}, {'name': 'stadium', 'id': 2525, 'trainId': 528, 'image_count': 1, 'frequency': 'rare', 'color': [109, 64, 140]}, {'name': 'safety belt', 'id': 2194, 'trainId': 529, 'image_count': 9, 'frequency': 'rare', 'color': [208, 225, 214]}, {'name': 'cover', 'id': 634, 'trainId': 530, 'image_count': 8, 'frequency': 'rare', 'color': [232, 99, 85]}, {'name': 'dish rack', 'id': 740, 'trainId': 531, 'image_count': 38, 'frequency': 'common', 'color': [10, 28, 44]}, {'name': 'synthesizer', 'id': 2682, 'trainId': 532, 'image_count': 18, 'frequency': 'common', 'color': [39, 24, 252]}, {'name': 'pumpkin', 'id': 2020, 'trainId': 533, 'image_count': 10, 'frequency': 'rare', 'color': [171, 157, 65]}, {'name': 'gutter', 'id': 1156, 'trainId': 534, 'image_count': 11, 'frequency': 'common', 'color': [22, 216, 162]}, {'name': 'fruit stand', 'id': 1036, 'trainId': 535, 'image_count': 4, 'frequency': 'rare', 'color': [109, 98, 8]}, {'name': 'ice floe, floe', 'id': 1295, 'trainId': 536, 'image_count': 6, 'frequency': 'rare', 'color': [151, 74, 223]}, {'name': 'handle, grip, handgrip, hold', 'id': 1181, 'trainId': 537, 'image_count': 31, 'frequency': 'common', 'color': [87, 132, 14]}, {'name': 'wheelchair', 'id': 3037, 'trainId': 538, 'image_count': 19, 'frequency': 'common', 'color': [214, 146, 70]}, {'name': 'mousepad, mouse mat', 'id': 1614, 'trainId': 539, 'image_count': 51, 'frequency': 'common', 'color': [41, 235, 96]}, {'name': 'diploma', 'id': 736, 'trainId': 540, 'image_count': 15, 'frequency': 'common', 'color': [69, 216, 19]}, {'name': 'fairground ride', 'id': 893, 'trainId': 541, 'image_count': 2, 'frequency': 'rare', 'color': [53, 105, 119]}, {'name': 'radio', 'id': 2047, 'trainId': 542, 'image_count': 34, 'frequency': 'common', 'color': [36, 61, 149]}, {'name': 'hotplate', 'id': 1274, 'trainId': 543, 'image_count': 43, 'frequency': 'common', 'color': [43, 82, 188]}, {'name': 'junk', 'id': 1361, 'trainId': 544, 'image_count': 3, 'frequency': 'rare', 'color': [98, 45, 54]}, {'name': 'wheelbarrow', 'id': 3036, 'trainId': 545, 'image_count': 20, 'frequency': 'common', 'color': [242, 92, 88]}, {'name': 'stream', 'id': 2606, 'trainId': 546, 'image_count': 1, 'frequency': 'rare', 'color': [189, 135, 0]}, {'name': 'toll plaza', 'id': 2797, 'trainId': 547, 'image_count': 5, 'frequency': 'rare', 'color': [7, 118, 123]}, {'name': 'punching bag', 'id': 2022, 'trainId': 548, 'image_count': 6, 'frequency': 'rare', 'color': [117, 220, 152]}, {'name': 'trough', 'id': 2876, 'trainId': 549, 'image_count': 3, 'frequency': 'rare', 'color': [202, 26, 119]}, {'name': 'throne', 'id': 2758, 'trainId': 550, 'image_count': 8, 'frequency': 'rare', 'color': [105, 232, 22]}, {'name': 'chair desk', 'id': 472, 'trainId': 551, 'image_count': 2, 'frequency': 'rare', 'color': [184, 150, 34]}, {'name': 'weighbridge', 'id': 3028, 'trainId': 552, 'image_count': 18, 'frequency': 'common', 'color': [167, 124, 74]}, {'name': 'extractor fan', 'id': 882, 'trainId': 553, 'image_count': 14, 'frequency': 'common', 'color': [73, 224, 77]}, {'name': 'hanging clothes', 'id': 1189, 'trainId': 554, 'image_count': 17, 'frequency': 'common', 'color': [74, 23, 28]}, {'name': 'dish, dish aerial, dish antenna, saucer', 'id': 743, 'trainId': 555, 'image_count': 55, 'frequency': 'common', 'color': [83, 220, 153]}, {'name': 'alarm clock, alarm', 'id': 3122, 'trainId': 556, 'image_count': 21, 'frequency': 'common', 'color': [1, 69, 156]}, {'name': 'ski lift', 'id': 2401, 'trainId': 557, 'image_count': 9, 'frequency': 'rare', 'color': [146, 235, 227]}, {'name': 'chain', 'id': 468, 'trainId': 558, 'image_count': 23, 'frequency': 'common', 'color': [137, 38, 121]}, {'name': 'garage', 'id': 1061, 'trainId': 559, 'image_count': 4, 'frequency': 'rare', 'color': [212, 191, 9]}, {'name': 'mechanical shovel', 'id': 1523, 'trainId': 560, 'image_count': 10, 'frequency': 'rare', 'color': [94, 96, 228]}, {'name': 'wine rack', 'id': 3059, 'trainId': 561, 'image_count': 5, 'frequency': 'rare', 'color': [99, 158, 107]}, {'name': 'tramway', 'id': 2843, 'trainId': 562, 'image_count': 2, 'frequency': 'rare', 'color': [189, 35, 76]}, {'name': 'treadmill', 'id': 2853, 'trainId': 563, 'image_count': 7, 'frequency': 'rare', 'color': [98, 103, 105]}, {'name': 'menu', 'id': 1529, 'trainId': 564, 'image_count': 16, 'frequency': 'common', 'color': [251, 193, 35]}, {'name': 'block', 'id': 214, 'trainId': 565, 'image_count': 8, 'frequency': 'rare', 'color': [213, 139, 148]}, {'name': 'well', 'id': 3032, 'trainId': 566, 'image_count': 7, 'frequency': 'rare', 'color': [20, 66, 103]}, {'name': 'witness stand', 'id': 3071, 'trainId': 567, 'image_count': 9, 'frequency': 'rare', 'color': [138, 2, 163]}, {'name': 'branch', 'id': 277, 'trainId': 568, 'image_count': 18, 'frequency': 'common', 'color': [101, 253, 86]}, {'name': 'duck', 'id': 813, 'trainId': 569, 'image_count': 24, 'frequency': 'common', 'color': [42, 175, 150]}, {'name': 'casserole', 'id': 426, 'trainId': 570, 'image_count': 51, 'frequency': 'common', 'color': [45, 82, 106]}, {'name': 'frying pan', 'id': 1039, 'trainId': 571, 'image_count': 39, 'frequency': 'common', 'color': [255, 12, 114]}, {'name': 'desk organizer', 'id': 727, 'trainId': 572, 'image_count': 37, 'frequency': 'common', 'color': [224, 214, 89]}, {'name': 'mast', 'id': 1508, 'trainId': 573, 'image_count': 38, 'frequency': 'common', 'color': [93, 74, 229]}, {'name': 'spectacles, specs, eyeglasses, glasses', 'id': 2490, 'trainId': 574, 'image_count': 76, 'frequency': 'common', 'color': [174, 125, 127]}, {'name': 'service elevator', 'id': 2299, 'trainId': 575, 'image_count': 2, 'frequency': 'rare', 'color': [217, 146, 25]}, {'name': 'dollhouse', 'id': 768, 'trainId': 576, 'image_count': 2, 'frequency': 'rare', 'color': [88, 212, 203]}, {'name': 'hammock', 'id': 1172, 'trainId': 577, 'image_count': 2, 'frequency': 'rare', 'color': [14, 242, 13]}, {'name': 'clothes hanging', 'id': 537, 'trainId': 578, 'image_count': 18, 'frequency': 'common', 'color': [252, 75, 60]}, {'name': 'photocopier', 'id': 1847, 'trainId': 579, 'image_count': 2, 'frequency': 'rare', 'color': [134, 109, 238]}, {'name': 'notepad', 'id': 1664, 'trainId': 580, 'image_count': 6, 'frequency': 'rare', 'color': [109, 112, 222]}, {'name': 'golf cart', 'id': 1110, 'trainId': 581, 'image_count': 8, 'frequency': 'rare', 'color': [220, 29, 100]}, {'name': 'footpath', 'id': 1014, 'trainId': 582, 'image_count': 7, 'frequency': 'rare', 'color': [128, 155, 186]}, {'name': 'cross', 'id': 662, 'trainId': 583, 'image_count': 74, 'frequency': 'common', 'color': [216, 191, 50]}, {'name': 'baptismal font', 'id': 121, 'trainId': 584, 'image_count': 9, 'frequency': 'rare', 'color': [102, 195, 175]}, {'name': 'boiler', 'id': 227, 'trainId': 585, 'image_count': 4, 'frequency': 'rare', 'color': [118, 34, 230]}, {'name': 'skip', 'id': 2410, 'trainId': 586, 'image_count': 6, 'frequency': 'rare', 'color': [5, 78, 174]}, {'name': 'rotisserie', 'id': 2165, 'trainId': 587, 'image_count': 4, 'frequency': 'rare', 'color': [61, 96, 183]}, {'name': 'tables', 'id': 2696, 'trainId': 588, 'image_count': 9, 'frequency': 'rare', 'color': [183, 207, 246]}, {'name': 'water mill', 'id': 3005, 'trainId': 589, 'image_count': 5, 'frequency': 'rare', 'color': [4, 178, 108]}, {'name': 'helmet', 'id': 1231, 'trainId': 590, 'image_count': 23, 'frequency': 'common', 'color': [245, 105, 99]}, {'name': 'cover curtain', 'id': 635, 'trainId': 591, 'image_count': 7, 'frequency': 'rare', 'color': [95, 195, 4]}, {'name': 'brick', 'id': 292, 'trainId': 592, 'image_count': 10, 'frequency': 'rare', 'color': [237, 115, 27]}, {'name': 'table runner', 'id': 2690, 'trainId': 593, 'image_count': 13, 'frequency': 'common', 'color': [101, 22, 196]}, {'name': 'ashtray', 'id': 65, 'trainId': 594, 'image_count': 83, 'frequency': 'common', 'color': [137, 85, 162]}, {'name': 'street box', 'id': 2607, 'trainId': 595, 'image_count': 6, 'frequency': 'rare', 'color': [206, 159, 52]}, {'name': 'stick', 'id': 2574, 'trainId': 596, 'image_count': 68, 'frequency': 'common', 'color': [32, 113, 65]}, {'name': 'hangers', 'id': 1188, 'trainId': 597, 'image_count': 27, 'frequency': 'common', 'color': [139, 47, 43]}, {'name': 'cells', 'id': 456, 'trainId': 598, 'image_count': 5, 'frequency': 'rare', 'color': [143, 145, 121]}, {'name': 'urinal', 'id': 2913, 'trainId': 599, 'image_count': 4, 'frequency': 'rare', 'color': [157, 205, 223]}, {'name': 'centerpiece', 'id': 459, 'trainId': 600, 'image_count': 11, 'frequency': 'common', 'color': [238, 8, 243]}, {'name': 'portable fridge', 'id': 1955, 'trainId': 601, 'image_count': 7, 'frequency': 'rare', 'color': [232, 189, 85]}, {'name': 'dvds', 'id': 827, 'trainId': 602, 'image_count': 12, 'frequency': 'common', 'color': [68, 207, 17]}, {'name': 'golf club', 'id': 1111, 'trainId': 603, 'image_count': 16, 'frequency': 'common', 'color': [222, 60, 33]}, {'name': 'skirting board', 'id': 2412, 'trainId': 604, 'image_count': 6, 'frequency': 'rare', 'color': [233, 21, 224]}, {'name': 'water cooler', 'id': 2997, 'trainId': 605, 'image_count': 7, 'frequency': 'rare', 'color': [91, 121, 49]}, {'name': 'clipboard', 'id': 528, 'trainId': 606, 'image_count': 5, 'frequency': 'rare', 'color': [43, 149, 89]}, {'name': 'camera, photographic camera', 'id': 366, 'trainId': 607, 'image_count': 53, 'frequency': 'common', 'color': [102, 161, 207]}, {'name': 'pigeonhole', 'id': 1863, 'trainId': 608, 'image_count': 9, 'frequency': 'rare', 'color': [254, 169, 23]}, {'name': 'chips', 'id': 500, 'trainId': 609, 'image_count': 8, 'frequency': 'rare', 'color': [70, 166, 192]}, {'name': 'food processor', 'id': 1001, 'trainId': 610, 'image_count': 36, 'frequency': 'common', 'color': [229, 135, 244]}, {'name': 'post box', 'id': 1958, 'trainId': 611, 'image_count': 1, 'frequency': 'rare', 'color': [128, 80, 239]}, {'name': 'lid', 'id': 1441, 'trainId': 612, 'image_count': 5, 'frequency': 'rare', 'color': [183, 215, 107]}, {'name': 'drum', 'id': 809, 'trainId': 613, 'image_count': 6, 'frequency': 'rare', 'color': [119, 52, 34]}, {'name': 'blender', 'id': 210, 'trainId': 614, 'image_count': 30, 'frequency': 'common', 'color': [185, 229, 222]}, {'name': 'cave entrance', 'id': 435, 'trainId': 615, 'image_count': 6, 'frequency': 'rare', 'color': [164, 165, 83]}, {'name': 'dental chair', 'id': 718, 'trainId': 616, 'image_count': 13, 'frequency': 'common', 'color': [19, 206, 233]}, {'name': 'obelisk', 'id': 1674, 'trainId': 617, 'image_count': 8, 'frequency': 'rare', 'color': [9, 14, 4]}, {'name': 'canoe', 'id': 388, 'trainId': 618, 'image_count': 3, 'frequency': 'rare', 'color': [96, 230, 47]}, {'name': 'mobile', 'id': 1572, 'trainId': 619, 'image_count': 17, 'frequency': 'common', 'color': [172, 87, 93]}, {'name': 'monitors', 'id': 1584, 'trainId': 620, 'image_count': 6, 'frequency': 'rare', 'color': [94, 118, 241]}, {'name': 'pool ball', 'id': 1944, 'trainId': 621, 'image_count': 67, 'frequency': 'common', 'color': [213, 101, 124]}, {'name': 'cue rack', 'id': 674, 'trainId': 622, 'image_count': 7, 'frequency': 'rare', 'color': [235, 233, 2]}, {'name': 'baggage carts', 'id': 99, 'trainId': 623, 'image_count': 3, 'frequency': 'rare', 'color': [0, 248, 204]}, {'name': 'shore', 'id': 2352, 'trainId': 624, 'image_count': 9, 'frequency': 'rare', 'color': [198, 59, 98]}, {'name': 'fork', 'id': 1019, 'trainId': 625, 'image_count': 81, 'frequency': 'common', 'color': [221, 12, 231]}, {'name': 'paper filer', 'id': 1763, 'trainId': 626, 'image_count': 13, 'frequency': 'common', 'color': [77, 60, 231]}, {'name': 'bicycle rack', 'id': 185, 'trainId': 627, 'image_count': 7, 'frequency': 'rare', 'color': [241, 169, 191]}, {'name': 'coat rack', 'id': 554, 'trainId': 628, 'image_count': 12, 'frequency': 'common', 'color': [74, 104, 82]}, {'name': 'garland', 'id': 1066, 'trainId': 629, 'image_count': 7, 'frequency': 'rare', 'color': [162, 191, 133]}, {'name': 'sports bag', 'id': 2508, 'trainId': 630, 'image_count': 2, 'frequency': 'rare', 'color': [127, 194, 71]}, {'name': 'fish tank', 'id': 951, 'trainId': 631, 'image_count': 4, 'frequency': 'rare', 'color': [133, 107, 215]}, {'name': 'towel dispenser', 'id': 2822, 'trainId': 632, 'image_count': 3, 'frequency': 'rare', 'color': [148, 149, 76]}, {'name': 'carriage', 'id': 415, 'trainId': 633, 'image_count': 5, 'frequency': 'rare', 'color': [254, 255, 132]}, {'name': 'brochure', 'id': 297, 'trainId': 634, 'image_count': 17, 'frequency': 'common', 'color': [137, 189, 164]}, {'name': 'plaque', 'id': 1914, 'trainId': 635, 'image_count': 28, 'frequency': 'common', 'color': [160, 48, 176]}, {'name': 'stringer', 'id': 2619, 'trainId': 636, 'image_count': 5, 'frequency': 'rare', 'color': [102, 62, 132]}, {'name': 'iron', 'id': 1338, 'trainId': 637, 'image_count': 11, 'frequency': 'common', 'color': [203, 218, 35]}, {'name': 'spoon', 'id': 2505, 'trainId': 638, 'image_count': 72, 'frequency': 'common', 'color': [6, 66, 139]}, {'name': 'flag pole', 'id': 955, 'trainId': 639, 'image_count': 6, 'frequency': 'rare', 'color': [99, 206, 20]}, {'name': 'toilet brush', 'id': 2786, 'trainId': 640, 'image_count': 27, 'frequency': 'common', 'color': [84, 237, 222]}, {'name': 'book stand', 'id': 238, 'trainId': 641, 'image_count': 8, 'frequency': 'rare', 'color': [105, 27, 52]}, {'name': 'water faucet, water tap, tap, hydrant', 'id': 3000, 'trainId': 642, 'image_count': 26, 'frequency': 'common', 'color': [137, 147, 101]}, {'name': 'ticket office', 'id': 2763, 'trainId': 643, 'image_count': 6, 'frequency': 'rare', 'color': [144, 169, 198]}, {'name': 'broom', 'id': 299, 'trainId': 644, 'image_count': 24, 'frequency': 'common', 'color': [10, 235, 222]}, {'name': 'dvd', 'id': 822, 'trainId': 645, 'image_count': 26, 'frequency': 'common', 'color': [99, 80, 149]}, {'name': 'ice bucket', 'id': 1288, 'trainId': 646, 'image_count': 17, 'frequency': 'common', 'color': [11, 12, 234]}, {'name': 'carapace, shell, cuticle, shield', 'id': 3101, 'trainId': 647, 'image_count': 36, 'frequency': 'common', 'color': [20, 218, 184]}, {'name': 'tureen', 'id': 2894, 'trainId': 648, 'image_count': 28, 'frequency': 'common', 'color': [133, 231, 95]}, {'name': 'folders', 'id': 992, 'trainId': 649, 'image_count': 22, 'frequency': 'common', 'color': [218, 208, 49]}, {'name': 'chess', 'id': 489, 'trainId': 650, 'image_count': 21, 'frequency': 'common', 'color': [118, 157, 214]}, {'name': 'root', 'id': 2157, 'trainId': 651, 'image_count': 1, 'frequency': 'rare', 'color': [153, 73, 24]}, {'name': 'sewing machine', 'id': 2309, 'trainId': 652, 'image_count': 10, 'frequency': 'rare', 'color': [96, 209, 176]}, {'name': 'model', 'id': 1576, 'trainId': 653, 'image_count': 3, 'frequency': 'rare', 'color': [205, 69, 149]}, {'name': 'pen', 'id': 1810, 'trainId': 654, 'image_count': 120, 'frequency': 'frequent', 'color': [66, 184, 43]}, {'name': 'violin', 'id': 2964, 'trainId': 655, 'image_count': 6, 'frequency': 'rare', 'color': [178, 16, 19]}, {'name': 'sweatshirt', 'id': 2662, 'trainId': 656, 'image_count': 5, 'frequency': 'rare', 'color': [20, 13, 108]}, {'name': 'recycling materials', 'id': 2087, 'trainId': 657, 'image_count': 1, 'frequency': 'rare', 'color': [254, 150, 24]}, {'name': 'mitten', 'id': 1569, 'trainId': 658, 'image_count': 14, 'frequency': 'common', 'color': [173, 127, 222]}, {'name': 'chopping board, cutting board', 'id': 503, 'trainId': 659, 'image_count': 30, 'frequency': 'common', 'color': [218, 231, 219]}, {'name': 'mask', 'id': 1505, 'trainId': 660, 'image_count': 18, 'frequency': 'common', 'color': [205, 74, 166]}, {'name': 'log', 'id': 1468, 'trainId': 661, 'image_count': 9, 'frequency': 'rare', 'color': [221, 163, 53]}, {'name': 'mouse, computer mouse', 'id': 1613, 'trainId': 662, 'image_count': 54, 'frequency': 'common', 'color': [50, 246, 247]}, {'name': 'grill', 'id': 1138, 'trainId': 663, 'image_count': 7, 'frequency': 'rare', 'color': [10, 190, 40]}, {'name': 'hole', 'id': 1256, 'trainId': 664, 'image_count': 10, 'frequency': 'rare', 'color': [13, 230, 254]}, {'name': 'target', 'id': 2715, 'trainId': 665, 'image_count': 7, 'frequency': 'rare', 'color': [60, 149, 223]}, {'name': 'trash bag', 'id': 2846, 'trainId': 666, 'image_count': 10, 'frequency': 'rare', 'color': [48, 49, 242]}, {'name': 'chalk', 'id': 477, 'trainId': 667, 'image_count': 7, 'frequency': 'rare', 'color': [108, 255, 3]}, {'name': 'sticks', 'id': 2576, 'trainId': 668, 'image_count': 7, 'frequency': 'rare', 'color': [26, 33, 133]}, {'name': 'balloon', 'id': 108, 'trainId': 669, 'image_count': 13, 'frequency': 'common', 'color': [170, 24, 178]}, {'name': 'score', 'id': 2245, 'trainId': 670, 'image_count': 6, 'frequency': 'rare', 'color': [108, 70, 24]}, {'name': 'hair spray', 'id': 1162, 'trainId': 671, 'image_count': 5, 'frequency': 'rare', 'color': [133, 40, 162]}, {'name': 'roll', 'id': 2149, 'trainId': 672, 'image_count': 14, 'frequency': 'common', 'color': [32, 232, 95]}, {'name': 'runner', 'id': 2183, 'trainId': 673, 'image_count': 1, 'frequency': 'rare', 'color': [161, 153, 70]}, {'name': 'engine', 'id': 858, 'trainId': 674, 'image_count': 3, 'frequency': 'rare', 'color': [206, 111, 178]}, {'name': 'inflatable glove', 'id': 1324, 'trainId': 675, 'image_count': 4, 'frequency': 'rare', 'color': [0, 95, 77]}, {'name': 'games', 'id': 1055, 'trainId': 676, 'image_count': 2, 'frequency': 'rare', 'color': [204, 67, 120]}, {'name': 'pallets', 'id': 1741, 'trainId': 677, 'image_count': 13, 'frequency': 'common', 'color': [34, 3, 237]}, {'name': 'baskets', 'id': 149, 'trainId': 678, 'image_count': 11, 'frequency': 'common', 'color': [166, 87, 38]}, {'name': 'coop', 'id': 615, 'trainId': 679, 'image_count': 2, 'frequency': 'rare', 'color': [150, 120, 211]}, {'name': 'dvd player', 'id': 825, 'trainId': 680, 'image_count': 28, 'frequency': 'common', 'color': [67, 178, 121]}, {'name': 'rocking horse', 'id': 2143, 'trainId': 681, 'image_count': 2, 'frequency': 'rare', 'color': [213, 153, 246]}, {'name': 'buckets', 'id': 304, 'trainId': 682, 'image_count': 4, 'frequency': 'rare', 'color': [13, 177, 162]}, {'name': 'bread rolls', 'id': 283, 'trainId': 683, 'image_count': 11, 'frequency': 'common', 'color': [98, 158, 14]}, {'name': 'shawl', 'id': 2322, 'trainId': 684, 'image_count': 3, 'frequency': 'rare', 'color': [59, 92, 130]}, {'name': 'watering can', 'id': 3017, 'trainId': 685, 'image_count': 10, 'frequency': 'rare', 'color': [104, 163, 166]}, {'name': 'spotlights', 'id': 2510, 'trainId': 686, 'image_count': 23, 'frequency': 'common', 'color': [154, 19, 72]}, {'name': 'post-it', 'id': 1960, 'trainId': 687, 'image_count': 8, 'frequency': 'rare', 'color': [254, 46, 239]}, {'name': 'bowls', 'id': 265, 'trainId': 688, 'image_count': 20, 'frequency': 'common', 'color': [161, 172, 119]}, {'name': 'security camera', 'id': 2282, 'trainId': 689, 'image_count': 25, 'frequency': 'common', 'color': [21, 41, 24]}, {'name': 'runner cloth', 'id': 2184, 'trainId': 690, 'image_count': 10, 'frequency': 'rare', 'color': [94, 66, 184]}, {'name': 'lock', 'id': 1461, 'trainId': 691, 'image_count': 10, 'frequency': 'rare', 'color': [245, 50, 65]}, {'name': 'alarm, warning device, alarm system', 'id': 3113, 'trainId': 692, 'image_count': 22, 'frequency': 'common', 'color': [136, 73, 205]}, {'name': 'side', 'id': 2372, 'trainId': 693, 'image_count': 2, 'frequency': 'rare', 'color': [145, 50, 87]}, {'name': 'roulette', 'id': 2166, 'trainId': 694, 'image_count': 1, 'frequency': 'rare', 'color': [230, 43, 45]}, {'name': 'bone', 'id': 232, 'trainId': 695, 'image_count': 3, 'frequency': 'rare', 'color': [67, 16, 175]}, {'name': 'cutlery', 'id': 693, 'trainId': 696, 'image_count': 16, 'frequency': 'common', 'color': [225, 11, 180]}, {'name': 'pool balls', 'id': 1945, 'trainId': 697, 'image_count': 27, 'frequency': 'common', 'color': [40, 52, 128]}, {'name': 'wheels', 'id': 3039, 'trainId': 698, 'image_count': 2, 'frequency': 'rare', 'color': [236, 220, 37]}, {'name': 'spice rack', 'id': 2494, 'trainId': 699, 'image_count': 19, 'frequency': 'common', 'color': [154, 63, 104]}, {'name': 'plant pots', 'id': 1908, 'trainId': 700, 'image_count': 16, 'frequency': 'common', 'color': [240, 228, 195]}, {'name': 'towel ring', 'id': 2827, 'trainId': 701, 'image_count': 27, 'frequency': 'common', 'color': [63, 136, 205]}, {'name': 'bread box', 'id': 280, 'trainId': 702, 'image_count': 9, 'frequency': 'rare', 'color': [217, 25, 191]}, {'name': 'video', 'id': 2950, 'trainId': 703, 'image_count': 8, 'frequency': 'rare', 'color': [23, 120, 156]}, {'name': 'funfair', 'id': 1044, 'trainId': 704, 'image_count': 2, 'frequency': 'rare', 'color': [31, 33, 154]}, {'name': 'breads', 'id': 288, 'trainId': 705, 'image_count': 13, 'frequency': 'common', 'color': [208, 216, 55]}, {'name': 'tripod', 'id': 2863, 'trainId': 706, 'image_count': 15, 'frequency': 'common', 'color': [12, 203, 60]}, {'name': 'ironing board', 'id': 1342, 'trainId': 707, 'image_count': 7, 'frequency': 'rare', 'color': [117, 75, 177]}, {'name': 'skimmer', 'id': 2409, 'trainId': 708, 'image_count': 7, 'frequency': 'rare', 'color': [208, 199, 211]}, {'name': 'hollow', 'id': 1258, 'trainId': 709, 'image_count': 5, 'frequency': 'rare', 'color': [55, 116, 2]}, {'name': 'scratching post', 'id': 2249, 'trainId': 710, 'image_count': 1, 'frequency': 'rare', 'color': [175, 121, 58]}, {'name': 'tricycle', 'id': 2862, 'trainId': 711, 'image_count': 3, 'frequency': 'rare', 'color': [122, 17, 121]}, {'name': 'file box', 'id': 920, 'trainId': 712, 'image_count': 7, 'frequency': 'rare', 'color': [101, 83, 242]}, {'name': 'mountain pass', 'id': 1607, 'trainId': 713, 'image_count': 14, 'frequency': 'common', 'color': [102, 101, 192]}, {'name': 'tombstones', 'id': 2802, 'trainId': 714, 'image_count': 3, 'frequency': 'rare', 'color': [47, 68, 254]}, {'name': 'cooker', 'id': 610, 'trainId': 715, 'image_count': 13, 'frequency': 'common', 'color': [160, 247, 98]}, {'name': 'card game, cards', 'id': 3129, 'trainId': 716, 'image_count': 17, 'frequency': 'common', 'color': [11, 185, 18]}, {'name': 'golf bag', 'id': 1108, 'trainId': 717, 'image_count': 7, 'frequency': 'rare', 'color': [107, 231, 178]}, {'name': 'towel paper', 'id': 2823, 'trainId': 718, 'image_count': 3, 'frequency': 'rare', 'color': [181, 12, 182]}, {'name': 'chaise lounge', 'id': 476, 'trainId': 719, 'image_count': 5, 'frequency': 'rare', 'color': [29, 49, 148]}, {'name': 'sun', 'id': 2641, 'trainId': 720, 'image_count': 66, 'frequency': 'common', 'color': [84, 117, 178]}, {'name': 'toilet paper holder', 'id': 2788, 'trainId': 721, 'image_count': 17, 'frequency': 'common', 'color': [19, 136, 153]}, {'name': 'rake', 'id': 2070, 'trainId': 722, 'image_count': 9, 'frequency': 'rare', 'color': [202, 236, 141]}, {'name': 'key', 'id': 1368, 'trainId': 723, 'image_count': 18, 'frequency': 'common', 'color': [146, 200, 22]}, {'name': 'umbrella stand', 'id': 2903, 'trainId': 724, 'image_count': 3, 'frequency': 'rare', 'color': [78, 236, 70]}, {'name': 'dartboard', 'id': 699, 'trainId': 725, 'image_count': 10, 'frequency': 'rare', 'color': [118, 62, 15]}, {'name': 'transformer', 'id': 2844, 'trainId': 726, 'image_count': 9, 'frequency': 'rare', 'color': [176, 229, 145]}, {'name': 'fireplace utensils', 'id': 942, 'trainId': 727, 'image_count': 7, 'frequency': 'rare', 'color': [66, 25, 185]}, {'name': 'sweatshirts', 'id': 2663, 'trainId': 728, 'image_count': 5, 'frequency': 'rare', 'color': [67, 45, 154]}, {'name': 'cellular telephone, cellular phone, cellphone, cell, mobile phone', 'id': 457, 'trainId': 729, 'image_count': 34, 'frequency': 'common', 'color': [141, 142, 122]}, {'name': 'tallboy', 'id': 2701, 'trainId': 730, 'image_count': 4, 'frequency': 'rare', 'color': [221, 204, 221]}, {'name': 'stapler', 'id': 2540, 'trainId': 731, 'image_count': 18, 'frequency': 'common', 'color': [221, 60, 213]}, {'name': 'sauna', 'id': 2231, 'trainId': 732, 'image_count': 4, 'frequency': 'rare', 'color': [201, 116, 213]}, {'name': 'test tube', 'id': 2746, 'trainId': 733, 'image_count': 7, 'frequency': 'rare', 'color': [22, 90, 28]}, {'name': 'palette', 'id': 1738, 'trainId': 734, 'image_count': 9, 'frequency': 'rare', 'color': [23, 38, 135]}, {'name': 'shopping carts', 'id': 2350, 'trainId': 735, 'image_count': 2, 'frequency': 'rare', 'color': [164, 245, 153]}, {'name': 'tools', 'id': 2808, 'trainId': 736, 'image_count': 9, 'frequency': 'rare', 'color': [242, 198, 9]}, {'name': 'push button, push, button', 'id': 2025, 'trainId': 737, 'image_count': 13, 'frequency': 'common', 'color': [172, 167, 183]}, {'name': 'star', 'id': 2541, 'trainId': 738, 'image_count': 11, 'frequency': 'common', 'color': [161, 94, 184]}, {'name': 'roof rack', 'id': 2156, 'trainId': 739, 'image_count': 3, 'frequency': 'rare', 'color': [245, 216, 249]}, {'name': 'barbed wire', 'id': 126, 'trainId': 740, 'image_count': 4, 'frequency': 'rare', 'color': [27, 173, 169]}, {'name': 'spray', 'id': 2512, 'trainId': 741, 'image_count': 10, 'frequency': 'rare', 'color': [112, 23, 108]}, {'name': 'ear', 'id': 831, 'trainId': 742, 'image_count': 4, 'frequency': 'rare', 'color': [112, 201, 96]}, {'name': 'sponge', 'id': 2503, 'trainId': 743, 'image_count': 14, 'frequency': 'common', 'color': [207, 28, 12]}, {'name': 'racket', 'id': 2039, 'trainId': 744, 'image_count': 14, 'frequency': 'common', 'color': [27, 114, 70]}, {'name': 'tins', 'id': 2774, 'trainId': 745, 'image_count': 19, 'frequency': 'common', 'color': [31, 75, 202]}, {'name': 'eyeglasses', 'id': 886, 'trainId': 746, 'image_count': 13, 'frequency': 'common', 'color': [154, 11, 59]}, {'name': 'file', 'id': 919, 'trainId': 747, 'image_count': 4, 'frequency': 'rare', 'color': [156, 228, 168]}, {'name': 'scarfs', 'id': 2240, 'trainId': 748, 'image_count': 1, 'frequency': 'rare', 'color': [90, 158, 113]}, {'name': 'sugar bowl', 'id': 2636, 'trainId': 749, 'image_count': 22, 'frequency': 'common', 'color': [88, 158, 155]}, {'name': 'flip flop', 'id': 963, 'trainId': 750, 'image_count': 14, 'frequency': 'common', 'color': [63, 227, 168]}, {'name': 'headstones', 'id': 1218, 'trainId': 751, 'image_count': 3, 'frequency': 'rare', 'color': [108, 94, 36]}, {'name': 'laptop bag', 'id': 1406, 'trainId': 752, 'image_count': 1, 'frequency': 'rare', 'color': [16, 126, 84]}, {'name': 'leash', 'id': 1420, 'trainId': 753, 'image_count': 3, 'frequency': 'rare', 'color': [79, 124, 150]}, {'name': 'climbing frame', 'id': 526, 'trainId': 754, 'image_count': 2, 'frequency': 'rare', 'color': [136, 231, 121]}, {'name': 'suit hanger', 'id': 2639, 'trainId': 755, 'image_count': 3, 'frequency': 'rare', 'color': [67, 224, 165]}, {'name': 'floor spotlight', 'id': 975, 'trainId': 756, 'image_count': 12, 'frequency': 'common', 'color': [180, 101, 43]}, {'name': 'plate rack', 'id': 1921, 'trainId': 757, 'image_count': 5, 'frequency': 'rare', 'color': [116, 200, 246]}, {'name': 'sewer', 'id': 2305, 'trainId': 758, 'image_count': 4, 'frequency': 'rare', 'color': [29, 4, 142]}, {'name': 'hard drive', 'id': 1193, 'trainId': 759, 'image_count': 5, 'frequency': 'rare', 'color': [250, 206, 169]}, {'name': 'sprinkler', 'id': 2517, 'trainId': 760, 'image_count': 27, 'frequency': 'common', 'color': [154, 32, 242]}, {'name': 'tools box', 'id': 2809, 'trainId': 761, 'image_count': 2, 'frequency': 'rare', 'color': [70, 145, 11]}, {'name': 'necklace', 'id': 1647, 'trainId': 762, 'image_count': 8, 'frequency': 'rare', 'color': [176, 63, 250]}, {'name': 'bulbs', 'id': 314, 'trainId': 763, 'image_count': 1, 'frequency': 'rare', 'color': [71, 24, 163]}, {'name': 'steel industry', 'id': 2560, 'trainId': 764, 'image_count': 1, 'frequency': 'rare', 'color': [99, 12, 51]}, {'name': 'club', 'id': 545, 'trainId': 765, 'image_count': 13, 'frequency': 'common', 'color': [136, 5, 23]}, {'name': 'jack', 'id': 1345, 'trainId': 766, 'image_count': 3, 'frequency': 'rare', 'color': [232, 177, 21]}, {'name': 'door bars', 'id': 775, 'trainId': 767, 'image_count': 1, 'frequency': 'rare', 'color': [60, 86, 248]}, {'name': 'control panel, instrument panel, control board, board, panel', 'id': 603, 'trainId': 768, 'image_count': 1, 'frequency': 'rare', 'color': [33, 219, 240]}, {'name': 'hairbrush', 'id': 1163, 'trainId': 769, 'image_count': 9, 'frequency': 'rare', 'color': [99, 123, 122]}, {'name': 'napkin holder', 'id': 1641, 'trainId': 770, 'image_count': 1, 'frequency': 'rare', 'color': [137, 232, 243]}, {'name': 'office', 'id': 1678, 'trainId': 771, 'image_count': 3, 'frequency': 'rare', 'color': [157, 114, 205]}, {'name': 'smoke detector', 'id': 2450, 'trainId': 772, 'image_count': 5, 'frequency': 'rare', 'color': [20, 23, 207]}, {'name': 'utensils', 'id': 2915, 'trainId': 773, 'image_count': 18, 'frequency': 'common', 'color': [246, 216, 211]}, {'name': 'apron', 'id': 42, 'trainId': 774, 'image_count': 2, 'frequency': 'rare', 'color': [165, 228, 168]}, {'name': 'scissors', 'id': 2242, 'trainId': 775, 'image_count': 19, 'frequency': 'common', 'color': [1, 1, 200]}, {'name': 'terminal', 'id': 2741, 'trainId': 776, 'image_count': 5, 'frequency': 'rare', 'color': [81, 177, 176]}, {'name': 'grinder', 'id': 1143, 'trainId': 777, 'image_count': 4, 'frequency': 'rare', 'color': [222, 24, 209]}, {'name': 'entry phone', 'id': 862, 'trainId': 778, 'image_count': 10, 'frequency': 'rare', 'color': [255, 200, 57]}, {'name': 'newspaper stand', 'id': 1654, 'trainId': 779, 'image_count': 3, 'frequency': 'rare', 'color': [231, 237, 192]}, {'name': 'pepper shaker', 'id': 1826, 'trainId': 780, 'image_count': 12, 'frequency': 'common', 'color': [143, 186, 244]}, {'name': 'onions', 'id': 1689, 'trainId': 781, 'image_count': 3, 'frequency': 'rare', 'color': [20, 147, 238]}, {'name': 'central processing unit, cpu, c p u , central processor, processor, mainframe', 'id': 3124, 'trainId': 782, 'image_count': 6, 'frequency': 'rare', 'color': [147, 68, 14]}, {'name': 'tape', 'id': 2710, 'trainId': 783, 'image_count': 3, 'frequency': 'rare', 'color': [135, 92, 250]}, {'name': 'bat', 'id': 152, 'trainId': 784, 'image_count': 13, 'frequency': 'common', 'color': [118, 36, 85]}, {'name': 'coaster', 'id': 549, 'trainId': 785, 'image_count': 6, 'frequency': 'rare', 'color': [242, 226, 183]}, {'name': 'calculator', 'id': 360, 'trainId': 786, 'image_count': 10, 'frequency': 'rare', 'color': [229, 18, 2]}, {'name': 'potatoes', 'id': 1982, 'trainId': 787, 'image_count': 2, 'frequency': 'rare', 'color': [164, 142, 101]}, {'name': 'luggage rack', 'id': 1478, 'trainId': 788, 'image_count': 1, 'frequency': 'rare', 'color': [124, 103, 51]}, {'name': 'salt', 'id': 2203, 'trainId': 789, 'image_count': 11, 'frequency': 'common', 'color': [182, 47, 150]}, {'name': 'street number', 'id': 2612, 'trainId': 790, 'image_count': 2, 'frequency': 'rare', 'color': [36, 185, 50]}, {'name': 'viewpoint', 'id': 2956, 'trainId': 791, 'image_count': 1, 'frequency': 'rare', 'color': [209, 12, 252]}, {'name': 'sword', 'id': 2681, 'trainId': 792, 'image_count': 1, 'frequency': 'rare', 'color': [201, 134, 192]}, {'name': 'cd', 'id': 437, 'trainId': 793, 'image_count': 4, 'frequency': 'rare', 'color': [195, 22, 245]}, {'name': 'rowing machine', 'id': 2171, 'trainId': 794, 'image_count': 1, 'frequency': 'rare', 'color': [183, 74, 160]}, {'name': 'plug', 'id': 1933, 'trainId': 795, 'image_count': 21, 'frequency': 'common', 'color': [156, 133, 240]}, {'name': 'andiron, firedog, dog, dog-iron', 'id': 3110, 'trainId': 796, 'image_count': 4, 'frequency': 'rare', 'color': [145, 182, 45]}, {'name': 'pepper', 'id': 1824, 'trainId': 797, 'image_count': 14, 'frequency': 'common', 'color': [146, 11, 136]}, {'name': 'tongs', 'id': 2803, 'trainId': 798, 'image_count': 10, 'frequency': 'rare', 'color': [95, 132, 201]}, {'name': 'bonfire', 'id': 234, 'trainId': 799, 'image_count': 3, 'frequency': 'rare', 'color': [67, 69, 83]}, {'name': 'dog dish', 'id': 764, 'trainId': 800, 'image_count': 3, 'frequency': 'rare', 'color': [17, 89, 174]}, {'name': 'belt', 'id': 177, 'trainId': 801, 'image_count': 5, 'frequency': 'rare', 'color': [192, 32, 250]}, {'name': 'dumbbells', 'id': 817, 'trainId': 802, 'image_count': 3, 'frequency': 'rare', 'color': [104, 76, 57]}, {'name': 'videocassette recorder, vcr', 'id': 3145, 'trainId': 803, 'image_count': 11, 'frequency': 'common', 'color': [226, 103, 85]}, {'name': 'hook', 'id': 1262, 'trainId': 804, 'image_count': 8, 'frequency': 'rare', 'color': [224, 251, 168]}, {'name': 'envelopes', 'id': 864, 'trainId': 805, 'image_count': 2, 'frequency': 'rare', 'color': [82, 22, 214]}, {'name': 'shower faucet', 'id': 2359, 'trainId': 806, 'image_count': 8, 'frequency': 'rare', 'color': [16, 227, 203]}, {'name': 'watch', 'id': 2992, 'trainId': 807, 'image_count': 9, 'frequency': 'rare', 'color': [7, 160, 219]}, {'name': 'padlock', 'id': 1725, 'trainId': 808, 'image_count': 5, 'frequency': 'rare', 'color': [175, 200, 113]}, {'name': 'swimming pool ladder', 'id': 2667, 'trainId': 809, 'image_count': 10, 'frequency': 'rare', 'color': [102, 100, 20]}, {'name': 'spanners', 'id': 2484, 'trainId': 810, 'image_count': 1, 'frequency': 'rare', 'color': [22, 76, 105]}, {'name': 'gravy boat', 'id': 1133, 'trainId': 811, 'image_count': 2, 'frequency': 'rare', 'color': [207, 26, 113]}, {'name': 'notice board', 'id': 1667, 'trainId': 812, 'image_count': 3, 'frequency': 'rare', 'color': [104, 9, 183]}, {'name': 'trash bags', 'id': 2847, 'trainId': 813, 'image_count': 1, 'frequency': 'rare', 'color': [254, 153, 30]}, {'name': 'fire alarm', 'id': 932, 'trainId': 814, 'image_count': 13, 'frequency': 'common', 'color': [226, 177, 32]}, {'name': 'ladle', 'id': 1392, 'trainId': 815, 'image_count': 6, 'frequency': 'rare', 'color': [249, 228, 217]}, {'name': 'stethoscope', 'id': 2573, 'trainId': 816, 'image_count': 3, 'frequency': 'rare', 'color': [144, 68, 142]}, {'name': 'rocket', 'id': 2140, 'trainId': 817, 'image_count': 1, 'frequency': 'rare', 'color': [56, 207, 239]}, {'name': 'funnel', 'id': 1046, 'trainId': 818, 'image_count': 3, 'frequency': 'rare', 'color': [79, 237, 16]}, {'name': 'bowling pins', 'id': 264, 'trainId': 819, 'image_count': 4, 'frequency': 'rare', 'color': [155, 236, 251]}, {'name': 'valve', 'id': 2927, 'trainId': 820, 'image_count': 3, 'frequency': 'rare', 'color': [214, 89, 254]}, {'name': 'thermometer', 'id': 2752, 'trainId': 821, 'image_count': 6, 'frequency': 'rare', 'color': [120, 106, 144]}, {'name': 'cups', 'id': 679, 'trainId': 822, 'image_count': 7, 'frequency': 'rare', 'color': [15, 178, 222]}, {'name': 'spice jar', 'id': 2493, 'trainId': 823, 'image_count': 1, 'frequency': 'rare', 'color': [253, 38, 177]}, {'name': 'night light', 'id': 1658, 'trainId': 824, 'image_count': 2, 'frequency': 'rare', 'color': [183, 62, 42]}, {'name': 'soaps', 'id': 2466, 'trainId': 825, 'image_count': 1, 'frequency': 'rare', 'color': [72, 104, 32]}, {'name': 'games table', 'id': 1057, 'trainId': 826, 'image_count': 2, 'frequency': 'rare', 'color': [60, 57, 241]}, {'name': 'slotted spoon', 'id': 2444, 'trainId': 827, 'image_count': 2, 'frequency': 'rare', 'color': [126, 81, 78]}, {'name': 'reel', 'id': 2093, 'trainId': 828, 'image_count': 1, 'frequency': 'rare', 'color': [103, 193, 193]}, {'name': 'scourer', 'id': 2248, 'trainId': 829, 'image_count': 1, 'frequency': 'rare', 'color': [95, 198, 7]}, {'name': 'sleeping robe', 'id': 2432, 'trainId': 830, 'image_count': 3, 'frequency': 'rare', 'color': [228, 70, 147]}, {'name': 'desk mat', 'id': 726, 'trainId': 831, 'image_count': 8, 'frequency': 'rare', 'color': [141, 184, 124]}, {'name': 'dumbbell', 'id': 816, 'trainId': 832, 'image_count': 2, 'frequency': 'rare', 'color': [80, 125, 114]}, {'name': 'hammer', 'id': 1171, 'trainId': 833, 'image_count': 5, 'frequency': 'rare', 'color': [48, 43, 159]}, {'name': 'tie', 'id': 2766, 'trainId': 834, 'image_count': 2, 'frequency': 'rare', 'color': [109, 95, 231]}, {'name': 'typewriter', 'id': 2900, 'trainId': 835, 'image_count': 3, 'frequency': 'rare', 'color': [34, 190, 129]}, {'name': 'shaker', 'id': 2313, 'trainId': 836, 'image_count': 3, 'frequency': 'rare', 'color': [127, 11, 166]}, {'name': 'cheese dish', 'id': 488, 'trainId': 837, 'image_count': 1, 'frequency': 'rare', 'color': [80, 208, 125]}, {'name': 'sea star', 'id': 2265, 'trainId': 838, 'image_count': 1, 'frequency': 'rare', 'color': [12, 94, 164]}, {'name': 'racquet', 'id': 2043, 'trainId': 839, 'image_count': 1, 'frequency': 'rare', 'color': [59, 206, 248]}, {'name': 'butane gas cylinder', 'id': 332, 'trainId': 840, 'image_count': 2, 'frequency': 'rare', 'color': [204, 233, 82]}, {'name': 'paper weight', 'id': 1771, 'trainId': 841, 'image_count': 2, 'frequency': 'rare', 'color': [209, 63, 100]}, {'name': 'shaving brush', 'id': 2320, 'trainId': 842, 'image_count': 2, 'frequency': 'rare', 'color': [255, 220, 185]}, {'name': 'sunglasses', 'id': 2646, 'trainId': 843, 'image_count': 2, 'frequency': 'rare', 'color': [235, 232, 98]}, {'name': 'gear shift', 'id': 1089, 'trainId': 844, 'image_count': 1, 'frequency': 'rare', 'color': [254, 7, 239]}, {'name': 'towel rail', 'id': 2826, 'trainId': 845, 'image_count': 1, 'frequency': 'rare', 'color': [52, 152, 14]}, {'name': 'adding machine, totalizer, totaliser', 'id': 3148, 'trainId': 846, 'image_count': 1, 'frequency': 'rare', 'color': [69, 80, 228]}]
+
+def _get_ade20k_full_val_all_meta_freq():
+ stuff_ids = [k["id"] for k in ADE20K_SEM_SEG_FULL_CATEGORIES]
+ assert len(stuff_ids) == 847, len(stuff_ids)
+
+ stuff_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(stuff_ids)}
+ stuff_classes = [k["name"] for k in ADE20K_SEM_SEG_FULL_CATEGORIES]
+ stuff_colors = [k["color"] for k in ADE20K_SEM_SEG_FULL_CATEGORIES]
+ ret = {
+ "stuff_dataset_id_to_contiguous_id": stuff_dataset_id_to_contiguous_id,
+ "stuff_classes": stuff_classes,
+ "stuff_colors": stuff_colors,
+ }
+ return ret
+
+
+def register_all_ade20k_full_val_all_freq(root):
+ root = os.path.join(root, "ADE20K_2021_17_01")
+ meta = _get_ade20k_full_val_all_meta_freq()
+
+ name, dirname = "val_all", "val_all"
+ image_dir = os.path.join(root, "images_detectron2", "validation")
+ gt_dir = os.path.join(root, "annotations_detectron2", "validation")
+ name = f"ade20k_full_sem_seg_freq_{name}"
+ DatasetCatalog.register(
+ name, lambda x=image_dir, y=gt_dir: load_sem_seg(y, x, gt_ext="tif", image_ext="jpg")
+ )
+
+ MetadataCatalog.get(name).set(
+ stuff_classes=meta["stuff_classes"][:],
+ image_root=image_dir,
+ sem_seg_root=gt_dir,
+ evaluator_type="sem_seg",
+ ignore_label=65535, # NOTE: gt is saved in 16-bit TIFF images
+ stuff_colors=meta["stuff_colors"],
+ )
+
+
+_root = os.getenv("DETECTRON2_DATASETS", "datasets")
+register_all_ade20k_full_val_all_freq(_root)
diff --git a/cat_seg/data/datasets/register_coco_stuff.py b/cat_seg/data/datasets/register_coco_stuff.py
new file mode 100755
index 0000000000000000000000000000000000000000..35c823dee37b1657dc61d1f5beab8c0ecaa98855
--- /dev/null
+++ b/cat_seg/data/datasets/register_coco_stuff.py
@@ -0,0 +1,216 @@
+import os
+
+from detectron2.data import DatasetCatalog, MetadataCatalog
+from detectron2.data.datasets import load_sem_seg
+
+COCO_CATEGORIES = [
+ {"color": [220, 20, 60], "isthing": 1, "id": 1, "name": "person"},
+ {"color": [119, 11, 32], "isthing": 1, "id": 2, "name": "bicycle"},
+ {"color": [0, 0, 142], "isthing": 1, "id": 3, "name": "car"},
+ {"color": [0, 0, 230], "isthing": 1, "id": 4, "name": "motorcycle"},
+ {"color": [106, 0, 228], "isthing": 1, "id": 5, "name": "airplane"},
+ {"color": [0, 60, 100], "isthing": 1, "id": 6, "name": "bus"},
+ {"color": [0, 80, 100], "isthing": 1, "id": 7, "name": "train"},
+ {"color": [0, 0, 70], "isthing": 1, "id": 8, "name": "truck"},
+ {"color": [0, 0, 192], "isthing": 1, "id": 9, "name": "boat"},
+ {"color": [250, 170, 30], "isthing": 1, "id": 10, "name": "traffic light"},
+ {"color": [100, 170, 30], "isthing": 1, "id": 11, "name": "fire hydrant"},
+ {"color": [220, 220, 0], "isthing": 1, "id": 13, "name": "stop sign"},
+ {"color": [175, 116, 175], "isthing": 1, "id": 14, "name": "parking meter"},
+ {"color": [250, 0, 30], "isthing": 1, "id": 15, "name": "bench"},
+ {"color": [165, 42, 42], "isthing": 1, "id": 16, "name": "bird"},
+ {"color": [255, 77, 255], "isthing": 1, "id": 17, "name": "cat"},
+ {"color": [0, 226, 252], "isthing": 1, "id": 18, "name": "dog"},
+ {"color": [182, 182, 255], "isthing": 1, "id": 19, "name": "horse"},
+ {"color": [0, 82, 0], "isthing": 1, "id": 20, "name": "sheep"},
+ {"color": [120, 166, 157], "isthing": 1, "id": 21, "name": "cow"},
+ {"color": [110, 76, 0], "isthing": 1, "id": 22, "name": "elephant"},
+ {"color": [174, 57, 255], "isthing": 1, "id": 23, "name": "bear"},
+ {"color": [199, 100, 0], "isthing": 1, "id": 24, "name": "zebra"},
+ {"color": [72, 0, 118], "isthing": 1, "id": 25, "name": "giraffe"},
+ {"color": [255, 179, 240], "isthing": 1, "id": 27, "name": "backpack"},
+ {"color": [0, 125, 92], "isthing": 1, "id": 28, "name": "umbrella"},
+ {"color": [209, 0, 151], "isthing": 1, "id": 31, "name": "handbag"},
+ {"color": [188, 208, 182], "isthing": 1, "id": 32, "name": "tie"},
+ {"color": [0, 220, 176], "isthing": 1, "id": 33, "name": "suitcase"},
+ {"color": [255, 99, 164], "isthing": 1, "id": 34, "name": "frisbee"},
+ {"color": [92, 0, 73], "isthing": 1, "id": 35, "name": "skis"},
+ {"color": [133, 129, 255], "isthing": 1, "id": 36, "name": "snowboard"},
+ {"color": [78, 180, 255], "isthing": 1, "id": 37, "name": "sports ball"},
+ {"color": [0, 228, 0], "isthing": 1, "id": 38, "name": "kite"},
+ {"color": [174, 255, 243], "isthing": 1, "id": 39, "name": "baseball bat"},
+ {"color": [45, 89, 255], "isthing": 1, "id": 40, "name": "baseball glove"},
+ {"color": [134, 134, 103], "isthing": 1, "id": 41, "name": "skateboard"},
+ {"color": [145, 148, 174], "isthing": 1, "id": 42, "name": "surfboard"},
+ {"color": [255, 208, 186], "isthing": 1, "id": 43, "name": "tennis racket"},
+ {"color": [197, 226, 255], "isthing": 1, "id": 44, "name": "bottle"},
+ {"color": [171, 134, 1], "isthing": 1, "id": 46, "name": "wine glass"},
+ {"color": [109, 63, 54], "isthing": 1, "id": 47, "name": "cup"},
+ {"color": [207, 138, 255], "isthing": 1, "id": 48, "name": "fork"},
+ {"color": [151, 0, 95], "isthing": 1, "id": 49, "name": "knife"},
+ {"color": [9, 80, 61], "isthing": 1, "id": 50, "name": "spoon"},
+ {"color": [84, 105, 51], "isthing": 1, "id": 51, "name": "bowl"},
+ {"color": [74, 65, 105], "isthing": 1, "id": 52, "name": "banana"},
+ {"color": [166, 196, 102], "isthing": 1, "id": 53, "name": "apple"},
+ {"color": [208, 195, 210], "isthing": 1, "id": 54, "name": "sandwich"},
+ {"color": [255, 109, 65], "isthing": 1, "id": 55, "name": "orange"},
+ {"color": [0, 143, 149], "isthing": 1, "id": 56, "name": "broccoli"},
+ {"color": [179, 0, 194], "isthing": 1, "id": 57, "name": "carrot"},
+ {"color": [209, 99, 106], "isthing": 1, "id": 58, "name": "hot dog"},
+ {"color": [5, 121, 0], "isthing": 1, "id": 59, "name": "pizza"},
+ {"color": [227, 255, 205], "isthing": 1, "id": 60, "name": "donut"},
+ {"color": [147, 186, 208], "isthing": 1, "id": 61, "name": "cake"},
+ {"color": [153, 69, 1], "isthing": 1, "id": 62, "name": "chair"},
+ {"color": [3, 95, 161], "isthing": 1, "id": 63, "name": "couch"},
+ {"color": [163, 255, 0], "isthing": 1, "id": 64, "name": "potted plant"},
+ {"color": [119, 0, 170], "isthing": 1, "id": 65, "name": "bed"},
+ {"color": [0, 182, 199], "isthing": 1, "id": 67, "name": "dining table"},
+ {"color": [0, 165, 120], "isthing": 1, "id": 70, "name": "toilet"},
+ {"color": [183, 130, 88], "isthing": 1, "id": 72, "name": "tv"},
+ {"color": [95, 32, 0], "isthing": 1, "id": 73, "name": "laptop"},
+ {"color": [130, 114, 135], "isthing": 1, "id": 74, "name": "mouse"},
+ {"color": [110, 129, 133], "isthing": 1, "id": 75, "name": "remote"},
+ {"color": [166, 74, 118], "isthing": 1, "id": 76, "name": "keyboard"},
+ {"color": [219, 142, 185], "isthing": 1, "id": 77, "name": "cell phone"},
+ {"color": [79, 210, 114], "isthing": 1, "id": 78, "name": "microwave"},
+ {"color": [178, 90, 62], "isthing": 1, "id": 79, "name": "oven"},
+ {"color": [65, 70, 15], "isthing": 1, "id": 80, "name": "toaster"},
+ {"color": [127, 167, 115], "isthing": 1, "id": 81, "name": "sink"},
+ {"color": [59, 105, 106], "isthing": 1, "id": 82, "name": "refrigerator"},
+ {"color": [142, 108, 45], "isthing": 1, "id": 84, "name": "book"},
+ {"color": [196, 172, 0], "isthing": 1, "id": 85, "name": "clock"},
+ {"color": [95, 54, 80], "isthing": 1, "id": 86, "name": "vase"},
+ {"color": [128, 76, 255], "isthing": 1, "id": 87, "name": "scissors"},
+ {"color": [201, 57, 1], "isthing": 1, "id": 88, "name": "teddy bear"},
+ {"color": [246, 0, 122], "isthing": 1, "id": 89, "name": "hair drier"},
+ {"color": [191, 162, 208], "isthing": 1, "id": 90, "name": "toothbrush"},
+ {"id": 92, "name": "banner", "supercategory": "textile"},
+ {"id": 93, "name": "blanket", "supercategory": "textile"},
+ {"id": 94, "name": "branch", "supercategory": "plant"},
+ {"id": 95, "name": "bridge", "supercategory": "building"},
+ {"id": 96, "name": "building-other", "supercategory": "building"},
+ {"id": 97, "name": "bush", "supercategory": "plant"},
+ {"id": 98, "name": "cabinet", "supercategory": "furniture-stuff"},
+ {"id": 99, "name": "cage", "supercategory": "structural"},
+ {"id": 100, "name": "cardboard", "supercategory": "raw-material"},
+ {"id": 101, "name": "carpet", "supercategory": "floor"},
+ {"id": 102, "name": "ceiling-other", "supercategory": "ceiling"},
+ {"id": 103, "name": "ceiling-tile", "supercategory": "ceiling"},
+ {"id": 104, "name": "cloth", "supercategory": "textile"},
+ {"id": 105, "name": "clothes", "supercategory": "textile"},
+ {"id": 106, "name": "clouds", "supercategory": "sky"},
+ {"id": 107, "name": "counter", "supercategory": "furniture-stuff"},
+ {"id": 108, "name": "cupboard", "supercategory": "furniture-stuff"},
+ {"id": 109, "name": "curtain", "supercategory": "textile"},
+ {"id": 110, "name": "desk-stuff", "supercategory": "furniture-stuff"},
+ {"id": 111, "name": "dirt", "supercategory": "ground"},
+ {"id": 112, "name": "door-stuff", "supercategory": "furniture-stuff"},
+ {"id": 113, "name": "fence", "supercategory": "structural"},
+ {"id": 114, "name": "floor-marble", "supercategory": "floor"},
+ {"id": 115, "name": "floor-other", "supercategory": "floor"},
+ {"id": 116, "name": "floor-stone", "supercategory": "floor"},
+ {"id": 117, "name": "floor-tile", "supercategory": "floor"},
+ {"id": 118, "name": "floor-wood", "supercategory": "floor"},
+ {"id": 119, "name": "flower", "supercategory": "plant"},
+ {"id": 120, "name": "fog", "supercategory": "water"},
+ {"id": 121, "name": "food-other", "supercategory": "food-stuff"},
+ {"id": 122, "name": "fruit", "supercategory": "food-stuff"},
+ {"id": 123, "name": "furniture-other", "supercategory": "furniture-stuff"},
+ {"id": 124, "name": "grass", "supercategory": "plant"},
+ {"id": 125, "name": "gravel", "supercategory": "ground"},
+ {"id": 126, "name": "ground-other", "supercategory": "ground"},
+ {"id": 127, "name": "hill", "supercategory": "solid"},
+ {"id": 128, "name": "house", "supercategory": "building"},
+ {"id": 129, "name": "leaves", "supercategory": "plant"},
+ {"id": 130, "name": "light", "supercategory": "furniture-stuff"},
+ {"id": 131, "name": "mat", "supercategory": "textile"},
+ {"id": 132, "name": "metal", "supercategory": "raw-material"},
+ {"id": 133, "name": "mirror-stuff", "supercategory": "furniture-stuff"},
+ {"id": 134, "name": "moss", "supercategory": "plant"},
+ {"id": 135, "name": "mountain", "supercategory": "solid"},
+ {"id": 136, "name": "mud", "supercategory": "ground"},
+ {"id": 137, "name": "napkin", "supercategory": "textile"},
+ {"id": 138, "name": "net", "supercategory": "structural"},
+ {"id": 139, "name": "paper", "supercategory": "raw-material"},
+ {"id": 140, "name": "pavement", "supercategory": "ground"},
+ {"id": 141, "name": "pillow", "supercategory": "textile"},
+ {"id": 142, "name": "plant-other", "supercategory": "plant"},
+ {"id": 143, "name": "plastic", "supercategory": "raw-material"},
+ {"id": 144, "name": "platform", "supercategory": "ground"},
+ {"id": 145, "name": "playingfield", "supercategory": "ground"},
+ {"id": 146, "name": "railing", "supercategory": "structural"},
+ {"id": 147, "name": "railroad", "supercategory": "ground"},
+ {"id": 148, "name": "river", "supercategory": "water"},
+ {"id": 149, "name": "road", "supercategory": "ground"},
+ {"id": 150, "name": "rock", "supercategory": "solid"},
+ {"id": 151, "name": "roof", "supercategory": "building"},
+ {"id": 152, "name": "rug", "supercategory": "textile"},
+ {"id": 153, "name": "salad", "supercategory": "food-stuff"},
+ {"id": 154, "name": "sand", "supercategory": "ground"},
+ {"id": 155, "name": "sea", "supercategory": "water"},
+ {"id": 156, "name": "shelf", "supercategory": "furniture-stuff"},
+ {"id": 157, "name": "sky-other", "supercategory": "sky"},
+ {"id": 158, "name": "skyscraper", "supercategory": "building"},
+ {"id": 159, "name": "snow", "supercategory": "ground"},
+ {"id": 160, "name": "solid-other", "supercategory": "solid"},
+ {"id": 161, "name": "stairs", "supercategory": "furniture-stuff"},
+ {"id": 162, "name": "stone", "supercategory": "solid"},
+ {"id": 163, "name": "straw", "supercategory": "plant"},
+ {"id": 164, "name": "structural-other", "supercategory": "structural"},
+ {"id": 165, "name": "table", "supercategory": "furniture-stuff"},
+ {"id": 166, "name": "tent", "supercategory": "building"},
+ {"id": 167, "name": "textile-other", "supercategory": "textile"},
+ {"id": 168, "name": "towel", "supercategory": "textile"},
+ {"id": 169, "name": "tree", "supercategory": "plant"},
+ {"id": 170, "name": "vegetable", "supercategory": "food-stuff"},
+ {"id": 171, "name": "wall-brick", "supercategory": "wall"},
+ {"id": 172, "name": "wall-concrete", "supercategory": "wall"},
+ {"id": 173, "name": "wall-other", "supercategory": "wall"},
+ {"id": 174, "name": "wall-panel", "supercategory": "wall"},
+ {"id": 175, "name": "wall-stone", "supercategory": "wall"},
+ {"id": 176, "name": "wall-tile", "supercategory": "wall"},
+ {"id": 177, "name": "wall-wood", "supercategory": "wall"},
+ {"id": 178, "name": "water-other", "supercategory": "water"},
+ {"id": 179, "name": "waterdrops", "supercategory": "water"},
+ {"id": 180, "name": "window-blind", "supercategory": "window"},
+ {"id": 181, "name": "window-other", "supercategory": "window"},
+ {"id": 182, "name": "wood", "supercategory": "solid"},
+]
+
+
+def _get_coco_stuff_meta():
+ stuff_ids = [k["id"] for k in COCO_CATEGORIES]
+ assert len(stuff_ids) == 171, len(stuff_ids)
+
+ stuff_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(stuff_ids)}
+ stuff_classes = [k["name"] for k in COCO_CATEGORIES]
+
+ ret = {
+ "stuff_dataset_id_to_contiguous_id": stuff_dataset_id_to_contiguous_id,
+ "stuff_classes": stuff_classes,
+ }
+ return ret
+
+def register_all_coco_stuff_10k(root):
+ root = os.path.join(root, "coco-stuff")
+ meta = _get_coco_stuff_meta()
+ for name, image_dirname, sem_seg_dirname in [
+ ("train", "images/train2017", "annotations_detectron2/train2017"),
+ ("test", "images/val2017", "annotations_detectron2/val2017"),
+ ]:
+ image_dir = os.path.join(root, image_dirname)
+ gt_dir = os.path.join(root, sem_seg_dirname)
+ name = f"coco_2017_{name}_stuff_all_sem_seg"
+ DatasetCatalog.register(
+ name, lambda x=image_dir, y=gt_dir: load_sem_seg(y, x, gt_ext="png", image_ext="jpg")
+ )
+ MetadataCatalog.get(name).set(
+ image_root=image_dir,
+ sem_seg_root=gt_dir,
+ evaluator_type="sem_seg",
+ ignore_label=255,
+ **meta,
+ )
+
+_root = os.getenv("DETECTRON2_DATASETS", "datasets")
+register_all_coco_stuff_10k(_root)
diff --git a/cat_seg/data/datasets/register_pascal_20.py b/cat_seg/data/datasets/register_pascal_20.py
new file mode 100644
index 0000000000000000000000000000000000000000..690050c7d4d02de4837545516c32cf795b3846f0
--- /dev/null
+++ b/cat_seg/data/datasets/register_pascal_20.py
@@ -0,0 +1,53 @@
+import os
+
+from detectron2.data import DatasetCatalog, MetadataCatalog
+from detectron2.data.datasets import load_sem_seg
+import copy
+
+def _get_pascal_voc_meta():
+ voc_classes = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"]
+ voc_colors = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128],
+ [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0],
+ [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128],
+ [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0],
+ [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]]
+ ret = {
+ "stuff_classes" : voc_classes,
+ "stuff_colors" : voc_colors,
+ }
+ return ret
+
+def register_all_pascal_voc(root):
+ root = os.path.join(root, "VOCdevkit/VOC2012")
+ meta = _get_pascal_voc_meta()
+ for name, image_dirname, sem_seg_dirname in [
+ ("test", "JPEGImages", "annotations_detectron2"),
+ ("test_background", "JPEGImages", "annotations_detectron2_bg"),
+ ]:
+ image_dir = os.path.join(root, image_dirname)
+ gt_dir = os.path.join(root, sem_seg_dirname, 'val')
+ name = f"voc_2012_{name}_sem_seg"
+
+ DatasetCatalog.register(name, lambda x=image_dir, y=gt_dir: load_sem_seg(y, x, gt_ext='png', image_ext='jpg'))
+ if "background" in name:
+ MetadataCatalog.get(name).set(image_root=image_dir, seg_seg_root=gt_dir, evaluator_type="sem_seg_background", ignore_label=255,
+ stuff_classes=meta["stuff_classes"] + ["background"], stuff_colors=meta["stuff_colors"])
+ else:
+ MetadataCatalog.get(name).set(image_root=image_dir, seg_seg_root=gt_dir, evaluator_type="sem_seg", ignore_label=255, **meta,)
+
+def register_all_pascal_voc_background(root):
+ root = os.path.join(root, "VOCdevkit/VOC2012")
+ meta = _get_pascal_voc_meta()
+ meta["stuff_classes"] = meta["stuff_classes"] + ["background"]
+ for name, image_dirname, sem_seg_dirname in [
+ ("test_background", "image", "label_openseg_background20"),
+ ]:
+ image_dir = os.path.join(root, image_dirname, 'validation')
+ gt_dir = os.path.join(root, sem_seg_dirname, 'validation')
+ name = f"voc_2012_{name}_sem_seg"
+ DatasetCatalog.register(name, lambda x=image_dir, y=gt_dir: load_sem_seg(y, x, gt_ext='png', image_ext='jpg'))
+ MetadataCatalog.get(name).set(image_root=image_dir, seg_seg_root=gt_dir, evaluator_type="sem_seg_background", ignore_label=255, **meta,)
+
+_root = os.getenv("DETECTRON2_DATASETS", "datasets")
+register_all_pascal_voc(_root)
+#register_all_pascal_voc_background(_root)
\ No newline at end of file
diff --git a/cat_seg/data/datasets/register_pascal_59.py b/cat_seg/data/datasets/register_pascal_59.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff49702fc898ecf38420985d143c70f71169b91a
--- /dev/null
+++ b/cat_seg/data/datasets/register_pascal_59.py
@@ -0,0 +1,81 @@
+import os
+
+from detectron2.data import DatasetCatalog, MetadataCatalog
+from detectron2.data.datasets import load_sem_seg
+import copy
+
+
+stuff_colors = [[0, 192, 64], [0, 192, 64], [0, 64, 96], [128, 192, 192],
+ [0, 64, 64], [0, 192, 224], [0, 192, 192], [128, 192, 64],
+ [0, 192, 96], [128, 192, 64], [128, 32, 192], [0, 0, 224],
+ [64, 128, 32], [0, 160, 0], [0, 0, 0], [192, 128, 160],
+ [0, 32, 0], [0, 128, 128], [64, 128, 160], [128, 160, 0],
+ [0, 128, 0], [192, 128, 32], [128, 96, 128], [0, 0, 128],
+ [64, 0, 32], [0, 224, 128], [128, 0, 0], [192, 0, 160],
+ [0, 96, 128], [128, 128, 128], [64, 0, 160], [128, 224, 128],
+ [128, 128, 64], [192, 0, 32], [128, 96, 0], [128, 0, 192],
+ [0, 128, 32], [64, 224, 0], [0, 0, 64], [128, 128, 160],
+ [0, 0, 64], [0, 160, 192], [128, 0, 96], [128, 0, 192],
+ [0, 32, 192], [128, 128, 224], [0, 0, 192], [128, 160, 192],
+ [128, 128, 0], [128, 0, 32], [128, 32, 0], [128, 0, 128],
+ [64, 96, 0], [0, 128, 192], [0, 128, 160], [192, 224, 0],
+ [0, 128, 64], [128, 128, 32], [192, 32, 128], [0, 64, 192],
+ [0, 0, 32], [64, 160, 128], [128, 64, 64], [128, 0, 160],
+ [128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
+ [190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0],
+ [107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60],
+ [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100],
+ [0, 0, 230], [119, 11, 32],
+ [64, 128, 64], [128, 192, 32], [192, 32, 192], [64, 64, 192],
+ [0, 64, 32], [64, 160, 192], [192, 64, 64], [128, 64, 160],
+ [64, 32, 192], [192, 192, 192], [0, 64, 160], [192, 160, 192],
+ [192, 192, 0], [128, 64, 96], [192, 32, 64], [192, 64, 128],
+ [64, 192, 96], [64, 160, 64], [64, 64, 0]]
+
+def _get_pascal_context_59_meta():
+ #context_classes = ["aeroplane", "bag", "bed", "bedclothes", "bench", "bicycle", "bird", "boat", "book", "bottle", "building", "bus", "cabinet", "car", "cat", "ceiling", "chair", "cloth", "computer", "cow", "cup", "curtain", "dog", "door", "fence", "floor", "flower", "food", "grass", "ground", "horse", "keyboard", "light", "motorbike", "mountain", "mouse", "person", "plate", "platform", "pottedplant", "road", "rock", "sheep", "shelves", "sidewalk", "sign", "sky", "snow", "sofa", "diningtable", "track", "train", "tree", "truck", "tvmonitor", "wall", "water", "window", "wood"]#, "background"]
+ context_classes = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor", "bag", "bed", "bench", "book", "building", "cabinet", "ceiling", "cloth", "computer", "cup", "door", "fence", "floor", "flower", "food", "grass", "ground", "keyboard", "light", "mountain", "mouse", "curtain", "platform", "sign", "plate", "road", "rock", "shelves", "sidewalk", "sky", "snow", "bedclothes", "track", "tree", "truck", "wall", "water", "window", "wood"]
+ context_colors = [stuff_colors[i % len(stuff_colors)] for i in range(len(context_classes))]
+ ret = {
+ "stuff_colors" : context_colors,
+ "stuff_classes" : context_classes,
+ }
+ return ret
+
+def register_pascal_context_59(root):
+ root = os.path.join(root, "VOCdevkit", "VOC2010")
+ meta = _get_pascal_context_59_meta()
+ for name, image_dirname, sem_seg_dirname in [
+ ("test", "JPEGImages", "annotations_detectron2/pc59_val"),
+ ]:
+ image_dir = os.path.join(root, image_dirname)
+ gt_dir = os.path.join(root, sem_seg_dirname)
+ name = f"context_59_{name}_sem_seg"
+ DatasetCatalog.register(name, lambda x=image_dir, y=gt_dir: load_sem_seg(y, x, gt_ext='png', image_ext='jpg'))
+ MetadataCatalog.get(name).set(image_root=image_dir, seg_seg_root=gt_dir, evaluator_type="sem_seg", ignore_label=255, **meta,)
+
+def _get_pascal_context_459_meta():
+ context_459_classes = ["accordion", "aeroplane", "airconditioner", "antenna", "artillery", "ashtray", "atrium", "babycarriage", "bag", "ball", "balloon", "bambooweaving", "barrel", "baseballbat", "basket", "basketballbackboard", "bathtub", "bed", "bedclothes", "beer", "bell", "bench", "bicycle", "binoculars", "bird", "birdcage", "birdfeeder", "birdnest", "blackboard", "board", "boat", "bone", "book", "bottle", "bottleopener", "bowl", "box", "bracelet", "brick", "bridge", "broom", "brush", "bucket", "building", "bus", "cabinet", "cabinetdoor", "cage", "cake", "calculator", "calendar", "camel", "camera", "cameralens", "can", "candle", "candleholder", "cap", "car", "card", "cart", "case", "casetterecorder", "cashregister", "cat", "cd", "cdplayer", "ceiling", "cellphone", "cello", "chain", "chair", "chessboard", "chicken", "chopstick", "clip", "clippers", "clock", "closet", "cloth", "clothestree", "coffee", "coffeemachine", "comb", "computer", "concrete", "cone", "container", "controlbooth", "controller", "cooker", "copyingmachine", "coral", "cork", "corkscrew", "counter", "court", "cow", "crabstick", "crane", "crate", "cross", "crutch", "cup", "curtain", "cushion", "cuttingboard", "dais", "disc", "disccase", "dishwasher", "dock", "dog", "dolphin", "door", "drainer", "dray", "drinkdispenser", "drinkingmachine", "drop", "drug", "drum", "drumkit", "duck", "dumbbell", "earphone", "earrings", "egg", "electricfan", "electriciron", "electricpot", "electricsaw", "electronickeyboard", "engine", "envelope", "equipment", "escalator", "exhibitionbooth", "extinguisher", "eyeglass", "fan", "faucet", "faxmachine", "fence", "ferriswheel", "fireextinguisher", "firehydrant", "fireplace", "fish", "fishtank", "fishbowl", "fishingnet", "fishingpole", "flag", "flagstaff", "flame", "flashlight", "floor", "flower", "fly", "foam", "food", "footbridge", "forceps", "fork", "forklift", "fountain", "fox", "frame", "fridge", "frog", "fruit", "funnel", "furnace", "gamecontroller", "gamemachine", "gascylinder", "gashood", "gasstove", "giftbox", "glass", "glassmarble", "globe", "glove", "goal", "grandstand", "grass", "gravestone", "ground", "guardrail", "guitar", "gun", "hammer", "handcart", "handle", "handrail", "hanger", "harddiskdrive", "hat", "hay", "headphone", "heater", "helicopter", "helmet", "holder", "hook", "horse", "horse-drawncarriage", "hot-airballoon", "hydrovalve", "ice", "inflatorpump", "ipod", "iron", "ironingboard", "jar", "kart", "kettle", "key", "keyboard", "kitchenrange", "kite", "knife", "knifeblock", "ladder", "laddertruck", "ladle", "laptop", "leaves", "lid", "lifebuoy", "light", "lightbulb", "lighter", "line", "lion", "lobster", "lock", "machine", "mailbox", "mannequin", "map", "mask", "mat", "matchbook", "mattress", "menu", "metal", "meterbox", "microphone", "microwave", "mirror", "missile", "model", "money", "monkey", "mop", "motorbike", "mountain", "mouse", "mousepad", "musicalinstrument", "napkin", "net", "newspaper", "oar", "ornament", "outlet", "oven", "oxygenbottle", "pack", "pan", "paper", "paperbox", "papercutter", "parachute", "parasol", "parterre", "patio", "pelage", "pen", "pencontainer", "pencil", "person", "photo", "piano", "picture", "pig", "pillar", "pillow", "pipe", "pitcher", "plant", "plastic", "plate", "platform", "player", "playground", "pliers", "plume", "poker", "pokerchip", "pole", "pooltable", "postcard", "poster", "pot", "pottedplant", "printer", "projector", "pumpkin", "rabbit", "racket", "radiator", "radio", "rail", "rake", "ramp", "rangehood", "receiver", "recorder", "recreationalmachines", "remotecontrol", "road", "robot", "rock", "rocket", "rockinghorse", "rope", "rug", "ruler", "runway", "saddle", "sand", "saw", "scale", "scanner", "scissors", "scoop", "screen", "screwdriver", "sculpture", "scythe", "sewer", "sewingmachine", "shed", "sheep", "shell", "shelves", "shoe", "shoppingcart", "shovel", "sidecar", "sidewalk", "sign", "signallight", "sink", "skateboard", "ski", "sky", "sled", "slippers", "smoke", "snail", "snake", "snow", "snowmobiles", "sofa", "spanner", "spatula", "speaker", "speedbump", "spicecontainer", "spoon", "sprayer", "squirrel", "stage", "stair", "stapler", "stick", "stickynote", "stone", "stool", "stove", "straw", "stretcher", "sun", "sunglass", "sunshade", "surveillancecamera", "swan", "sweeper", "swimring", "swimmingpool", "swing", "switch", "table", "tableware", "tank", "tap", "tape", "tarp", "telephone", "telephonebooth", "tent", "tire", "toaster", "toilet", "tong", "tool", "toothbrush", "towel", "toy", "toycar", "track", "train", "trampoline", "trashbin", "tray", "tree", "tricycle", "tripod", "trophy", "truck", "tube", "turtle", "tvmonitor", "tweezers", "typewriter", "umbrella", "unknown", "vacuumcleaner", "vendingmachine", "videocamera", "videogameconsole", "videoplayer", "videotape", "violin", "wakeboard", "wall", "wallet", "wardrobe", "washingmachine", "watch", "water", "waterdispenser", "waterpipe", "waterskateboard", "watermelon", "whale", "wharf", "wheel", "wheelchair", "window", "windowblinds", "wineglass", "wire", "wood", "wool"]
+ context_colors = [stuff_colors[i % len(stuff_colors)] for i in range(len(context_459_classes))]
+ ret = {
+ "stuff_colors" : context_colors,
+ "stuff_classes" : context_459_classes,
+ }
+ return ret
+
+def register_pascal_context_459(root):
+ root = os.path.join(root, "VOCdevkit", "VOC2010")
+ meta = _get_pascal_context_459_meta()
+ for name, image_dirname, sem_seg_dirname in [
+ ("test", "JPEGImages", "annotations_detectron2/pc459_val"),
+ ]:
+ image_dir = os.path.join(root, image_dirname)
+ gt_dir = os.path.join(root, sem_seg_dirname)
+ name = f"context_459_{name}_sem_seg"
+ DatasetCatalog.register(name, lambda x=image_dir, y=gt_dir: load_sem_seg(y, x, gt_ext='tif', image_ext='jpg'))
+ MetadataCatalog.get(name).set(image_root=image_dir, seg_seg_root=gt_dir, evaluator_type="sem_seg", ignore_label=459, **meta,)
+
+
+_root = os.getenv("DETECTRON2_DATASETS", "datasets")
+register_pascal_context_59(_root)
+register_pascal_context_459(_root)
\ No newline at end of file
diff --git a/cat_seg/modeling/__init__.py b/cat_seg/modeling/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b788ab8d314e5401c06df8cfc405f0571801487f
--- /dev/null
+++ b/cat_seg/modeling/__init__.py
@@ -0,0 +1,3 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+from .backbone.swin import D2SwinTransformer
+from .heads.cat_seg_head import CATSegHead
\ No newline at end of file
diff --git a/cat_seg/modeling/__pycache__/__init__.cpython-38.pyc b/cat_seg/modeling/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2077e5efdb6632eaa7af9b2abed0c0191d10c633
Binary files /dev/null and b/cat_seg/modeling/__pycache__/__init__.cpython-38.pyc differ
diff --git a/cat_seg/modeling/__pycache__/criterion.cpython-38.pyc b/cat_seg/modeling/__pycache__/criterion.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..40267ba817e0438c574f467b7852d59867b19bc2
Binary files /dev/null and b/cat_seg/modeling/__pycache__/criterion.cpython-38.pyc differ
diff --git a/cat_seg/modeling/__pycache__/matcher.cpython-38.pyc b/cat_seg/modeling/__pycache__/matcher.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b69aaa8f09d1d7f02f68e08d42fcb9d41432a977
Binary files /dev/null and b/cat_seg/modeling/__pycache__/matcher.cpython-38.pyc differ
diff --git a/cat_seg/modeling/backbone/__init__.py b/cat_seg/modeling/backbone/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9020c2df23e2af280b7bb168b996ae9eaf312eb8
--- /dev/null
+++ b/cat_seg/modeling/backbone/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
diff --git a/cat_seg/modeling/backbone/__pycache__/__init__.cpython-38.pyc b/cat_seg/modeling/backbone/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bd4a5195830916c2eb7267fa8961bad6359d5ec1
Binary files /dev/null and b/cat_seg/modeling/backbone/__pycache__/__init__.cpython-38.pyc differ
diff --git a/cat_seg/modeling/backbone/__pycache__/image_encoder.cpython-38.pyc b/cat_seg/modeling/backbone/__pycache__/image_encoder.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f1e25c9679b653b9265e58e6255b60d23016e541
Binary files /dev/null and b/cat_seg/modeling/backbone/__pycache__/image_encoder.cpython-38.pyc differ
diff --git a/cat_seg/modeling/backbone/__pycache__/swin.cpython-38.pyc b/cat_seg/modeling/backbone/__pycache__/swin.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1b847c8fd753b09989c7bba17a4779efbdf8d9f6
Binary files /dev/null and b/cat_seg/modeling/backbone/__pycache__/swin.cpython-38.pyc differ
diff --git a/cat_seg/modeling/backbone/swin.py b/cat_seg/modeling/backbone/swin.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3150121f3254ea0cf07aec675b4a0b3b71ca743
--- /dev/null
+++ b/cat_seg/modeling/backbone/swin.py
@@ -0,0 +1,768 @@
+# --------------------------------------------------------
+# Swin Transformer
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Ze Liu, Yutong Lin, Yixuan Wei
+# --------------------------------------------------------
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Modified by Bowen Cheng from https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation/blob/main/mmseg/models/backbones/swin_transformer.py
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+from timm.models.layers import DropPath, to_2tuple, trunc_normal_
+
+from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec
+
+
+class Mlp(nn.Module):
+ """Multilayer perceptron."""
+
+ def __init__(
+ self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0
+ ):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+def window_partition(x, window_size):
+ """
+ Args:
+ x: (B, H, W, C)
+ window_size (int): window size
+ Returns:
+ windows: (num_windows*B, window_size, window_size, C)
+ """
+ B, H, W, C = x.shape
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+ return windows
+
+
+def window_reverse(windows, window_size, H, W):
+ """
+ Args:
+ windows: (num_windows*B, window_size, window_size, C)
+ window_size (int): Window size
+ H (int): Height of image
+ W (int): Width of image
+ Returns:
+ x: (B, H, W, C)
+ """
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ return x
+
+
+class WindowAttention(nn.Module):
+ """Window based multi-head self attention (W-MSA) module with relative position bias.
+ It supports both of shifted and non-shifted window.
+ Args:
+ dim (int): Number of input channels.
+ window_size (tuple[int]): The height and width of the window.
+ num_heads (int): Number of attention heads.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+ """
+
+ def __init__(
+ self,
+ dim,
+ window_size,
+ num_heads,
+ qkv_bias=True,
+ qk_scale=None,
+ attn_drop=0.0,
+ proj_drop=0.0,
+ ):
+
+ super().__init__()
+ self.dim = dim
+ self.window_size = window_size # Wh, Ww
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim ** -0.5
+
+ # define a parameter table of relative position bias
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
+ ) # 2*Wh-1 * 2*Ww-1, nH
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(self.window_size[0])
+ coords_w = torch.arange(self.window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += self.window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ self.register_buffer("relative_position_index", relative_position_index)
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ trunc_normal_(self.relative_position_bias_table, std=0.02)
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, x, mask=None):
+ """Forward function.
+ Args:
+ x: input features with shape of (num_windows*B, N, C)
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
+ """
+ B_, N, C = x.shape
+ qkv = (
+ self.qkv(x)
+ .reshape(B_, N, 3, self.num_heads, C // self.num_heads)
+ .permute(2, 0, 3, 1, 4)
+ )
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+
+ q = q * self.scale
+ attn = q @ k.transpose(-2, -1)
+
+ relative_position_bias = self.relative_position_bias_table[
+ self.relative_position_index.view(-1)
+ ].view(
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
+ ) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(
+ 2, 0, 1
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
+ attn = attn + relative_position_bias.unsqueeze(0)
+
+ if mask is not None:
+ nW = mask.shape[0]
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
+ attn = attn.view(-1, self.num_heads, N, N)
+ attn = self.softmax(attn)
+ else:
+ attn = self.softmax(attn)
+
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class SwinTransformerBlock(nn.Module):
+ """Swin Transformer Block.
+ Args:
+ dim (int): Number of input channels.
+ num_heads (int): Number of attention heads.
+ window_size (int): Window size.
+ shift_size (int): Shift size for SW-MSA.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ window_size=7,
+ shift_size=0,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ ):
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.shift_size = shift_size
+ self.mlp_ratio = mlp_ratio
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
+
+ self.norm1 = norm_layer(dim)
+ self.attn = WindowAttention(
+ dim,
+ window_size=to_2tuple(self.window_size),
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(
+ in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop
+ )
+
+ self.H = None
+ self.W = None
+
+ def forward(self, x, mask_matrix):
+ """Forward function.
+ Args:
+ x: Input feature, tensor size (B, H*W, C).
+ H, W: Spatial resolution of the input feature.
+ mask_matrix: Attention mask for cyclic shift.
+ """
+ B, L, C = x.shape
+ H, W = self.H, self.W
+ assert L == H * W, "input feature has wrong size"
+
+ shortcut = x
+ x = self.norm1(x)
+ x = x.view(B, H, W, C)
+
+ # pad feature maps to multiples of window size
+ pad_l = pad_t = 0
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
+ _, Hp, Wp, _ = x.shape
+
+ # cyclic shift
+ if self.shift_size > 0:
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+ attn_mask = mask_matrix
+ else:
+ shifted_x = x
+ attn_mask = None
+
+ # partition windows
+ x_windows = window_partition(
+ shifted_x, self.window_size
+ ) # nW*B, window_size, window_size, C
+ x_windows = x_windows.view(
+ -1, self.window_size * self.window_size, C
+ ) # nW*B, window_size*window_size, C
+
+ # W-MSA/SW-MSA
+ attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
+
+ # merge windows
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
+
+ # reverse cyclic shift
+ if self.shift_size > 0:
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+ else:
+ x = shifted_x
+
+ if pad_r > 0 or pad_b > 0:
+ x = x[:, :H, :W, :].contiguous()
+
+ x = x.view(B, H * W, C)
+
+ # FFN
+ x = shortcut + self.drop_path(x)
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+
+ return x
+
+
+class PatchMerging(nn.Module):
+ """Patch Merging Layer
+ Args:
+ dim (int): Number of input channels.
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.dim = dim
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
+ self.norm = norm_layer(4 * dim)
+
+ def forward(self, x, H, W):
+ """Forward function.
+ Args:
+ x: Input feature, tensor size (B, H*W, C).
+ H, W: Spatial resolution of the input feature.
+ """
+ B, L, C = x.shape
+ assert L == H * W, "input feature has wrong size"
+
+ x = x.view(B, H, W, C)
+
+ # padding
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
+ if pad_input:
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
+
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
+
+ x = self.norm(x)
+ x = self.reduction(x)
+
+ return x
+
+
+class BasicLayer(nn.Module):
+ """A basic Swin Transformer layer for one stage.
+ Args:
+ dim (int): Number of feature channels
+ depth (int): Depths of this stage.
+ num_heads (int): Number of attention head.
+ window_size (int): Local window size. Default: 7.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ """
+
+ def __init__(
+ self,
+ dim,
+ depth,
+ num_heads,
+ window_size=7,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ norm_layer=nn.LayerNorm,
+ downsample=None,
+ use_checkpoint=False,
+ ):
+ super().__init__()
+ self.window_size = window_size
+ self.shift_size = window_size // 2
+ self.depth = depth
+ self.use_checkpoint = use_checkpoint
+
+ # build blocks
+ self.blocks = nn.ModuleList(
+ [
+ SwinTransformerBlock(
+ dim=dim,
+ num_heads=num_heads,
+ window_size=window_size,
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop,
+ attn_drop=attn_drop,
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
+ norm_layer=norm_layer,
+ )
+ for i in range(depth)
+ ]
+ )
+
+ # patch merging layer
+ if downsample is not None:
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
+ else:
+ self.downsample = None
+
+ def forward(self, x, H, W):
+ """Forward function.
+ Args:
+ x: Input feature, tensor size (B, H*W, C).
+ H, W: Spatial resolution of the input feature.
+ """
+
+ # calculate attention mask for SW-MSA
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
+ h_slices = (
+ slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None),
+ )
+ w_slices = (
+ slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None),
+ )
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ img_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ mask_windows = window_partition(
+ img_mask, self.window_size
+ ) # nW, window_size, window_size, 1
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
+ attn_mask == 0, float(0.0)
+ )
+
+ for blk in self.blocks:
+ blk.H, blk.W = H, W
+ if self.use_checkpoint:
+ x = checkpoint.checkpoint(blk, x, attn_mask)
+ else:
+ x = blk(x, attn_mask)
+ if self.downsample is not None:
+ x_down = self.downsample(x, H, W)
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
+ return x, H, W, x_down, Wh, Ww
+ else:
+ return x, H, W, x, H, W
+
+
+class PatchEmbed(nn.Module):
+ """Image to Patch Embedding
+ Args:
+ patch_size (int): Patch token size. Default: 4.
+ in_chans (int): Number of input image channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
+ """
+
+ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
+ super().__init__()
+ patch_size = to_2tuple(patch_size)
+ self.patch_size = patch_size
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+ if norm_layer is not None:
+ self.norm = norm_layer(embed_dim)
+ else:
+ self.norm = None
+
+ def forward(self, x):
+ """Forward function."""
+ # padding
+ _, _, H, W = x.size()
+ if W % self.patch_size[1] != 0:
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
+ if H % self.patch_size[0] != 0:
+ x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
+
+ x = self.proj(x) # B C Wh Ww
+ if self.norm is not None:
+ Wh, Ww = x.size(2), x.size(3)
+ x = x.flatten(2).transpose(1, 2)
+ x = self.norm(x)
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
+
+ return x
+
+
+class SwinTransformer(nn.Module):
+ """Swin Transformer backbone.
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
+ https://arxiv.org/pdf/2103.14030
+ Args:
+ pretrain_img_size (int): Input image size for training the pretrained model,
+ used in absolute postion embedding. Default 224.
+ patch_size (int | tuple(int)): Patch size. Default: 4.
+ in_chans (int): Number of input image channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ depths (tuple[int]): Depths of each Swin Transformer stage.
+ num_heads (tuple[int]): Number of attention head of each stage.
+ window_size (int): Window size. Default: 7.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
+ drop_rate (float): Dropout rate.
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
+ out_indices (Sequence[int]): Output from which stages.
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+ -1 means not freezing any parameters.
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ """
+
+ def __init__(
+ self,
+ pretrain_img_size=224,
+ patch_size=4,
+ in_chans=3,
+ embed_dim=96,
+ depths=[2, 2, 6, 2],
+ num_heads=[3, 6, 12, 24],
+ window_size=7,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.0,
+ attn_drop_rate=0.0,
+ drop_path_rate=0.2,
+ norm_layer=nn.LayerNorm,
+ ape=False,
+ patch_norm=True,
+ out_indices=(0, 1, 2), #3),
+ frozen_stages=-1,
+ use_checkpoint=False,
+ ):
+ super().__init__()
+
+ self.pretrain_img_size = pretrain_img_size
+ self.num_layers = len(depths)
+ self.embed_dim = embed_dim
+ self.ape = ape
+ self.patch_norm = patch_norm
+ self.out_indices = out_indices
+ self.frozen_stages = frozen_stages
+
+ # split image into non-overlapping patches
+ self.patch_embed = PatchEmbed(
+ patch_size=patch_size,
+ in_chans=in_chans,
+ embed_dim=embed_dim,
+ norm_layer=norm_layer if self.patch_norm else None,
+ )
+
+ # absolute position embedding
+ if self.ape:
+ pretrain_img_size = to_2tuple(pretrain_img_size)
+ patch_size = to_2tuple(patch_size)
+ patches_resolution = [
+ pretrain_img_size[0] // patch_size[0],
+ pretrain_img_size[1] // patch_size[1],
+ ]
+
+ self.absolute_pos_embed = nn.Parameter(
+ torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])
+ )
+ trunc_normal_(self.absolute_pos_embed, std=0.02)
+
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ # stochastic depth
+ dpr = [
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
+ ] # stochastic depth decay rule
+
+ # build layers
+ self.layers = nn.ModuleList()
+ for i_layer in range(self.num_layers):
+ layer = BasicLayer(
+ dim=int(embed_dim * 2 ** i_layer),
+ depth=depths[i_layer],
+ num_heads=num_heads[i_layer],
+ window_size=window_size,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
+ norm_layer=norm_layer,
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
+ use_checkpoint=use_checkpoint,
+ )
+ self.layers.append(layer)
+
+ num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
+ self.num_features = num_features
+
+ # add a norm layer for each output
+ for i_layer in out_indices:
+ layer = norm_layer(num_features[i_layer])
+ layer_name = f"norm{i_layer}"
+ self.add_module(layer_name, layer)
+
+ self._freeze_stages()
+
+ def _freeze_stages(self):
+ if self.frozen_stages >= 0:
+ self.patch_embed.eval()
+ for param in self.patch_embed.parameters():
+ param.requires_grad = False
+
+ if self.frozen_stages >= 1 and self.ape:
+ self.absolute_pos_embed.requires_grad = False
+
+ if self.frozen_stages >= 2:
+ self.pos_drop.eval()
+ for i in range(0, self.frozen_stages - 1):
+ m = self.layers[i]
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+
+ def init_weights(self, pretrained=None):
+ """Initialize the weights in backbone.
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+
+ def _init_weights(m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ def forward(self, x):
+ """Forward function."""
+ x = self.patch_embed(x)
+
+ Wh, Ww = x.size(2), x.size(3)
+ if self.ape:
+ # interpolate the position embedding to the corresponding size
+ absolute_pos_embed = F.interpolate(
+ self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
+ )
+ x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
+ else:
+ x = x.flatten(2).transpose(1, 2)
+ x = self.pos_drop(x)
+
+ outs = {}
+ for i in range(self.num_layers):
+ layer = self.layers[i]
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
+
+ if i in self.out_indices:
+ norm_layer = getattr(self, f"norm{i}")
+ x_out = norm_layer(x_out)
+
+ out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
+ outs["res{}".format(i + 2)] = out
+
+ return outs
+
+ def train(self, mode=True):
+ """Convert the model into training mode while keep layers freezed."""
+ super(SwinTransformer, self).train(mode)
+ self._freeze_stages()
+
+
+@BACKBONE_REGISTRY.register()
+class D2SwinTransformer(SwinTransformer, Backbone):
+ def __init__(self, cfg, input_shape):
+
+ pretrain_img_size = cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE
+ patch_size = cfg.MODEL.SWIN.PATCH_SIZE
+ in_chans = 3
+ embed_dim = cfg.MODEL.SWIN.EMBED_DIM
+ depths = cfg.MODEL.SWIN.DEPTHS
+ num_heads = cfg.MODEL.SWIN.NUM_HEADS
+ window_size = cfg.MODEL.SWIN.WINDOW_SIZE
+ mlp_ratio = cfg.MODEL.SWIN.MLP_RATIO
+ qkv_bias = cfg.MODEL.SWIN.QKV_BIAS
+ qk_scale = cfg.MODEL.SWIN.QK_SCALE
+ drop_rate = cfg.MODEL.SWIN.DROP_RATE
+ attn_drop_rate = cfg.MODEL.SWIN.ATTN_DROP_RATE
+ drop_path_rate = cfg.MODEL.SWIN.DROP_PATH_RATE
+ norm_layer = nn.LayerNorm
+ ape = cfg.MODEL.SWIN.APE
+ patch_norm = cfg.MODEL.SWIN.PATCH_NORM
+
+ super().__init__(
+ pretrain_img_size,
+ patch_size,
+ in_chans,
+ embed_dim,
+ depths,
+ num_heads,
+ window_size,
+ mlp_ratio,
+ qkv_bias,
+ qk_scale,
+ drop_rate,
+ attn_drop_rate,
+ drop_path_rate,
+ norm_layer,
+ ape,
+ patch_norm,
+ )
+
+ self._out_features = cfg.MODEL.SWIN.OUT_FEATURES
+
+ self._out_feature_strides = {
+ "res2": 4,
+ "res3": 8,
+ "res4": 16,
+ #"res5": 32,
+ }
+ self._out_feature_channels = {
+ "res2": self.num_features[0],
+ "res3": self.num_features[1],
+ "res4": self.num_features[2],
+ #"res5": self.num_features[3],
+ }
+
+ def forward(self, x):
+ """
+ Args:
+ x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
+ Returns:
+ dict[str->Tensor]: names and the corresponding features
+ """
+ assert (
+ x.dim() == 4
+ ), f"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!"
+ outputs = {}
+ y = super().forward(x)
+ for k in y.keys():
+ if k in self._out_features:
+ outputs[k] = y[k]
+ return outputs
+
+ def output_shape(self):
+ return {
+ name: ShapeSpec(
+ channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
+ )
+ for name in self._out_features
+ }
+
+ @property
+ def size_divisibility(self):
+ return 32
diff --git a/cat_seg/modeling/heads/__init__.py b/cat_seg/modeling/heads/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9020c2df23e2af280b7bb168b996ae9eaf312eb8
--- /dev/null
+++ b/cat_seg/modeling/heads/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
diff --git a/cat_seg/modeling/heads/__pycache__/__init__.cpython-38.pyc b/cat_seg/modeling/heads/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..67b3b59ef2fac8fa195d69842e97d66867e5d024
Binary files /dev/null and b/cat_seg/modeling/heads/__pycache__/__init__.cpython-38.pyc differ
diff --git a/cat_seg/modeling/heads/__pycache__/cat_seg_head.cpython-38.pyc b/cat_seg/modeling/heads/__pycache__/cat_seg_head.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..53f569d1d6aaad6b40148cdbf73d1e9aadb22644
Binary files /dev/null and b/cat_seg/modeling/heads/__pycache__/cat_seg_head.cpython-38.pyc differ
diff --git a/cat_seg/modeling/heads/__pycache__/cat_seg_panoptic_head.cpython-38.pyc b/cat_seg/modeling/heads/__pycache__/cat_seg_panoptic_head.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..52ec35b40b35e72afdfc75b068ead0ee89e9b244
Binary files /dev/null and b/cat_seg/modeling/heads/__pycache__/cat_seg_panoptic_head.cpython-38.pyc differ
diff --git a/cat_seg/modeling/heads/__pycache__/pancat_head.cpython-38.pyc b/cat_seg/modeling/heads/__pycache__/pancat_head.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..72a3007512a86297fd4cdbf6882b311407c23220
Binary files /dev/null and b/cat_seg/modeling/heads/__pycache__/pancat_head.cpython-38.pyc differ
diff --git a/cat_seg/modeling/heads/cat_seg_head.py b/cat_seg/modeling/heads/cat_seg_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..1dfb9b30be3727c3b2421fea28bde220536c799f
--- /dev/null
+++ b/cat_seg/modeling/heads/cat_seg_head.py
@@ -0,0 +1,72 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import logging
+from copy import deepcopy
+from typing import Callable, Dict, List, Optional, Tuple, Union
+from einops import rearrange
+
+import fvcore.nn.weight_init as weight_init
+from torch import nn
+from torch.nn import functional as F
+
+from detectron2.config import configurable
+from detectron2.layers import Conv2d, ShapeSpec, get_norm
+from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
+
+from ..transformer.cat_seg_predictor import CATSegPredictor
+
+
+@SEM_SEG_HEADS_REGISTRY.register()
+class CATSegHead(nn.Module):
+
+ @configurable
+ def __init__(
+ self,
+ input_shape: Dict[str, ShapeSpec],
+ *,
+ num_classes: int,
+ ignore_value: int = -1,
+ # extra parameters
+ feature_resolution: list,
+ transformer_predictor: nn.Module,
+ ):
+ """
+ NOTE: this interface is experimental.
+ Args:
+ input_shape: shapes (channels and stride) of the input features
+ num_classes: number of classes to predict
+ pixel_decoder: the pixel decoder module
+ loss_weight: loss weight
+ ignore_value: category id to be ignored during training.
+ transformer_predictor: the transformer decoder that makes prediction
+ transformer_in_feature: input feature name to the transformer_predictor
+ """
+ super().__init__()
+ input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
+ self.in_features = [k for k, v in input_shape]
+ self.ignore_value = ignore_value
+ self.predictor = transformer_predictor
+ self.num_classes = num_classes
+ self.feature_resolution = feature_resolution
+
+ @classmethod
+ def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
+ return {
+ "input_shape": {
+ k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES
+ },
+ "ignore_value": cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
+ "num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES,
+ "feature_resolution": cfg.MODEL.SEM_SEG_HEAD.FEATURE_RESOLUTION,
+ "transformer_predictor": CATSegPredictor(
+ cfg,
+ ),
+ }
+
+ def forward(self, features, guidance_features):
+ """
+ Arguments:
+ img_feats: (B, C, HW)
+ affinity_features: (B, C, )
+ """
+ img_feat = rearrange(features[:, 1:, :], "b (h w) c->b c h w", h=self.feature_resolution[0], w=self.feature_resolution[1])
+ return self.predictor(img_feat, guidance_features)
\ No newline at end of file
diff --git a/cat_seg/modeling/transformer/__init__.py b/cat_seg/modeling/transformer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9020c2df23e2af280b7bb168b996ae9eaf312eb8
--- /dev/null
+++ b/cat_seg/modeling/transformer/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
diff --git a/cat_seg/modeling/transformer/__pycache__/__init__.cpython-38.pyc b/cat_seg/modeling/transformer/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..43a9d1f514fba85fcab93e98418507bf0a0ef8f9
Binary files /dev/null and b/cat_seg/modeling/transformer/__pycache__/__init__.cpython-38.pyc differ
diff --git a/cat_seg/modeling/transformer/__pycache__/cat_seg_panoptic_predictor.cpython-38.pyc b/cat_seg/modeling/transformer/__pycache__/cat_seg_panoptic_predictor.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f15b41fd61ffd23e87d70cd29449de55ef8cbba7
Binary files /dev/null and b/cat_seg/modeling/transformer/__pycache__/cat_seg_panoptic_predictor.cpython-38.pyc differ
diff --git a/cat_seg/modeling/transformer/__pycache__/cat_seg_predictor.cpython-38.pyc b/cat_seg/modeling/transformer/__pycache__/cat_seg_predictor.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b8f1d2aa6d1d5e3f3f90f94a0e09ce31bc2d84e3
Binary files /dev/null and b/cat_seg/modeling/transformer/__pycache__/cat_seg_predictor.cpython-38.pyc differ
diff --git a/cat_seg/modeling/transformer/__pycache__/model.cpython-38.pyc b/cat_seg/modeling/transformer/__pycache__/model.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4d61483252425e261873d3f5ab5cce162288e3ff
Binary files /dev/null and b/cat_seg/modeling/transformer/__pycache__/model.cpython-38.pyc differ
diff --git a/cat_seg/modeling/transformer/__pycache__/pancat_predictor.cpython-38.pyc b/cat_seg/modeling/transformer/__pycache__/pancat_predictor.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b8ea761725881fed4d30655fc2718eea11953b10
Binary files /dev/null and b/cat_seg/modeling/transformer/__pycache__/pancat_predictor.cpython-38.pyc differ
diff --git a/cat_seg/modeling/transformer/cat_seg_predictor.py b/cat_seg/modeling/transformer/cat_seg_predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c0546d830d35a1d397eef6974a8ac88b3875812
--- /dev/null
+++ b/cat_seg/modeling/transformer/cat_seg_predictor.py
@@ -0,0 +1,175 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/detr.py
+# Modified by Jian Ding from: https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py
+import fvcore.nn.weight_init as weight_init
+import torch
+
+from torch import nn
+from torch.nn import functional as F
+
+from detectron2.config import configurable
+from detectron2.layers import Conv2d
+
+from .model import Aggregator
+from cat_seg.third_party import clip
+from cat_seg.third_party import imagenet_templates
+
+import numpy as np
+import open_clip
+class CATSegPredictor(nn.Module):
+ @configurable
+ def __init__(
+ self,
+ *,
+ train_class_json: str,
+ test_class_json: str,
+ clip_pretrained: str,
+ prompt_ensemble_type: str,
+ text_guidance_dim: int,
+ text_guidance_proj_dim: int,
+ appearance_guidance_dim: int,
+ appearance_guidance_proj_dim: int,
+ prompt_depth: int,
+ prompt_length: int,
+ decoder_dims: list,
+ decoder_guidance_dims: list,
+ decoder_guidance_proj_dims: list,
+ num_heads: int,
+ num_layers: tuple,
+ hidden_dims: tuple,
+ pooling_sizes: tuple,
+ feature_resolution: tuple,
+ window_sizes: tuple,
+ attention_type: str,
+ ):
+ """
+ Args:
+
+ """
+ super().__init__()
+
+ import json
+ # use class_texts in train_forward, and test_class_texts in test_forward
+ #with open(train_class_json, 'r') as f_in:
+ # self.class_texts = json.load(f_in)
+ #with open(test_class_json, 'r') as f_in:
+ # self.test_class_texts = json.load(f_in)
+ #assert self.class_texts != None
+ #if self.test_class_texts == None:
+ # self.test_class_texts = self.class_texts
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ self.device = device
+ self.tokenizer = None
+ if clip_pretrained == "ViT-G" or clip_pretrained == "ViT-H":
+ # for OpenCLIP models
+ name, pretrain = ('ViT-H-14', 'laion2b_s32b_b79k') if clip_pretrained == 'ViT-H' else ('ViT-bigG-14', 'laion2b_s39b_b160k')
+ clip_model, _, clip_preprocess = open_clip.create_model_and_transforms(
+ name,
+ pretrained=pretrain,
+ device=device,
+ force_image_size=336,)
+
+ self.tokenizer = open_clip.get_tokenizer(name)
+ else:
+ # for OpenAI models
+ clip_model, clip_preprocess = clip.load(clip_pretrained, device=device, jit=False, prompt_depth=prompt_depth, prompt_length=prompt_length)
+
+ self.prompt_ensemble_type = prompt_ensemble_type
+
+ if self.prompt_ensemble_type == "imagenet_select":
+ prompt_templates = imagenet_templates.IMAGENET_TEMPLATES_SELECT
+ elif self.prompt_ensemble_type == "imagenet":
+ prompt_templates = imagenet_templates.IMAGENET_TEMPLATES
+ elif self.prompt_ensemble_type == "single":
+ prompt_templates = ['A photo of a {} in the scene',]
+ else:
+ raise NotImplementedError
+
+ #self.text_features = self.class_embeddings(self.class_texts, prompt_templates, clip_model).permute(1, 0, 2).float()
+ #self.text_features_test = self.class_embeddings(self.test_class_texts, prompt_templates, clip_model).permute(1, 0, 2).float()
+
+ self.clip_model = clip_model.float()
+ self.clip_preprocess = clip_preprocess
+
+ transformer = Aggregator(
+ text_guidance_dim=text_guidance_dim,
+ text_guidance_proj_dim=text_guidance_proj_dim,
+ appearance_guidance_dim=appearance_guidance_dim,
+ appearance_guidance_proj_dim=appearance_guidance_proj_dim,
+ decoder_dims=decoder_dims,
+ decoder_guidance_dims=decoder_guidance_dims,
+ decoder_guidance_proj_dims=decoder_guidance_proj_dims,
+ num_layers=num_layers,
+ nheads=num_heads,
+ hidden_dim=hidden_dims,
+ pooling_size=pooling_sizes,
+ feature_resolution=feature_resolution,
+ window_size=window_sizes,
+ attention_type=attention_type
+ )
+ self.transformer = transformer
+
+ @classmethod
+ def from_config(cls, cfg):#, in_channels, mask_classification):
+ ret = {}
+
+ ret["train_class_json"] = cfg.MODEL.SEM_SEG_HEAD.TRAIN_CLASS_JSON
+ ret["test_class_json"] = cfg.MODEL.SEM_SEG_HEAD.TEST_CLASS_JSON
+ ret["clip_pretrained"] = cfg.MODEL.SEM_SEG_HEAD.CLIP_PRETRAINED
+ ret["prompt_ensemble_type"] = cfg.MODEL.PROMPT_ENSEMBLE_TYPE
+
+ # Aggregator parameters:
+ ret["text_guidance_dim"] = cfg.MODEL.SEM_SEG_HEAD.TEXT_AFFINITY_DIM
+ ret["text_guidance_proj_dim"] = cfg.MODEL.SEM_SEG_HEAD.TEXT_AFFINITY_PROJ_DIM
+ ret["appearance_guidance_dim"] = cfg.MODEL.SEM_SEG_HEAD.APPEARANCE_AFFINITY_DIM
+ ret["appearance_guidance_proj_dim"] = cfg.MODEL.SEM_SEG_HEAD.APPEARANCE_AFFINITY_PROJ_DIM
+
+ ret["decoder_dims"] = cfg.MODEL.SEM_SEG_HEAD.DECODER_DIMS
+ ret["decoder_guidance_dims"] = cfg.MODEL.SEM_SEG_HEAD.DECODER_AFFINITY_DIMS
+ ret["decoder_guidance_proj_dims"] = cfg.MODEL.SEM_SEG_HEAD.DECODER_AFFINITY_PROJ_DIMS
+
+ ret["prompt_depth"] = cfg.MODEL.SEM_SEG_HEAD.PROMPT_DEPTH
+ ret["prompt_length"] = cfg.MODEL.SEM_SEG_HEAD.PROMPT_LENGTH
+
+ ret["num_layers"] = cfg.MODEL.SEM_SEG_HEAD.NUM_LAYERS
+ ret["num_heads"] = cfg.MODEL.SEM_SEG_HEAD.NUM_HEADS
+ ret["hidden_dims"] = cfg.MODEL.SEM_SEG_HEAD.HIDDEN_DIMS
+ ret["pooling_sizes"] = cfg.MODEL.SEM_SEG_HEAD.POOLING_SIZES
+ ret["feature_resolution"] = cfg.MODEL.SEM_SEG_HEAD.FEATURE_RESOLUTION
+ ret["window_sizes"] = cfg.MODEL.SEM_SEG_HEAD.WINDOW_SIZES
+ ret["attention_type"] = cfg.MODEL.SEM_SEG_HEAD.ATTENTION_TYPE
+
+ return ret
+
+ def forward(self, x, vis_affinity):
+ vis = [vis_affinity[k] for k in vis_affinity.keys()][::-1]
+ text = self.text_features if self.training else self.text_features_test
+ text = text.repeat(x.shape[0], 1, 1, 1)
+ out = self.transformer(x, text, vis)
+ return out
+
+ @torch.no_grad()
+ def class_embeddings(self, classnames, templates, clip_model):
+ zeroshot_weights = []
+ for classname in classnames:
+ if ', ' in classname:
+ classname_splits = classname.split(', ')
+ texts = []
+ for template in templates:
+ for cls_split in classname_splits:
+ texts.append(template.format(cls_split))
+ else:
+ texts = [template.format(classname) for template in templates] # format with class
+ if self.tokenizer is not None:
+ texts = self.tokenizer(texts).to(self.device)
+ else:
+ texts = clip.tokenize(texts).to(self.device)
+ class_embeddings = clip_model.encode_text(texts)
+ class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
+ if len(templates) != class_embeddings.shape[0]:
+ class_embeddings = class_embeddings.reshape(len(templates), -1, class_embeddings.shape[-1]).mean(dim=1)
+ class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
+ class_embedding = class_embeddings
+ zeroshot_weights.append(class_embedding)
+ zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(self.device)
+ return zeroshot_weights
\ No newline at end of file
diff --git a/cat_seg/modeling/transformer/model.py b/cat_seg/modeling/transformer/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..01811affed02540a86bbdecdd097ff4c5fabb71a
--- /dev/null
+++ b/cat_seg/modeling/transformer/model.py
@@ -0,0 +1,650 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from einops import rearrange, repeat
+from einops.layers.torch import Rearrange
+
+from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, to_ntuple, trunc_normal_, _assert
+
+def window_partition(x, window_size: int):
+ """
+ Args:
+ x: (B, H, W, C)
+ window_size (int): window size
+
+ Returns:
+ windows: (num_windows*B, window_size, window_size, C)
+ """
+ B, H, W, C = x.shape
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+ return windows
+
+
+def window_reverse(windows, window_size: int, H: int, W: int):
+ """
+ Args:
+ windows: (num_windows*B, window_size, window_size, C)
+ window_size (int): Window size
+ H (int): Height of image
+ W (int): Width of image
+
+ Returns:
+ x: (B, H, W, C)
+ """
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ return x
+
+
+
+class WindowAttention(nn.Module):
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
+ It supports both of shifted and non-shifted window.
+
+ Args:
+ dim (int): Number of input channels.
+ num_heads (int): Number of attention heads.
+ head_dim (int): Number of channels per head (dim // num_heads if not set)
+ window_size (tuple[int]): The height and width of the window.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+ """
+
+ def __init__(self, dim, appearance_guidance_dim, num_heads, head_dim=None, window_size=7, qkv_bias=True, attn_drop=0., proj_drop=0.):
+
+ super().__init__()
+ self.dim = dim
+ self.window_size = to_2tuple(window_size) # Wh, Ww
+ win_h, win_w = self.window_size
+ self.window_area = win_h * win_w
+ self.num_heads = num_heads
+ head_dim = head_dim or dim // num_heads
+ attn_dim = head_dim * num_heads
+ self.scale = head_dim ** -0.5
+
+ self.q = nn.Linear(dim + appearance_guidance_dim, attn_dim, bias=qkv_bias)
+ self.k = nn.Linear(dim + appearance_guidance_dim, attn_dim, bias=qkv_bias)
+ self.v = nn.Linear(dim, attn_dim, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(attn_dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, x, mask=None):
+ """
+ Args:
+ x: input features with shape of (num_windows*B, N, C)
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
+ """
+ B_, N, C = x.shape
+
+ q = self.q(x).reshape(B_, N, self.num_heads, -1).permute(0, 2, 1, 3)
+ k = self.k(x).reshape(B_, N, self.num_heads, -1).permute(0, 2, 1, 3)
+ v = self.v(x[:, :, :self.dim]).reshape(B_, N, self.num_heads, -1).permute(0, 2, 1, 3)
+
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1))
+
+ if mask is not None:
+ num_win = mask.shape[0]
+ attn = attn.view(B_ // num_win, num_win, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
+ attn = attn.view(-1, self.num_heads, N, N)
+ attn = self.softmax(attn)
+ else:
+ attn = self.softmax(attn)
+
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, -1)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class SwinTransformerBlock(nn.Module):
+ r""" Swin Transformer Block.
+
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resulotion.
+ window_size (int): Window size.
+ num_heads (int): Number of attention heads.
+ head_dim (int): Enforce the number of channels per head
+ shift_size (int): Shift size for SW-MSA.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(
+ self, dim, appearance_guidance_dim, input_resolution, num_heads=4, head_dim=None, window_size=7, shift_size=0,
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.dim = dim
+ self.input_resolution = input_resolution
+ self.window_size = window_size
+ self.shift_size = shift_size
+ self.mlp_ratio = mlp_ratio
+ if min(self.input_resolution) <= self.window_size:
+ # if window size is larger than input resolution, we don't partition windows
+ self.shift_size = 0
+ self.window_size = min(self.input_resolution)
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
+
+ self.norm1 = norm_layer(dim)
+ self.attn = WindowAttention(
+ dim, appearance_guidance_dim=appearance_guidance_dim, num_heads=num_heads, head_dim=head_dim, window_size=to_2tuple(self.window_size),
+ qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
+
+ if self.shift_size > 0:
+ # calculate attention mask for SW-MSA
+ H, W = self.input_resolution
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
+ cnt = 0
+ for h in (
+ slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None)):
+ for w in (
+ slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None)):
+ img_mask[:, h, w, :] = cnt
+ cnt += 1
+ mask_windows = window_partition(img_mask, self.window_size) # num_win, window_size, window_size, 1
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+ else:
+ attn_mask = None
+
+ self.register_buffer("attn_mask", attn_mask)
+
+ def forward(self, x, appearance_guidance):
+ H, W = self.input_resolution
+ B, L, C = x.shape
+ assert L == H * W, "input feature has wrong size"
+
+ shortcut = x
+ x = self.norm1(x)
+ x = x.view(B, H, W, C)
+ if appearance_guidance is not None:
+ appearance_guidance = appearance_guidance.view(B, H, W, -1)
+ x = torch.cat([x, appearance_guidance], dim=-1)
+
+ # cyclic shift
+ if self.shift_size > 0:
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+ else:
+ shifted_x = x
+
+ # partition windows
+ x_windows = window_partition(shifted_x, self.window_size) # num_win*B, window_size, window_size, C
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, x_windows.shape[-1]) # num_win*B, window_size*window_size, C
+
+ # W-MSA/SW-MSA
+ attn_windows = self.attn(x_windows, mask=self.attn_mask) # num_win*B, window_size*window_size, C
+
+ # merge windows
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
+
+ # reverse cyclic shift
+ if self.shift_size > 0:
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+ else:
+ x = shifted_x
+ x = x.view(B, H * W, C)
+
+ # FFN
+ x = shortcut + self.drop_path(x)
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+
+ return x
+
+
+class SwinTransformerBlockWrapper(nn.Module):
+ def __init__(self, dim, appearance_guidance_dim, input_resolution, nheads=4, window_size=5):
+ super().__init__()
+ self.block_1 = SwinTransformerBlock(dim, appearance_guidance_dim, input_resolution, num_heads=nheads, head_dim=None, window_size=window_size, shift_size=0)
+ self.block_2 = SwinTransformerBlock(dim, appearance_guidance_dim, input_resolution, num_heads=nheads, head_dim=None, window_size=window_size, shift_size=window_size // 2)
+ self.guidance_norm = nn.LayerNorm(appearance_guidance_dim) if appearance_guidance_dim > 0 else None
+
+ def forward(self, x, appearance_guidance):
+ """
+ Arguments:
+ x: B C T H W
+ appearance_guidance: B C H W
+ """
+ B, C, T, H, W = x.shape
+ x = rearrange(x, 'B C T H W -> (B T) (H W) C')
+ if appearance_guidance is not None:
+ appearance_guidance = self.guidance_norm(repeat(appearance_guidance, 'B C H W -> (B T) (H W) C', T=T))
+ x = self.block_1(x, appearance_guidance)
+ x = self.block_2(x, appearance_guidance)
+ x = rearrange(x, '(B T) (H W) C -> B C T H W', B=B, T=T, H=H, W=W)
+ return x
+
+
+def elu_feature_map(x):
+ return torch.nn.functional.elu(x) + 1
+
+
+class LinearAttention(nn.Module):
+ def __init__(self, eps=1e-6):
+ super().__init__()
+ self.feature_map = elu_feature_map
+ self.eps = eps
+
+ def forward(self, queries, keys, values):
+ """ Multi-Head linear attention proposed in "Transformers are RNNs"
+ Args:
+ queries: [N, L, H, D]
+ keys: [N, S, H, D]
+ values: [N, S, H, D]
+ q_mask: [N, L]
+ kv_mask: [N, S]
+ Returns:
+ queried_values: (N, L, H, D)
+ """
+ Q = self.feature_map(queries)
+ K = self.feature_map(keys)
+
+ v_length = values.size(1)
+ values = values / v_length # prevent fp16 overflow
+ KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V
+ Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps)
+ queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length
+
+ return queried_values.contiguous()
+
+
+class FullAttention(nn.Module):
+ def __init__(self, use_dropout=False, attention_dropout=0.1):
+ super().__init__()
+ self.use_dropout = use_dropout
+ self.dropout = nn.Dropout(attention_dropout)
+
+ def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
+ """ Multi-head scaled dot-product attention, a.k.a full attention.
+ Args:
+ queries: [N, L, H, D]
+ keys: [N, S, H, D]
+ values: [N, S, H, D]
+ q_mask: [N, L]
+ kv_mask: [N, S]
+ Returns:
+ queried_values: (N, L, H, D)
+ """
+
+ # Compute the unnormalized attention and apply the masks
+ QK = torch.einsum("nlhd,nshd->nlsh", queries, keys)
+ if kv_mask is not None:
+ QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float('-inf'))
+
+ # Compute the attention and the weighted average
+ softmax_temp = 1. / queries.size(3)**.5 # sqrt(D)
+ A = torch.softmax(softmax_temp * QK, dim=2)
+ if self.use_dropout:
+ A = self.dropout(A)
+
+ queried_values = torch.einsum("nlsh,nshd->nlhd", A, values)
+
+ return queried_values.contiguous()
+
+
+class AttentionLayer(nn.Module):
+ def __init__(self, hidden_dim, guidance_dim, nheads=8, attention_type='linear'):
+ super().__init__()
+ self.nheads = nheads
+ self.q = nn.Linear(hidden_dim + guidance_dim, hidden_dim)
+ self.k = nn.Linear(hidden_dim + guidance_dim, hidden_dim)
+ self.v = nn.Linear(hidden_dim, hidden_dim)
+
+ if attention_type == 'linear':
+ self.attention = LinearAttention()
+ elif attention_type == 'full':
+ self.attention = FullAttention()
+ else:
+ raise NotImplementedError
+
+ def forward(self, x, guidance):
+ """
+ Arguments:
+ x: B, L, C
+ guidance: B, L, C
+ """
+ q = self.q(torch.cat([x, guidance], dim=-1)) if guidance is not None else self.q(x)
+ k = self.k(torch.cat([x, guidance], dim=-1)) if guidance is not None else self.k(x)
+ v = self.v(x)
+
+ q = rearrange(q, 'B L (H D) -> B L H D', H=self.nheads)
+ k = rearrange(k, 'B S (H D) -> B S H D', H=self.nheads)
+ v = rearrange(v, 'B S (H D) -> B S H D', H=self.nheads)
+
+ out = self.attention(q, k, v)
+ out = rearrange(out, 'B L H D -> B L (H D)')
+ return out
+
+
+class ClassTransformerLayer(nn.Module):
+ def __init__(self, hidden_dim=64, guidance_dim=64, nheads=8, attention_type='linear', pooling_size=(4, 4)) -> None:
+ super().__init__()
+ self.pool = nn.AvgPool2d(pooling_size)
+ self.attention = AttentionLayer(hidden_dim, guidance_dim, nheads=nheads, attention_type=attention_type)
+ self.MLP = nn.Sequential(
+ nn.Linear(hidden_dim, hidden_dim * 4),
+ nn.ReLU(),
+ nn.Linear(hidden_dim * 4, hidden_dim)
+ )
+
+ self.norm1 = nn.LayerNorm(hidden_dim)
+ self.norm2 = nn.LayerNorm(hidden_dim)
+
+ def pool_features(self, x):
+ """
+ Intermediate pooling layer for computational efficiency.
+ Arguments:
+ x: B, C, T, H, W
+ """
+ B = x.size(0)
+ x = rearrange(x, 'B C T H W -> (B T) C H W')
+ x = self.pool(x)
+ x = rearrange(x, '(B T) C H W -> B C T H W', B=B)
+ return x
+
+ def forward(self, x, guidance):
+ """
+ Arguments:
+ x: B, C, T, H, W
+ guidance: B, T, C
+ """
+ B, _, _, H, W = x.size()
+ x_pool = self.pool_features(x)
+ *_, H_pool, W_pool = x_pool.size()
+
+ x_pool = rearrange(x_pool, 'B C T H W -> (B H W) T C')
+ if guidance is not None:
+ guidance = repeat(guidance, 'B T C -> (B H W) T C', H=H_pool, W=W_pool)
+
+ x_pool = x_pool + self.attention(self.norm1(x_pool), guidance) # Attention
+ x_pool = x_pool + self.MLP(self.norm2(x_pool)) # MLP
+
+ x_pool = rearrange(x_pool, '(B H W) T C -> (B T) C H W', H=H_pool, W=W_pool)
+ x_pool = F.interpolate(x_pool, size=(H, W), mode='bilinear', align_corners=True)
+ x_pool = rearrange(x_pool, '(B T) C H W -> B C T H W', B=B)
+
+ x = x + x_pool # Residual
+ return x
+
+
+def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+ padding=dilation, groups=groups, bias=False, dilation=dilation)
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+ """1x1 convolution"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+ __constants__ = ['downsample']
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
+ base_width=64, dilation=1, norm_layer=None):
+ super(Bottleneck, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ width = int(planes * (base_width / 64.)) * groups
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
+ self.conv1 = conv1x1(inplanes, width)
+ self.bn1 = norm_layer(width)
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
+ self.bn2 = norm_layer(width)
+ self.conv3 = conv1x1(width, planes * self.expansion)
+ self.bn3 = norm_layer(planes * self.expansion)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+
+class AggregatorLayer(nn.Module):
+ def __init__(self, hidden_dim=64, text_guidance_dim=512, appearance_guidance=512, nheads=4, input_resolution=(20, 20), pooling_size=(5, 5), window_size=(10, 10), attention_type='linear') -> None:
+ super().__init__()
+ self.swin_block = SwinTransformerBlockWrapper(hidden_dim, appearance_guidance, input_resolution, nheads, window_size)
+ self.attention = ClassTransformerLayer(hidden_dim, text_guidance_dim, nheads=nheads, attention_type=attention_type, pooling_size=pooling_size)
+
+
+ def forward(self, x, appearance_guidance, text_guidance):
+ """
+ Arguments:
+ x: B C T H W
+ """
+ x = self.swin_block(x, appearance_guidance)
+ x = self.attention(x, text_guidance)
+ return x
+
+
+class AggregatorResNetLayer(nn.Module):
+ def __init__(self, hidden_dim=64, appearance_guidance=512) -> None:
+ super().__init__()
+ self.conv_linear = nn.Conv2d(hidden_dim + appearance_guidance, hidden_dim, kernel_size=1, stride=1)
+ self.conv_layer = Bottleneck(hidden_dim, hidden_dim // 4)
+
+
+ def forward(self, x, appearance_guidance):
+ """
+ Arguments:
+ x: B C T H W
+ """
+ B, T = x.size(0), x.size(2)
+ x = rearrange(x, 'B C T H W -> (B T) C H W')
+ appearance_guidance = repeat(appearance_guidance, 'B C H W -> (B T) C H W', T=T)
+
+ x = self.conv_linear(torch.cat([x, appearance_guidance], dim=1))
+ x = self.conv_layer(x)
+ x = rearrange(x, '(B T) C H W -> B C T H W', B=B)
+ return x
+
+
+class DoubleConv(nn.Module):
+ """(convolution => [GN] => ReLU) * 2"""
+
+ def __init__(self, in_channels, out_channels, mid_channels=None):
+ super().__init__()
+ if not mid_channels:
+ mid_channels = out_channels
+ self.double_conv = nn.Sequential(
+ nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
+ nn.GroupNorm(mid_channels // 16, mid_channels),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
+ nn.GroupNorm(mid_channels // 16, mid_channels),
+ nn.ReLU(inplace=True)
+ )
+
+ def forward(self, x):
+ return self.double_conv(x)
+
+
+class Up(nn.Module):
+ """Upscaling then double conv"""
+
+ def __init__(self, in_channels, out_channels, guidance_channels):
+ super().__init__()
+
+ self.up = nn.ConvTranspose2d(in_channels, in_channels - guidance_channels, kernel_size=2, stride=2)
+ self.conv = DoubleConv(in_channels, out_channels)
+
+ def forward(self, x, guidance=None):
+ x = self.up(x)
+ if guidance is not None:
+ T = x.size(0) // guidance.size(0)
+ guidance = repeat(guidance, "B C H W -> (B T) C H W", T=T)
+ x = torch.cat([x, guidance], dim=1)
+ return self.conv(x)
+
+
+class Aggregator(nn.Module):
+ def __init__(self,
+ text_guidance_dim=512,
+ text_guidance_proj_dim=128,
+ appearance_guidance_dim=512,
+ appearance_guidance_proj_dim=128,
+ decoder_dims = (64, 32),
+ decoder_guidance_dims=(256, 128),
+ decoder_guidance_proj_dims=(32, 16),
+ num_layers=4,
+ nheads=4,
+ hidden_dim=128,
+ pooling_size=(6, 6),
+ feature_resolution=(24, 24),
+ window_size=12,
+ attention_type='linear',
+ prompt_channel=80,
+ ) -> None:
+ super().__init__()
+ self.num_layers = num_layers
+ self.hidden_dim = hidden_dim
+
+ self.layers = nn.ModuleList([
+ AggregatorLayer(
+ hidden_dim=hidden_dim, text_guidance_dim=text_guidance_proj_dim, appearance_guidance=appearance_guidance_proj_dim,
+ nheads=nheads, input_resolution=feature_resolution, pooling_size=pooling_size, window_size=window_size, attention_type=attention_type
+ ) for _ in range(num_layers)
+ ])
+
+ self.conv1 = nn.Conv2d(prompt_channel, hidden_dim, kernel_size=7, stride=1, padding=3)
+
+ self.guidance_projection = nn.Sequential(
+ nn.Conv2d(appearance_guidance_dim, appearance_guidance_proj_dim, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(),
+ ) if appearance_guidance_dim > 0 else None
+
+ self.text_guidance_projection = nn.Sequential(
+ nn.Linear(text_guidance_dim, text_guidance_proj_dim),
+ nn.ReLU(),
+ ) if text_guidance_dim > 0 else None
+
+ self.decoder_guidance_projection = nn.ModuleList([
+ nn.Sequential(
+ nn.Conv2d(d, dp, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(),
+ ) for d, dp in zip(decoder_guidance_dims, decoder_guidance_proj_dims)
+ ]) if decoder_guidance_dims[0] > 0 else None
+
+ self.decoder1 = Up(hidden_dim, decoder_dims[0], decoder_guidance_proj_dims[0])
+ self.decoder2 = Up(decoder_dims[0], decoder_dims[1], decoder_guidance_proj_dims[1])
+ self.head = nn.Conv2d(decoder_dims[1], 1, kernel_size=3, stride=1, padding=1)
+
+ def feature_map(self, img_feats, text_feats):
+ img_feats = F.normalize(img_feats, dim=1) # B C H W
+ img_feats = repeat(img_feats, "B C H W -> B C T H W", T=text_feats.shape[1])
+ text_feats = F.normalize(text_feats, dim=-1) # B T P C
+ text_feats = text_feats.mean(dim=-2)
+ text_feats = F.normalize(text_feats, dim=-1) # B T C
+ text_feats = repeat(text_feats, "B T C -> B C T H W", H=img_feats.shape[-2], W=img_feats.shape[-1])
+ return torch.cat((img_feats, text_feats), dim=1) # B 2C T H W
+
+ def correlation(self, img_feats, text_feats):
+ img_feats = F.normalize(img_feats, dim=1) # B C H W
+ text_feats = F.normalize(text_feats, dim=-1) # B T P C
+ corr = torch.einsum('bchw, btpc -> bpthw', img_feats, text_feats)
+ return corr
+
+ def corr_embed(self, x):
+ B = x.shape[0]
+ corr_embed = rearrange(x, 'B P T H W -> (B T) P H W')
+ corr_embed = self.conv1(corr_embed)
+ corr_embed = rearrange(corr_embed, '(B T) C H W -> B C T H W', B=B)
+ return corr_embed
+
+ def corr_projection(self, x, proj):
+ corr_embed = rearrange(x, 'B C T H W -> B T H W C')
+ corr_embed = proj(corr_embed)
+ corr_embed = rearrange(corr_embed, 'B T H W C -> B C T H W')
+ return corr_embed
+
+ def upsample(self, x):
+ B = x.shape[0]
+ corr_embed = rearrange(x, 'B C T H W -> (B T) C H W')
+ corr_embed = F.interpolate(corr_embed, scale_factor=2, mode='bilinear', align_corners=True)
+ corr_embed = rearrange(corr_embed, '(B T) C H W -> B C T H W', B=B)
+ return corr_embed
+
+ def conv_decoder(self, x, guidance):
+ B = x.shape[0]
+ corr_embed = rearrange(x, 'B C T H W -> (B T) C H W')
+ corr_embed = self.decoder1(corr_embed, guidance[0])
+ corr_embed = self.decoder2(corr_embed, guidance[1])
+ corr_embed = self.head(corr_embed)
+ corr_embed = rearrange(corr_embed, '(B T) () H W -> B T H W', B=B)
+ return corr_embed
+
+ def forward(self, img_feats, text_feats, appearance_guidance):
+ """
+ Arguments:
+ img_feats: (B, C, H, W)
+ text_feats: (B, T, P, C)
+ apperance_guidance: tuple of (B, C, H, W)
+ """
+ corr = self.correlation(img_feats, text_feats)
+ #corr = self.feature_map(img_feats, text_feats)
+ corr_embed = self.corr_embed(corr)
+
+ projected_guidance, projected_text_guidance, projected_decoder_guidance = None, None, [None, None]
+ if self.guidance_projection is not None:
+ projected_guidance = self.guidance_projection(appearance_guidance[0])
+ if self.decoder_guidance_projection is not None:
+ projected_decoder_guidance = [proj(g) for proj, g in zip(self.decoder_guidance_projection, appearance_guidance[1:])]
+
+ if self.text_guidance_projection is not None:
+ text_feats = text_feats.mean(dim=-2)
+ text_feats = text_feats / text_feats.norm(dim=-1, keepdim=True)
+ projected_text_guidance = self.text_guidance_projection(text_feats)
+
+ for layer in self.layers:
+ corr_embed = layer(corr_embed, projected_guidance, projected_text_guidance)
+
+ logit = self.conv_decoder(corr_embed, projected_decoder_guidance)
+
+ return logit
diff --git a/cat_seg/test_time_augmentation.py b/cat_seg/test_time_augmentation.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d250b6bb7792b54ddeaaab62cc6c170d74d3bb9
--- /dev/null
+++ b/cat_seg/test_time_augmentation.py
@@ -0,0 +1,113 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import copy
+from itertools import count
+
+import numpy as np
+import torch
+from fvcore.transforms import HFlipTransform
+from torch import nn
+from torch.nn.parallel import DistributedDataParallel
+
+from detectron2.data.detection_utils import read_image
+from detectron2.modeling import DatasetMapperTTA
+
+__all__ = [
+ "SemanticSegmentorWithTTA",
+]
+
+
+class SemanticSegmentorWithTTA(nn.Module):
+ """
+ A SemanticSegmentor with test-time augmentation enabled.
+ Its :meth:`__call__` method has the same interface as :meth:`SemanticSegmentor.forward`.
+ """
+
+ def __init__(self, cfg, model, tta_mapper=None, batch_size=1):
+ """
+ Args:
+ cfg (CfgNode):
+ model (SemanticSegmentor): a SemanticSegmentor to apply TTA on.
+ tta_mapper (callable): takes a dataset dict and returns a list of
+ augmented versions of the dataset dict. Defaults to
+ `DatasetMapperTTA(cfg)`.
+ batch_size (int): batch the augmented images into this batch size for inference.
+ """
+ super().__init__()
+ if isinstance(model, DistributedDataParallel):
+ model = model.module
+ self.cfg = cfg.clone()
+
+ self.model = model
+
+ if tta_mapper is None:
+ tta_mapper = DatasetMapperTTA(cfg)
+ self.tta_mapper = tta_mapper
+ self.batch_size = batch_size
+
+ def _batch_inference(self, batched_inputs):
+ """
+ Execute inference on a list of inputs,
+ using batch size = self.batch_size, instead of the length of the list.
+ Inputs & outputs have the same format as :meth:`SemanticSegmentor.forward`
+ """
+ outputs = []
+ inputs = []
+ for idx, input in zip(count(), batched_inputs):
+ inputs.append(input)
+ if len(inputs) == self.batch_size or idx == len(batched_inputs) - 1:
+ with torch.no_grad():
+ outputs.extend(self.model(inputs))
+ inputs = []
+ return outputs
+
+ def __call__(self, batched_inputs):
+ """
+ Same input/output format as :meth:`SemanticSegmentor.forward`
+ """
+
+ def _maybe_read_image(dataset_dict):
+ ret = copy.copy(dataset_dict)
+ if "image" not in ret:
+ image = read_image(ret.pop("file_name"), self.model.input_format)
+ image = torch.from_numpy(np.ascontiguousarray(image.transpose(2, 0, 1))) # CHW
+ ret["image"] = image
+ if "height" not in ret and "width" not in ret:
+ ret["height"] = image.shape[1]
+ ret["width"] = image.shape[2]
+ return ret
+
+ return [self._inference_one_image(_maybe_read_image(x)) for x in batched_inputs]
+
+ def _inference_one_image(self, input):
+ """
+ Args:
+ input (dict): one dataset dict with "image" field being a CHW tensor
+ Returns:
+ dict: one output dict
+ """
+ augmented_inputs, tfms = self._get_augmented_inputs(input)
+ # 1: forward with all augmented images
+ outputs = self._batch_inference(augmented_inputs)
+ # Delete now useless variables to avoid being out of memory
+ del augmented_inputs
+ # 2: merge the results
+ # handle flip specially
+ new_outputs = []
+ for output, tfm in zip(outputs, tfms):
+ if any(isinstance(t, HFlipTransform) for t in tfm.transforms):
+ new_outputs.append(output.pop("sem_seg").flip(dims=[2]))
+ else:
+ new_outputs.append(output.pop("sem_seg"))
+ del outputs
+ # to avoid OOM with torch.stack
+ final_predictions = new_outputs[0]
+ for i in range(1, len(new_outputs)):
+ final_predictions += new_outputs[i]
+ final_predictions = final_predictions / len(new_outputs)
+ del new_outputs
+ return {"sem_seg": final_predictions}
+
+ def _get_augmented_inputs(self, input):
+ augmented_inputs = self.tta_mapper(input)
+ tfms = [x.pop("transforms") for x in augmented_inputs]
+ return augmented_inputs, tfms
diff --git a/cat_seg/third_party/__pycache__/clip.cpython-38.pyc b/cat_seg/third_party/__pycache__/clip.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7be311758d35ab88062241c96967527f82ae45c4
Binary files /dev/null and b/cat_seg/third_party/__pycache__/clip.cpython-38.pyc differ
diff --git a/cat_seg/third_party/__pycache__/imagenet_templates.cpython-38.pyc b/cat_seg/third_party/__pycache__/imagenet_templates.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..04939b1522601493308fab31d13589a0ca873b8c
Binary files /dev/null and b/cat_seg/third_party/__pycache__/imagenet_templates.cpython-38.pyc differ
diff --git a/cat_seg/third_party/__pycache__/model_vpt.cpython-38.pyc b/cat_seg/third_party/__pycache__/model_vpt.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7c4e8aa43d23b4bda0f86131d9138236b6f12a45
Binary files /dev/null and b/cat_seg/third_party/__pycache__/model_vpt.cpython-38.pyc differ
diff --git a/cat_seg/third_party/__pycache__/simple_tokenizer.cpython-38.pyc b/cat_seg/third_party/__pycache__/simple_tokenizer.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f3c49d74378f05072fe337fdba8fd695b3a96d48
Binary files /dev/null and b/cat_seg/third_party/__pycache__/simple_tokenizer.cpython-38.pyc differ
diff --git a/cat_seg/third_party/bpe_simple_vocab_16e6.txt.gz b/cat_seg/third_party/bpe_simple_vocab_16e6.txt.gz
new file mode 100644
index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113
--- /dev/null
+++ b/cat_seg/third_party/bpe_simple_vocab_16e6.txt.gz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
+size 1356917
diff --git a/cat_seg/third_party/clip.py b/cat_seg/third_party/clip.py
new file mode 100644
index 0000000000000000000000000000000000000000..916eb2745a411064f519592414150c408beb7204
--- /dev/null
+++ b/cat_seg/third_party/clip.py
@@ -0,0 +1,211 @@
+import hashlib
+import os
+import urllib
+import warnings
+from typing import Union, List
+
+import torch
+from PIL import Image
+from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
+from tqdm import tqdm
+
+#from .model import build_model
+from .model_vpt import build_model
+from .simple_tokenizer import SimpleTokenizer as _Tokenizer
+
+__all__ = ["available_models", "load", "tokenize"]
+_tokenizer = _Tokenizer()
+
+_MODELS = {
+ "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
+ "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
+ "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
+ "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
+ "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
+ "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
+ "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
+ "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
+ "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
+}
+
+
+def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
+ os.makedirs(root, exist_ok=True)
+ filename = os.path.basename(url)
+
+ expected_sha256 = url.split("/")[-2]
+ download_target = os.path.join(root, filename)
+
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
+
+ if os.path.isfile(download_target):
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
+ return download_target
+ else:
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
+
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
+ with tqdm(total=int(source.info().get("Content-Length")), ncols=80) as loop:
+ while True:
+ buffer = source.read(8192)
+ if not buffer:
+ break
+
+ output.write(buffer)
+ loop.update(len(buffer))
+
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
+ raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
+
+ return download_target
+
+
+def available_models():
+ return list(_MODELS.keys())
+
+
+def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True, prompt_depth=0, prompt_length=0):
+ if name not in _MODELS:
+ raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
+
+ model_path = _download(_MODELS[name])
+ model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
+ n_px = model.input_resolution.item()
+
+ transform = Compose([
+ Resize(n_px, interpolation=Image.BICUBIC),
+ CenterCrop(n_px),
+ lambda image: image.convert("RGB"),
+ ToTensor(),
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
+ ])
+
+ if not jit:
+ model = build_model(model.state_dict(), prompt_depth, prompt_length).to(device)
+ return model, transform
+
+ # patch the device names
+ device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
+ device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
+
+ def patch_device(module):
+ graphs = [module.graph] if hasattr(module, "graph") else []
+ if hasattr(module, "forward1"):
+ graphs.append(module.forward1.graph)
+
+ for graph in graphs:
+ for node in graph.findAllNodes("prim::Constant"):
+ if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
+ node.copyAttributes(device_node)
+
+ model.apply(patch_device)
+ patch_device(model.encode_image)
+ patch_device(model.encode_text)
+
+ # patch dtype to float32 on CPU
+ if device == "cpu":
+ float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
+ float_node = float_input.node()
+
+ def patch_float(module):
+ graphs = [module.graph] if hasattr(module, "graph") else []
+ if hasattr(module, "forward1"):
+ graphs.append(module.forward1.graph)
+
+ for graph in graphs:
+ for node in graph.findAllNodes("aten::to"):
+ inputs = list(node.inputs())
+ for i in [1, 2]: # dtype can be the second or third argument to aten::to()
+ if inputs[i].node()["value"] == 5:
+ inputs[i].node().copyAttributes(float_node)
+
+ model.apply(patch_float)
+ patch_float(model.encode_image)
+ patch_float(model.encode_text)
+
+ model.float()
+
+ return model, transform
+
+
+def load_custom(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True, n_px=224):
+ if name not in _MODELS:
+ raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
+
+ model_path = _download(_MODELS[name])
+ model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
+ # n_px = model.input_resolution.item()
+
+ transform = Compose([
+ Resize(n_px, interpolation=Image.BICUBIC),
+ CenterCrop(n_px),
+ lambda image: image.convert("RGB"),
+ ToTensor(),
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
+ ])
+
+ if not jit:
+ model = build_model(model.state_dict()).to(device)
+ return model, transform
+
+ # patch the device names
+ device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
+ device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
+
+ def patch_device(module):
+ graphs = [module.graph] if hasattr(module, "graph") else []
+ if hasattr(module, "forward1"):
+ graphs.append(module.forward1.graph)
+
+ for graph in graphs:
+ for node in graph.findAllNodes("prim::Constant"):
+ if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
+ node.copyAttributes(device_node)
+
+ model.apply(patch_device)
+ patch_device(model.encode_image)
+ patch_device(model.encode_text)
+
+ # patch dtype to float32 on CPU
+ if device == "cpu":
+ float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
+ float_node = float_input.node()
+
+ def patch_float(module):
+ graphs = [module.graph] if hasattr(module, "graph") else []
+ if hasattr(module, "forward1"):
+ graphs.append(module.forward1.graph)
+
+ for graph in graphs:
+ for node in graph.findAllNodes("aten::to"):
+ inputs = list(node.inputs())
+ for i in [1, 2]: # dtype can be the second or third argument to aten::to()
+ if inputs[i].node()["value"] == 5:
+ inputs[i].node().copyAttributes(float_node)
+
+ model.apply(patch_float)
+ patch_float(model.encode_image)
+ patch_float(model.encode_text)
+
+ model.float()
+
+ return model, transform
+
+def tokenize(texts: Union[str, List[str]], context_length: int = 77):
+ if isinstance(texts, str):
+ texts = [texts]
+
+ sot_token = _tokenizer.encoder["<|startoftext|>"]
+ eot_token = _tokenizer.encoder["<|endoftext|>"]
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
+
+ for i, tokens in enumerate(all_tokens):
+ if len(tokens) > context_length:
+ raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
+ result[i, :len(tokens)] = torch.tensor(tokens)
+
+ return result
diff --git a/cat_seg/third_party/imagenet_templates.py b/cat_seg/third_party/imagenet_templates.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7f9355568443efa458d0e4da58acd31a2c34002
--- /dev/null
+++ b/cat_seg/third_party/imagenet_templates.py
@@ -0,0 +1,445 @@
+# source: https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb
+
+IMAGENET_TEMPLATES = [
+ 'a bad photo of a {}.',
+ 'a photo of many {}.',
+ 'a sculpture of a {}.',
+ 'a photo of the hard to see {}.',
+ 'a low resolution photo of the {}.',
+ 'a rendering of a {}.',
+ 'graffiti of a {}.',
+ 'a bad photo of the {}.',
+ 'a cropped photo of the {}.',
+ 'a tattoo of a {}.',
+ 'the embroidered {}.',
+ 'a photo of a hard to see {}.',
+ 'a bright photo of a {}.',
+ 'a photo of a clean {}.',
+ 'a photo of a dirty {}.',
+ 'a dark photo of the {}.',
+ 'a drawing of a {}.',
+ 'a photo of my {}.',
+ 'the plastic {}.',
+ 'a photo of the cool {}.',
+ 'a close-up photo of a {}.',
+ 'a black and white photo of the {}.',
+ 'a painting of the {}.',
+ 'a painting of a {}.',
+ 'a pixelated photo of the {}.',
+ 'a sculpture of the {}.',
+ 'a bright photo of the {}.',
+ 'a cropped photo of a {}.',
+ 'a plastic {}.',
+ 'a photo of the dirty {}.',
+ 'a jpeg corrupted photo of a {}.',
+ 'a blurry photo of the {}.',
+ 'a photo of the {}.',
+ 'a good photo of the {}.',
+ 'a rendering of the {}.',
+ 'a {} in a video game.',
+ 'a photo of one {}.',
+ 'a doodle of a {}.',
+ 'a close-up photo of the {}.',
+ 'a photo of a {}.',
+ 'the origami {}.',
+ 'the {} in a video game.',
+ 'a sketch of a {}.',
+ 'a doodle of the {}.',
+ 'a origami {}.',
+ 'a low resolution photo of a {}.',
+ 'the toy {}.',
+ 'a rendition of the {}.',
+ 'a photo of the clean {}.',
+ 'a photo of a large {}.',
+ 'a rendition of a {}.',
+ 'a photo of a nice {}.',
+ 'a photo of a weird {}.',
+ 'a blurry photo of a {}.',
+ 'a cartoon {}.',
+ 'art of a {}.',
+ 'a sketch of the {}.',
+ 'a embroidered {}.',
+ 'a pixelated photo of a {}.',
+ 'itap of the {}.',
+ 'a jpeg corrupted photo of the {}.',
+ 'a good photo of a {}.',
+ 'a plushie {}.',
+ 'a photo of the nice {}.',
+ 'a photo of the small {}.',
+ 'a photo of the weird {}.',
+ 'the cartoon {}.',
+ 'art of the {}.',
+ 'a drawing of the {}.',
+ 'a photo of the large {}.',
+ 'a black and white photo of a {}.',
+ 'the plushie {}.',
+ 'a dark photo of a {}.',
+ 'itap of a {}.',
+ 'graffiti of the {}.',
+ 'a toy {}.',
+ 'itap of my {}.',
+ 'a photo of a cool {}.',
+ 'a photo of a small {}.',
+ 'a tattoo of the {}.',
+ # 'A photo of a {} in the scene.',
+]
+
+# v1: 59.0875
+IMAGENET_TEMPLATES_SELECT = [
+ 'itap of a {}.',
+ 'a bad photo of the {}.',
+ 'a origami {}.',
+ 'a photo of the large {}.',
+ 'a {} in a video game.',
+ 'art of the {}.',
+ 'a photo of the small {}.',
+ 'A photo of a {} in the scene',
+]
+
+# v2: 58.2584
+# IMAGENET_TEMPLATES_SELECT = [
+# 'itap of a {}',
+# 'a bad photo of the {}',
+# 'a origami {}',
+# 'a photo of the large {}',
+# 'art of the {}',
+# 'a photo of the small {}',
+# 'A photo of a {} in the scene',
+# ]
+
+# v3: 59.1006
+# IMAGENET_TEMPLATES_SELECT = [
+# 'itap of a {}.',
+# 'a bad photo of the {}.',
+# 'a origami {}.',
+# 'a photo of the large {}.',
+# 'art of the {}.',
+# 'a photo of the small {}.',
+# 'a cropped photo of a {}.',
+# 'A photo of a {} in the scene',
+# 'itap of a {} in the scene',
+# 'a bad photo of the {} in the scene',
+# 'a origami {} in the scene',
+# 'a photo of the large {} in the scene',
+# 'art of the {} in the scene',
+# 'a photo of the small {} in the scene',
+# 'a cropped photo of a {} in the scene',
+# ]
+
+# v4: 59.8659
+# IMAGENET_TEMPLATES_SELECT = [
+# 'a bad photo of the {}.',
+# 'a photo of the large {}.',
+# 'art of the {}.',
+# 'a photo of the small {}.',
+# 'a cropped photo of a {}.',
+# 'A photo of a {} in the scene',
+# 'a bad photo of the {} in the scene',
+# 'a photo of the large {} in the scene',
+# 'art of the {} in the scene',
+# 'a photo of the small {} in the scene',
+# 'a cropped photo of a {} in the scene',
+# 'a photo of a masked {} in the scene',
+# ]
+
+# v5: 59.9346
+# IMAGENET_TEMPLATES_SELECT = [
+# 'a bad photo of the {}.',
+# 'a photo of the large {}.',
+# 'art of the {}.',
+# 'a photo of the small {}.',
+# 'a cropped photo of a {}.',
+# 'This is a photo of a {}',
+# 'This is a photo of a small {}',
+# 'This is a photo of a medium {}',
+# 'This is a photo of a large {}',
+# 'A photo of a {} in the scene',
+# 'a bad photo of the {} in the scene',
+# 'a photo of the large {} in the scene',
+# 'art of the {} in the scene',
+# 'a photo of the small {} in the scene',
+# 'a cropped photo of a {} in the scene',
+# 'a photo of a masked {} in the scene',
+# 'There is a {} in the scene',
+# 'There is the {} in the scene',
+# 'This is a {} in the scene',
+# 'This is the {} in the scene',
+# 'This is one {} in the scene',
+# ]
+
+# v6: 60.6611
+# IMAGENET_TEMPLATES_SELECT = [
+# 'a bad photo of the {}.',
+# 'a photo of the large {}.',
+# 'art of the {}.',
+# 'a photo of the small {}.',
+# 'a cropped photo of a {}.',
+# 'This is a photo of a {}',
+# 'This is a photo of a small {}',
+# 'This is a photo of a medium {}',
+# 'This is a photo of a large {}',
+# 'A photo of a {} in the scene',
+# 'a bad photo of the {} in the scene',
+# 'a photo of the large {} in the scene',
+# 'art of the {} in the scene',
+# 'a photo of the small {} in the scene',
+# 'a cropped photo of a {} in the scene',
+# 'a photo of a masked {} in the scene',
+# 'There is a {} in the scene',
+# 'There is the {} in the scene',
+# 'This is a {} in the scene',
+# 'This is the {} in the scene',
+# 'This is one {} in the scene',
+#
+# 'There is a masked {} in the scene',
+# 'There is the masked {} in the scene',
+# 'This is a masked {} in the scene',
+# 'This is the masked {} in the scene',
+# 'This is one masked {} in the scene',
+# ]
+
+# v7: 60.4529
+# IMAGENET_TEMPLATES_SELECT = [
+# 'a bad photo of the {}.',
+# 'a photo of the large {}.',
+# 'art of the {}.',
+# 'a photo of the small {}.',
+# 'a cropped photo of a {}.',
+# 'This is a photo of a {}',
+# 'This is a photo of a small {}',
+# 'This is a photo of a medium {}',
+# 'This is a photo of a large {}',
+# 'A photo of a {} in the scene',
+# 'a bad photo of the {} in the scene',
+# 'a photo of the large {} in the scene',
+# 'art of the {} in the scene',
+# 'a photo of the small {} in the scene',
+# 'a cropped photo of a {} in the scene',
+# 'a photo of a masked {} in the scene',
+# 'There is a {} in the scene',
+# 'There is the {} in the scene',
+# 'This is a {} in the scene',
+# 'This is the {} in the scene',
+# 'This is one {} in the scene',
+#
+# 'There is a cropped {} in the scene',
+# 'There is the cropped {} in the scene',
+# 'This is a cropped {} in the scene',
+# 'This is the cropped {} in the scene',
+# 'This is one cropped {} in the scene',
+#
+# 'a cropped photo of the {}',
+# 'a cropped photo of a {}',
+# 'a cropped photo of one {}',
+#
+# 'There is a masked {} in the scene',
+# 'There is the masked {} in the scene',
+# 'This is a masked {} in the scene',
+# 'This is the masked {} in the scene',
+# 'This is one masked {} in the scene',
+# ]
+
+# v8: 60.7057
+# IMAGENET_TEMPLATES_SELECT = [
+# 'a bad photo of the {}.',
+# 'a photo of the large {}.',
+# 'a photo of the small {}.',
+# 'a cropped photo of a {}.',
+# 'This is a photo of a {}',
+# 'This is a photo of a small {}',
+# 'This is a photo of a medium {}',
+# 'This is a photo of a large {}',
+#
+# 'This is a masked photo of a {}',
+# 'This is a masked photo of a small {}',
+# 'This is a masked photo of a medium {}',
+# 'This is a masked photo of a large {}',
+#
+# 'A photo of a {} in the scene',
+# 'a bad photo of the {} in the scene',
+# 'a photo of the large {} in the scene',
+# 'a photo of the small {} in the scene',
+# 'a cropped photo of a {} in the scene',
+# 'a photo of a masked {} in the scene',
+# 'There is a {} in the scene',
+# 'There is the {} in the scene',
+# 'This is a {} in the scene',
+# 'This is the {} in the scene',
+# 'This is one {} in the scene',
+#
+# 'There is a masked {} in the scene',
+# 'There is the masked {} in the scene',
+# 'This is a masked {} in the scene',
+# 'This is the masked {} in the scene',
+# 'This is one masked {} in the scene',
+# ]
+
+# v9: 60.8775
+# IMAGENET_TEMPLATES_SELECT = [
+# 'a bad photo of the {}.',
+# 'a photo of the large {}.',
+# 'a photo of the small {}.',
+# 'a cropped photo of a {}.',
+# 'This is a photo of a {}',
+# 'This is a photo of a small {}',
+# 'This is a photo of a medium {}',
+# 'This is a photo of a large {}',
+#
+# 'This is a masked photo of a {}',
+# 'This is a masked photo of a small {}',
+# 'This is a masked photo of a medium {}',
+# 'This is a masked photo of a large {}',
+#
+# 'This is a cropped photo of a {}',
+# 'This is a cropped photo of a small {}',
+# 'This is a cropped photo of a medium {}',
+# 'This is a cropped photo of a large {}',
+#
+# 'A photo of a {} in the scene',
+# 'a bad photo of the {} in the scene',
+# 'a photo of the large {} in the scene',
+# 'a photo of the small {} in the scene',
+# 'a cropped photo of a {} in the scene',
+# 'a photo of a masked {} in the scene',
+# 'There is a {} in the scene',
+# 'There is the {} in the scene',
+# 'This is a {} in the scene',
+# 'This is the {} in the scene',
+# 'This is one {} in the scene',
+#
+# 'There is a masked {} in the scene',
+# 'There is the masked {} in the scene',
+# 'This is a masked {} in the scene',
+# 'This is the masked {} in the scene',
+# 'This is one masked {} in the scene',
+# ]
+
+# v9
+IMAGENET_TEMPLATES_SELECT_CLIP = [
+ 'a bad photo of the {}.',
+ 'a photo of the large {}.',
+ 'a photo of the small {}.',
+ 'a cropped photo of a {}.',
+ 'This is a photo of a {}',
+ 'This is a photo of a small {}',
+ 'This is a photo of a medium {}',
+ 'This is a photo of a large {}',
+
+ 'This is a masked photo of a {}',
+ 'This is a masked photo of a small {}',
+ 'This is a masked photo of a medium {}',
+ 'This is a masked photo of a large {}',
+
+ 'This is a cropped photo of a {}',
+ 'This is a cropped photo of a small {}',
+ 'This is a cropped photo of a medium {}',
+ 'This is a cropped photo of a large {}',
+
+ 'A photo of a {} in the scene',
+ 'a bad photo of the {} in the scene',
+ 'a photo of the large {} in the scene',
+ 'a photo of the small {} in the scene',
+ 'a cropped photo of a {} in the scene',
+ 'a photo of a masked {} in the scene',
+ 'There is a {} in the scene',
+ 'There is the {} in the scene',
+ 'This is a {} in the scene',
+ 'This is the {} in the scene',
+ 'This is one {} in the scene',
+
+ 'There is a masked {} in the scene',
+ 'There is the masked {} in the scene',
+ 'This is a masked {} in the scene',
+ 'This is the masked {} in the scene',
+ 'This is one masked {} in the scene',
+]
+
+# v10, for comparison
+# IMAGENET_TEMPLATES_SELECT_CLIP = [
+# 'a photo of a {}.',
+#
+# 'This is a photo of a {}',
+# 'This is a photo of a small {}',
+# 'This is a photo of a medium {}',
+# 'This is a photo of a large {}',
+#
+# 'This is a photo of a {}',
+# 'This is a photo of a small {}',
+# 'This is a photo of a medium {}',
+# 'This is a photo of a large {}',
+#
+# 'a photo of a {} in the scene',
+# 'a photo of a {} in the scene',
+#
+# 'There is a {} in the scene',
+# 'There is the {} in the scene',
+# 'This is a {} in the scene',
+# 'This is the {} in the scene',
+# 'This is one {} in the scene',
+# ]
+
+ViLD_templates = [
+'There is {article} {category} in the scene.',
+'There is the {category} in the scene.',
+'a photo of {article} {category} in the scene.',
+'a photo of the {category} in the scene.',
+'a photo of one {category} in the scene.',
+'itap of {article} {category}.',
+'itap of my {category}.',
+'itap of the {category}.',
+'a photo of {article} {category}.',
+'a photo of my {category}.',
+'a photo of the {category}.',
+'a photo of one {category}.',
+'a photo of many {category}.',
+'a good photo of {article} {category}.',
+'a good photo of the {category}.',
+'a bad photo of {article} {category}.',
+'a bad photo of the {category}.',
+'a photo of a nice {category}.',
+'a photo of the nice {category}.',
+'a photo of a cool {category}.',
+'a photo of the cool {category}.',
+'a photo of a weird {category}.',
+'a photo of the weird {category}.',
+'a photo of a small {category}.',
+'a photo of the small {category}.',
+'a photo of a large {category}.',
+'a photo of the large {category}.',
+'a photo of a clean {category}.',
+'a photo of the clean {category}.',
+'a photo of a dirty {category}.',
+'a photo of the dirty {category}.',
+'a bright photo of {article} {category}.',
+'a bright photo of the {category}.',
+'a dark photo of {article} {category}.',
+'a dark photo of the {category}.',
+'a photo of a hard to see {category}.',
+'a photo of the hard to see {category}.',
+'a low resolution photo of {article} {category}.',
+'a low resolution photo of the {category}.',
+'a cropped photo of {article} {category}.',
+'a cropped photo of the {category}.',
+'a close-up photo of {article} {category}.',
+'a close-up photo of the {category}.',
+'a jpeg corrupted photo of {article} {category}.',
+'a jpeg corrupted photo of the {category}.',
+'a blurry photo of {article} {category}.',
+'a blurry photo of the {category}.',
+'a pixelated photo of {article} {category}.',
+'a pixelated photo of the {category}.',
+'a black and white photo of the {category}.',
+'a black and white photo of {article} {category}.',
+'a plastic {category}.',
+'the plastic {category}.',
+'a toy {category}.',
+'the toy {category}.',
+'a plushie {category}.',
+'the plushie {category}.',
+'a cartoon {category}.',
+'the cartoon {category}.',
+'an embroidered {category}.',
+'the embroidered {category}.',
+'a painting of the {category}.',
+'a painting of a {category}.'
+]
\ No newline at end of file
diff --git a/cat_seg/third_party/model.py b/cat_seg/third_party/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c2f1467977139685bea3a39a2cab8b4c84bdaa5
--- /dev/null
+++ b/cat_seg/third_party/model.py
@@ -0,0 +1,456 @@
+from collections import OrderedDict
+from typing import Tuple, Union
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1):
+ super().__init__()
+
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.relu1 = nn.ReLU(inplace=True)
+
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.relu2 = nn.ReLU(inplace=True)
+
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
+
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+ self.relu3 = nn.ReLU(inplace=True)
+
+ self.downsample = None
+ self.stride = stride
+
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
+ self.downsample = nn.Sequential(OrderedDict([
+ ("-1", nn.AvgPool2d(stride)),
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
+ ("1", nn.BatchNorm2d(planes * self.expansion))
+ ]))
+
+ def forward(self, x: torch.Tensor):
+ identity = x
+
+ out = self.relu1(self.bn1(self.conv1(x)))
+ out = self.relu2(self.bn2(self.conv2(out)))
+ out = self.avgpool(out)
+ out = self.bn3(self.conv3(out))
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu3(out)
+ return out
+
+
+class AttentionPool2d(nn.Module):
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
+ super().__init__()
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
+ self.num_heads = num_heads
+
+ def forward(self, x):
+ x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
+ x, _ = F.multi_head_attention_forward(
+ query=x[:1], key=x, value=x,
+ embed_dim_to_check=x.shape[-1],
+ num_heads=self.num_heads,
+ q_proj_weight=self.q_proj.weight,
+ k_proj_weight=self.k_proj.weight,
+ v_proj_weight=self.v_proj.weight,
+ in_proj_weight=None,
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
+ bias_k=None,
+ bias_v=None,
+ add_zero_attn=False,
+ dropout_p=0,
+ out_proj_weight=self.c_proj.weight,
+ out_proj_bias=self.c_proj.bias,
+ use_separate_proj_weight=True,
+ training=self.training,
+ need_weights=False
+ )
+ return x.squeeze(0)
+
+
+class ModifiedResNet(nn.Module):
+ """
+ A ResNet class that is similar to torchvision's but contains the following changes:
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
+ - The final pooling layer is a QKV attention instead of an average pool
+ """
+
+ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
+ super().__init__()
+ self.output_dim = output_dim
+ self.input_resolution = input_resolution
+
+ # the 3-layer stem
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(width // 2)
+ self.relu1 = nn.ReLU(inplace=True)
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(width // 2)
+ self.relu2 = nn.ReLU(inplace=True)
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(width)
+ self.relu3 = nn.ReLU(inplace=True)
+ self.avgpool = nn.AvgPool2d(2)
+
+ # residual layers
+ self._inplanes = width # this is a *mutable* variable used during construction
+ self.layer1 = self._make_layer(width, layers[0])
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
+
+ embed_dim = width * 32 # the ResNet feature dimension
+ self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
+
+ def _make_layer(self, planes, blocks, stride=1):
+ layers = [Bottleneck(self._inplanes, planes, stride)]
+
+ self._inplanes = planes * Bottleneck.expansion
+ for _ in range(1, blocks):
+ layers.append(Bottleneck(self._inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ def stem(x):
+ x = self.relu1(self.bn1(self.conv1(x)))
+ x = self.relu2(self.bn2(self.conv2(x)))
+ x = self.relu3(self.bn3(self.conv3(x)))
+ x = self.avgpool(x)
+ return x
+
+ x = x.type(self.conv1.weight.dtype)
+ x = stem(x)
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+ x = self.attnpool(x)
+
+ return x
+
+class LayerNorm(nn.LayerNorm):
+ """Subclass torch's LayerNorm to handle fp16."""
+
+ def forward(self, x: torch.Tensor):
+ orig_type = x.dtype
+ ret = super().forward(x.type(torch.float32))
+ return ret.type(orig_type)
+
+
+class QuickGELU(nn.Module):
+ def forward(self, x: torch.Tensor):
+ return x * torch.sigmoid(1.702 * x)
+
+
+class ResidualAttentionBlock(nn.Module):
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
+ super().__init__()
+
+ self.attn = nn.MultiheadAttention(d_model, n_head)
+ self.ln_1 = LayerNorm(d_model)
+ self.mlp = nn.Sequential(OrderedDict([
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
+ ("gelu", QuickGELU()),
+ ("c_proj", nn.Linear(d_model * 4, d_model))
+ ]))
+ self.ln_2 = LayerNorm(d_model)
+ self.attn_mask = attn_mask
+ self.mask_pre_mlp = True
+
+ def attention(self, x: torch.Tensor):
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
+
+ def forward(self, x: torch.Tensor):
+ x = x + self.attention(self.ln_1(x))
+ x = x + self.mlp(self.ln_2(x))
+ return x
+
+ def forward_dense(self, x: torch.Tensor):
+ y = self.ln_1(x)
+ y = F.linear(y, self.attn.in_proj_weight, self.attn.in_proj_bias)
+ L, N, D = y.shape # L N 3D
+
+ y = y.reshape(L, N, 3, D // 3).permute(2, 1, 0, 3).reshape(3 * N, L, D // 3)
+ y = F.linear(y, self.attn.out_proj.weight, self.attn.out_proj.bias)
+
+ q, k, v = y.tensor_split(3, dim=0)
+ v = v.transpose(1, 0) + x # L N D
+
+ v = v + self.mlp(self.ln_2(v))
+ return v
+
+class Transformer(nn.Module):
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
+ super().__init__()
+ self.width = width
+ self.layers = layers
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
+
+ def forward(self, x: torch.Tensor, dense=False):
+ for i, resblock in enumerate(self.resblocks):
+ if i == self.layers - 1 and dense:
+ x = resblock.forward_dense(x)
+ else:
+ x = resblock(x)
+ return x
+
+
+class VisualTransformer(nn.Module):
+ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
+ super().__init__()
+ self.output_dim = output_dim
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
+
+ scale = width ** -0.5
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
+ self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
+ self.ln_pre = LayerNorm(width)
+
+ self.transformer = Transformer(width, layers, heads)
+
+ self.ln_post = LayerNorm(width)
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
+
+ self.patch_size = patch_size
+ self.input_resolution = input_resolution
+
+ def forward(self, x: torch.Tensor, dense=False):
+ x = self.conv1(x) # shape = [*, width, grid, grid]
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
+
+ if dense and (x.shape[1] != self.positional_embedding.shape[0]):
+ x = x + self.resized_pos_embed(self.input_resolution, x.shape[1]).to(x.dtype)
+ else:
+ x = x + self.positional_embedding.to(x.dtype)
+ x = self.ln_pre(x)
+
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.transformer(x, dense)
+ x = x.permute(1, 0, 2) # LND -> NLD
+
+ if dense:
+ x = self.ln_post(x[:, :, :])
+ else:
+ x = self.ln_post(x[:, 0, :])
+
+ if self.proj is not None:
+ x = x @ self.proj
+
+ return x
+
+ def resized_pos_embed(self, in_res, tgt_res, mode="bicubic"):
+ #assert L == (input_resolution // self.patch_size) ** 2 + 1
+ L, D = self.positional_embedding.shape
+
+ in_side = in_res // self.patch_size
+ #tgt_side = tgt_res // self.patch_size
+ tgt_side = int((tgt_res - 1) ** 0.5)
+
+ cls_pos = self.positional_embedding[0].unsqueeze(0) # 1 D
+ pos_embed = self.positional_embedding[1:].reshape(1, in_side, in_side, D).permute(0, 3, 1, 2) # L-1 D -> 1 D S S
+ resized_pos_embed = F.interpolate(pos_embed, size=(tgt_side, tgt_side), mode=mode, align_corners=False,) # 1 D S S -> 1 D S' S'
+ resized_pos_embed = resized_pos_embed.squeeze(0).reshape(D, -1).T # L'-1 D
+
+ return torch.cat((cls_pos, resized_pos_embed), dim=0)
+
+
+class CLIP(nn.Module):
+ def __init__(self,
+ embed_dim: int,
+ # vision
+ image_resolution: int,
+ vision_layers: Union[Tuple[int, int, int, int], int],
+ vision_width: int,
+ vision_patch_size: int,
+ # text
+ context_length: int,
+ vocab_size: int,
+ transformer_width: int,
+ transformer_heads: int,
+ transformer_layers: int
+ ):
+ super().__init__()
+
+ self.context_length = context_length
+
+ self.image_resolution = image_resolution
+
+
+ if isinstance(vision_layers, (tuple, list)):
+ vision_heads = vision_width * 32 // 64
+ self.visual = ModifiedResNet(
+ layers=vision_layers,
+ output_dim=embed_dim,
+ heads=vision_heads,
+ input_resolution=image_resolution,
+ width=vision_width
+ )
+ else:
+ vision_heads = vision_width // 64
+ self.visual = VisualTransformer(
+ input_resolution=image_resolution,
+ patch_size=vision_patch_size,
+ width=vision_width,
+ layers=vision_layers,
+ heads=vision_heads,
+ output_dim=embed_dim
+ )
+
+ self.transformer = Transformer(
+ width=transformer_width,
+ layers=transformer_layers,
+ heads=transformer_heads,
+ attn_mask=self.build_attention_mask()
+ )
+
+ self.vocab_size = vocab_size
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
+ self.ln_final = LayerNorm(transformer_width)
+
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
+ self.logit_scale = nn.Parameter(torch.ones([]))
+
+
+ def build_attention_mask(self):
+ # lazily create causal attention mask, with full attention between the vision tokens
+ # pytorch uses additive attention mask; fill with -inf
+ mask = torch.empty(self.context_length, self.context_length)
+ mask.fill_(float("-inf"))
+ mask.triu_(1) # zero out the lower diagonal
+ return mask
+
+ @property
+ def dtype(self):
+ return self.visual.conv1.weight.dtype
+
+
+ def encode_image(self, image, masks=None, pool_mask=None, dense=False):
+ if pool_mask is not None:
+ return self.visual(image.type(self.dtype), mask=pool_mask, dense=dense)
+ if masks == None:
+ return self.visual(image.type(self.dtype), dense=dense)
+ else:
+ return self.visual(image.type(self.dtype), masks.type(self.dtype))
+
+ def encode_text(self, text):
+ x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
+
+ x = x + self.positional_embedding.type(self.dtype)
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.transformer(x)
+ x = x.permute(1, 0, 2) # LND -> NLD
+ x = self.ln_final(x).type(self.dtype)
+
+ # x.shape = [batch_size, n_ctx, transformer.width]
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
+
+ return x
+
+ def forward(self, image, text):
+ image_features = self.encode_image(image)
+ text_features = self.encode_text(text)
+ # import pdb; pdb.set_trace()
+ # normalized features
+ # image_features shape: [1, 1024]
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
+
+ # cosine similarity as logits
+ logit_scale = self.logit_scale.exp()
+ logits_per_iamge = logit_scale * image_features @ text_features.t()
+ logits_per_text = logit_scale * text_features @ image_features.t()
+
+ # shape = [global_batch_size, global_batch_size]
+ return logits_per_iamge, logits_per_text
+
+
+def convert_weights(model: nn.Module):
+ """Convert applicable model parameters to fp16"""
+
+ def _convert_weights_to_fp16(l):
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
+ l.weight.data = l.weight.data.half()
+ if l.bias is not None:
+ l.bias.data = l.bias.data.half()
+
+ if isinstance(l, nn.MultiheadAttention):
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
+ tensor = getattr(l, attr)
+ if tensor is not None:
+ tensor.data = tensor.data.half()
+
+ for name in ["text_projection", "proj"]:
+ if hasattr(l, name):
+ attr = getattr(l, name)
+ if attr is not None:
+ attr.data = attr.data.half()
+
+ model.apply(_convert_weights_to_fp16)
+
+
+def build_model(state_dict: dict):
+ vit = "visual.proj" in state_dict
+
+ if vit:
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
+ vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
+ image_resolution = vision_patch_size * grid_size
+ else:
+ counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
+ vision_layers = tuple(counts)
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
+ vision_patch_size = None
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
+ image_resolution = output_width * 32
+
+ embed_dim = state_dict["text_projection"].shape[1]
+ context_length = state_dict["positional_embedding"].shape[0]
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
+ transformer_width = state_dict["ln_final.weight"].shape[0]
+ transformer_heads = transformer_width // 64
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
+
+ model = CLIP(
+ embed_dim,
+ image_resolution, vision_layers, vision_width, vision_patch_size,
+ context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
+ )
+
+ for key in ["input_resolution", "context_length", "vocab_size"]:
+ del state_dict[key]
+
+ convert_weights(model)
+ model.load_state_dict(state_dict)
+ return model.eval()
diff --git a/cat_seg/third_party/model_vpt.py b/cat_seg/third_party/model_vpt.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e958112828e81b788418ab573ea1962684667b7
--- /dev/null
+++ b/cat_seg/third_party/model_vpt.py
@@ -0,0 +1,477 @@
+from collections import OrderedDict
+from typing import Tuple, Union
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1):
+ super().__init__()
+
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
+ self.bn1 = nn.BatchNorm2d(planes)
+
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
+
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = None
+ self.stride = stride
+
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
+ self.downsample = nn.Sequential(OrderedDict([
+ ("-1", nn.AvgPool2d(stride)),
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
+ ("1", nn.BatchNorm2d(planes * self.expansion))
+ ]))
+
+ def forward(self, x: torch.Tensor):
+ identity = x
+
+ out = self.relu(self.bn1(self.conv1(x)))
+ out = self.relu(self.bn2(self.conv2(out)))
+ out = self.avgpool(out)
+ out = self.bn3(self.conv3(out))
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+ return out
+
+
+class AttentionPool2d(nn.Module):
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
+ super().__init__()
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
+ self.num_heads = num_heads
+
+ def forward(self, x):
+ x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
+ x, _ = F.multi_head_attention_forward(
+ query=x[:1], key=x, value=x,
+ embed_dim_to_check=x.shape[-1],
+ num_heads=self.num_heads,
+ q_proj_weight=self.q_proj.weight,
+ k_proj_weight=self.k_proj.weight,
+ v_proj_weight=self.v_proj.weight,
+ in_proj_weight=None,
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
+ bias_k=None,
+ bias_v=None,
+ add_zero_attn=False,
+ dropout_p=0,
+ out_proj_weight=self.c_proj.weight,
+ out_proj_bias=self.c_proj.bias,
+ use_separate_proj_weight=True,
+ training=self.training,
+ need_weights=False
+ )
+ return x.squeeze(0)
+
+
+class ModifiedResNet(nn.Module):
+ """
+ A ResNet class that is similar to torchvision's but contains the following changes:
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
+ - The final pooling layer is a QKV attention instead of an average pool
+ """
+
+ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
+ super().__init__()
+ self.output_dim = output_dim
+ self.input_resolution = input_resolution
+
+ # the 3-layer stem
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(width // 2)
+ self.relu1 = nn.ReLU(inplace=True)
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(width // 2)
+ self.relu2 = nn.ReLU(inplace=True)
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(width)
+ self.relu3 = nn.ReLU(inplace=True)
+ self.avgpool = nn.AvgPool2d(2)
+
+ # residual layers
+ self._inplanes = width # this is a *mutable* variable used during construction
+ self.layer1 = self._make_layer(width, layers[0])
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
+
+ embed_dim = width * 32 # the ResNet feature dimension
+ self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
+
+ def _make_layer(self, planes, blocks, stride=1):
+ layers = [Bottleneck(self._inplanes, planes, stride)]
+
+ self._inplanes = planes * Bottleneck.expansion
+ for _ in range(1, blocks):
+ layers.append(Bottleneck(self._inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ def stem(x):
+ x = self.relu1(self.bn1(self.conv1(x)))
+ x = self.relu2(self.bn2(self.conv2(x)))
+ x = self.relu3(self.bn3(self.conv3(x)))
+ x = self.avgpool(x)
+ return x
+
+ x = x.type(self.conv1.weight.dtype)
+ x = stem(x)
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+ x = self.attnpool(x)
+
+ return x
+
+
+class LayerNorm(nn.LayerNorm):
+ """Subclass torch's LayerNorm to handle fp16."""
+
+ def forward(self, x: torch.Tensor):
+ orig_type = x.dtype
+ ret = super().forward(x.type(torch.float32))
+ return ret.type(orig_type)
+
+
+class QuickGELU(nn.Module):
+ def forward(self, x: torch.Tensor):
+ return x * torch.sigmoid(1.702 * x)
+
+
+class ResidualAttentionBlock(nn.Module):
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
+ super().__init__()
+
+ self.attn = nn.MultiheadAttention(d_model, n_head)
+ self.ln_1 = LayerNorm(d_model)
+ self.mlp = nn.Sequential(OrderedDict([
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
+ ("gelu", QuickGELU()),
+ ("c_proj", nn.Linear(d_model * 4, d_model))
+ ]))
+ self.ln_2 = LayerNorm(d_model)
+ self.attn_mask = attn_mask
+ self.mask_pre_mlp = True
+
+ def attention(self, x: torch.Tensor):
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
+
+ def forward(self, x: torch.Tensor):
+ x = x + self.attention(self.ln_1(x))
+ x = x + self.mlp(self.ln_2(x))
+ return x
+
+ def forward_dense(self, x: torch.Tensor):
+ y = self.ln_1(x)
+ y = F.linear(y, self.attn.in_proj_weight, self.attn.in_proj_bias)
+ L, N, D = y.shape # L N 3D
+
+ y = y.reshape(L, N, 3, D // 3).permute(2, 1, 0, 3).reshape(3 * N, L, D // 3)
+ y = F.linear(y, self.attn.out_proj.weight, self.attn.out_proj.bias)
+
+ q, k, v = y.tensor_split(3, dim=0)
+ #v = v.transpose(1, 0) + x # L N D
+ v = v.transpose(1, 0) + x[:1] # L N D
+
+ v = v + self.mlp(self.ln_2(v))
+ return v
+
+
+class Transformer(nn.Module):
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, prompt_length=0, prompt_depth=0):
+ super().__init__()
+ self.width = width
+ self.layers = layers
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
+
+ self.prompt_length = prompt_length
+ self.prompt_depth = prompt_depth
+ self.prompt_tokens = nn.Parameter(torch.zeros(prompt_depth, prompt_length, width)) if prompt_length > 0 else None
+ if self.prompt_tokens is not None:
+ nn.init.xavier_uniform_(self.prompt_tokens)
+
+ def forward(self, x: torch.Tensor, dense=False):
+ for i, resblock in enumerate(self.resblocks):
+ if self.prompt_length > 0 and i < self.prompt_depth:
+ l = self.prompt_length + 1 if i > 0 else 1
+ x = torch.cat((x[0:1, :, :], self.prompt_tokens[i].repeat(x.shape[1], 1, 1).permute(1, 0, 2) ,x[l:, :, :]))
+
+ if i == self.layers - 1 and dense:
+ x = resblock.forward_dense(x)
+ x = torch.cat((x[0:1, :, :], x[self.prompt_length + 1: :, :]), dim=0)
+ else:
+ x = resblock(x)
+
+ return x
+
+
+class VisualTransformer(nn.Module):
+ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int, prompt_depth: int, prompt_length: int):
+ super().__init__()
+ self.output_dim = output_dim
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
+
+ scale = width ** -0.5
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
+ self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
+ self.ln_pre = LayerNorm(width)
+
+ self.transformer = Transformer(width, layers, heads, prompt_depth=prompt_depth, prompt_length=prompt_length)
+
+ self.ln_post = LayerNorm(width)
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
+
+ self.patch_size = patch_size
+ self.input_resolution = input_resolution
+
+ def forward(self, x: torch.Tensor, dense=False):
+ x = self.conv1(x) # shape = [*, width, grid, grid]
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
+
+ if dense and (x.shape[1] != self.positional_embedding.shape[0]):
+ x = x + self.resized_pos_embed(self.input_resolution, x.shape[1]).to(x.dtype)
+ else:
+ x = x + self.positional_embedding.to(x.dtype)
+
+ x = self.ln_pre(x)
+
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.transformer(x, dense)
+ x = x.permute(1, 0, 2) # LND -> NLD
+
+ if dense:
+ x = self.ln_post(x[:, :, :])
+ else:
+ x = self.ln_post(x[:, 0, :])
+
+ if self.proj is not None:
+ x = x @ self.proj
+
+ return x
+
+ def resized_pos_embed(self, in_res, tgt_res, mode="bicubic"):
+ #assert L == (input_resolution // self.patch_size) ** 2 + 1
+ L, D = self.positional_embedding.shape
+
+ in_side = in_res // self.patch_size
+ #tgt_side = tgt_res // self.patch_size
+ tgt_side = int((tgt_res - 1) ** 0.5)
+
+ cls_pos = self.positional_embedding[0].unsqueeze(0) # 1 D
+ pos_embed = self.positional_embedding[1:].reshape(1, in_side, in_side, D).permute(0, 3, 1, 2) # L-1 D -> 1 D S S
+ resized_pos_embed = F.interpolate(pos_embed, size=(tgt_side, tgt_side), mode=mode, align_corners=False,) # 1 D S S -> 1 D S' S'
+ resized_pos_embed = resized_pos_embed.squeeze(0).reshape(D, -1).T # L'-1 D
+
+ return torch.cat((cls_pos, resized_pos_embed), dim=0)
+
+
+class CLIP(nn.Module):
+ def __init__(self,
+ embed_dim: int,
+ # vision
+ image_resolution: int,
+ vision_layers: Union[Tuple[int, int, int, int], int],
+ vision_width: int,
+ vision_patch_size: int,
+ # text
+ context_length: int,
+ vocab_size: int,
+ transformer_width: int,
+ transformer_heads: int,
+ transformer_layers: int,
+ # prompt
+ prompt_depth: int=0,
+ prompt_length: int=0,
+ ):
+ super().__init__()
+
+ self.context_length = context_length
+
+ self.image_resolution = image_resolution
+
+
+ if isinstance(vision_layers, (tuple, list)):
+ assert prompt_length == 0 and prompt_depth==0
+ vision_heads = vision_width * 32 // 64
+ self.visual = ModifiedResNet(
+ layers=vision_layers,
+ output_dim=embed_dim,
+ heads=vision_heads,
+ input_resolution=image_resolution,
+ width=vision_width
+ )
+ else:
+ vision_heads = vision_width // 64
+ self.visual = VisualTransformer(
+ input_resolution=image_resolution,
+ patch_size=vision_patch_size,
+ width=vision_width,
+ layers=vision_layers,
+ heads=vision_heads,
+ output_dim=embed_dim,
+ prompt_depth=prompt_depth,
+ prompt_length=prompt_length,
+ )
+
+ self.transformer = Transformer(
+ width=transformer_width,
+ layers=transformer_layers,
+ heads=transformer_heads,
+ attn_mask=self.build_attention_mask()
+ )
+
+ self.vocab_size = vocab_size
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
+ self.ln_final = LayerNorm(transformer_width)
+
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
+ self.logit_scale = nn.Parameter(torch.ones([]))
+
+
+ def build_attention_mask(self):
+ # lazily create causal attention mask, with full attention between the vision tokens
+ # pytorch uses additive attention mask; fill with -inf
+ mask = torch.empty(self.context_length, self.context_length)
+ mask.fill_(float("-inf"))
+ mask.triu_(1) # zero out the lower diagonal
+ return mask
+
+ @property
+ def dtype(self):
+ return self.visual.conv1.weight.dtype
+
+
+ def encode_image(self, image, masks=None, pool_mask=None, dense=False):
+ if pool_mask is not None:
+ return self.visual(image.type(self.dtype), mask=pool_mask, dense=dense)
+ if masks == None:
+ return self.visual(image.type(self.dtype), dense=dense)
+ else:
+ return self.visual(image.type(self.dtype), masks.type(self.dtype))
+
+ def encode_text(self, text):
+ x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
+
+ x = x + self.positional_embedding.type(self.dtype)
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.transformer(x)
+ x = x.permute(1, 0, 2) # LND -> NLD
+ x = self.ln_final(x).type(self.dtype)
+
+ # x.shape = [batch_size, n_ctx, transformer.width]
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
+
+ return x
+
+ def forward(self, image, text):
+ image_features = self.encode_image(image)
+ text_features = self.encode_text(text)
+ # import pdb; pdb.set_trace()
+ # normalized features
+ # image_features shape: [1, 1024]
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
+
+ # cosine similarity as logits
+ logit_scale = self.logit_scale.exp()
+ logits_per_iamge = logit_scale * image_features @ text_features.t()
+ logits_per_text = logit_scale * text_features @ image_features.t()
+
+ # shape = [global_batch_size, global_batch_size]
+ return logits_per_iamge, logits_per_text
+
+
+def convert_weights(model: nn.Module):
+ """Convert applicable model parameters to fp16"""
+
+ def _convert_weights_to_fp16(l):
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
+ l.weight.data = l.weight.data.half()
+ if l.bias is not None:
+ l.bias.data = l.bias.data.half()
+
+ if isinstance(l, nn.MultiheadAttention):
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
+ tensor = getattr(l, attr)
+ if tensor is not None:
+ tensor.data = tensor.data.half()
+
+ for name in ["text_projection", "proj"]:
+ if hasattr(l, name):
+ attr = getattr(l, name)
+ if attr is not None:
+ attr.data = attr.data.half()
+
+ model.apply(_convert_weights_to_fp16)
+
+
+def build_model(state_dict: dict, prompt_depth=0, prompt_length=0):
+ vit = "visual.proj" in state_dict
+
+ if vit:
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
+ vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
+ image_resolution = vision_patch_size * grid_size
+ else:
+ counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
+ vision_layers = tuple(counts)
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
+ vision_patch_size = None
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
+ image_resolution = output_width * 32
+
+ embed_dim = state_dict["text_projection"].shape[1]
+ context_length = state_dict["positional_embedding"].shape[0]
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
+ transformer_width = state_dict["ln_final.weight"].shape[0]
+ transformer_heads = transformer_width // 64
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
+
+ model = CLIP(
+ embed_dim,
+ image_resolution, vision_layers, vision_width, vision_patch_size,
+ context_length, vocab_size, transformer_width, transformer_heads, transformer_layers,
+ prompt_depth=prompt_depth, prompt_length=prompt_length,
+ )
+
+ for key in ["input_resolution", "context_length", "vocab_size"]:
+ del state_dict[key]
+
+ convert_weights(model)
+ model.load_state_dict(state_dict, strict=False)
+ return model.eval()
diff --git a/cat_seg/third_party/simple_tokenizer.py b/cat_seg/third_party/simple_tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a66286b7d5019c6e221932a813768038f839c91
--- /dev/null
+++ b/cat_seg/third_party/simple_tokenizer.py
@@ -0,0 +1,132 @@
+import gzip
+import html
+import os
+from functools import lru_cache
+
+import ftfy
+import regex as re
+
+
+@lru_cache()
+def default_bpe():
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
+
+
+@lru_cache()
+def bytes_to_unicode():
+ """
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
+ The reversible bpe codes work on unicode strings.
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
+ """
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
+ cs = bs[:]
+ n = 0
+ for b in range(2**8):
+ if b not in bs:
+ bs.append(b)
+ cs.append(2**8+n)
+ n += 1
+ cs = [chr(n) for n in cs]
+ return dict(zip(bs, cs))
+
+
+def get_pairs(word):
+ """Return set of symbol pairs in a word.
+ Word is represented as tuple of symbols (symbols being variable-length strings).
+ """
+ pairs = set()
+ prev_char = word[0]
+ for char in word[1:]:
+ pairs.add((prev_char, char))
+ prev_char = char
+ return pairs
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r'\s+', ' ', text)
+ text = text.strip()
+ return text
+
+
+class SimpleTokenizer(object):
+ def __init__(self, bpe_path: str = default_bpe()):
+ self.byte_encoder = bytes_to_unicode()
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
+ merges = merges[1:49152-256-2+1]
+ merges = [tuple(merge.split()) for merge in merges]
+ vocab = list(bytes_to_unicode().values())
+ vocab = vocab + [v+'' for v in vocab]
+ for merge in merges:
+ vocab.append(''.join(merge))
+ vocab.extend(['<|startoftext|>', '<|endoftext|>'])
+ self.encoder = dict(zip(vocab, range(len(vocab))))
+ self.decoder = {v: k for k, v in self.encoder.items()}
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
+ self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
+ self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
+
+ def bpe(self, token):
+ if token in self.cache:
+ return self.cache[token]
+ word = tuple(token[:-1]) + ( token[-1] + '',)
+ pairs = get_pairs(word)
+
+ if not pairs:
+ return token+''
+
+ while True:
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
+ if bigram not in self.bpe_ranks:
+ break
+ first, second = bigram
+ new_word = []
+ i = 0
+ while i < len(word):
+ try:
+ j = word.index(first, i)
+ new_word.extend(word[i:j])
+ i = j
+ except:
+ new_word.extend(word[i:])
+ break
+
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
+ new_word.append(first+second)
+ i += 2
+ else:
+ new_word.append(word[i])
+ i += 1
+ new_word = tuple(new_word)
+ word = new_word
+ if len(word) == 1:
+ break
+ else:
+ pairs = get_pairs(word)
+ word = ' '.join(word)
+ self.cache[token] = word
+ return word
+
+ def encode(self, text):
+ bpe_tokens = []
+ text = whitespace_clean(basic_clean(text)).lower()
+ for token in re.findall(self.pat, text):
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
+ return bpe_tokens
+
+ def decode(self, tokens):
+ text = ''.join([self.decoder[token] for token in tokens])
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
+ return text
diff --git a/cat_seg/utils/__init__.py b/cat_seg/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9020c2df23e2af280b7bb168b996ae9eaf312eb8
--- /dev/null
+++ b/cat_seg/utils/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
diff --git a/cat_seg/utils/__pycache__/__init__.cpython-38.pyc b/cat_seg/utils/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..92c5486d433f3efbd342186838496ff7fae6f3c6
Binary files /dev/null and b/cat_seg/utils/__pycache__/__init__.cpython-38.pyc differ
diff --git a/cat_seg/utils/__pycache__/misc.cpython-38.pyc b/cat_seg/utils/__pycache__/misc.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d631b2048813a12b45009ecae83ab7d0c14f6c14
Binary files /dev/null and b/cat_seg/utils/__pycache__/misc.cpython-38.pyc differ
diff --git a/cat_seg/utils/misc.py b/cat_seg/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..874d9805b482f52bbffc1be620e36e0cffc07c46
--- /dev/null
+++ b/cat_seg/utils/misc.py
@@ -0,0 +1,111 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/util/misc.py
+"""
+Misc functions, including distributed helpers.
+
+Mostly copy-paste from torchvision references.
+"""
+from typing import List, Optional
+
+import torch
+import torch.distributed as dist
+import torchvision
+from torch import Tensor
+
+
+def _max_by_axis(the_list):
+ # type: (List[List[int]]) -> List[int]
+ maxes = the_list[0]
+ for sublist in the_list[1:]:
+ for index, item in enumerate(sublist):
+ maxes[index] = max(maxes[index], item)
+ return maxes
+
+
+class NestedTensor(object):
+ def __init__(self, tensors, mask: Optional[Tensor]):
+ self.tensors = tensors
+ self.mask = mask
+
+ def to(self, device):
+ # type: (Device) -> NestedTensor # noqa
+ cast_tensor = self.tensors.to(device)
+ mask = self.mask
+ if mask is not None:
+ assert mask is not None
+ cast_mask = mask.to(device)
+ else:
+ cast_mask = None
+ return NestedTensor(cast_tensor, cast_mask)
+
+ def decompose(self):
+ return self.tensors, self.mask
+
+ def __repr__(self):
+ return str(self.tensors)
+
+
+def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
+ # TODO make this more general
+ if tensor_list[0].ndim == 3:
+ if torchvision._is_tracing():
+ # nested_tensor_from_tensor_list() does not export well to ONNX
+ # call _onnx_nested_tensor_from_tensor_list() instead
+ return _onnx_nested_tensor_from_tensor_list(tensor_list)
+
+ # TODO make it support different-sized images
+ max_size = _max_by_axis([list(img.shape) for img in tensor_list])
+ # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
+ batch_shape = [len(tensor_list)] + max_size
+ b, c, h, w = batch_shape
+ dtype = tensor_list[0].dtype
+ device = tensor_list[0].device
+ tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
+ mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
+ for img, pad_img, m in zip(tensor_list, tensor, mask):
+ pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
+ m[: img.shape[1], : img.shape[2]] = False
+ else:
+ raise ValueError("not supported")
+ return NestedTensor(tensor, mask)
+
+
+# _onnx_nested_tensor_from_tensor_list() is an implementation of
+# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
+@torch.jit.unused
+def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
+ max_size = []
+ for i in range(tensor_list[0].dim()):
+ max_size_i = torch.max(
+ torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)
+ ).to(torch.int64)
+ max_size.append(max_size_i)
+ max_size = tuple(max_size)
+
+ # work around for
+ # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
+ # m[: img.shape[1], :img.shape[2]] = False
+ # which is not yet supported in onnx
+ padded_imgs = []
+ padded_masks = []
+ for img in tensor_list:
+ padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
+ padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
+ padded_imgs.append(padded_img)
+
+ m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
+ padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
+ padded_masks.append(padded_mask.to(torch.bool))
+
+ tensor = torch.stack(padded_imgs)
+ mask = torch.stack(padded_masks)
+
+ return NestedTensor(tensor, mask=mask)
+
+
+def is_dist_avail_and_initialized():
+ if not dist.is_available():
+ return False
+ if not dist.is_initialized():
+ return False
+ return True
diff --git a/configs/config.yaml b/configs/config.yaml
new file mode 100755
index 0000000000000000000000000000000000000000..dbe5c0a473815474570ecd11a119026d19414672
--- /dev/null
+++ b/configs/config.yaml
@@ -0,0 +1,85 @@
+MODEL:
+ META_ARCHITECTURE: "CATSeg"
+ BACKBONE:
+ FREEZE_AT: 0
+ NAME: "D2SwinTransformer"
+ SWIN:
+ EMBED_DIM: 192
+ DEPTHS: [2, 2, 18, 2]
+ NUM_HEADS: [4, 8, 16, 32]
+ WINDOW_SIZE: 12
+ APE: False
+ DROP_PATH_RATE: 0.3
+ PATCH_NORM: True
+ PRETRAIN_IMG_SIZE: 384
+ OUT_FEATURES: ["res2", "res3", "res4"]
+ WEIGHTS: "swin_large_patch4_window12_384_22k.pkl"
+ PIXEL_MEAN: [123.675, 116.280, 103.530]
+ PIXEL_STD: [58.395, 57.120, 57.375]
+ SEM_SEG_HEAD:
+ NAME: "OpenVocabHead"
+ IN_FEATURES: ["res2", "res3", "res4"]
+ IGNORE_VALUE: 255
+ NUM_CLASSES: 171
+ TRAIN_CLASS_JSON: "datasets/coco.json"
+ TEST_CLASS_JSON: "datasets/coco.json"
+ CLIP_PRETRAINED: "ViT-L/14@336px"
+ PROMPT_DEPTH: 0
+ PROMPT_LENGTH: 0
+ TEXT_AFFINITY_DIM: 768
+ TEXT_AFFINITY_PROJ_DIM: 128
+ APPEARANCE_AFFINITY_DIM: 768
+ APPEARANCE_AFFINITY_PROJ_DIM: 128
+ DECODER_DIMS: [64, 32]
+ DECODER_AFFINITY_DIMS: [256, 128]
+ DECODER_AFFINITY_PROJ_DIMS: [32, 16]
+ NUM_LAYERS: 4
+ NUM_HEADS: 4
+ HIDDEN_DIMS: 128
+ POOLING_SIZES: [6, 6]
+ FEATURE_RESOLUTION: [24, 24]
+ WINDOW_SIZES: 12
+ ATTENTION_TYPE: "linear"
+ CLIP_FINETUNE: "attention"
+ PROMPT_ENSEMBLE_TYPE: "imagenet"
+DATASETS:
+ TRAIN: ("coco_2017_train_stuff_all_sem_seg",)
+ TEST: ("coco_2017_test_stuff_all_sem_seg",)
+SOLVER:
+ IMS_PER_BATCH: 4
+ BASE_LR: 0.0002
+ MAX_ITER: 80000
+ WARMUP_FACTOR: 1.0
+ WARMUP_ITERS: 0
+ WEIGHT_DECAY: 0.0001
+ OPTIMIZER: "ADAMW"
+ LR_SCHEDULER_NAME: "WarmupCosineLR"
+ BACKBONE_MULTIPLIER: 0.01
+ CLIP_MULTIPLIER: 0.01
+ CLIP_GRADIENTS:
+ ENABLED: True
+ CLIP_TYPE: "full_model"
+ CLIP_VALUE: 0.01
+ NORM_TYPE: 2.0
+INPUT:
+ MIN_SIZE_TRAIN: (384, )
+ MIN_SIZE_TRAIN_SAMPLING: "choice"
+ MIN_SIZE_TEST: 640
+ MAX_SIZE_TEST: 2560
+ CROP:
+ ENABLED: True
+ TYPE: "absolute"
+ SIZE: (384, 384)
+ SINGLE_CATEGORY_MAX_AREA: 1.0
+ COLOR_AUG_SSD: True
+ SIZE_DIVISIBILITY: 384
+ FORMAT: "RGB"
+ DATASET_MAPPER_NAME: "mask_former_semantic"
+TEST:
+ EVAL_PERIOD: 5000
+ SLIDING_WINDOW: False
+DATALOADER:
+ FILTER_EMPTY_ANNOTATIONS: True
+ NUM_WORKERS: 8
+VERSION: 2
+CUDNN_BENCHMARK: True
\ No newline at end of file
diff --git a/configs/vitb_r101_384.yaml b/configs/vitb_r101_384.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..13611c32e264f1cadf5aee8f810760e262cc52a8
--- /dev/null
+++ b/configs/vitb_r101_384.yaml
@@ -0,0 +1,43 @@
+_BASE_: config.yaml
+MODEL:
+ META_ARCHITECTURE: "CATSeg"
+ BACKBONE:
+ FREEZE_AT: 0
+ NAME: "build_resnet_backbone"
+ WEIGHTS: "R-101.pkl"
+ RESNETS:
+ DEPTH: 101
+ STEM_TYPE: "basic"
+ STEM_OUT_CHANNELS: 64
+ STRIDE_IN_1X1: False
+ OUT_FEATURES: ["res2", "res3", "res4"]
+ PIXEL_MEAN: [123.675, 116.280, 103.530]
+ PIXEL_STD: [58.395, 57.120, 57.375]
+ SEM_SEG_HEAD:
+ NAME: "CATSegHead"
+ IN_FEATURES: ["res2", "res3", "res4"]
+ IGNORE_VALUE: 255
+ NUM_CLASSES: 171
+ TRAIN_CLASS_JSON: "datasets/coco.json"
+ TEST_CLASS_JSON: "datasets/coco.json"
+ CLIP_PRETRAINED: "ViT-B/16"
+ PROMPT_DEPTH: 0
+ PROMPT_LENGTH: 0
+ TEXT_AFFINITY_DIM: 512
+ TEXT_AFFINITY_PROJ_DIM: 128
+ APPEARANCE_AFFINITY_DIM: 1024
+ APPEARANCE_AFFINITY_PROJ_DIM: 128
+ DECODER_DIMS: [64, 32]
+ DECODER_AFFINITY_DIMS: [512, 256]
+ DECODER_AFFINITY_PROJ_DIMS: [32, 16]
+ NUM_LAYERS: 2
+ NUM_HEADS: 4
+ HIDDEN_DIMS: 128
+ POOLING_SIZES: [2, 2]
+ FEATURE_RESOLUTION: [24, 24]
+ WINDOW_SIZES: 12
+ ATTENTION_TYPE: "linear"
+ CLIP_FINETUNE: "attention"
+ PROMPT_ENSEMBLE_TYPE: "imagenet"
+SOLVER:
+ BACKBONE_MULTIPLIER: 0.01
\ No newline at end of file
diff --git a/configs/vitl_swinb_384.yaml b/configs/vitl_swinb_384.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1c30e76ec5f343c772d6ce352191898854cfbcb3
--- /dev/null
+++ b/configs/vitl_swinb_384.yaml
@@ -0,0 +1,66 @@
+_BASE_: config.yaml
+MODEL:
+ META_ARCHITECTURE: "CATSeg"
+ BACKBONE:
+ FREEZE_AT: 0
+ NAME: "D2SwinTransformer"
+ SWIN:
+ EMBED_DIM: 128
+ DEPTHS: [2, 2, 18]
+ NUM_HEADS: [4, 8, 16]
+ WINDOW_SIZE: 12
+ APE: False
+ DROP_PATH_RATE: 0.3
+ PATCH_NORM: True
+ PRETRAIN_IMG_SIZE: 384
+ OUT_FEATURES: ["res2", "res3", "res4"]
+ WEIGHTS: "swin_base_patch4_window12_384_22k.pkl"
+ PIXEL_MEAN: [123.675, 116.280, 103.530]
+ PIXEL_STD: [58.395, 57.120, 57.375]
+ SEM_SEG_HEAD:
+ NAME: "CATSegHead"
+ IN_FEATURES: ["res2", "res3", "res4"]
+ IGNORE_VALUE: 255
+ NUM_CLASSES: 171
+ TRAIN_CLASS_JSON: "datasets/coco.json"
+ TEST_CLASS_JSON: "datasets/coco.json"
+ CLIP_PRETRAINED: "ViT-L/14@336px"
+ PROMPT_DEPTH: 0
+ PROMPT_LENGTH: 0
+ TEXT_AFFINITY_DIM: 768
+ TEXT_AFFINITY_PROJ_DIM: 128
+ APPEARANCE_AFFINITY_DIM: 512
+ APPEARANCE_AFFINITY_PROJ_DIM: 128
+ DECODER_DIMS: [64, 32]
+ DECODER_AFFINITY_DIMS: [256, 128]
+ DECODER_AFFINITY_PROJ_DIMS: [32, 16]
+ NUM_LAYERS: 2
+ NUM_HEADS: 4
+ HIDDEN_DIMS: 128
+ POOLING_SIZES: [2, 2]
+ FEATURE_RESOLUTION: [24, 24]
+ WINDOW_SIZES: 12
+ ATTENTION_TYPE: "linear"
+ CLIP_FINETUNE: "attention"
+ PROMPT_ENSEMBLE_TYPE: "imagenet"
+INPUT:
+ MIN_SIZE_TRAIN: (384, )
+ MIN_SIZE_TRAIN_SAMPLING: "choice"
+ MIN_SIZE_TEST: 640
+ CROP:
+ ENABLED: True
+ TYPE: "absolute"
+ SIZE: (384, 384)
+ SIZE_DIVISIBILITY: 384
+ FORMAT: "RGB"
+ DATASET_MAPPER_NAME: "mask_former_semantic"
+SOLVER:
+ IMS_PER_BATCH: 4
+ LR_SCHEDULER_NAME: WarmupCosineLR
+ BASE_LR: 0.0002
+ MAX_ITER: 80000
+ BACKBONE_MULTIPLIER: 0.0
+ CLIP_MULTIPLIER: 0.01
+TEST:
+ EVAL_PERIOD: 5000
+
\ No newline at end of file
diff --git a/configs/vitl_swinb_384_ade150.yaml b/configs/vitl_swinb_384_ade150.yaml
new file mode 100755
index 0000000000000000000000000000000000000000..6c413e4c8d19627f51030eb64e434fc839b90583
--- /dev/null
+++ b/configs/vitl_swinb_384_ade150.yaml
@@ -0,0 +1,7 @@
+_BASE_: vitl_swinb_384.yaml
+MODEL:
+ META_ARCHITECTURE: "CATSeg"
+ SEM_SEG_HEAD:
+ TEST_CLASS_JSON: "datasets/ade150.json"
+DATASETS:
+ TEST: ("ade20k_150_test_sem_seg",)
\ No newline at end of file
diff --git a/configs/vitl_swinb_384_ade847.yaml b/configs/vitl_swinb_384_ade847.yaml
new file mode 100755
index 0000000000000000000000000000000000000000..133de94a46d512786d6f4eb55fca4d09804668eb
--- /dev/null
+++ b/configs/vitl_swinb_384_ade847.yaml
@@ -0,0 +1,7 @@
+_BASE_: vitl_swinb_384.yaml
+MODEL:
+ META_ARCHITECTURE: "CATSeg"
+ SEM_SEG_HEAD:
+ TEST_CLASS_JSON: "datasets/ade847.json"
+DATASETS:
+ TEST: ("ade20k_full_sem_seg_freq_val_all",)
\ No newline at end of file
diff --git a/configs/vitl_swinb_384_pas20.yaml b/configs/vitl_swinb_384_pas20.yaml
new file mode 100755
index 0000000000000000000000000000000000000000..c2793af52dad23b59ef07c1e665529b08e4e7c3e
--- /dev/null
+++ b/configs/vitl_swinb_384_pas20.yaml
@@ -0,0 +1,7 @@
+_BASE_: vitl_swinb_384.yaml
+MODEL:
+ META_ARCHITECTURE: "CATSeg"
+ SEM_SEG_HEAD:
+ TEST_CLASS_JSON: "datasets/voc20.json"
+DATASETS:
+ TEST: ("voc_2012_test_sem_seg",)
\ No newline at end of file
diff --git a/configs/vitl_swinb_384_pas20b.yaml b/configs/vitl_swinb_384_pas20b.yaml
new file mode 100755
index 0000000000000000000000000000000000000000..48274ce06079912a0e4696ce73e97bd31e8e3a87
--- /dev/null
+++ b/configs/vitl_swinb_384_pas20b.yaml
@@ -0,0 +1,7 @@
+_BASE_: vitl_swinb_384.yaml
+MODEL:
+ META_ARCHITECTURE: "CATSeg"
+ SEM_SEG_HEAD:
+ TEST_CLASS_JSON: "datasets/voc20b.json"
+DATASETS:
+ TEST: ("voc_2012_test_background_sem_seg",)
\ No newline at end of file
diff --git a/configs/vitl_swinb_384_pas459.yaml b/configs/vitl_swinb_384_pas459.yaml
new file mode 100755
index 0000000000000000000000000000000000000000..b3b4704c8e334f083ffb19140679db4b5c6fdc6e
--- /dev/null
+++ b/configs/vitl_swinb_384_pas459.yaml
@@ -0,0 +1,7 @@
+_BASE_: vitl_swinb_384.yaml
+MODEL:
+ META_ARCHITECTURE: "CATSeg"
+ SEM_SEG_HEAD:
+ TEST_CLASS_JSON: "datasets/pc459.json"
+DATASETS:
+ TEST: ("context_459_test_sem_seg",)
\ No newline at end of file
diff --git a/configs/vitl_swinb_384_pas59.yaml b/configs/vitl_swinb_384_pas59.yaml
new file mode 100755
index 0000000000000000000000000000000000000000..39a4014b6c49add0e24a3dfa6eda028ba18ac435
--- /dev/null
+++ b/configs/vitl_swinb_384_pas59.yaml
@@ -0,0 +1,7 @@
+_BASE_: vitl_swinb_384.yaml
+MODEL:
+ META_ARCHITECTURE: "CATSeg"
+ SEM_SEG_HEAD:
+ TEST_CLASS_JSON: "datasets/pc59.json"
+DATASETS:
+ TEST: ("context_59_test_sem_seg",)
\ No newline at end of file
diff --git a/datasets/README.md b/datasets/README.md
new file mode 100755
index 0000000000000000000000000000000000000000..db2642a9b39eab0d02857ac2dafb15b4658e7cad
--- /dev/null
+++ b/datasets/README.md
@@ -0,0 +1,167 @@
+# Prepare Datasets for CAT-Seg
+
+A dataset can be used by accessing [DatasetCatalog](https://detectron2.readthedocs.io/modules/data.html#detectron2.data.DatasetCatalog)
+for its data, or [MetadataCatalog](https://detectron2.readthedocs.io/modules/data.html#detectron2.data.MetadataCatalog) for its metadata (class names, etc).
+This document explains how to setup the builtin datasets so they can be used by the above APIs.
+[Use Custom Datasets](https://detectron2.readthedocs.io/tutorials/datasets.html) gives a deeper dive on how to use `DatasetCatalog` and `MetadataCatalog`,
+and how to add new datasets to them.
+
+CAT-Seg has builtin support for a few datasets.
+The datasets are assumed to exist in a directory specified by the environment variable
+`DETECTRON2_DATASETS`.
+Under this directory, detectron2 will look for datasets in the structure described below, if needed.
+```
+$DETECTRON2_DATASETS/
+ coco/ # COCO-Stuff
+ ADEChallengeData2016/ # ADE20K-150
+ ADE20K_2021_17_01/ # ADE20K-847
+ VOCdevkit/
+ VOC2010/ # PASCAL Context
+ VOC2012/ # PASCAL VOC
+```
+
+You can set the location for builtin datasets by `export DETECTRON2_DATASETS=/path/to/datasets`.
+If left unset, the default is `./datasets` relative to your current working directory.
+
+## Prepare data for [COCO-Stuff](https://github.com/nightrome/cocostuff):
+
+### Expected data structure
+
+```
+coco-stuff/
+ annotations/
+ train2017/
+ val2017/
+ images/
+ train2017/
+ val2017/
+ # below are generated by prepare_coco_stuff.py
+ annotations_detectron2/
+ train2017/
+ val2017/
+```
+Download the COCO (2017) images from https://cocodataset.org/
+
+```bash
+wget http://images.cocodataset.org/zips/train2017.zip
+wget http://images.cocodataset.org/zips/val2017.zip
+```
+
+Download the COCO-Stuff annotation from https://github.com/nightrome/cocostuff.
+```bash
+wget http://calvin.inf.ed.ac.uk/wp-content/uploads/data/cocostuffdataset/stuffthingmaps_trainval2017.zip
+```
+Unzip `train2017.zip`, `val2017.zip`, and `stuffthingmaps_trainval2017.zip`. Then put them to the correct location listed above.
+
+Generate the labels for training and testing.
+
+```
+python datasets/prepare_coco_stuff.py
+```
+
+
+
+## Prepare data for [ADE20K-150](http://sceneparsing.csail.mit.edu):
+
+### Expected data structure
+```
+ADEChallengeData2016/
+ annotations/
+ validation/
+ images/
+ validation/
+ # below are generated by prepare_ade20k_150.py
+ annotations_detectron2/
+ validation/
+```
+Download the data of ADE20K-150 from http://sceneparsing.csail.mit.edu.
+```
+wget http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip
+```
+Unzip `ADEChallengeData2016.zip` and generate the labels for testing.
+```
+python datasets/prepare_ade20k_150.py
+```
+## Prepare data for [ADE20k-847](https://groups.csail.mit.edu/vision/datasets/ADE20K/):
+
+### Expected data structure
+```
+ADE20K_2021_17_01/
+ images/
+ ADE/
+ validation/
+ index_ade20k.mat
+ index_ade20k.pkl
+ # below are generated by prepare_ade20k_847.py
+ annotations_detectron2/
+ validation/
+```
+Download the data of ADE20k-Full from https://groups.csail.mit.edu/vision/datasets/ADE20K/request_data/
+Unzip the dataset and generate the labels for testing.
+```
+python datasets/prepare_ade20k_847.py
+```
+
+## Prepare data for [PASCAL VOC 2012](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/#devkit):
+
+
+### Expected data structure
+```
+VOCdevkit/
+ VOC2012/
+ Annotations/
+ ImageSets/
+ JPEGImages/
+ SegmentationClass/
+ SegmentationClassAug/
+ SegmentationObject/
+ # below are generated by prepare_voc.py
+ annotations_detectron2
+ annotations_detectron2_bg
+
+```
+Download the data of PASCAL VOC from http://host.robots.ox.ac.uk/pascal/VOC/voc2012/#devkit.
+
+We use SBD augmentated training data as SegmentationClassAug following [Deeplab](https://github.com/kazuto1011/deeplab-pytorch/blob/master/data/datasets/voc12/README.md).
+```
+wget http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
+wget https://www.dropbox.com/s/oeu149j8qtbs1x0/SegmentationClassAug.zip
+```
+Unzip `VOCtrainval_11-May-2012.tar` and `SegmentationClassAug.zip`. Then put them to the correct location listed above and generate the labels for testing.
+```
+python datasets/prepare_voc.py
+```
+
+
+## Prepare data for [PASCAL Context](https://www.cs.stanford.edu/~roozbeh/pascal-context/):
+
+
+### Expected data structure
+```
+VOCdevkit/
+ VOC2010/
+ Annotations/
+ ImageSets/
+ JPEGImages/
+ SegmentationClass/
+ SegmentationObject/
+ trainval/
+ labels.txt
+ 59_labels.txt
+ pascalcontext_val.txt
+ # below are generated by prepare_pascal_context.py
+ annotations_detectron2/
+ pc459_val
+ pc59_val
+```
+Download the data of PASCAL VOC 2010 from https://www.cs.stanford.edu/~roozbeh/pascal-context/.
+
+```
+wget http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar
+wget https://www.cs.stanford.edu/~roozbeh/pascal-context/trainval.tar.gz
+wget https://www.cs.stanford.edu/~roozbeh/pascal-context/59_labels.txt
+```
+Unzip `VOCtrainval_03-May-2010.tar` and `trainval.tar.gz`. Then put them to the correct location listed above and generate the labels for testing.
+```
+python datasets/prepare_pascal_context.py
+```
\ No newline at end of file
diff --git a/datasets/ade150.json b/datasets/ade150.json
new file mode 100755
index 0000000000000000000000000000000000000000..58772475641879daf8ada3e5c5204773b576780a
--- /dev/null
+++ b/datasets/ade150.json
@@ -0,0 +1 @@
+["wall", "building", "sky", "floor", "tree", "ceiling", "road", "bed ", "windowpane", "grass", "cabinet", "sidewalk", "person", "earth", "door", "table", "mountain", "plant", "curtain", "chair", "car", "water", "painting", "sofa", "shelf", "house", "sea", "mirror", "rug", "field", "armchair", "seat", "fence", "desk", "rock", "wardrobe", "lamp", "bathtub", "railing", "cushion", "base", "box", "column", "signboard", "chest of drawers", "counter", "sand", "sink", "skyscraper", "fireplace", "refrigerator", "grandstand", "path", "stairs", "runway", "case", "pool table", "pillow", "screen door", "stairway", "river", "bridge", "bookcase", "blind", "coffee table", "toilet", "flower", "book", "hill", "bench", "countertop", "stove", "palm", "kitchen island", "computer", "swivel chair", "boat", "bar", "arcade machine", "hovel", "bus", "towel", "light", "truck", "tower", "chandelier", "awning", "streetlight", "booth", "television receiver", "airplane", "dirt track", "apparel", "pole", "land", "bannister", "escalator", "ottoman", "bottle", "buffet", "poster", "stage", "van", "ship", "fountain", "conveyer belt", "canopy", "washer", "plaything", "swimming pool", "stool", "barrel", "basket", "waterfall", "tent", "bag", "minibike", "cradle", "oven", "ball", "food", "step", "tank", "trade name", "microwave", "pot", "animal", "bicycle", "lake", "dishwasher", "screen", "blanket", "sculpture", "hood", "sconce", "vase", "traffic light", "tray", "ashcan", "fan", "pier", "crt screen", "plate", "monitor", "bulletin board", "shower", "radiator", "glass", "clock", "flag"]
diff --git a/datasets/ade847.json b/datasets/ade847.json
new file mode 100755
index 0000000000000000000000000000000000000000..5150122783f7330795ac28ca0dfeeb08bab8e3f3
--- /dev/null
+++ b/datasets/ade847.json
@@ -0,0 +1 @@
+["wall", "building, edifice", "sky", "tree", "road, route", "floor, flooring", "ceiling", "bed", "sidewalk, pavement", "earth, ground", "cabinet", "person, individual, someone, somebody, mortal, soul", "grass", "windowpane, window", "car, auto, automobile, machine, motorcar", "mountain, mount", "plant, flora, plant life", "table", "chair", "curtain, drape, drapery, mantle, pall", "door", "sofa, couch, lounge", "sea", "painting, picture", "water", "mirror", "house", "rug, carpet, carpeting", "shelf", "armchair", "fence, fencing", "field", "lamp", "rock, stone", "seat", "river", "desk", "bathtub, bathing tub, bath, tub", "railing, rail", "signboard, sign", "cushion", "path", "work surface", "stairs, steps", "column, pillar", "sink", "wardrobe, closet, press", "snow", "refrigerator, icebox", "base, pedestal, stand", "bridge, span", "blind, screen", "runway", "cliff, drop, drop-off", "sand", "fireplace, hearth, open fireplace", "pillow", "screen door, screen", "toilet, can, commode, crapper, pot, potty, stool, throne", "skyscraper", "grandstand, covered stand", "box", "pool table, billiard table, snooker table", "palm, palm tree", "double door", "coffee table, cocktail table", "counter", "countertop", "chest of drawers, chest, bureau, dresser", "kitchen island", "boat", "waterfall, falls", "stove, kitchen stove, range, kitchen range, cooking stove", "flower", "bookcase", "controls", "book", "stairway, staircase", "streetlight, street lamp", "computer, computing machine, computing device, data processor, electronic computer, information processing system", "bus, autobus, coach, charabanc, double-decker, jitney, motorbus, motorcoach, omnibus, passenger vehicle", "swivel chair", "light, light source", "bench", "case, display case, showcase, vitrine", "towel", "fountain", "embankment", "television receiver, television, television set, tv, tv set, idiot box, boob tube, telly, goggle box", "van", "hill", "awning, sunshade, sunblind", "poster, posting, placard, notice, bill, card", "truck, motortruck", "airplane, aeroplane, plane", "pole", "tower", "court", "ball", "aircraft carrier, carrier, flattop, attack aircraft carrier", "buffet, counter, sideboard", "hovel, hut, hutch, shack, shanty", "apparel, wearing apparel, dress, clothes", "minibike, motorbike", "animal, animate being, beast, brute, creature, fauna", "chandelier, pendant, pendent", "step, stair", "booth, cubicle, stall, kiosk", "bicycle, bike, wheel, cycle", "doorframe, doorcase", "sconce", "pond", "trade name, brand name, brand, marque", "bannister, banister, balustrade, balusters, handrail", "bag", "traffic light, traffic signal, stoplight", "gazebo", "escalator, moving staircase, moving stairway", "land, ground, soil", "board, plank", "arcade machine", "eiderdown, duvet, continental quilt", "bar", "stall, stand, sales booth", "playground", "ship", "ottoman, pouf, pouffe, puff, hassock", "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin", "bottle", "cradle", "pot, flowerpot", "conveyer belt, conveyor belt, conveyer, conveyor, transporter", "train, railroad train", "stool", "lake", "tank, storage tank", "ice, water ice", "basket, handbasket", "manhole", "tent, collapsible shelter", "canopy", "microwave, microwave oven", "barrel, cask", "dirt track", "beam", "dishwasher, dish washer, dishwashing machine", "plate", "screen, crt screen", "ruins", "washer, automatic washer, washing machine", "blanket, cover", "plaything, toy", "food, solid food", "screen, silver screen, projection screen", "oven", "stage", "beacon, lighthouse, beacon light, pharos", "umbrella", "sculpture", "aqueduct", "container", "scaffolding, staging", "hood, exhaust hood", "curb, curbing, kerb", "roller coaster", "horse, equus caballus", "catwalk", "glass, drinking glass", "vase", "central reservation", "carousel", "radiator", "closet", "machine", "pier, wharf, wharfage, dock", "fan", "inflatable bounce game", "pitch", "paper", "arcade, colonnade", "hot tub", "helicopter", "tray", "partition, divider", "vineyard", "bowl", "bullring", "flag", "pot", "footbridge, overcrossing, pedestrian bridge", "shower", "bag, traveling bag, travelling bag, grip, suitcase", "bulletin board, notice board", "confessional booth", "trunk, tree trunk, bole", "forest", "elevator door", "laptop, laptop computer", "instrument panel", "bucket, pail", "tapestry, tapis", "platform", "jacket", "gate", "monitor, monitoring device", "telephone booth, phone booth, call box, telephone box, telephone kiosk", "spotlight, spot", "ring", "control panel", "blackboard, chalkboard", "air conditioner, air conditioning", "chest", "clock", "sand dune", "pipe, pipage, piping", "vault", "table football", "cannon", "swimming pool, swimming bath, natatorium", "fluorescent, fluorescent fixture", "statue", "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system", "exhibitor", "ladder", "carport", "dam", "pulpit", "skylight, fanlight", "water tower", "grill, grille, grillwork", "display board", "pane, pane of glass, window glass", "rubbish, trash, scrap", "ice rink", "fruit", "patio", "vending machine", "telephone, phone, telephone set", "net", "backpack, back pack, knapsack, packsack, rucksack, haversack", "jar", "track", "magazine", "shutter", "roof", "banner, streamer", "landfill", "post", "altarpiece, reredos", "hat, chapeau, lid", "arch, archway", "table game", "bag, handbag, pocketbook, purse", "document, written document, papers", "dome", "pier", "shanties", "forecourt", "crane", "dog, domestic dog, canis familiaris", "piano, pianoforte, forte-piano", "drawing", "cabin", "ad, advertisement, advertizement, advertising, advertizing, advert", "amphitheater, amphitheatre, coliseum", "monument", "henhouse", "cockpit", "heater, warmer", "windmill, aerogenerator, wind generator", "pool", "elevator, lift", "decoration, ornament, ornamentation", "labyrinth", "text, textual matter", "printer", "mezzanine, first balcony", "mattress", "straw", "stalls", "patio, terrace", "billboard, hoarding", "bus stop", "trouser, pant", "console table, console", "rack", "notebook", "shrine", "pantry", "cart", "steam shovel", "porch", "postbox, mailbox, letter box", "figurine, statuette", "recycling bin", "folding screen", "telescope", "deck chair, beach chair", "kennel", "coffee maker", "altar, communion table, lord's table", "fish", "easel", "artificial golf green", "iceberg", "candlestick, candle holder", "shower stall, shower bath", "television stand", "wall socket, wall plug, electric outlet, electrical outlet, outlet, electric receptacle", "skeleton", "grand piano, grand", "candy, confect", "grille door", "pedestal, plinth, footstall", "jersey, t-shirt, tee shirt", "shoe", "gravestone, headstone, tombstone", "shanty", "structure", "rocking chair, rocker", "bird", "place mat", "tomb", "big top", "gas pump, gasoline pump, petrol pump, island dispenser", "lockers", "cage", "finger", "bleachers", "ferris wheel", "hairdresser chair", "mat", "stands", "aquarium, fish tank, marine museum", "streetcar, tram, tramcar, trolley, trolley car", "napkin, table napkin, serviette", "dummy", "booklet, brochure, folder, leaflet, pamphlet", "sand trap", "shop, store", "table cloth", "service station", "coffin", "drawer", "cages", "slot machine, coin machine", "balcony", "volleyball court", "table tennis", "control table", "shirt", "merchandise, ware, product", "railway", "parterre", "chimney", "can, tin, tin can", "tanks", "fabric, cloth, material, textile", "alga, algae", "system", "map", "greenhouse", "mug", "barbecue", "trailer", "toilet tissue, toilet paper, bathroom tissue", "organ", "dishrag, dishcloth", "island", "keyboard", "trench", "basket, basketball hoop, hoop", "steering wheel, wheel", "pitcher, ewer", "goal", "bread, breadstuff, staff of life", "beds", "wood", "file cabinet", "newspaper, paper", "motorboat", "rope", "guitar", "rubble", "scarf", "barrels", "cap", "leaves", "control tower", "dashboard", "bandstand", "lectern", "switch, electric switch, electrical switch", "baseboard, mopboard, skirting board", "shower room", "smoke", "faucet, spigot", "bulldozer", "saucepan", "shops", "meter", "crevasse", "gear", "candelabrum, candelabra", "sofa bed", "tunnel", "pallet", "wire, conducting wire", "kettle, boiler", "bidet", "baby buggy, baby carriage, carriage, perambulator, pram, stroller, go-cart, pushchair, pusher", "music stand", "pipe, tube", "cup", "parking meter", "ice hockey rink", "shelter", "weeds", "temple", "patty, cake", "ski slope", "panel", "wallet", "wheel", "towel rack, towel horse", "roundabout", "canister, cannister, tin", "rod", "soap dispenser", "bell", "canvas", "box office, ticket office, ticket booth", "teacup", "trellis", "workbench", "valley, vale", "toaster", "knife", "podium", "ramp", "tumble dryer", "fireplug, fire hydrant, plug", "gym shoe, sneaker, tennis shoe", "lab bench", "equipment", "rocky formation", "plastic", "calendar", "caravan", "check-in-desk", "ticket counter", "brush", "mill", "covered bridge", "bowling alley", "hanger", "excavator", "trestle", "revolving door", "blast furnace", "scale, weighing machine", "projector", "soap", "locker", "tractor", "stretcher", "frame", "grating", "alembic", "candle, taper, wax light", "barrier", "cardboard", "cave", "puddle", "tarp", "price tag", "watchtower", "meters", "light bulb, lightbulb, bulb, incandescent lamp, electric light, electric-light bulb", "tracks", "hair dryer", "skirt", "viaduct", "paper towel", "coat", "sheet", "fire extinguisher, extinguisher, asphyxiator", "water wheel", "pottery, clayware", "magazine rack", "teapot", "microphone, mike", "support", "forklift", "canyon", "cash register, register", "leaf, leafage, foliage", "remote control, remote", "soap dish", "windshield, windscreen", "cat", "cue, cue stick, pool cue, pool stick", "vent, venthole, vent-hole, blowhole", "videos", "shovel", "eaves", "antenna, aerial, transmitting aerial", "shipyard", "hen, biddy", "traffic cone", "washing machines", "truck crane", "cds", "niche", "scoreboard", "briefcase", "boot", "sweater, jumper", "hay", "pack", "bottle rack", "glacier", "pergola", "building materials", "television camera", "first floor", "rifle", "tennis table", "stadium", "safety belt", "cover", "dish rack", "synthesizer", "pumpkin", "gutter", "fruit stand", "ice floe, floe", "handle, grip, handgrip, hold", "wheelchair", "mousepad, mouse mat", "diploma", "fairground ride", "radio", "hotplate", "junk", "wheelbarrow", "stream", "toll plaza", "punching bag", "trough", "throne", "chair desk", "weighbridge", "extractor fan", "hanging clothes", "dish, dish aerial, dish antenna, saucer", "alarm clock, alarm", "ski lift", "chain", "garage", "mechanical shovel", "wine rack", "tramway", "treadmill", "menu", "block", "well", "witness stand", "branch", "duck", "casserole", "frying pan", "desk organizer", "mast", "spectacles, specs, eyeglasses, glasses", "service elevator", "dollhouse", "hammock", "clothes hanging", "photocopier", "notepad", "golf cart", "footpath", "cross", "baptismal font", "boiler", "skip", "rotisserie", "tables", "water mill", "helmet", "cover curtain", "brick", "table runner", "ashtray", "street box", "stick", "hangers", "cells", "urinal", "centerpiece", "portable fridge", "dvds", "golf club", "skirting board", "water cooler", "clipboard", "camera, photographic camera", "pigeonhole", "chips", "food processor", "post box", "lid", "drum", "blender", "cave entrance", "dental chair", "obelisk", "canoe", "mobile", "monitors", "pool ball", "cue rack", "baggage carts", "shore", "fork", "paper filer", "bicycle rack", "coat rack", "garland", "sports bag", "fish tank", "towel dispenser", "carriage", "brochure", "plaque", "stringer", "iron", "spoon", "flag pole", "toilet brush", "book stand", "water faucet, water tap, tap, hydrant", "ticket office", "broom", "dvd", "ice bucket", "carapace, shell, cuticle, shield", "tureen", "folders", "chess", "root", "sewing machine", "model", "pen", "violin", "sweatshirt", "recycling materials", "mitten", "chopping board, cutting board", "mask", "log", "mouse, computer mouse", "grill", "hole", "target", "trash bag", "chalk", "sticks", "balloon", "score", "hair spray", "roll", "runner", "engine", "inflatable glove", "games", "pallets", "baskets", "coop", "dvd player", "rocking horse", "buckets", "bread rolls", "shawl", "watering can", "spotlights", "post-it", "bowls", "security camera", "runner cloth", "lock", "alarm, warning device, alarm system", "side", "roulette", "bone", "cutlery", "pool balls", "wheels", "spice rack", "plant pots", "towel ring", "bread box", "video", "funfair", "breads", "tripod", "ironing board", "skimmer", "hollow", "scratching post", "tricycle", "file box", "mountain pass", "tombstones", "cooker", "card game, cards", "golf bag", "towel paper", "chaise lounge", "sun", "toilet paper holder", "rake", "key", "umbrella stand", "dartboard", "transformer", "fireplace utensils", "sweatshirts", "cellular telephone, cellular phone, cellphone, cell, mobile phone", "tallboy", "stapler", "sauna", "test tube", "palette", "shopping carts", "tools", "push button, push, button", "star", "roof rack", "barbed wire", "spray", "ear", "sponge", "racket", "tins", "eyeglasses", "file", "scarfs", "sugar bowl", "flip flop", "headstones", "laptop bag", "leash", "climbing frame", "suit hanger", "floor spotlight", "plate rack", "sewer", "hard drive", "sprinkler", "tools box", "necklace", "bulbs", "steel industry", "club", "jack", "door bars", "control panel, instrument panel, control board, board, panel", "hairbrush", "napkin holder", "office", "smoke detector", "utensils", "apron", "scissors", "terminal", "grinder", "entry phone", "newspaper stand", "pepper shaker", "onions", "central processing unit, cpu, c p u , central processor, processor, mainframe", "tape", "bat", "coaster", "calculator", "potatoes", "luggage rack", "salt", "street number", "viewpoint", "sword", "cd", "rowing machine", "plug", "andiron, firedog, dog, dog-iron", "pepper", "tongs", "bonfire", "dog dish", "belt", "dumbbells", "videocassette recorder, vcr", "hook", "envelopes", "shower faucet", "watch", "padlock", "swimming pool ladder", "spanners", "gravy boat", "notice board", "trash bags", "fire alarm", "ladle", "stethoscope", "rocket", "funnel", "bowling pins", "valve", "thermometer", "cups", "spice jar", "night light", "soaps", "games table", "slotted spoon", "reel", "scourer", "sleeping robe", "desk mat", "dumbbell", "hammer", "tie", "typewriter", "shaker", "cheese dish", "sea star", "racquet", "butane gas cylinder", "paper weight", "shaving brush", "sunglasses", "gear shift", "towel rail", "adding machine, totalizer, totaliser"]
\ No newline at end of file
diff --git a/datasets/coco.json b/datasets/coco.json
new file mode 100755
index 0000000000000000000000000000000000000000..8feed8cb0c0a79d26879f3156ba22ad9557604cf
--- /dev/null
+++ b/datasets/coco.json
@@ -0,0 +1 @@
+["person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush", "banner", "blanket", "branch", "bridge", "building-other", "bush", "cabinet", "cage", "cardboard", "carpet", "ceiling-other", "ceiling-tile", "cloth", "clothes", "clouds", "counter", "cupboard", "curtain", "desk-stuff", "dirt", "door-stuff", "fence", "floor-marble", "floor-other", "floor-stone", "floor-tile", "floor-wood", "flower", "fog", "food-other", "fruit", "furniture-other", "grass", "gravel", "ground-other", "hill", "house", "leaves", "light", "mat", "metal", "mirror-stuff", "moss", "mountain", "mud", "napkin", "net", "paper", "pavement", "pillow", "plant-other", "plastic", "platform", "playingfield", "railing", "railroad", "river", "road", "rock", "roof", "rug", "salad", "sand", "sea", "shelf", "sky-other", "skyscraper", "snow", "solid-other", "stairs", "stone", "straw", "structural-other", "table", "tent", "textile-other", "towel", "tree", "vegetable", "wall-brick", "wall-concrete", "wall-other", "wall-panel", "wall-stone", "wall-tile", "wall-wood", "water-other", "waterdrops", "window-blind", "window-other", "wood"]
\ No newline at end of file
diff --git a/datasets/pc459.json b/datasets/pc459.json
new file mode 100755
index 0000000000000000000000000000000000000000..a7554da8933c55cb57e2352466beccefb6ac2a89
--- /dev/null
+++ b/datasets/pc459.json
@@ -0,0 +1,2 @@
+["accordion", "aeroplane", "airconditioner", "antenna", "artillery", "ashtray", "atrium", "babycarriage", "bag", "ball", "balloon", "bambooweaving", "barrel", "baseballbat", "basket", "basketballbackboard", "bathtub", "bed", "bedclothes", "beer", "bell", "bench", "bicycle", "binoculars", "bird", "birdcage", "birdfeeder", "birdnest", "blackboard", "board", "boat", "bone", "book", "bottle", "bottleopener", "bowl", "box", "bracelet", "brick", "bridge", "broom", "brush", "bucket", "building", "bus", "cabinet", "cabinetdoor", "cage", "cake", "calculator", "calendar", "camel", "camera", "cameralens", "can", "candle", "candleholder", "cap", "car", "card", "cart", "case", "casetterecorder", "cashregister", "cat", "cd", "cdplayer", "ceiling", "cellphone", "cello", "chain", "chair", "chessboard", "chicken", "chopstick", "clip", "clippers", "clock", "closet", "cloth", "clothestree", "coffee", "coffeemachine", "comb", "computer", "concrete", "cone", "container", "controlbooth", "controller", "cooker", "copyingmachine", "coral", "cork", "corkscrew", "counter", "court", "cow", "crabstick", "crane", "crate", "cross", "crutch", "cup", "curtain", "cushion", "cuttingboard", "dais", "disc", "disccase", "dishwasher", "dock", "dog", "dolphin", "door", "drainer", "dray", "drinkdispenser", "drinkingmachine", "drop", "drug", "drum", "drumkit", "duck", "dumbbell", "earphone", "earrings", "egg", "electricfan", "electriciron", "electricpot", "electricsaw", "electronickeyboard", "engine", "envelope", "equipment", "escalator", "exhibitionbooth", "extinguisher", "eyeglass", "fan", "faucet", "faxmachine", "fence", "ferriswheel", "fireextinguisher", "firehydrant", "fireplace", "fish", "fishtank", "fishbowl", "fishingnet", "fishingpole", "flag", "flagstaff", "flame", "flashlight", "floor", "flower", "fly", "foam", "food", "footbridge", "forceps", "fork", "forklift", "fountain", "fox", "frame", "fridge", "frog", "fruit", "funnel", "furnace", "gamecontroller", "gamemachine", "gascylinder", "gashood", "gasstove", "giftbox", "glass", "glassmarble", "globe", "glove", "goal", "grandstand", "grass", "gravestone", "ground", "guardrail", "guitar", "gun", "hammer", "handcart", "handle", "handrail", "hanger", "harddiskdrive", "hat", "hay", "headphone", "heater", "helicopter", "helmet", "holder", "hook", "horse", "horse-drawncarriage", "hot-airballoon", "hydrovalve", "ice", "inflatorpump", "ipod", "iron", "ironingboard", "jar", "kart", "kettle", "key", "keyboard", "kitchenrange", "kite", "knife", "knifeblock", "ladder", "laddertruck", "ladle", "laptop", "leaves", "lid", "lifebuoy", "light", "lightbulb", "lighter", "line", "lion", "lobster", "lock", "machine", "mailbox", "mannequin", "map", "mask", "mat", "matchbook", "mattress", "menu", "metal", "meterbox", "microphone", "microwave", "mirror", "missile", "model", "money", "monkey", "mop", "motorbike", "mountain", "mouse", "mousepad", "musicalinstrument", "napkin", "net", "newspaper", "oar", "ornament", "outlet", "oven", "oxygenbottle", "pack", "pan", "paper", "paperbox", "papercutter", "parachute", "parasol", "parterre", "patio", "pelage", "pen", "pencontainer", "pencil", "person", "photo", "piano", "picture", "pig", "pillar", "pillow", "pipe", "pitcher", "plant", "plastic", "plate", "platform", "player", "playground", "pliers", "plume", "poker", "pokerchip", "pole", "pooltable", "postcard", "poster", "pot", "pottedplant", "printer", "projector", "pumpkin", "rabbit", "racket", "radiator", "radio", "rail", "rake", "ramp", "rangehood", "receiver", "recorder", "recreationalmachines", "remotecontrol", "road", "robot", "rock", "rocket", "rockinghorse", "rope", "rug", "ruler", "runway", "saddle", "sand", "saw", "scale", "scanner", "scissors", "scoop", "screen", "screwdriver", "sculpture", "scythe", "sewer", "sewingmachine", "shed", "sheep", "shell", "shelves", "shoe", "shoppingcart", "shovel", "sidecar", "sidewalk", "sign", "signallight", "sink", "skateboard", "ski", "sky", "sled", "slippers", "smoke", "snail", "snake", "snow", "snowmobiles", "sofa", "spanner", "spatula", "speaker", "speedbump", "spicecontainer", "spoon", "sprayer", "squirrel", "stage", "stair", "stapler", "stick", "stickynote", "stone", "stool", "stove", "straw", "stretcher", "sun", "sunglass", "sunshade", "surveillancecamera", "swan", "sweeper", "swimring", "swimmingpool", "swing", "switch", "table", "tableware", "tank", "tap", "tape", "tarp", "telephone", "telephonebooth", "tent", "tire", "toaster", "toilet", "tong", "tool", "toothbrush", "towel", "toy", "toycar", "track", "train", "trampoline", "trashbin", "tray", "tree", "tricycle", "tripod", "trophy", "truck", "tube", "turtle", "tvmonitor", "tweezers", "typewriter", "umbrella", "unknown", "vacuumcleaner", "vendingmachine", "videocamera", "videogameconsole", "videoplayer", "videotape", "violin", "wakeboard", "wall", "wallet", "wardrobe", "washingmachine", "watch", "water", "waterdispenser", "waterpipe", "waterskateboard", "watermelon", "whale", "wharf", "wheel", "wheelchair", "window", "windowblinds", "wineglass", "wire", "wood", "wool"]
+
diff --git a/datasets/pc59.json b/datasets/pc59.json
new file mode 100755
index 0000000000000000000000000000000000000000..cb0bd21a8958bed63c6d2f6a3cb270b9c4e99635
--- /dev/null
+++ b/datasets/pc59.json
@@ -0,0 +1 @@
+["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor", "bag", "bed", "bench", "book", "building", "cabinet", "ceiling", "cloth", "computer", "cup", "door", "fence", "floor", "flower", "food", "grass", "ground", "keyboard", "light", "mountain", "mouse", "curtain", "platform", "sign", "plate", "road", "rock", "shelves", "sidewalk", "sky", "snow", "bedclothes", "track", "tree", "truck", "wall", "water", "window", "wood"]
\ No newline at end of file
diff --git a/datasets/prepare_ade20k_150.py b/datasets/prepare_ade20k_150.py
new file mode 100644
index 0000000000000000000000000000000000000000..c001db4bdf17a1b03693aaa60b8ced153e081c6c
--- /dev/null
+++ b/datasets/prepare_ade20k_150.py
@@ -0,0 +1,27 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+# Copyright (c) Facebook, Inc. and its affiliates.
+import os
+from pathlib import Path
+
+import numpy as np
+import tqdm
+from PIL import Image
+
+
+def convert(input, output):
+ img = np.asarray(Image.open(input))
+ assert img.dtype == np.uint8
+ img = img - 1 # 0 (ignore) becomes 255. others are shifted by 1
+ Image.fromarray(img).save(output)
+
+
+if __name__ == "__main__":
+ dataset_dir = Path(os.getenv("DETECTRON2_DATASETS", "datasets")) / "ADEChallengeData2016"
+ for name in ["validation"]:
+ annotation_dir = dataset_dir / "annotations" / name
+ output_dir = dataset_dir / "annotations_detectron2" / name
+ output_dir.mkdir(parents=True, exist_ok=True)
+ for file in tqdm.tqdm(list(annotation_dir.iterdir())):
+ output_file = output_dir / file.name
+ convert(file, output_file)
\ No newline at end of file
diff --git a/datasets/prepare_ade20k_full.py b/datasets/prepare_ade20k_full.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a55e039549ff0aaf928a4dddee7a94ea8d0f6bf
--- /dev/null
+++ b/datasets/prepare_ade20k_full.py
@@ -0,0 +1,1011 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved
+
+import os
+import pickle as pkl
+from pathlib import Path
+
+import cv2
+import numpy as np
+import tqdm
+from PIL import Image
+
+ADE20K_SEM_SEG_FULL_CATEGORIES = [
+ {"name": "wall", "id": 2978, "trainId": 0},
+ {"name": "building, edifice", "id": 312, "trainId": 1},
+ {"name": "sky", "id": 2420, "trainId": 2},
+ {"name": "tree", "id": 2855, "trainId": 3},
+ {"name": "road, route", "id": 2131, "trainId": 4},
+ {"name": "floor, flooring", "id": 976, "trainId": 5},
+ {"name": "ceiling", "id": 447, "trainId": 6},
+ {"name": "bed", "id": 165, "trainId": 7},
+ {"name": "sidewalk, pavement", "id": 2377, "trainId": 8},
+ {"name": "earth, ground", "id": 838, "trainId": 9},
+ {"name": "cabinet", "id": 350, "trainId": 10},
+ {"name": "person, individual, someone, somebody, mortal, soul", "id": 1831, "trainId": 11},
+ {"name": "grass", "id": 1125, "trainId": 12},
+ {"name": "windowpane, window", "id": 3055, "trainId": 13},
+ {"name": "car, auto, automobile, machine, motorcar", "id": 401, "trainId": 14},
+ {"name": "mountain, mount", "id": 1610, "trainId": 15},
+ {"name": "plant, flora, plant life", "id": 1910, "trainId": 16},
+ {"name": "table", "id": 2684, "trainId": 17},
+ {"name": "chair", "id": 471, "trainId": 18},
+ {"name": "curtain, drape, drapery, mantle, pall", "id": 687, "trainId": 19},
+ {"name": "door", "id": 774, "trainId": 20},
+ {"name": "sofa, couch, lounge", "id": 2473, "trainId": 21},
+ {"name": "sea", "id": 2264, "trainId": 22},
+ {"name": "painting, picture", "id": 1735, "trainId": 23},
+ {"name": "water", "id": 2994, "trainId": 24},
+ {"name": "mirror", "id": 1564, "trainId": 25},
+ {"name": "house", "id": 1276, "trainId": 26},
+ {"name": "rug, carpet, carpeting", "id": 2178, "trainId": 27},
+ {"name": "shelf", "id": 2329, "trainId": 28},
+ {"name": "armchair", "id": 57, "trainId": 29},
+ {"name": "fence, fencing", "id": 907, "trainId": 30},
+ {"name": "field", "id": 913, "trainId": 31},
+ {"name": "lamp", "id": 1395, "trainId": 32},
+ {"name": "rock, stone", "id": 2138, "trainId": 33},
+ {"name": "seat", "id": 2272, "trainId": 34},
+ {"name": "river", "id": 2128, "trainId": 35},
+ {"name": "desk", "id": 724, "trainId": 36},
+ {"name": "bathtub, bathing tub, bath, tub", "id": 155, "trainId": 37},
+ {"name": "railing, rail", "id": 2053, "trainId": 38},
+ {"name": "signboard, sign", "id": 2380, "trainId": 39},
+ {"name": "cushion", "id": 689, "trainId": 40},
+ {"name": "path", "id": 1788, "trainId": 41},
+ {"name": "work surface", "id": 3087, "trainId": 42},
+ {"name": "stairs, steps", "id": 2530, "trainId": 43},
+ {"name": "column, pillar", "id": 581, "trainId": 44},
+ {"name": "sink", "id": 2388, "trainId": 45},
+ {"name": "wardrobe, closet, press", "id": 2985, "trainId": 46},
+ {"name": "snow", "id": 2454, "trainId": 47},
+ {"name": "refrigerator, icebox", "id": 2096, "trainId": 48},
+ {"name": "base, pedestal, stand", "id": 137, "trainId": 49},
+ {"name": "bridge, span", "id": 294, "trainId": 50},
+ {"name": "blind, screen", "id": 212, "trainId": 51},
+ {"name": "runway", "id": 2185, "trainId": 52},
+ {"name": "cliff, drop, drop-off", "id": 524, "trainId": 53},
+ {"name": "sand", "id": 2212, "trainId": 54},
+ {"name": "fireplace, hearth, open fireplace", "id": 943, "trainId": 55},
+ {"name": "pillow", "id": 1869, "trainId": 56},
+ {"name": "screen door, screen", "id": 2251, "trainId": 57},
+ {"name": "toilet, can, commode, crapper, pot, potty, stool, throne", "id": 2793, "trainId": 58},
+ {"name": "skyscraper", "id": 2423, "trainId": 59},
+ {"name": "grandstand, covered stand", "id": 1121, "trainId": 60},
+ {"name": "box", "id": 266, "trainId": 61},
+ {"name": "pool table, billiard table, snooker table", "id": 1948, "trainId": 62},
+ {"name": "palm, palm tree", "id": 1744, "trainId": 63},
+ {"name": "double door", "id": 783, "trainId": 64},
+ {"name": "coffee table, cocktail table", "id": 571, "trainId": 65},
+ {"name": "counter", "id": 627, "trainId": 66},
+ {"name": "countertop", "id": 629, "trainId": 67},
+ {"name": "chest of drawers, chest, bureau, dresser", "id": 491, "trainId": 68},
+ {"name": "kitchen island", "id": 1374, "trainId": 69},
+ {"name": "boat", "id": 223, "trainId": 70},
+ {"name": "waterfall, falls", "id": 3016, "trainId": 71},
+ {
+ "name": "stove, kitchen stove, range, kitchen range, cooking stove",
+ "id": 2598,
+ "trainId": 72,
+ },
+ {"name": "flower", "id": 978, "trainId": 73},
+ {"name": "bookcase", "id": 239, "trainId": 74},
+ {"name": "controls", "id": 608, "trainId": 75},
+ {"name": "book", "id": 236, "trainId": 76},
+ {"name": "stairway, staircase", "id": 2531, "trainId": 77},
+ {"name": "streetlight, street lamp", "id": 2616, "trainId": 78},
+ {
+ "name": "computer, computing machine, computing device, data processor, electronic computer, information processing system",
+ "id": 591,
+ "trainId": 79,
+ },
+ {
+ "name": "bus, autobus, coach, charabanc, double-decker, jitney, motorbus, motorcoach, omnibus, passenger vehicle",
+ "id": 327,
+ "trainId": 80,
+ },
+ {"name": "swivel chair", "id": 2679, "trainId": 81},
+ {"name": "light, light source", "id": 1451, "trainId": 82},
+ {"name": "bench", "id": 181, "trainId": 83},
+ {"name": "case, display case, showcase, vitrine", "id": 420, "trainId": 84},
+ {"name": "towel", "id": 2821, "trainId": 85},
+ {"name": "fountain", "id": 1023, "trainId": 86},
+ {"name": "embankment", "id": 855, "trainId": 87},
+ {
+ "name": "television receiver, television, television set, tv, tv set, idiot box, boob tube, telly, goggle box",
+ "id": 2733,
+ "trainId": 88,
+ },
+ {"name": "van", "id": 2928, "trainId": 89},
+ {"name": "hill", "id": 1240, "trainId": 90},
+ {"name": "awning, sunshade, sunblind", "id": 77, "trainId": 91},
+ {"name": "poster, posting, placard, notice, bill, card", "id": 1969, "trainId": 92},
+ {"name": "truck, motortruck", "id": 2880, "trainId": 93},
+ {"name": "airplane, aeroplane, plane", "id": 14, "trainId": 94},
+ {"name": "pole", "id": 1936, "trainId": 95},
+ {"name": "tower", "id": 2828, "trainId": 96},
+ {"name": "court", "id": 631, "trainId": 97},
+ {"name": "ball", "id": 103, "trainId": 98},
+ {
+ "name": "aircraft carrier, carrier, flattop, attack aircraft carrier",
+ "id": 3144,
+ "trainId": 99,
+ },
+ {"name": "buffet, counter, sideboard", "id": 308, "trainId": 100},
+ {"name": "hovel, hut, hutch, shack, shanty", "id": 1282, "trainId": 101},
+ {"name": "apparel, wearing apparel, dress, clothes", "id": 38, "trainId": 102},
+ {"name": "minibike, motorbike", "id": 1563, "trainId": 103},
+ {"name": "animal, animate being, beast, brute, creature, fauna", "id": 29, "trainId": 104},
+ {"name": "chandelier, pendant, pendent", "id": 480, "trainId": 105},
+ {"name": "step, stair", "id": 2569, "trainId": 106},
+ {"name": "booth, cubicle, stall, kiosk", "id": 247, "trainId": 107},
+ {"name": "bicycle, bike, wheel, cycle", "id": 187, "trainId": 108},
+ {"name": "doorframe, doorcase", "id": 778, "trainId": 109},
+ {"name": "sconce", "id": 2243, "trainId": 110},
+ {"name": "pond", "id": 1941, "trainId": 111},
+ {"name": "trade name, brand name, brand, marque", "id": 2833, "trainId": 112},
+ {"name": "bannister, banister, balustrade, balusters, handrail", "id": 120, "trainId": 113},
+ {"name": "bag", "id": 95, "trainId": 114},
+ {"name": "traffic light, traffic signal, stoplight", "id": 2836, "trainId": 115},
+ {"name": "gazebo", "id": 1087, "trainId": 116},
+ {"name": "escalator, moving staircase, moving stairway", "id": 868, "trainId": 117},
+ {"name": "land, ground, soil", "id": 1401, "trainId": 118},
+ {"name": "board, plank", "id": 220, "trainId": 119},
+ {"name": "arcade machine", "id": 47, "trainId": 120},
+ {"name": "eiderdown, duvet, continental quilt", "id": 843, "trainId": 121},
+ {"name": "bar", "id": 123, "trainId": 122},
+ {"name": "stall, stand, sales booth", "id": 2537, "trainId": 123},
+ {"name": "playground", "id": 1927, "trainId": 124},
+ {"name": "ship", "id": 2337, "trainId": 125},
+ {"name": "ottoman, pouf, pouffe, puff, hassock", "id": 1702, "trainId": 126},
+ {
+ "name": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin",
+ "id": 64,
+ "trainId": 127,
+ },
+ {"name": "bottle", "id": 249, "trainId": 128},
+ {"name": "cradle", "id": 642, "trainId": 129},
+ {"name": "pot, flowerpot", "id": 1981, "trainId": 130},
+ {
+ "name": "conveyer belt, conveyor belt, conveyer, conveyor, transporter",
+ "id": 609,
+ "trainId": 131,
+ },
+ {"name": "train, railroad train", "id": 2840, "trainId": 132},
+ {"name": "stool", "id": 2586, "trainId": 133},
+ {"name": "lake", "id": 1393, "trainId": 134},
+ {"name": "tank, storage tank", "id": 2704, "trainId": 135},
+ {"name": "ice, water ice", "id": 1304, "trainId": 136},
+ {"name": "basket, handbasket", "id": 146, "trainId": 137},
+ {"name": "manhole", "id": 1494, "trainId": 138},
+ {"name": "tent, collapsible shelter", "id": 2739, "trainId": 139},
+ {"name": "canopy", "id": 389, "trainId": 140},
+ {"name": "microwave, microwave oven", "id": 1551, "trainId": 141},
+ {"name": "barrel, cask", "id": 131, "trainId": 142},
+ {"name": "dirt track", "id": 738, "trainId": 143},
+ {"name": "beam", "id": 161, "trainId": 144},
+ {"name": "dishwasher, dish washer, dishwashing machine", "id": 747, "trainId": 145},
+ {"name": "plate", "id": 1919, "trainId": 146},
+ {"name": "screen, crt screen", "id": 3109, "trainId": 147},
+ {"name": "ruins", "id": 2179, "trainId": 148},
+ {"name": "washer, automatic washer, washing machine", "id": 2989, "trainId": 149},
+ {"name": "blanket, cover", "id": 206, "trainId": 150},
+ {"name": "plaything, toy", "id": 1930, "trainId": 151},
+ {"name": "food, solid food", "id": 1002, "trainId": 152},
+ {"name": "screen, silver screen, projection screen", "id": 2254, "trainId": 153},
+ {"name": "oven", "id": 1708, "trainId": 154},
+ {"name": "stage", "id": 2526, "trainId": 155},
+ {"name": "beacon, lighthouse, beacon light, pharos", "id": 160, "trainId": 156},
+ {"name": "umbrella", "id": 2901, "trainId": 157},
+ {"name": "sculpture", "id": 2262, "trainId": 158},
+ {"name": "aqueduct", "id": 44, "trainId": 159},
+ {"name": "container", "id": 597, "trainId": 160},
+ {"name": "scaffolding, staging", "id": 2235, "trainId": 161},
+ {"name": "hood, exhaust hood", "id": 1260, "trainId": 162},
+ {"name": "curb, curbing, kerb", "id": 682, "trainId": 163},
+ {"name": "roller coaster", "id": 2151, "trainId": 164},
+ {"name": "horse, equus caballus", "id": 3107, "trainId": 165},
+ {"name": "catwalk", "id": 432, "trainId": 166},
+ {"name": "glass, drinking glass", "id": 1098, "trainId": 167},
+ {"name": "vase", "id": 2932, "trainId": 168},
+ {"name": "central reservation", "id": 461, "trainId": 169},
+ {"name": "carousel", "id": 410, "trainId": 170},
+ {"name": "radiator", "id": 2046, "trainId": 171},
+ {"name": "closet", "id": 533, "trainId": 172},
+ {"name": "machine", "id": 1481, "trainId": 173},
+ {"name": "pier, wharf, wharfage, dock", "id": 1858, "trainId": 174},
+ {"name": "fan", "id": 894, "trainId": 175},
+ {"name": "inflatable bounce game", "id": 1322, "trainId": 176},
+ {"name": "pitch", "id": 1891, "trainId": 177},
+ {"name": "paper", "id": 1756, "trainId": 178},
+ {"name": "arcade, colonnade", "id": 49, "trainId": 179},
+ {"name": "hot tub", "id": 1272, "trainId": 180},
+ {"name": "helicopter", "id": 1229, "trainId": 181},
+ {"name": "tray", "id": 2850, "trainId": 182},
+ {"name": "partition, divider", "id": 1784, "trainId": 183},
+ {"name": "vineyard", "id": 2962, "trainId": 184},
+ {"name": "bowl", "id": 259, "trainId": 185},
+ {"name": "bullring", "id": 319, "trainId": 186},
+ {"name": "flag", "id": 954, "trainId": 187},
+ {"name": "pot", "id": 1974, "trainId": 188},
+ {"name": "footbridge, overcrossing, pedestrian bridge", "id": 1013, "trainId": 189},
+ {"name": "shower", "id": 2356, "trainId": 190},
+ {"name": "bag, traveling bag, travelling bag, grip, suitcase", "id": 97, "trainId": 191},
+ {"name": "bulletin board, notice board", "id": 318, "trainId": 192},
+ {"name": "confessional booth", "id": 592, "trainId": 193},
+ {"name": "trunk, tree trunk, bole", "id": 2885, "trainId": 194},
+ {"name": "forest", "id": 1017, "trainId": 195},
+ {"name": "elevator door", "id": 851, "trainId": 196},
+ {"name": "laptop, laptop computer", "id": 1407, "trainId": 197},
+ {"name": "instrument panel", "id": 1332, "trainId": 198},
+ {"name": "bucket, pail", "id": 303, "trainId": 199},
+ {"name": "tapestry, tapis", "id": 2714, "trainId": 200},
+ {"name": "platform", "id": 1924, "trainId": 201},
+ {"name": "jacket", "id": 1346, "trainId": 202},
+ {"name": "gate", "id": 1081, "trainId": 203},
+ {"name": "monitor, monitoring device", "id": 1583, "trainId": 204},
+ {
+ "name": "telephone booth, phone booth, call box, telephone box, telephone kiosk",
+ "id": 2727,
+ "trainId": 205,
+ },
+ {"name": "spotlight, spot", "id": 2509, "trainId": 206},
+ {"name": "ring", "id": 2123, "trainId": 207},
+ {"name": "control panel", "id": 602, "trainId": 208},
+ {"name": "blackboard, chalkboard", "id": 202, "trainId": 209},
+ {"name": "air conditioner, air conditioning", "id": 10, "trainId": 210},
+ {"name": "chest", "id": 490, "trainId": 211},
+ {"name": "clock", "id": 530, "trainId": 212},
+ {"name": "sand dune", "id": 2213, "trainId": 213},
+ {"name": "pipe, pipage, piping", "id": 1884, "trainId": 214},
+ {"name": "vault", "id": 2934, "trainId": 215},
+ {"name": "table football", "id": 2687, "trainId": 216},
+ {"name": "cannon", "id": 387, "trainId": 217},
+ {"name": "swimming pool, swimming bath, natatorium", "id": 2668, "trainId": 218},
+ {"name": "fluorescent, fluorescent fixture", "id": 982, "trainId": 219},
+ {"name": "statue", "id": 2547, "trainId": 220},
+ {
+ "name": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system",
+ "id": 1474,
+ "trainId": 221,
+ },
+ {"name": "exhibitor", "id": 877, "trainId": 222},
+ {"name": "ladder", "id": 1391, "trainId": 223},
+ {"name": "carport", "id": 414, "trainId": 224},
+ {"name": "dam", "id": 698, "trainId": 225},
+ {"name": "pulpit", "id": 2019, "trainId": 226},
+ {"name": "skylight, fanlight", "id": 2422, "trainId": 227},
+ {"name": "water tower", "id": 3010, "trainId": 228},
+ {"name": "grill, grille, grillwork", "id": 1139, "trainId": 229},
+ {"name": "display board", "id": 753, "trainId": 230},
+ {"name": "pane, pane of glass, window glass", "id": 1747, "trainId": 231},
+ {"name": "rubbish, trash, scrap", "id": 2175, "trainId": 232},
+ {"name": "ice rink", "id": 1301, "trainId": 233},
+ {"name": "fruit", "id": 1033, "trainId": 234},
+ {"name": "patio", "id": 1789, "trainId": 235},
+ {"name": "vending machine", "id": 2939, "trainId": 236},
+ {"name": "telephone, phone, telephone set", "id": 2730, "trainId": 237},
+ {"name": "net", "id": 1652, "trainId": 238},
+ {
+ "name": "backpack, back pack, knapsack, packsack, rucksack, haversack",
+ "id": 90,
+ "trainId": 239,
+ },
+ {"name": "jar", "id": 1349, "trainId": 240},
+ {"name": "track", "id": 2830, "trainId": 241},
+ {"name": "magazine", "id": 1485, "trainId": 242},
+ {"name": "shutter", "id": 2370, "trainId": 243},
+ {"name": "roof", "id": 2155, "trainId": 244},
+ {"name": "banner, streamer", "id": 118, "trainId": 245},
+ {"name": "landfill", "id": 1402, "trainId": 246},
+ {"name": "post", "id": 1957, "trainId": 247},
+ {"name": "altarpiece, reredos", "id": 3130, "trainId": 248},
+ {"name": "hat, chapeau, lid", "id": 1197, "trainId": 249},
+ {"name": "arch, archway", "id": 52, "trainId": 250},
+ {"name": "table game", "id": 2688, "trainId": 251},
+ {"name": "bag, handbag, pocketbook, purse", "id": 96, "trainId": 252},
+ {"name": "document, written document, papers", "id": 762, "trainId": 253},
+ {"name": "dome", "id": 772, "trainId": 254},
+ {"name": "pier", "id": 1857, "trainId": 255},
+ {"name": "shanties", "id": 2315, "trainId": 256},
+ {"name": "forecourt", "id": 1016, "trainId": 257},
+ {"name": "crane", "id": 643, "trainId": 258},
+ {"name": "dog, domestic dog, canis familiaris", "id": 3105, "trainId": 259},
+ {"name": "piano, pianoforte, forte-piano", "id": 1849, "trainId": 260},
+ {"name": "drawing", "id": 791, "trainId": 261},
+ {"name": "cabin", "id": 349, "trainId": 262},
+ {
+ "name": "ad, advertisement, advertizement, advertising, advertizing, advert",
+ "id": 6,
+ "trainId": 263,
+ },
+ {"name": "amphitheater, amphitheatre, coliseum", "id": 3114, "trainId": 264},
+ {"name": "monument", "id": 1587, "trainId": 265},
+ {"name": "henhouse", "id": 1233, "trainId": 266},
+ {"name": "cockpit", "id": 559, "trainId": 267},
+ {"name": "heater, warmer", "id": 1223, "trainId": 268},
+ {"name": "windmill, aerogenerator, wind generator", "id": 3049, "trainId": 269},
+ {"name": "pool", "id": 1943, "trainId": 270},
+ {"name": "elevator, lift", "id": 853, "trainId": 271},
+ {"name": "decoration, ornament, ornamentation", "id": 709, "trainId": 272},
+ {"name": "labyrinth", "id": 1390, "trainId": 273},
+ {"name": "text, textual matter", "id": 2748, "trainId": 274},
+ {"name": "printer", "id": 2007, "trainId": 275},
+ {"name": "mezzanine, first balcony", "id": 1546, "trainId": 276},
+ {"name": "mattress", "id": 1513, "trainId": 277},
+ {"name": "straw", "id": 2600, "trainId": 278},
+ {"name": "stalls", "id": 2538, "trainId": 279},
+ {"name": "patio, terrace", "id": 1790, "trainId": 280},
+ {"name": "billboard, hoarding", "id": 194, "trainId": 281},
+ {"name": "bus stop", "id": 326, "trainId": 282},
+ {"name": "trouser, pant", "id": 2877, "trainId": 283},
+ {"name": "console table, console", "id": 594, "trainId": 284},
+ {"name": "rack", "id": 2036, "trainId": 285},
+ {"name": "notebook", "id": 1662, "trainId": 286},
+ {"name": "shrine", "id": 2366, "trainId": 287},
+ {"name": "pantry", "id": 1754, "trainId": 288},
+ {"name": "cart", "id": 418, "trainId": 289},
+ {"name": "steam shovel", "id": 2553, "trainId": 290},
+ {"name": "porch", "id": 1951, "trainId": 291},
+ {"name": "postbox, mailbox, letter box", "id": 1963, "trainId": 292},
+ {"name": "figurine, statuette", "id": 918, "trainId": 293},
+ {"name": "recycling bin", "id": 2086, "trainId": 294},
+ {"name": "folding screen", "id": 997, "trainId": 295},
+ {"name": "telescope", "id": 2731, "trainId": 296},
+ {"name": "deck chair, beach chair", "id": 704, "trainId": 297},
+ {"name": "kennel", "id": 1365, "trainId": 298},
+ {"name": "coffee maker", "id": 569, "trainId": 299},
+ {"name": "altar, communion table, lord's table", "id": 3108, "trainId": 300},
+ {"name": "fish", "id": 948, "trainId": 301},
+ {"name": "easel", "id": 839, "trainId": 302},
+ {"name": "artificial golf green", "id": 63, "trainId": 303},
+ {"name": "iceberg", "id": 1305, "trainId": 304},
+ {"name": "candlestick, candle holder", "id": 378, "trainId": 305},
+ {"name": "shower stall, shower bath", "id": 2362, "trainId": 306},
+ {"name": "television stand", "id": 2734, "trainId": 307},
+ {
+ "name": "wall socket, wall plug, electric outlet, electrical outlet, outlet, electric receptacle",
+ "id": 2982,
+ "trainId": 308,
+ },
+ {"name": "skeleton", "id": 2398, "trainId": 309},
+ {"name": "grand piano, grand", "id": 1119, "trainId": 310},
+ {"name": "candy, confect", "id": 382, "trainId": 311},
+ {"name": "grille door", "id": 1141, "trainId": 312},
+ {"name": "pedestal, plinth, footstall", "id": 1805, "trainId": 313},
+ {"name": "jersey, t-shirt, tee shirt", "id": 3102, "trainId": 314},
+ {"name": "shoe", "id": 2341, "trainId": 315},
+ {"name": "gravestone, headstone, tombstone", "id": 1131, "trainId": 316},
+ {"name": "shanty", "id": 2316, "trainId": 317},
+ {"name": "structure", "id": 2626, "trainId": 318},
+ {"name": "rocking chair, rocker", "id": 3104, "trainId": 319},
+ {"name": "bird", "id": 198, "trainId": 320},
+ {"name": "place mat", "id": 1896, "trainId": 321},
+ {"name": "tomb", "id": 2800, "trainId": 322},
+ {"name": "big top", "id": 190, "trainId": 323},
+ {"name": "gas pump, gasoline pump, petrol pump, island dispenser", "id": 3131, "trainId": 324},
+ {"name": "lockers", "id": 1463, "trainId": 325},
+ {"name": "cage", "id": 357, "trainId": 326},
+ {"name": "finger", "id": 929, "trainId": 327},
+ {"name": "bleachers", "id": 209, "trainId": 328},
+ {"name": "ferris wheel", "id": 912, "trainId": 329},
+ {"name": "hairdresser chair", "id": 1164, "trainId": 330},
+ {"name": "mat", "id": 1509, "trainId": 331},
+ {"name": "stands", "id": 2539, "trainId": 332},
+ {"name": "aquarium, fish tank, marine museum", "id": 3116, "trainId": 333},
+ {"name": "streetcar, tram, tramcar, trolley, trolley car", "id": 2615, "trainId": 334},
+ {"name": "napkin, table napkin, serviette", "id": 1644, "trainId": 335},
+ {"name": "dummy", "id": 818, "trainId": 336},
+ {"name": "booklet, brochure, folder, leaflet, pamphlet", "id": 242, "trainId": 337},
+ {"name": "sand trap", "id": 2217, "trainId": 338},
+ {"name": "shop, store", "id": 2347, "trainId": 339},
+ {"name": "table cloth", "id": 2686, "trainId": 340},
+ {"name": "service station", "id": 2300, "trainId": 341},
+ {"name": "coffin", "id": 572, "trainId": 342},
+ {"name": "drawer", "id": 789, "trainId": 343},
+ {"name": "cages", "id": 358, "trainId": 344},
+ {"name": "slot machine, coin machine", "id": 2443, "trainId": 345},
+ {"name": "balcony", "id": 101, "trainId": 346},
+ {"name": "volleyball court", "id": 2969, "trainId": 347},
+ {"name": "table tennis", "id": 2692, "trainId": 348},
+ {"name": "control table", "id": 606, "trainId": 349},
+ {"name": "shirt", "id": 2339, "trainId": 350},
+ {"name": "merchandise, ware, product", "id": 1533, "trainId": 351},
+ {"name": "railway", "id": 2060, "trainId": 352},
+ {"name": "parterre", "id": 1782, "trainId": 353},
+ {"name": "chimney", "id": 495, "trainId": 354},
+ {"name": "can, tin, tin can", "id": 371, "trainId": 355},
+ {"name": "tanks", "id": 2707, "trainId": 356},
+ {"name": "fabric, cloth, material, textile", "id": 889, "trainId": 357},
+ {"name": "alga, algae", "id": 3156, "trainId": 358},
+ {"name": "system", "id": 2683, "trainId": 359},
+ {"name": "map", "id": 1499, "trainId": 360},
+ {"name": "greenhouse", "id": 1135, "trainId": 361},
+ {"name": "mug", "id": 1619, "trainId": 362},
+ {"name": "barbecue", "id": 125, "trainId": 363},
+ {"name": "trailer", "id": 2838, "trainId": 364},
+ {"name": "toilet tissue, toilet paper, bathroom tissue", "id": 2792, "trainId": 365},
+ {"name": "organ", "id": 1695, "trainId": 366},
+ {"name": "dishrag, dishcloth", "id": 746, "trainId": 367},
+ {"name": "island", "id": 1343, "trainId": 368},
+ {"name": "keyboard", "id": 1370, "trainId": 369},
+ {"name": "trench", "id": 2858, "trainId": 370},
+ {"name": "basket, basketball hoop, hoop", "id": 145, "trainId": 371},
+ {"name": "steering wheel, wheel", "id": 2565, "trainId": 372},
+ {"name": "pitcher, ewer", "id": 1892, "trainId": 373},
+ {"name": "goal", "id": 1103, "trainId": 374},
+ {"name": "bread, breadstuff, staff of life", "id": 286, "trainId": 375},
+ {"name": "beds", "id": 170, "trainId": 376},
+ {"name": "wood", "id": 3073, "trainId": 377},
+ {"name": "file cabinet", "id": 922, "trainId": 378},
+ {"name": "newspaper, paper", "id": 1655, "trainId": 379},
+ {"name": "motorboat", "id": 1602, "trainId": 380},
+ {"name": "rope", "id": 2160, "trainId": 381},
+ {"name": "guitar", "id": 1151, "trainId": 382},
+ {"name": "rubble", "id": 2176, "trainId": 383},
+ {"name": "scarf", "id": 2239, "trainId": 384},
+ {"name": "barrels", "id": 132, "trainId": 385},
+ {"name": "cap", "id": 394, "trainId": 386},
+ {"name": "leaves", "id": 1424, "trainId": 387},
+ {"name": "control tower", "id": 607, "trainId": 388},
+ {"name": "dashboard", "id": 700, "trainId": 389},
+ {"name": "bandstand", "id": 116, "trainId": 390},
+ {"name": "lectern", "id": 1425, "trainId": 391},
+ {"name": "switch, electric switch, electrical switch", "id": 2676, "trainId": 392},
+ {"name": "baseboard, mopboard, skirting board", "id": 141, "trainId": 393},
+ {"name": "shower room", "id": 2360, "trainId": 394},
+ {"name": "smoke", "id": 2449, "trainId": 395},
+ {"name": "faucet, spigot", "id": 897, "trainId": 396},
+ {"name": "bulldozer", "id": 317, "trainId": 397},
+ {"name": "saucepan", "id": 2228, "trainId": 398},
+ {"name": "shops", "id": 2351, "trainId": 399},
+ {"name": "meter", "id": 1543, "trainId": 400},
+ {"name": "crevasse", "id": 656, "trainId": 401},
+ {"name": "gear", "id": 1088, "trainId": 402},
+ {"name": "candelabrum, candelabra", "id": 373, "trainId": 403},
+ {"name": "sofa bed", "id": 2472, "trainId": 404},
+ {"name": "tunnel", "id": 2892, "trainId": 405},
+ {"name": "pallet", "id": 1740, "trainId": 406},
+ {"name": "wire, conducting wire", "id": 3067, "trainId": 407},
+ {"name": "kettle, boiler", "id": 1367, "trainId": 408},
+ {"name": "bidet", "id": 188, "trainId": 409},
+ {
+ "name": "baby buggy, baby carriage, carriage, perambulator, pram, stroller, go-cart, pushchair, pusher",
+ "id": 79,
+ "trainId": 410,
+ },
+ {"name": "music stand", "id": 1633, "trainId": 411},
+ {"name": "pipe, tube", "id": 1885, "trainId": 412},
+ {"name": "cup", "id": 677, "trainId": 413},
+ {"name": "parking meter", "id": 1779, "trainId": 414},
+ {"name": "ice hockey rink", "id": 1297, "trainId": 415},
+ {"name": "shelter", "id": 2334, "trainId": 416},
+ {"name": "weeds", "id": 3027, "trainId": 417},
+ {"name": "temple", "id": 2735, "trainId": 418},
+ {"name": "patty, cake", "id": 1791, "trainId": 419},
+ {"name": "ski slope", "id": 2405, "trainId": 420},
+ {"name": "panel", "id": 1748, "trainId": 421},
+ {"name": "wallet", "id": 2983, "trainId": 422},
+ {"name": "wheel", "id": 3035, "trainId": 423},
+ {"name": "towel rack, towel horse", "id": 2824, "trainId": 424},
+ {"name": "roundabout", "id": 2168, "trainId": 425},
+ {"name": "canister, cannister, tin", "id": 385, "trainId": 426},
+ {"name": "rod", "id": 2148, "trainId": 427},
+ {"name": "soap dispenser", "id": 2465, "trainId": 428},
+ {"name": "bell", "id": 175, "trainId": 429},
+ {"name": "canvas", "id": 390, "trainId": 430},
+ {"name": "box office, ticket office, ticket booth", "id": 268, "trainId": 431},
+ {"name": "teacup", "id": 2722, "trainId": 432},
+ {"name": "trellis", "id": 2857, "trainId": 433},
+ {"name": "workbench", "id": 3088, "trainId": 434},
+ {"name": "valley, vale", "id": 2926, "trainId": 435},
+ {"name": "toaster", "id": 2782, "trainId": 436},
+ {"name": "knife", "id": 1378, "trainId": 437},
+ {"name": "podium", "id": 1934, "trainId": 438},
+ {"name": "ramp", "id": 2072, "trainId": 439},
+ {"name": "tumble dryer", "id": 2889, "trainId": 440},
+ {"name": "fireplug, fire hydrant, plug", "id": 944, "trainId": 441},
+ {"name": "gym shoe, sneaker, tennis shoe", "id": 1158, "trainId": 442},
+ {"name": "lab bench", "id": 1383, "trainId": 443},
+ {"name": "equipment", "id": 867, "trainId": 444},
+ {"name": "rocky formation", "id": 2145, "trainId": 445},
+ {"name": "plastic", "id": 1915, "trainId": 446},
+ {"name": "calendar", "id": 361, "trainId": 447},
+ {"name": "caravan", "id": 402, "trainId": 448},
+ {"name": "check-in-desk", "id": 482, "trainId": 449},
+ {"name": "ticket counter", "id": 2761, "trainId": 450},
+ {"name": "brush", "id": 300, "trainId": 451},
+ {"name": "mill", "id": 1554, "trainId": 452},
+ {"name": "covered bridge", "id": 636, "trainId": 453},
+ {"name": "bowling alley", "id": 260, "trainId": 454},
+ {"name": "hanger", "id": 1186, "trainId": 455},
+ {"name": "excavator", "id": 871, "trainId": 456},
+ {"name": "trestle", "id": 2859, "trainId": 457},
+ {"name": "revolving door", "id": 2103, "trainId": 458},
+ {"name": "blast furnace", "id": 208, "trainId": 459},
+ {"name": "scale, weighing machine", "id": 2236, "trainId": 460},
+ {"name": "projector", "id": 2012, "trainId": 461},
+ {"name": "soap", "id": 2462, "trainId": 462},
+ {"name": "locker", "id": 1462, "trainId": 463},
+ {"name": "tractor", "id": 2832, "trainId": 464},
+ {"name": "stretcher", "id": 2617, "trainId": 465},
+ {"name": "frame", "id": 1024, "trainId": 466},
+ {"name": "grating", "id": 1129, "trainId": 467},
+ {"name": "alembic", "id": 18, "trainId": 468},
+ {"name": "candle, taper, wax light", "id": 376, "trainId": 469},
+ {"name": "barrier", "id": 134, "trainId": 470},
+ {"name": "cardboard", "id": 407, "trainId": 471},
+ {"name": "cave", "id": 434, "trainId": 472},
+ {"name": "puddle", "id": 2017, "trainId": 473},
+ {"name": "tarp", "id": 2717, "trainId": 474},
+ {"name": "price tag", "id": 2005, "trainId": 475},
+ {"name": "watchtower", "id": 2993, "trainId": 476},
+ {"name": "meters", "id": 1545, "trainId": 477},
+ {
+ "name": "light bulb, lightbulb, bulb, incandescent lamp, electric light, electric-light bulb",
+ "id": 1445,
+ "trainId": 478,
+ },
+ {"name": "tracks", "id": 2831, "trainId": 479},
+ {"name": "hair dryer", "id": 1161, "trainId": 480},
+ {"name": "skirt", "id": 2411, "trainId": 481},
+ {"name": "viaduct", "id": 2949, "trainId": 482},
+ {"name": "paper towel", "id": 1769, "trainId": 483},
+ {"name": "coat", "id": 552, "trainId": 484},
+ {"name": "sheet", "id": 2327, "trainId": 485},
+ {"name": "fire extinguisher, extinguisher, asphyxiator", "id": 939, "trainId": 486},
+ {"name": "water wheel", "id": 3013, "trainId": 487},
+ {"name": "pottery, clayware", "id": 1986, "trainId": 488},
+ {"name": "magazine rack", "id": 1486, "trainId": 489},
+ {"name": "teapot", "id": 2723, "trainId": 490},
+ {"name": "microphone, mike", "id": 1549, "trainId": 491},
+ {"name": "support", "id": 2649, "trainId": 492},
+ {"name": "forklift", "id": 1020, "trainId": 493},
+ {"name": "canyon", "id": 392, "trainId": 494},
+ {"name": "cash register, register", "id": 422, "trainId": 495},
+ {"name": "leaf, leafage, foliage", "id": 1419, "trainId": 496},
+ {"name": "remote control, remote", "id": 2099, "trainId": 497},
+ {"name": "soap dish", "id": 2464, "trainId": 498},
+ {"name": "windshield, windscreen", "id": 3058, "trainId": 499},
+ {"name": "cat", "id": 430, "trainId": 500},
+ {"name": "cue, cue stick, pool cue, pool stick", "id": 675, "trainId": 501},
+ {"name": "vent, venthole, vent-hole, blowhole", "id": 2941, "trainId": 502},
+ {"name": "videos", "id": 2955, "trainId": 503},
+ {"name": "shovel", "id": 2355, "trainId": 504},
+ {"name": "eaves", "id": 840, "trainId": 505},
+ {"name": "antenna, aerial, transmitting aerial", "id": 32, "trainId": 506},
+ {"name": "shipyard", "id": 2338, "trainId": 507},
+ {"name": "hen, biddy", "id": 1232, "trainId": 508},
+ {"name": "traffic cone", "id": 2834, "trainId": 509},
+ {"name": "washing machines", "id": 2991, "trainId": 510},
+ {"name": "truck crane", "id": 2879, "trainId": 511},
+ {"name": "cds", "id": 444, "trainId": 512},
+ {"name": "niche", "id": 1657, "trainId": 513},
+ {"name": "scoreboard", "id": 2246, "trainId": 514},
+ {"name": "briefcase", "id": 296, "trainId": 515},
+ {"name": "boot", "id": 245, "trainId": 516},
+ {"name": "sweater, jumper", "id": 2661, "trainId": 517},
+ {"name": "hay", "id": 1202, "trainId": 518},
+ {"name": "pack", "id": 1714, "trainId": 519},
+ {"name": "bottle rack", "id": 251, "trainId": 520},
+ {"name": "glacier", "id": 1095, "trainId": 521},
+ {"name": "pergola", "id": 1828, "trainId": 522},
+ {"name": "building materials", "id": 311, "trainId": 523},
+ {"name": "television camera", "id": 2732, "trainId": 524},
+ {"name": "first floor", "id": 947, "trainId": 525},
+ {"name": "rifle", "id": 2115, "trainId": 526},
+ {"name": "tennis table", "id": 2738, "trainId": 527},
+ {"name": "stadium", "id": 2525, "trainId": 528},
+ {"name": "safety belt", "id": 2194, "trainId": 529},
+ {"name": "cover", "id": 634, "trainId": 530},
+ {"name": "dish rack", "id": 740, "trainId": 531},
+ {"name": "synthesizer", "id": 2682, "trainId": 532},
+ {"name": "pumpkin", "id": 2020, "trainId": 533},
+ {"name": "gutter", "id": 1156, "trainId": 534},
+ {"name": "fruit stand", "id": 1036, "trainId": 535},
+ {"name": "ice floe, floe", "id": 1295, "trainId": 536},
+ {"name": "handle, grip, handgrip, hold", "id": 1181, "trainId": 537},
+ {"name": "wheelchair", "id": 3037, "trainId": 538},
+ {"name": "mousepad, mouse mat", "id": 1614, "trainId": 539},
+ {"name": "diploma", "id": 736, "trainId": 540},
+ {"name": "fairground ride", "id": 893, "trainId": 541},
+ {"name": "radio", "id": 2047, "trainId": 542},
+ {"name": "hotplate", "id": 1274, "trainId": 543},
+ {"name": "junk", "id": 1361, "trainId": 544},
+ {"name": "wheelbarrow", "id": 3036, "trainId": 545},
+ {"name": "stream", "id": 2606, "trainId": 546},
+ {"name": "toll plaza", "id": 2797, "trainId": 547},
+ {"name": "punching bag", "id": 2022, "trainId": 548},
+ {"name": "trough", "id": 2876, "trainId": 549},
+ {"name": "throne", "id": 2758, "trainId": 550},
+ {"name": "chair desk", "id": 472, "trainId": 551},
+ {"name": "weighbridge", "id": 3028, "trainId": 552},
+ {"name": "extractor fan", "id": 882, "trainId": 553},
+ {"name": "hanging clothes", "id": 1189, "trainId": 554},
+ {"name": "dish, dish aerial, dish antenna, saucer", "id": 743, "trainId": 555},
+ {"name": "alarm clock, alarm", "id": 3122, "trainId": 556},
+ {"name": "ski lift", "id": 2401, "trainId": 557},
+ {"name": "chain", "id": 468, "trainId": 558},
+ {"name": "garage", "id": 1061, "trainId": 559},
+ {"name": "mechanical shovel", "id": 1523, "trainId": 560},
+ {"name": "wine rack", "id": 3059, "trainId": 561},
+ {"name": "tramway", "id": 2843, "trainId": 562},
+ {"name": "treadmill", "id": 2853, "trainId": 563},
+ {"name": "menu", "id": 1529, "trainId": 564},
+ {"name": "block", "id": 214, "trainId": 565},
+ {"name": "well", "id": 3032, "trainId": 566},
+ {"name": "witness stand", "id": 3071, "trainId": 567},
+ {"name": "branch", "id": 277, "trainId": 568},
+ {"name": "duck", "id": 813, "trainId": 569},
+ {"name": "casserole", "id": 426, "trainId": 570},
+ {"name": "frying pan", "id": 1039, "trainId": 571},
+ {"name": "desk organizer", "id": 727, "trainId": 572},
+ {"name": "mast", "id": 1508, "trainId": 573},
+ {"name": "spectacles, specs, eyeglasses, glasses", "id": 2490, "trainId": 574},
+ {"name": "service elevator", "id": 2299, "trainId": 575},
+ {"name": "dollhouse", "id": 768, "trainId": 576},
+ {"name": "hammock", "id": 1172, "trainId": 577},
+ {"name": "clothes hanging", "id": 537, "trainId": 578},
+ {"name": "photocopier", "id": 1847, "trainId": 579},
+ {"name": "notepad", "id": 1664, "trainId": 580},
+ {"name": "golf cart", "id": 1110, "trainId": 581},
+ {"name": "footpath", "id": 1014, "trainId": 582},
+ {"name": "cross", "id": 662, "trainId": 583},
+ {"name": "baptismal font", "id": 121, "trainId": 584},
+ {"name": "boiler", "id": 227, "trainId": 585},
+ {"name": "skip", "id": 2410, "trainId": 586},
+ {"name": "rotisserie", "id": 2165, "trainId": 587},
+ {"name": "tables", "id": 2696, "trainId": 588},
+ {"name": "water mill", "id": 3005, "trainId": 589},
+ {"name": "helmet", "id": 1231, "trainId": 590},
+ {"name": "cover curtain", "id": 635, "trainId": 591},
+ {"name": "brick", "id": 292, "trainId": 592},
+ {"name": "table runner", "id": 2690, "trainId": 593},
+ {"name": "ashtray", "id": 65, "trainId": 594},
+ {"name": "street box", "id": 2607, "trainId": 595},
+ {"name": "stick", "id": 2574, "trainId": 596},
+ {"name": "hangers", "id": 1188, "trainId": 597},
+ {"name": "cells", "id": 456, "trainId": 598},
+ {"name": "urinal", "id": 2913, "trainId": 599},
+ {"name": "centerpiece", "id": 459, "trainId": 600},
+ {"name": "portable fridge", "id": 1955, "trainId": 601},
+ {"name": "dvds", "id": 827, "trainId": 602},
+ {"name": "golf club", "id": 1111, "trainId": 603},
+ {"name": "skirting board", "id": 2412, "trainId": 604},
+ {"name": "water cooler", "id": 2997, "trainId": 605},
+ {"name": "clipboard", "id": 528, "trainId": 606},
+ {"name": "camera, photographic camera", "id": 366, "trainId": 607},
+ {"name": "pigeonhole", "id": 1863, "trainId": 608},
+ {"name": "chips", "id": 500, "trainId": 609},
+ {"name": "food processor", "id": 1001, "trainId": 610},
+ {"name": "post box", "id": 1958, "trainId": 611},
+ {"name": "lid", "id": 1441, "trainId": 612},
+ {"name": "drum", "id": 809, "trainId": 613},
+ {"name": "blender", "id": 210, "trainId": 614},
+ {"name": "cave entrance", "id": 435, "trainId": 615},
+ {"name": "dental chair", "id": 718, "trainId": 616},
+ {"name": "obelisk", "id": 1674, "trainId": 617},
+ {"name": "canoe", "id": 388, "trainId": 618},
+ {"name": "mobile", "id": 1572, "trainId": 619},
+ {"name": "monitors", "id": 1584, "trainId": 620},
+ {"name": "pool ball", "id": 1944, "trainId": 621},
+ {"name": "cue rack", "id": 674, "trainId": 622},
+ {"name": "baggage carts", "id": 99, "trainId": 623},
+ {"name": "shore", "id": 2352, "trainId": 624},
+ {"name": "fork", "id": 1019, "trainId": 625},
+ {"name": "paper filer", "id": 1763, "trainId": 626},
+ {"name": "bicycle rack", "id": 185, "trainId": 627},
+ {"name": "coat rack", "id": 554, "trainId": 628},
+ {"name": "garland", "id": 1066, "trainId": 629},
+ {"name": "sports bag", "id": 2508, "trainId": 630},
+ {"name": "fish tank", "id": 951, "trainId": 631},
+ {"name": "towel dispenser", "id": 2822, "trainId": 632},
+ {"name": "carriage", "id": 415, "trainId": 633},
+ {"name": "brochure", "id": 297, "trainId": 634},
+ {"name": "plaque", "id": 1914, "trainId": 635},
+ {"name": "stringer", "id": 2619, "trainId": 636},
+ {"name": "iron", "id": 1338, "trainId": 637},
+ {"name": "spoon", "id": 2505, "trainId": 638},
+ {"name": "flag pole", "id": 955, "trainId": 639},
+ {"name": "toilet brush", "id": 2786, "trainId": 640},
+ {"name": "book stand", "id": 238, "trainId": 641},
+ {"name": "water faucet, water tap, tap, hydrant", "id": 3000, "trainId": 642},
+ {"name": "ticket office", "id": 2763, "trainId": 643},
+ {"name": "broom", "id": 299, "trainId": 644},
+ {"name": "dvd", "id": 822, "trainId": 645},
+ {"name": "ice bucket", "id": 1288, "trainId": 646},
+ {"name": "carapace, shell, cuticle, shield", "id": 3101, "trainId": 647},
+ {"name": "tureen", "id": 2894, "trainId": 648},
+ {"name": "folders", "id": 992, "trainId": 649},
+ {"name": "chess", "id": 489, "trainId": 650},
+ {"name": "root", "id": 2157, "trainId": 651},
+ {"name": "sewing machine", "id": 2309, "trainId": 652},
+ {"name": "model", "id": 1576, "trainId": 653},
+ {"name": "pen", "id": 1810, "trainId": 654},
+ {"name": "violin", "id": 2964, "trainId": 655},
+ {"name": "sweatshirt", "id": 2662, "trainId": 656},
+ {"name": "recycling materials", "id": 2087, "trainId": 657},
+ {"name": "mitten", "id": 1569, "trainId": 658},
+ {"name": "chopping board, cutting board", "id": 503, "trainId": 659},
+ {"name": "mask", "id": 1505, "trainId": 660},
+ {"name": "log", "id": 1468, "trainId": 661},
+ {"name": "mouse, computer mouse", "id": 1613, "trainId": 662},
+ {"name": "grill", "id": 1138, "trainId": 663},
+ {"name": "hole", "id": 1256, "trainId": 664},
+ {"name": "target", "id": 2715, "trainId": 665},
+ {"name": "trash bag", "id": 2846, "trainId": 666},
+ {"name": "chalk", "id": 477, "trainId": 667},
+ {"name": "sticks", "id": 2576, "trainId": 668},
+ {"name": "balloon", "id": 108, "trainId": 669},
+ {"name": "score", "id": 2245, "trainId": 670},
+ {"name": "hair spray", "id": 1162, "trainId": 671},
+ {"name": "roll", "id": 2149, "trainId": 672},
+ {"name": "runner", "id": 2183, "trainId": 673},
+ {"name": "engine", "id": 858, "trainId": 674},
+ {"name": "inflatable glove", "id": 1324, "trainId": 675},
+ {"name": "games", "id": 1055, "trainId": 676},
+ {"name": "pallets", "id": 1741, "trainId": 677},
+ {"name": "baskets", "id": 149, "trainId": 678},
+ {"name": "coop", "id": 615, "trainId": 679},
+ {"name": "dvd player", "id": 825, "trainId": 680},
+ {"name": "rocking horse", "id": 2143, "trainId": 681},
+ {"name": "buckets", "id": 304, "trainId": 682},
+ {"name": "bread rolls", "id": 283, "trainId": 683},
+ {"name": "shawl", "id": 2322, "trainId": 684},
+ {"name": "watering can", "id": 3017, "trainId": 685},
+ {"name": "spotlights", "id": 2510, "trainId": 686},
+ {"name": "post-it", "id": 1960, "trainId": 687},
+ {"name": "bowls", "id": 265, "trainId": 688},
+ {"name": "security camera", "id": 2282, "trainId": 689},
+ {"name": "runner cloth", "id": 2184, "trainId": 690},
+ {"name": "lock", "id": 1461, "trainId": 691},
+ {"name": "alarm, warning device, alarm system", "id": 3113, "trainId": 692},
+ {"name": "side", "id": 2372, "trainId": 693},
+ {"name": "roulette", "id": 2166, "trainId": 694},
+ {"name": "bone", "id": 232, "trainId": 695},
+ {"name": "cutlery", "id": 693, "trainId": 696},
+ {"name": "pool balls", "id": 1945, "trainId": 697},
+ {"name": "wheels", "id": 3039, "trainId": 698},
+ {"name": "spice rack", "id": 2494, "trainId": 699},
+ {"name": "plant pots", "id": 1908, "trainId": 700},
+ {"name": "towel ring", "id": 2827, "trainId": 701},
+ {"name": "bread box", "id": 280, "trainId": 702},
+ {"name": "video", "id": 2950, "trainId": 703},
+ {"name": "funfair", "id": 1044, "trainId": 704},
+ {"name": "breads", "id": 288, "trainId": 705},
+ {"name": "tripod", "id": 2863, "trainId": 706},
+ {"name": "ironing board", "id": 1342, "trainId": 707},
+ {"name": "skimmer", "id": 2409, "trainId": 708},
+ {"name": "hollow", "id": 1258, "trainId": 709},
+ {"name": "scratching post", "id": 2249, "trainId": 710},
+ {"name": "tricycle", "id": 2862, "trainId": 711},
+ {"name": "file box", "id": 920, "trainId": 712},
+ {"name": "mountain pass", "id": 1607, "trainId": 713},
+ {"name": "tombstones", "id": 2802, "trainId": 714},
+ {"name": "cooker", "id": 610, "trainId": 715},
+ {"name": "card game, cards", "id": 3129, "trainId": 716},
+ {"name": "golf bag", "id": 1108, "trainId": 717},
+ {"name": "towel paper", "id": 2823, "trainId": 718},
+ {"name": "chaise lounge", "id": 476, "trainId": 719},
+ {"name": "sun", "id": 2641, "trainId": 720},
+ {"name": "toilet paper holder", "id": 2788, "trainId": 721},
+ {"name": "rake", "id": 2070, "trainId": 722},
+ {"name": "key", "id": 1368, "trainId": 723},
+ {"name": "umbrella stand", "id": 2903, "trainId": 724},
+ {"name": "dartboard", "id": 699, "trainId": 725},
+ {"name": "transformer", "id": 2844, "trainId": 726},
+ {"name": "fireplace utensils", "id": 942, "trainId": 727},
+ {"name": "sweatshirts", "id": 2663, "trainId": 728},
+ {
+ "name": "cellular telephone, cellular phone, cellphone, cell, mobile phone",
+ "id": 457,
+ "trainId": 729,
+ },
+ {"name": "tallboy", "id": 2701, "trainId": 730},
+ {"name": "stapler", "id": 2540, "trainId": 731},
+ {"name": "sauna", "id": 2231, "trainId": 732},
+ {"name": "test tube", "id": 2746, "trainId": 733},
+ {"name": "palette", "id": 1738, "trainId": 734},
+ {"name": "shopping carts", "id": 2350, "trainId": 735},
+ {"name": "tools", "id": 2808, "trainId": 736},
+ {"name": "push button, push, button", "id": 2025, "trainId": 737},
+ {"name": "star", "id": 2541, "trainId": 738},
+ {"name": "roof rack", "id": 2156, "trainId": 739},
+ {"name": "barbed wire", "id": 126, "trainId": 740},
+ {"name": "spray", "id": 2512, "trainId": 741},
+ {"name": "ear", "id": 831, "trainId": 742},
+ {"name": "sponge", "id": 2503, "trainId": 743},
+ {"name": "racket", "id": 2039, "trainId": 744},
+ {"name": "tins", "id": 2774, "trainId": 745},
+ {"name": "eyeglasses", "id": 886, "trainId": 746},
+ {"name": "file", "id": 919, "trainId": 747},
+ {"name": "scarfs", "id": 2240, "trainId": 748},
+ {"name": "sugar bowl", "id": 2636, "trainId": 749},
+ {"name": "flip flop", "id": 963, "trainId": 750},
+ {"name": "headstones", "id": 1218, "trainId": 751},
+ {"name": "laptop bag", "id": 1406, "trainId": 752},
+ {"name": "leash", "id": 1420, "trainId": 753},
+ {"name": "climbing frame", "id": 526, "trainId": 754},
+ {"name": "suit hanger", "id": 2639, "trainId": 755},
+ {"name": "floor spotlight", "id": 975, "trainId": 756},
+ {"name": "plate rack", "id": 1921, "trainId": 757},
+ {"name": "sewer", "id": 2305, "trainId": 758},
+ {"name": "hard drive", "id": 1193, "trainId": 759},
+ {"name": "sprinkler", "id": 2517, "trainId": 760},
+ {"name": "tools box", "id": 2809, "trainId": 761},
+ {"name": "necklace", "id": 1647, "trainId": 762},
+ {"name": "bulbs", "id": 314, "trainId": 763},
+ {"name": "steel industry", "id": 2560, "trainId": 764},
+ {"name": "club", "id": 545, "trainId": 765},
+ {"name": "jack", "id": 1345, "trainId": 766},
+ {"name": "door bars", "id": 775, "trainId": 767},
+ {
+ "name": "control panel, instrument panel, control board, board, panel",
+ "id": 603,
+ "trainId": 768,
+ },
+ {"name": "hairbrush", "id": 1163, "trainId": 769},
+ {"name": "napkin holder", "id": 1641, "trainId": 770},
+ {"name": "office", "id": 1678, "trainId": 771},
+ {"name": "smoke detector", "id": 2450, "trainId": 772},
+ {"name": "utensils", "id": 2915, "trainId": 773},
+ {"name": "apron", "id": 42, "trainId": 774},
+ {"name": "scissors", "id": 2242, "trainId": 775},
+ {"name": "terminal", "id": 2741, "trainId": 776},
+ {"name": "grinder", "id": 1143, "trainId": 777},
+ {"name": "entry phone", "id": 862, "trainId": 778},
+ {"name": "newspaper stand", "id": 1654, "trainId": 779},
+ {"name": "pepper shaker", "id": 1826, "trainId": 780},
+ {"name": "onions", "id": 1689, "trainId": 781},
+ {
+ "name": "central processing unit, cpu, c p u , central processor, processor, mainframe",
+ "id": 3124,
+ "trainId": 782,
+ },
+ {"name": "tape", "id": 2710, "trainId": 783},
+ {"name": "bat", "id": 152, "trainId": 784},
+ {"name": "coaster", "id": 549, "trainId": 785},
+ {"name": "calculator", "id": 360, "trainId": 786},
+ {"name": "potatoes", "id": 1982, "trainId": 787},
+ {"name": "luggage rack", "id": 1478, "trainId": 788},
+ {"name": "salt", "id": 2203, "trainId": 789},
+ {"name": "street number", "id": 2612, "trainId": 790},
+ {"name": "viewpoint", "id": 2956, "trainId": 791},
+ {"name": "sword", "id": 2681, "trainId": 792},
+ {"name": "cd", "id": 437, "trainId": 793},
+ {"name": "rowing machine", "id": 2171, "trainId": 794},
+ {"name": "plug", "id": 1933, "trainId": 795},
+ {"name": "andiron, firedog, dog, dog-iron", "id": 3110, "trainId": 796},
+ {"name": "pepper", "id": 1824, "trainId": 797},
+ {"name": "tongs", "id": 2803, "trainId": 798},
+ {"name": "bonfire", "id": 234, "trainId": 799},
+ {"name": "dog dish", "id": 764, "trainId": 800},
+ {"name": "belt", "id": 177, "trainId": 801},
+ {"name": "dumbbells", "id": 817, "trainId": 802},
+ {"name": "videocassette recorder, vcr", "id": 3145, "trainId": 803},
+ {"name": "hook", "id": 1262, "trainId": 804},
+ {"name": "envelopes", "id": 864, "trainId": 805},
+ {"name": "shower faucet", "id": 2359, "trainId": 806},
+ {"name": "watch", "id": 2992, "trainId": 807},
+ {"name": "padlock", "id": 1725, "trainId": 808},
+ {"name": "swimming pool ladder", "id": 2667, "trainId": 809},
+ {"name": "spanners", "id": 2484, "trainId": 810},
+ {"name": "gravy boat", "id": 1133, "trainId": 811},
+ {"name": "notice board", "id": 1667, "trainId": 812},
+ {"name": "trash bags", "id": 2847, "trainId": 813},
+ {"name": "fire alarm", "id": 932, "trainId": 814},
+ {"name": "ladle", "id": 1392, "trainId": 815},
+ {"name": "stethoscope", "id": 2573, "trainId": 816},
+ {"name": "rocket", "id": 2140, "trainId": 817},
+ {"name": "funnel", "id": 1046, "trainId": 818},
+ {"name": "bowling pins", "id": 264, "trainId": 819},
+ {"name": "valve", "id": 2927, "trainId": 820},
+ {"name": "thermometer", "id": 2752, "trainId": 821},
+ {"name": "cups", "id": 679, "trainId": 822},
+ {"name": "spice jar", "id": 2493, "trainId": 823},
+ {"name": "night light", "id": 1658, "trainId": 824},
+ {"name": "soaps", "id": 2466, "trainId": 825},
+ {"name": "games table", "id": 1057, "trainId": 826},
+ {"name": "slotted spoon", "id": 2444, "trainId": 827},
+ {"name": "reel", "id": 2093, "trainId": 828},
+ {"name": "scourer", "id": 2248, "trainId": 829},
+ {"name": "sleeping robe", "id": 2432, "trainId": 830},
+ {"name": "desk mat", "id": 726, "trainId": 831},
+ {"name": "dumbbell", "id": 816, "trainId": 832},
+ {"name": "hammer", "id": 1171, "trainId": 833},
+ {"name": "tie", "id": 2766, "trainId": 834},
+ {"name": "typewriter", "id": 2900, "trainId": 835},
+ {"name": "shaker", "id": 2313, "trainId": 836},
+ {"name": "cheese dish", "id": 488, "trainId": 837},
+ {"name": "sea star", "id": 2265, "trainId": 838},
+ {"name": "racquet", "id": 2043, "trainId": 839},
+ {"name": "butane gas cylinder", "id": 332, "trainId": 840},
+ {"name": "paper weight", "id": 1771, "trainId": 841},
+ {"name": "shaving brush", "id": 2320, "trainId": 842},
+ {"name": "sunglasses", "id": 2646, "trainId": 843},
+ {"name": "gear shift", "id": 1089, "trainId": 844},
+ {"name": "towel rail", "id": 2826, "trainId": 845},
+ {"name": "adding machine, totalizer, totaliser", "id": 3148, "trainId": 846},
+]
+
+
+def loadAde20K(file):
+ fileseg = file.replace(".jpg", "_seg.png")
+ with Image.open(fileseg) as io:
+ seg = np.array(io)
+
+ R = seg[:, :, 0]
+ G = seg[:, :, 1]
+ ObjectClassMasks = (R / 10).astype(np.int32) * 256 + (G.astype(np.int32))
+
+ return {"img_name": file, "segm_name": fileseg, "class_mask": ObjectClassMasks}
+
+
+if __name__ == "__main__":
+ dataset_dir = Path(os.getenv("DETECTRON2_DATASETS", "datasets"))
+ index_file = dataset_dir / "ADE20K_2021_17_01" / "index_ade20k.pkl"
+ print('Caution: we only generate the validation set!')
+ with open(index_file, "rb") as f:
+ index_ade20k = pkl.load(f)
+
+ id_map = {}
+ for cat in ADE20K_SEM_SEG_FULL_CATEGORIES:
+ id_map[cat["id"]] = cat["trainId"]
+
+ # make output dir
+ for name in ["training", "validation"]:
+ image_dir = dataset_dir / "ADE20K_2021_17_01" / "images_detectron2" / name
+ image_dir.mkdir(parents=True, exist_ok=True)
+ annotation_dir = dataset_dir / "ADE20K_2021_17_01" / "annotations_detectron2" / name
+ annotation_dir.mkdir(parents=True, exist_ok=True)
+
+ # process image and gt
+ for i, (folder_name, file_name) in tqdm.tqdm(
+ enumerate(zip(index_ade20k["folder"], index_ade20k["filename"])),
+ total=len(index_ade20k["filename"]),
+ ):
+ split = "validation" if file_name.split("_")[1] == "val" else "training"
+ if split == 'training':
+ # FIXME: If you want to generate training set, delete this condition
+ continue
+ info = loadAde20K(str(dataset_dir / folder_name / file_name))
+
+ # resize image and label
+ img = np.asarray(Image.open(info["img_name"]))
+ lab = np.asarray(info["class_mask"])
+
+ h, w = img.shape[0], img.shape[1]
+ max_size = 512
+ resize = True
+ if w >= h > max_size:
+ h_new, w_new = max_size, round(w / float(h) * max_size)
+ elif h >= w > max_size:
+ h_new, w_new = round(h / float(w) * max_size), max_size
+ else:
+ resize = False
+
+ if resize:
+ img = cv2.resize(img, (w_new, h_new), interpolation=cv2.INTER_LINEAR)
+ lab = cv2.resize(lab, (w_new, h_new), interpolation=cv2.INTER_NEAREST)
+
+ assert img.dtype == np.uint8
+ assert lab.dtype == np.int32
+
+ # apply label conversion and save into uint16 images
+ output = np.zeros_like(lab, dtype=np.uint16) + 65535
+ for obj_id in np.unique(lab):
+ if obj_id in id_map:
+ output[lab == obj_id] = id_map[obj_id]
+
+ output_img = dataset_dir / "ADE20K_2021_17_01" / "images_detectron2" / split / file_name
+ output_lab = (
+ dataset_dir
+ / "ADE20K_2021_17_01"
+ / "annotations_detectron2"
+ / split
+ / file_name.replace(".jpg", ".tif")
+ )
+ Image.fromarray(img).save(output_img)
+
+ assert output.dtype == np.uint16
+ Image.fromarray(output).save(output_lab)
\ No newline at end of file
diff --git a/datasets/prepare_coco_stuff.py b/datasets/prepare_coco_stuff.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba1a3671ff982c409b12d3e7d21f438a11f530e9
--- /dev/null
+++ b/datasets/prepare_coco_stuff.py
@@ -0,0 +1,205 @@
+import os
+import os.path as osp
+from pathlib import Path
+import tqdm
+from glob import glob
+
+import numpy as np
+from PIL import Image
+
+COCO_CATEGORIES = [{'color': [220, 20, 60], 'isthing': 1, 'id': 0, 'name': 'person', 'trainId': 0},
+ {'color': [119, 11, 32], 'isthing': 1, 'id': 1, 'name': 'bicycle', 'trainId': 1},
+ {'color': [0, 0, 142], 'isthing': 1, 'id': 2, 'name': 'car', 'trainId': 2},
+ {'color': [0, 0, 230], 'isthing': 1, 'id': 3, 'name': 'motorcycle', 'trainId': 3},
+ {'color': [106, 0, 228], 'isthing': 1, 'id': 4, 'name': 'airplane', 'trainId': 4},
+ {'color': [0, 60, 100], 'isthing': 1, 'id': 5, 'name': 'bus', 'trainId': 5},
+ {'color': [0, 80, 100], 'isthing': 1, 'id': 6, 'name': 'train', 'trainId': 6},
+ {'color': [0, 0, 70], 'isthing': 1, 'id': 7, 'name': 'truck', 'trainId': 7},
+ {'color': [0, 0, 192], 'isthing': 1, 'id': 8, 'name': 'boat', 'trainId': 8},
+ {'color': [250, 170, 30], 'isthing': 1, 'id': 9, 'name': 'traffic light', 'trainId': 9},
+ {'color': [100, 170, 30], 'isthing': 1, 'id': 10, 'name': 'fire hydrant', 'trainId': 10},
+ {'color': [220, 220, 0], 'isthing': 1, 'id': 12, 'name': 'stop sign', 'trainId': 11},
+ {'color': [175, 116, 175], 'isthing': 1, 'id': 13, 'name': 'parking meter', 'trainId': 12},
+ {'color': [250, 0, 30], 'isthing': 1, 'id': 14, 'name': 'bench', 'trainId': 13},
+ {'color': [165, 42, 42], 'isthing': 1, 'id': 15, 'name': 'bird', 'trainId': 14},
+ {'color': [255, 77, 255], 'isthing': 1, 'id': 16, 'name': 'cat', 'trainId': 15},
+ {'color': [0, 226, 252], 'isthing': 1, 'id': 17, 'name': 'dog', 'trainId': 16},
+ {'color': [182, 182, 255], 'isthing': 1, 'id': 18, 'name': 'horse', 'trainId': 17},
+ {'color': [0, 82, 0], 'isthing': 1, 'id': 19, 'name': 'sheep', 'trainId': 18},
+ {'color': [120, 166, 157], 'isthing': 1, 'id': 20, 'name': 'cow', 'trainId': 19},
+ {'color': [110, 76, 0], 'isthing': 1, 'id': 21, 'name': 'elephant', 'trainId': 20},
+ {'color': [174, 57, 255], 'isthing': 1, 'id': 22, 'name': 'bear', 'trainId': 21},
+ {'color': [199, 100, 0], 'isthing': 1, 'id': 23, 'name': 'zebra', 'trainId': 22},
+ {'color': [72, 0, 118], 'isthing': 1, 'id': 24, 'name': 'giraffe', 'trainId': 23},
+ {'color': [255, 179, 240], 'isthing': 1, 'id': 26, 'name': 'backpack', 'trainId': 24},
+ {'color': [0, 125, 92], 'isthing': 1, 'id': 27, 'name': 'umbrella', 'trainId': 25},
+ {'color': [209, 0, 151], 'isthing': 1, 'id': 30, 'name': 'handbag', 'trainId': 26},
+ {'color': [188, 208, 182], 'isthing': 1, 'id': 31, 'name': 'tie', 'trainId': 27},
+ {'color': [0, 220, 176], 'isthing': 1, 'id': 32, 'name': 'suitcase', 'trainId': 28},
+ {'color': [255, 99, 164], 'isthing': 1, 'id': 33, 'name': 'frisbee', 'trainId': 29},
+ {'color': [92, 0, 73], 'isthing': 1, 'id': 34, 'name': 'skis', 'trainId': 30},
+ {'color': [133, 129, 255], 'isthing': 1, 'id': 35, 'name': 'snowboard', 'trainId': 31},
+ {'color': [78, 180, 255], 'isthing': 1, 'id': 36, 'name': 'sports ball', 'trainId': 32},
+ {'color': [0, 228, 0], 'isthing': 1, 'id': 37, 'name': 'kite', 'trainId': 33},
+ {'color': [174, 255, 243], 'isthing': 1, 'id': 38, 'name': 'baseball bat', 'trainId': 34},
+ {'color': [45, 89, 255], 'isthing': 1, 'id': 39, 'name': 'baseball glove', 'trainId': 35},
+ {'color': [134, 134, 103], 'isthing': 1, 'id': 40, 'name': 'skateboard', 'trainId': 36},
+ {'color': [145, 148, 174], 'isthing': 1, 'id': 41, 'name': 'surfboard', 'trainId': 37},
+ {'color': [255, 208, 186], 'isthing': 1, 'id': 42, 'name': 'tennis racket', 'trainId': 38},
+ {'color': [197, 226, 255], 'isthing': 1, 'id': 43, 'name': 'bottle', 'trainId': 39},
+ {'color': [171, 134, 1], 'isthing': 1, 'id': 45, 'name': 'wine glass', 'trainId': 40},
+ {'color': [109, 63, 54], 'isthing': 1, 'id': 46, 'name': 'cup', 'trainId': 41},
+ {'color': [207, 138, 255], 'isthing': 1, 'id': 47, 'name': 'fork', 'trainId': 42},
+ {'color': [151, 0, 95], 'isthing': 1, 'id': 48, 'name': 'knife', 'trainId': 43},
+ {'color': [9, 80, 61], 'isthing': 1, 'id': 49, 'name': 'spoon', 'trainId': 44},
+ {'color': [84, 105, 51], 'isthing': 1, 'id': 50, 'name': 'bowl', 'trainId': 45},
+ {'color': [74, 65, 105], 'isthing': 1, 'id': 51, 'name': 'banana', 'trainId': 46},
+ {'color': [166, 196, 102], 'isthing': 1, 'id': 52, 'name': 'apple', 'trainId': 47},
+ {'color': [208, 195, 210], 'isthing': 1, 'id': 53, 'name': 'sandwich', 'trainId': 48},
+ {'color': [255, 109, 65], 'isthing': 1, 'id': 54, 'name': 'orange', 'trainId': 49},
+ {'color': [0, 143, 149], 'isthing': 1, 'id': 55, 'name': 'broccoli', 'trainId': 50},
+ {'color': [179, 0, 194], 'isthing': 1, 'id': 56, 'name': 'carrot', 'trainId': 51},
+ {'color': [209, 99, 106], 'isthing': 1, 'id': 57, 'name': 'hot dog', 'trainId': 52},
+ {'color': [5, 121, 0], 'isthing': 1, 'id': 58, 'name': 'pizza', 'trainId': 53},
+ {'color': [227, 255, 205], 'isthing': 1, 'id': 59, 'name': 'donut', 'trainId': 54},
+ {'color': [147, 186, 208], 'isthing': 1, 'id': 60, 'name': 'cake', 'trainId': 55},
+ {'color': [153, 69, 1], 'isthing': 1, 'id': 61, 'name': 'chair', 'trainId': 56},
+ {'color': [3, 95, 161], 'isthing': 1, 'id': 62, 'name': 'couch', 'trainId': 57},
+ {'color': [163, 255, 0], 'isthing': 1, 'id': 63, 'name': 'potted plant', 'trainId': 58},
+ {'color': [119, 0, 170], 'isthing': 1, 'id': 64, 'name': 'bed', 'trainId': 59},
+ {'color': [0, 182, 199], 'isthing': 1, 'id': 66, 'name': 'dining table', 'trainId': 60},
+ {'color': [0, 165, 120], 'isthing': 1, 'id': 69, 'name': 'toilet', 'trainId': 61},
+ {'color': [183, 130, 88], 'isthing': 1, 'id': 71, 'name': 'tv', 'trainId': 62},
+ {'color': [95, 32, 0], 'isthing': 1, 'id': 72, 'name': 'laptop', 'trainId': 63},
+ {'color': [130, 114, 135], 'isthing': 1, 'id': 73, 'name': 'mouse', 'trainId': 64},
+ {'color': [110, 129, 133], 'isthing': 1, 'id': 74, 'name': 'remote', 'trainId': 65},
+ {'color': [166, 74, 118], 'isthing': 1, 'id': 75, 'name': 'keyboard', 'trainId': 66},
+ {'color': [219, 142, 185], 'isthing': 1, 'id': 76, 'name': 'cell phone', 'trainId': 67},
+ {'color': [79, 210, 114], 'isthing': 1, 'id': 77, 'name': 'microwave', 'trainId': 68},
+ {'color': [178, 90, 62], 'isthing': 1, 'id': 78, 'name': 'oven', 'trainId': 69},
+ {'color': [65, 70, 15], 'isthing': 1, 'id': 79, 'name': 'toaster', 'trainId': 70},
+ {'color': [127, 167, 115], 'isthing': 1, 'id': 80, 'name': 'sink', 'trainId': 71},
+ {'color': [59, 105, 106], 'isthing': 1, 'id': 81, 'name': 'refrigerator', 'trainId': 72},
+ {'color': [142, 108, 45], 'isthing': 1, 'id': 83, 'name': 'book', 'trainId': 73},
+ {'color': [196, 172, 0], 'isthing': 1, 'id': 84, 'name': 'clock', 'trainId': 74},
+ {'color': [95, 54, 80], 'isthing': 1, 'id': 85, 'name': 'vase', 'trainId': 75},
+ {'color': [128, 76, 255], 'isthing': 1, 'id': 86, 'name': 'scissors', 'trainId': 76},
+ {'color': [201, 57, 1], 'isthing': 1, 'id': 87, 'name': 'teddy bear', 'trainId': 77},
+ {'color': [246, 0, 122], 'isthing': 1, 'id': 88, 'name': 'hair drier', 'trainId': 78},
+ {'color': [191, 162, 208], 'isthing': 1, 'id': 89, 'name': 'toothbrush', 'trainId': 79},
+ {'id': 91, 'name': 'banner', 'supercategory': 'textile', 'trainId': 80},
+ {'id': 92, 'name': 'blanket', 'supercategory': 'textile', 'trainId': 81},
+ {'id': 93, 'name': 'branch', 'supercategory': 'plant', 'trainId': 82},
+ {'id': 94, 'name': 'bridge', 'supercategory': 'building', 'trainId': 83},
+ {'id': 95, 'name': 'building-other', 'supercategory': 'building', 'trainId': 84},
+ {'id': 96, 'name': 'bush', 'supercategory': 'plant', 'trainId': 85},
+ {'id': 97, 'name': 'cabinet', 'supercategory': 'furniture-stuff', 'trainId': 86},
+ {'id': 98, 'name': 'cage', 'supercategory': 'structural', 'trainId': 87},
+ {'id': 99, 'name': 'cardboard', 'supercategory': 'raw-material', 'trainId': 88},
+ {'id': 100, 'name': 'carpet', 'supercategory': 'floor', 'trainId': 89},
+ {'id': 101, 'name': 'ceiling-other', 'supercategory': 'ceiling', 'trainId': 90},
+ {'id': 102, 'name': 'ceiling-tile', 'supercategory': 'ceiling', 'trainId': 91},
+ {'id': 103, 'name': 'cloth', 'supercategory': 'textile', 'trainId': 92},
+ {'id': 104, 'name': 'clothes', 'supercategory': 'textile', 'trainId': 93},
+ {'id': 105, 'name': 'clouds', 'supercategory': 'sky', 'trainId': 94},
+ {'id': 106, 'name': 'counter', 'supercategory': 'furniture-stuff', 'trainId': 95},
+ {'id': 107, 'name': 'cupboard', 'supercategory': 'furniture-stuff', 'trainId': 96},
+ {'id': 108, 'name': 'curtain', 'supercategory': 'textile', 'trainId': 97},
+ {'id': 109, 'name': 'desk-stuff', 'supercategory': 'furniture-stuff', 'trainId': 98},
+ {'id': 110, 'name': 'dirt', 'supercategory': 'ground', 'trainId': 99},
+ {'id': 111, 'name': 'door-stuff', 'supercategory': 'furniture-stuff', 'trainId': 100},
+ {'id': 112, 'name': 'fence', 'supercategory': 'structural', 'trainId': 101},
+ {'id': 113, 'name': 'floor-marble', 'supercategory': 'floor', 'trainId': 102},
+ {'id': 114, 'name': 'floor-other', 'supercategory': 'floor', 'trainId': 103},
+ {'id': 115, 'name': 'floor-stone', 'supercategory': 'floor', 'trainId': 104},
+ {'id': 116, 'name': 'floor-tile', 'supercategory': 'floor', 'trainId': 105},
+ {'id': 117, 'name': 'floor-wood', 'supercategory': 'floor', 'trainId': 106},
+ {'id': 118, 'name': 'flower', 'supercategory': 'plant', 'trainId': 107},
+ {'id': 119, 'name': 'fog', 'supercategory': 'water', 'trainId': 108},
+ {'id': 120, 'name': 'food-other', 'supercategory': 'food-stuff', 'trainId': 109},
+ {'id': 121, 'name': 'fruit', 'supercategory': 'food-stuff', 'trainId': 110},
+ {'id': 122, 'name': 'furniture-other', 'supercategory': 'furniture-stuff', 'trainId': 111},
+ {'id': 123, 'name': 'grass', 'supercategory': 'plant', 'trainId': 112},
+ {'id': 124, 'name': 'gravel', 'supercategory': 'ground', 'trainId': 113},
+ {'id': 125, 'name': 'ground-other', 'supercategory': 'ground', 'trainId': 114},
+ {'id': 126, 'name': 'hill', 'supercategory': 'solid', 'trainId': 115},
+ {'id': 127, 'name': 'house', 'supercategory': 'building', 'trainId': 116},
+ {'id': 128, 'name': 'leaves', 'supercategory': 'plant', 'trainId': 117},
+ {'id': 129, 'name': 'light', 'supercategory': 'furniture-stuff', 'trainId': 118},
+ {'id': 130, 'name': 'mat', 'supercategory': 'textile', 'trainId': 119},
+ {'id': 131, 'name': 'metal', 'supercategory': 'raw-material', 'trainId': 120},
+ {'id': 132, 'name': 'mirror-stuff', 'supercategory': 'furniture-stuff', 'trainId': 121},
+ {'id': 133, 'name': 'moss', 'supercategory': 'plant', 'trainId': 122},
+ {'id': 134, 'name': 'mountain', 'supercategory': 'solid', 'trainId': 123},
+ {'id': 135, 'name': 'mud', 'supercategory': 'ground', 'trainId': 124},
+ {'id': 136, 'name': 'napkin', 'supercategory': 'textile', 'trainId': 125},
+ {'id': 137, 'name': 'net', 'supercategory': 'structural', 'trainId': 126},
+ {'id': 138, 'name': 'paper', 'supercategory': 'raw-material', 'trainId': 127},
+ {'id': 139, 'name': 'pavement', 'supercategory': 'ground', 'trainId': 128},
+ {'id': 140, 'name': 'pillow', 'supercategory': 'textile', 'trainId': 129},
+ {'id': 141, 'name': 'plant-other', 'supercategory': 'plant', 'trainId': 130},
+ {'id': 142, 'name': 'plastic', 'supercategory': 'raw-material', 'trainId': 131},
+ {'id': 143, 'name': 'platform', 'supercategory': 'ground', 'trainId': 132},
+ {'id': 144, 'name': 'playingfield', 'supercategory': 'ground', 'trainId': 133},
+ {'id': 145, 'name': 'railing', 'supercategory': 'structural', 'trainId': 134},
+ {'id': 146, 'name': 'railroad', 'supercategory': 'ground', 'trainId': 135},
+ {'id': 147, 'name': 'river', 'supercategory': 'water', 'trainId': 136},
+ {'id': 148, 'name': 'road', 'supercategory': 'ground', 'trainId': 137},
+ {'id': 149, 'name': 'rock', 'supercategory': 'solid', 'trainId': 138},
+ {'id': 150, 'name': 'roof', 'supercategory': 'building', 'trainId': 139},
+ {'id': 151, 'name': 'rug', 'supercategory': 'textile', 'trainId': 140},
+ {'id': 152, 'name': 'salad', 'supercategory': 'food-stuff', 'trainId': 141},
+ {'id': 153, 'name': 'sand', 'supercategory': 'ground', 'trainId': 142},
+ {'id': 154, 'name': 'sea', 'supercategory': 'water', 'trainId': 143},
+ {'id': 155, 'name': 'shelf', 'supercategory': 'furniture-stuff', 'trainId': 144},
+ {'id': 156, 'name': 'sky-other', 'supercategory': 'sky', 'trainId': 145},
+ {'id': 157, 'name': 'skyscraper', 'supercategory': 'building', 'trainId': 146},
+ {'id': 158, 'name': 'snow', 'supercategory': 'ground', 'trainId': 147},
+ {'id': 159, 'name': 'solid-other', 'supercategory': 'solid', 'trainId': 148},
+ {'id': 160, 'name': 'stairs', 'supercategory': 'furniture-stuff', 'trainId': 149},
+ {'id': 161, 'name': 'stone', 'supercategory': 'solid', 'trainId': 150},
+ {'id': 162, 'name': 'straw', 'supercategory': 'plant', 'trainId': 151},
+ {'id': 163, 'name': 'structural-other', 'supercategory': 'structural', 'trainId': 152},
+ {'id': 164, 'name': 'table', 'supercategory': 'furniture-stuff', 'trainId': 153},
+ {'id': 165, 'name': 'tent', 'supercategory': 'building', 'trainId': 154},
+ {'id': 166, 'name': 'textile-other', 'supercategory': 'textile', 'trainId': 155},
+ {'id': 167, 'name': 'towel', 'supercategory': 'textile', 'trainId': 156},
+ {'id': 168, 'name': 'tree', 'supercategory': 'plant', 'trainId': 157},
+ {'id': 169, 'name': 'vegetable', 'supercategory': 'food-stuff', 'trainId': 158},
+ {'id': 170, 'name': 'wall-brick', 'supercategory': 'wall', 'trainId': 159},
+ {'id': 171, 'name': 'wall-concrete', 'supercategory': 'wall', 'trainId': 160},
+ {'id': 172, 'name': 'wall-other', 'supercategory': 'wall', 'trainId': 161},
+ {'id': 173, 'name': 'wall-panel', 'supercategory': 'wall', 'trainId': 162},
+ {'id': 174, 'name': 'wall-stone', 'supercategory': 'wall', 'trainId': 163},
+ {'id': 175, 'name': 'wall-tile', 'supercategory': 'wall', 'trainId': 164},
+ {'id': 176, 'name': 'wall-wood', 'supercategory': 'wall', 'trainId': 165},
+ {'id': 177, 'name': 'water-other', 'supercategory': 'water', 'trainId': 166},
+ {'id': 178, 'name': 'waterdrops', 'supercategory': 'water', 'trainId': 167},
+ {'id': 179, 'name': 'window-blind', 'supercategory': 'window', 'trainId': 168},
+ {'id': 180, 'name': 'window-other', 'supercategory': 'window', 'trainId': 169},
+ {'id': 181, 'name': 'wood', 'supercategory': 'solid', 'trainId': 170}]
+
+
+if __name__ == "__main__":
+ dataset_dir = Path(os.getenv("DETECTRON2_DATASETS", "datasets")) / "coco-stuff"
+
+ id_map = {}
+ for cat in COCO_CATEGORIES:
+ id_map[cat["id"]] = cat["trainId"]
+
+ for name in ["train2017", "val2017"]:
+ annotation_dir = dataset_dir / "annotations" / name
+ output_dir = dataset_dir / "annotations_detectron2" / name
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ for file in tqdm.tqdm(list(annotation_dir.iterdir())):
+ output_file = output_dir / file.name
+ lab = np.asarray(Image.open(file))
+ assert lab.dtype == np.uint8
+
+ output = np.zeros_like(lab, dtype=np.uint8) + 255
+ for obj_id in np.unique(lab):
+ if obj_id in id_map:
+ output[lab == obj_id] = id_map[obj_id]
+
+ Image.fromarray(output).save(output_file)
\ No newline at end of file
diff --git a/datasets/prepare_pascal_context.py b/datasets/prepare_pascal_context.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3042c624187f1188e5835dc06bdc365e0d90755
--- /dev/null
+++ b/datasets/prepare_pascal_context.py
@@ -0,0 +1,69 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved
+
+import tqdm
+import os
+import os.path as osp
+from pathlib import Path
+
+import numpy as np
+from PIL import Image
+import scipy.io
+
+def convert_pc59(mask_path, new_mask_path, pc59_dict):
+ mat = scipy.io.loadmat(mask_path)
+ mask = mat['LabelMap']
+
+ mask_copy = np.ones_like(mask, dtype=np.uint8) * 255
+ for trID, clsID in pc59_dict.items():
+ mask_copy[mask == clsID] = trID
+
+ min_value = np.amin(mask_copy)
+ assert min_value >= 0, print(min_value)
+ Image.fromarray(mask_copy).save(new_mask_path, "PNG")
+
+def convert_pc459(mask_path, new_mask_path):
+ mat = scipy.io.loadmat(mask_path)
+ mask = mat['LabelMap']
+ mask = mask - 1
+ min_value = np.amin(mask)
+ assert min_value >= 0, print(min_value)
+ Image.fromarray(mask).save(new_mask_path, "TIFF")
+
+
+if __name__ == "__main__":
+ dataset_dir = Path(os.getenv("DETECTRON2_DATASETS", "datasets"))
+ print('Caution: we only generate the validation set!')
+ pc_path = dataset_dir / "VOCdevkit/VOC2010"
+
+ val_list = open(pc_path / "pascalcontext_val.txt", "r")
+ pc459_labels = open(pc_path / "labels.txt", "r")
+ pc59_labels = open(pc_path / "59_labels.txt", "r")
+
+ pc459_dict = {}
+ for line in pc459_labels.readlines():
+ if ':' in line:
+ idx, name = line.split(':')
+ idx = int(idx.strip())
+ name = name.strip()
+ pc459_dict[name] = idx
+
+ pc59_dict = {}
+ for i, line in enumerate(pc59_labels.readlines()):
+ name = line.split(':')[-1].strip()
+ if name is not '':
+ pc59_dict[i] = pc459_dict[name]
+
+ pc459_dir = pc_path / "annotations_detectron2" / "pc459_val"
+ pc459_dir.mkdir(parents=True, exist_ok=True)
+ pc59_dir = pc_path / "annotations_detectron2" / "pc59_val"
+ pc59_dir.mkdir(parents=True, exist_ok=True)
+
+ for line in tqdm.tqdm(val_list.readlines()):
+ fileid = line.strip()
+ ori_mask = f'{pc_path}/trainval/{fileid}.mat'
+ pc459_dst = f'{pc459_dir}/{fileid}.tif'
+ pc59_dst = f'{pc59_dir}/{fileid}.png'
+ if osp.exists(ori_mask):
+ convert_pc459(ori_mask, pc459_dst)
+ convert_pc59(ori_mask, pc59_dst, pc59_dict)
\ No newline at end of file
diff --git a/datasets/prepare_voc.py b/datasets/prepare_voc.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ab2ca43ada301d72ec09df61c82bf30d2f20036
--- /dev/null
+++ b/datasets/prepare_voc.py
@@ -0,0 +1,77 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved
+# Modified by Feng Liang from https://github.com/MendelXu/zsseg.baseline/blob/master/datasets/prepare_voc_sem_seg.py
+# Modified by Heeseong Shin from https://github.com/facebookresearch/ov-seg/blob/main/datasets/prepare_voc_sem_seg.py
+
+import os
+import os.path as osp
+from pathlib import Path
+import tqdm
+
+import numpy as np
+from PIL import Image
+
+
+clsID_to_trID = {
+ 0: 255,
+ 1: 0,
+ 2: 1,
+ 3: 2,
+ 4: 3,
+ 5: 4,
+ 6: 5,
+ 7: 6,
+ 8: 7,
+ 9: 8,
+ 10: 9,
+ 11: 10,
+ 12: 11,
+ 13: 12,
+ 14: 13,
+ 15: 14,
+ 16: 15,
+ 17: 16,
+ 18: 17,
+ 19: 18,
+ 20: 19,
+ 255: 255,
+}
+clsID_to_trID_bg = clsID_to_trID.copy()
+clsID_to_trID_bg[0] = 20
+
+def convert_to_trainID(
+ maskpath, out_mask_dir, is_train, clsID_to_trID=clsID_to_trID, suffix=""
+):
+ mask = np.array(Image.open(maskpath))
+ mask_copy = np.ones_like(mask, dtype=np.uint8) * 255
+ for clsID, trID in clsID_to_trID.items():
+ mask_copy[mask == clsID] = trID
+ seg_filename = (
+ osp.join(out_mask_dir, "train" + suffix, osp.basename(maskpath))
+ if is_train
+ else osp.join(out_mask_dir, "val" + suffix, osp.basename(maskpath))
+ )
+ if len(np.unique(mask_copy)) == 1 and np.unique(mask_copy)[0] == 255:
+ return
+ Image.fromarray(mask_copy).save(seg_filename, "PNG")
+
+
+
+if __name__ == "__main__":
+ dataset_dir = Path(os.getenv("DETECTRON2_DATASETS", "datasets"))
+ print('Caution: we only generate the validation set!')
+ voc_path = dataset_dir / "VOCdevkit" / "VOC2012"
+ out_mask_dir = voc_path / "annotations_detectron2"
+ out_mask_dir_bg = voc_path / "annotations_detectron2_bg"
+ #out_image_dir = voc_path / "images_detectron2"
+ for name in ["val"]:
+ os.makedirs((out_mask_dir / name), exist_ok=True)
+ os.makedirs((out_mask_dir_bg / name), exist_ok=True)
+ #os.makedirs((out_image_dir / name), exist_ok=True)
+ val_list = [
+ osp.join(voc_path, "SegmentationClassAug", f + ".png")
+ for f in np.loadtxt(osp.join(voc_path, "ImageSets/Segmentation/val.txt"), dtype=np.str).tolist()
+ ]
+ for file in tqdm.tqdm(val_list):
+ convert_to_trainID(file, out_mask_dir, is_train=False)
+ convert_to_trainID(file, out_mask_dir_bg, is_train=False, clsID_to_trID=clsID_to_trID_bg)
\ No newline at end of file
diff --git a/datasets/voc20.json b/datasets/voc20.json
new file mode 100755
index 0000000000000000000000000000000000000000..b38e07ba1bd78eab8d3589a5f396c632fcfc13ac
--- /dev/null
+++ b/datasets/voc20.json
@@ -0,0 +1,2 @@
+["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"]
+
diff --git a/datasets/voc20b.json b/datasets/voc20b.json
new file mode 100755
index 0000000000000000000000000000000000000000..37f95f77b88dd8c94fca5702c71351b7704ea10e
--- /dev/null
+++ b/datasets/voc20b.json
@@ -0,0 +1 @@
+["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor", "bag", "bed", "bench", "book", "building", "cabinet", "ceiling", "cloth", "computer", "cup", "door", "fence", "floor", "flower", "food", "grass", "ground", "keyboard", "light", "mountain", "mouse", "curtain", "platform", "sign", "plate", "road", "rock", "shelves", "sidewalk", "sky", "snow", "bedclothes", "track", "tree", "truck", "wall", "water", "window", "wood"]
diff --git a/demo/__pycache__/predictor.cpython-38.pyc b/demo/__pycache__/predictor.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a70fa383473ff97b01cc4becf3f525a516352a31
Binary files /dev/null and b/demo/__pycache__/predictor.cpython-38.pyc differ
diff --git a/demo/demo.py b/demo/demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..2105b06ba1aa7fcafac51965035df5905ec974d7
--- /dev/null
+++ b/demo/demo.py
@@ -0,0 +1,194 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Modified by Bowen Cheng from: https://github.com/facebookresearch/detectron2/blob/master/demo/demo.py
+import argparse
+import glob
+import multiprocessing as mp
+import os
+
+# fmt: off
+import sys
+sys.path.insert(1, os.path.join(sys.path[0], '..'))
+# fmt: on
+
+import tempfile
+import time
+import warnings
+
+import cv2
+import numpy as np
+import tqdm
+
+from detectron2.config import get_cfg
+from detectron2.data.detection_utils import read_image
+from detectron2.projects.deeplab import add_deeplab_config
+from detectron2.utils.logger import setup_logger
+
+from mask_former import add_mask_former_config
+from predictor import VisualizationDemo
+
+
+# constants
+WINDOW_NAME = "MaskFormer demo"
+
+
+def setup_cfg(args):
+ # load config from file and command-line arguments
+ cfg = get_cfg()
+ add_deeplab_config(cfg)
+ add_mask_former_config(cfg)
+ cfg.merge_from_file(args.config_file)
+ cfg.merge_from_list(args.opts)
+ cfg.freeze()
+ return cfg
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(description="Detectron2 demo for builtin configs")
+ parser.add_argument(
+ "--config-file",
+ default="configs/ade20k-150/maskformer_R50_bs16_160k.yaml",
+ metavar="FILE",
+ help="path to config file",
+ )
+ parser.add_argument("--webcam", action="store_true", help="Take inputs from webcam.")
+ parser.add_argument("--video-input", help="Path to video file.")
+ parser.add_argument(
+ "--input",
+ nargs="+",
+ help="A list of space separated input images; "
+ "or a single glob pattern such as 'directory/*.jpg'",
+ )
+ parser.add_argument(
+ "--output",
+ help="A file or directory to save output visualizations. "
+ "If not given, will show output in an OpenCV window.",
+ )
+
+ parser.add_argument(
+ "--confidence-threshold",
+ type=float,
+ default=0.5,
+ help="Minimum score for instance predictions to be shown",
+ )
+ parser.add_argument(
+ "--opts",
+ help="Modify config options using the command-line 'KEY VALUE' pairs",
+ default=[],
+ nargs=argparse.REMAINDER,
+ )
+ return parser
+
+
+def test_opencv_video_format(codec, file_ext):
+ with tempfile.TemporaryDirectory(prefix="video_format_test") as dir:
+ filename = os.path.join(dir, "test_file" + file_ext)
+ writer = cv2.VideoWriter(
+ filename=filename,
+ fourcc=cv2.VideoWriter_fourcc(*codec),
+ fps=float(30),
+ frameSize=(10, 10),
+ isColor=True,
+ )
+ [writer.write(np.zeros((10, 10, 3), np.uint8)) for _ in range(30)]
+ writer.release()
+ if os.path.isfile(filename):
+ return True
+ return False
+
+
+if __name__ == "__main__":
+ mp.set_start_method("spawn", force=True)
+ args = get_parser().parse_args()
+ setup_logger(name="fvcore")
+ logger = setup_logger()
+ logger.info("Arguments: " + str(args))
+
+ cfg = setup_cfg(args)
+
+ demo = VisualizationDemo(cfg)
+
+ if args.input:
+ if len(args.input) == 1:
+ args.input = glob.glob(os.path.expanduser(args.input[0]))
+ assert args.input, "The input path(s) was not found"
+ for path in tqdm.tqdm(args.input, disable=not args.output):
+ # use PIL, to be consistent with evaluation
+ img = read_image(path, format="BGR")
+ start_time = time.time()
+ predictions, visualized_output = demo.run_on_image(img)
+ logger.info(
+ "{}: {} in {:.2f}s".format(
+ path,
+ "detected {} instances".format(len(predictions["instances"]))
+ if "instances" in predictions
+ else "finished",
+ time.time() - start_time,
+ )
+ )
+
+ if args.output:
+ if os.path.isdir(args.output):
+ assert os.path.isdir(args.output), args.output
+ out_filename = os.path.join(args.output, os.path.basename(path))
+ else:
+ assert len(args.input) == 1, "Please specify a directory with args.output"
+ out_filename = args.output
+ visualized_output.save(out_filename)
+ else:
+ cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL)
+ cv2.imshow(WINDOW_NAME, visualized_output.get_image()[:, :, ::-1])
+ if cv2.waitKey(0) == 27:
+ break # esc to quit
+ elif args.webcam:
+ assert args.input is None, "Cannot have both --input and --webcam!"
+ assert args.output is None, "output not yet supported with --webcam!"
+ cam = cv2.VideoCapture(0)
+ for vis in tqdm.tqdm(demo.run_on_video(cam)):
+ cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL)
+ cv2.imshow(WINDOW_NAME, vis)
+ if cv2.waitKey(1) == 27:
+ break # esc to quit
+ cam.release()
+ cv2.destroyAllWindows()
+ elif args.video_input:
+ video = cv2.VideoCapture(args.video_input)
+ width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
+ height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
+ frames_per_second = video.get(cv2.CAP_PROP_FPS)
+ num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
+ basename = os.path.basename(args.video_input)
+ codec, file_ext = (
+ ("x264", ".mkv") if test_opencv_video_format("x264", ".mkv") else ("mp4v", ".mp4")
+ )
+ if codec == ".mp4v":
+ warnings.warn("x264 codec not available, switching to mp4v")
+ if args.output:
+ if os.path.isdir(args.output):
+ output_fname = os.path.join(args.output, basename)
+ output_fname = os.path.splitext(output_fname)[0] + file_ext
+ else:
+ output_fname = args.output
+ assert not os.path.isfile(output_fname), output_fname
+ output_file = cv2.VideoWriter(
+ filename=output_fname,
+ # some installation of opencv may not support x264 (due to its license),
+ # you can try other format (e.g. MPEG)
+ fourcc=cv2.VideoWriter_fourcc(*codec),
+ fps=float(frames_per_second),
+ frameSize=(width, height),
+ isColor=True,
+ )
+ assert os.path.isfile(args.video_input)
+ for vis_frame in tqdm.tqdm(demo.run_on_video(video), total=num_frames):
+ if args.output:
+ output_file.write(vis_frame)
+ else:
+ cv2.namedWindow(basename, cv2.WINDOW_NORMAL)
+ cv2.imshow(basename, vis_frame)
+ if cv2.waitKey(1) == 27:
+ break # esc to quit
+ video.release()
+ if args.output:
+ output_file.release()
+ else:
+ cv2.destroyAllWindows()
diff --git a/demo/demo_visual_gt.py b/demo/demo_visual_gt.py
new file mode 100644
index 0000000000000000000000000000000000000000..e828860ae5982fba62319b41c42a7311880539a9
--- /dev/null
+++ b/demo/demo_visual_gt.py
@@ -0,0 +1,210 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Modified by Bowen Cheng from: https://github.com/facebookresearch/detectron2/blob/master/demo/demo.py
+import argparse
+import glob
+import multiprocessing as mp
+import os
+
+# fmt: off
+import sys
+sys.path.insert(1, os.path.join(sys.path[0], '..'))
+# fmt: on
+
+import tempfile
+import time
+import warnings
+
+import cv2
+import numpy as np
+import tqdm
+
+from detectron2.config import get_cfg
+from detectron2.data.detection_utils import read_image
+from detectron2.projects.deeplab import add_deeplab_config
+from detectron2.utils.logger import setup_logger
+
+from mask_former import add_mask_former_config
+# from predictor import VisualizationDemo
+from visualizer import VisualizationGt
+from PIL import Image
+
+# constants
+WINDOW_NAME = "MaskFormer demo"
+
+
+def setup_cfg(args):
+ # load config from file and command-line arguments
+ cfg = get_cfg()
+ add_deeplab_config(cfg)
+ add_mask_former_config(cfg)
+ cfg.merge_from_file(args.config_file)
+ cfg.merge_from_list(args.opts)
+ cfg.freeze()
+ return cfg
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(description="Detectron2 demo for builtin configs")
+ parser.add_argument(
+ "--config-file",
+ default="configs/ade20k-150/maskformer_R50_bs16_160k.yaml",
+ metavar="FILE",
+ help="path to config file",
+ )
+ parser.add_argument("--webcam", action="store_true", help="Take inputs from webcam.")
+ parser.add_argument("--video-input", help="Path to video file.")
+ parser.add_argument(
+ "--input",
+ nargs="+",
+ help="A list of space separated input images; "
+ "or a single glob pattern such as 'directory/*.jpg'",
+ )
+ # parser.add_argument(
+ # "--gt",
+ # nargs="+",
+ # help="A list of space seperated ground truth images;"
+ # "or a single glob pattern such as 'directory/*.png'"
+ # )
+ parser.add_argument(
+ "--gt",
+ # type="str",
+ help="ground truth path of segmentation"
+ )
+ parser.add_argument(
+ "--output",
+ help="A file or directory to save output visualizations. "
+ "If not given, will show output in an OpenCV window.",
+ )
+
+ parser.add_argument(
+ "--confidence-threshold",
+ type=float,
+ default=0.5,
+ help="Minimum score for instance predictions to be shown",
+ )
+ parser.add_argument(
+ "--opts",
+ help="Modify config options using the command-line 'KEY VALUE' pairs",
+ default=[],
+ nargs=argparse.REMAINDER,
+ )
+ return parser
+
+
+def test_opencv_video_format(codec, file_ext):
+ with tempfile.TemporaryDirectory(prefix="video_format_test") as dir:
+ filename = os.path.join(dir, "test_file" + file_ext)
+ writer = cv2.VideoWriter(
+ filename=filename,
+ fourcc=cv2.VideoWriter_fourcc(*codec),
+ fps=float(30),
+ frameSize=(10, 10),
+ isColor=True,
+ )
+ [writer.write(np.zeros((10, 10, 3), np.uint8)) for _ in range(30)]
+ writer.release()
+ if os.path.isfile(filename):
+ return True
+ return False
+
+
+if __name__ == "__main__":
+ mp.set_start_method("spawn", force=True)
+ args = get_parser().parse_args()
+ setup_logger(name="fvcore")
+ logger = setup_logger()
+ logger.info("Arguments: " + str(args))
+
+ cfg = setup_cfg(args)
+
+ demo = VisualizationGt(cfg)
+ gt_path = args.gt
+ if args.input:
+ if len(args.input) == 1:
+ args.input = glob.glob(os.path.expanduser(args.input[0]))
+ assert args.input, "The input path(s) was not found"
+ for path in tqdm.tqdm(args.input, disable=not args.output):
+ # use PIL, to be consistent with evaluation
+ img = read_image(path, format="BGR")
+ start_time = time.time()
+ predictions = {}
+ gt_file = os.path.join(gt_path, os.path.splitext(os.path.basename(path))[0] + '.png')
+ # import pdb; pdb.set_trace()
+ predictions['sem_seg'] = np.asarray(Image.open(gt_file))
+ predictions, visualized_output = demo.run_on_image(img, predictions)
+ logger.info(
+ "{}: {} in {:.2f}s".format(
+ path,
+ "detected {} instances".format(len(predictions["instances"]))
+ if "instances" in predictions
+ else "finished",
+ time.time() - start_time,
+ )
+ )
+
+ if args.output:
+ if os.path.isdir(args.output):
+ assert os.path.isdir(args.output), args.output
+ out_filename = os.path.join(args.output, os.path.basename(path))
+ else:
+ assert len(args.input) == 1, "Please specify a directory with args.output"
+ out_filename = args.output
+ visualized_output.save(out_filename)
+ else:
+ cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL)
+ cv2.imshow(WINDOW_NAME, visualized_output.get_image()[:, :, ::-1])
+ if cv2.waitKey(0) == 27:
+ break # esc to quit
+ elif args.webcam:
+ assert args.input is None, "Cannot have both --input and --webcam!"
+ assert args.output is None, "output not yet supported with --webcam!"
+ cam = cv2.VideoCapture(0)
+ for vis in tqdm.tqdm(demo.run_on_video(cam)):
+ cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL)
+ cv2.imshow(WINDOW_NAME, vis)
+ if cv2.waitKey(1) == 27:
+ break # esc to quit
+ cam.release()
+ cv2.destroyAllWindows()
+ elif args.video_input:
+ video = cv2.VideoCapture(args.video_input)
+ width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
+ height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
+ frames_per_second = video.get(cv2.CAP_PROP_FPS)
+ num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
+ basename = os.path.basename(args.video_input)
+ codec, file_ext = (
+ ("x264", ".mkv") if test_opencv_video_format("x264", ".mkv") else ("mp4v", ".mp4")
+ )
+ if codec == ".mp4v":
+ warnings.warn("x264 codec not available, switching to mp4v")
+ if args.output:
+ if os.path.isdir(args.output):
+ output_fname = os.path.join(args.output, basename)
+ output_fname = os.path.splitext(output_fname)[0] + file_ext
+ else:
+ output_fname = args.output
+ assert not os.path.isfile(output_fname), output_fname
+ output_file = cv2.VideoWriter(
+ filename=output_fname,
+ # some installation of opencv may not support x264 (due to its license),
+ # you can try other format (e.g. MPEG)
+ fourcc=cv2.VideoWriter_fourcc(*codec),
+ fps=float(frames_per_second),
+ frameSize=(width, height),
+ isColor=True,
+ )
+ assert os.path.isfile(args.video_input)
+ for vis_frame in tqdm.tqdm(demo.run_on_video(video), total=num_frames):
+ if args.output:
+ output_file.write(vis_frame)
+ else:
+ cv2.namedWindow(basename, cv2.WINDOW_NORMAL)
+ cv2.imshow(basename, vis_frame)
+ if cv2.waitKey(1) == 27:
+ break # esc to quit
+ video.release()
+ if args.output:
+ output_file.release()
+ else:
+ cv2.destroyAllWindows()
diff --git a/demo/predictor.py b/demo/predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd9e2eb5e12e195fbbe838a8d169f8dbbeffca16
--- /dev/null
+++ b/demo/predictor.py
@@ -0,0 +1,261 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Copied from: https://github.com/facebookresearch/detectron2/blob/master/demo/predictor.py
+import atexit
+import bisect
+import multiprocessing as mp
+from collections import deque
+
+import cv2
+import torch
+
+from detectron2.data import MetadataCatalog
+from detectron2.engine.defaults import DefaultPredictor
+from detectron2.utils.video_visualizer import VideoVisualizer
+from detectron2.utils.visualizer import ColorMode, Visualizer
+
+from cat_seg.third_party import imagenet_templates
+from types import SimpleNamespace as ns
+
+class VisualizationDemo(object):
+ def __init__(self, cfg, instance_mode=ColorMode.IMAGE, parallel=False, text=None):
+ """
+ Args:
+ cfg (CfgNode):
+ instance_mode (ColorMode):
+ parallel (bool): whether to run the model in different processes from visualization.
+ Useful since the visualization logic can be slow.
+ """
+ self.metadata = MetadataCatalog.get(
+ cfg.DATASETS.TEST[0] if len(cfg.DATASETS.TEST) else "__unused"
+ )
+ self.cpu_device = torch.device("cpu")
+ self.instance_mode = instance_mode
+
+ self.parallel = parallel
+ if parallel:
+ num_gpu = torch.cuda.device_count()
+ self.predictor = AsyncPredictor(cfg, num_gpus=num_gpu)
+ else:
+ self.predictor = DefaultPredictor(cfg)
+
+ # set classes
+ templates = ['A photo of a {} in the scene',]
+ #templates = imagenet_templates.IMAGENET_TEMPLATES
+ if text is not None:
+ pred = self.predictor.model.sem_seg_head.predictor
+ pred.test_class_texts = [t.strip() for t in text.split(',')]
+ pred.text_features_test = pred.class_embeddings(pred.test_class_texts,
+ templates,
+ pred.clip_model).permute(1, 0, 2).float()
+ if len(templates) == 1:
+ pred.text_features_test = pred.text_features_test.repeat(1, 80, 1)
+ self.metadata = ns()
+ self.metadata.stuff_classes = pred.test_class_texts
+
+ self.filter_background = False
+
+ def run_on_image(self, image, text=None, use_sam=False):
+ """
+ Args:
+ image (np.ndarray): an image of shape (H, W, C) (in BGR order).
+ This is the format used by OpenCV.
+ Returns:
+ predictions (dict): the output of the model.
+ vis_output (VisImage): the visualized image output.
+ """
+ vis_output = None
+
+ if text is not None:
+ pred = self.predictor.model.sem_seg_head.predictor
+ pred.test_class_texts = text.split(',')
+ pred.text_features_test = pred.class_embeddings(pred.test_class_texts,
+ #imagenet_templates.IMAGENET_TEMPLATES,
+ ['A photo of a {} in the scene',],
+ pred.clip_model).permute(1, 0, 2).float().repeat(1, 80, 1)
+ self.metadata = ns()
+ self.metadata.stuff_classes = pred.test_class_texts
+ self.metadata.thing_classes = pred.test_class_texts
+
+ self.predictor.model.use_sam = use_sam
+
+ predictions = self.predictor(image)
+ # Convert image from OpenCV BGR format to Matplotlib RGB format.
+ image = image[:, :, ::-1]
+ visualizer = Visualizer(image, self.metadata, instance_mode=self.instance_mode)
+ #import pdb; pdb.set_trace()
+ if "panoptic_seg" in predictions:
+ panoptic_seg, segments_info = predictions["panoptic_seg"]
+ vis_output = visualizer.draw_panoptic_seg_predictions(
+ panoptic_seg.to(self.cpu_device), segments_info,
+ alpha=0.5,
+ )
+ else:
+ if "sem_seg" in predictions:
+ vis_output = visualizer.draw_sem_seg(
+ self.filter_bg(predictions["sem_seg"].argmax(dim=0).to(self.cpu_device)),
+ alpha=0.5,
+ )
+ if "instances" in predictions:
+ instances = predictions["instances"].to(self.cpu_device)
+ vis_output = visualizer.draw_instance_predictions(predictions=instances)
+
+ return predictions, vis_output
+
+ def _frame_from_video(self, video):
+ while video.isOpened():
+ success, frame = video.read()
+ if success:
+ yield frame
+ else:
+ break
+
+ def run_on_video(self, video):
+ """
+ Visualizes predictions on frames of the input video.
+ Args:
+ video (cv2.VideoCapture): a :class:`VideoCapture` object, whose source can be
+ either a webcam or a video file.
+ Yields:
+ ndarray: BGR visualizations of each video frame.
+ """
+ video_visualizer = VideoVisualizer(self.metadata, self.instance_mode)
+
+ def process_predictions(frame, predictions):
+ import pdb; pdb.set_trace()
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+ if "panoptic_seg" in predictions:
+ panoptic_seg, segments_info = predictions["panoptic_seg"]
+ vis_frame = video_visualizer.draw_panoptic_seg_predictions(
+ frame, panoptic_seg.to(self.cpu_device), segments_info
+ )
+ elif "instances" in predictions:
+ predictions = predictions["instances"].to(self.cpu_device)
+ vis_frame = video_visualizer.draw_instance_predictions(frame, predictions)
+ elif "sem_seg" in predictions:
+ vis_frame = video_visualizer.draw_sem_seg(
+ frame,
+ predictions["sem_seg"].argmax(dim=0).to(self.cpu_device),
+ )
+
+ # Converts Matplotlib RGB format to OpenCV BGR format
+ vis_frame = cv2.cvtColor(vis_frame.get_image(), cv2.COLOR_RGB2BGR)
+ return vis_frame
+
+ frame_gen = self._frame_from_video(video)
+ if self.parallel:
+ buffer_size = self.predictor.default_buffer_size
+
+ frame_data = deque()
+
+ for cnt, frame in enumerate(frame_gen):
+ frame_data.append(frame)
+ self.predictor.put(frame)
+
+ if cnt >= buffer_size:
+ frame = frame_data.popleft()
+ predictions = self.predictor.get()
+ yield process_predictions(frame, predictions)
+
+ while len(frame_data):
+ frame = frame_data.popleft()
+ predictions = self.predictor.get()
+ yield process_predictions(frame, predictions)
+ else:
+ for frame in frame_gen:
+ yield process_predictions(frame, self.predictor(frame))
+
+ def filter_bg(self, pred):
+ if self.filter_background:
+ pred[pred == 0] = 255
+ return pred
+
+
+class AsyncPredictor:
+ """
+ A predictor that runs the model asynchronously, possibly on >1 GPUs.
+ Because rendering the visualization takes considerably amount of time,
+ this helps improve throughput a little bit when rendering videos.
+ """
+
+ class _StopToken:
+ pass
+
+ class _PredictWorker(mp.Process):
+ def __init__(self, cfg, task_queue, result_queue):
+ self.cfg = cfg
+ self.task_queue = task_queue
+ self.result_queue = result_queue
+ super().__init__()
+
+ def run(self):
+ predictor = DefaultPredictor(self.cfg)
+
+ while True:
+ task = self.task_queue.get()
+ if isinstance(task, AsyncPredictor._StopToken):
+ break
+ idx, data = task
+ result = predictor(data)
+ self.result_queue.put((idx, result))
+
+ def __init__(self, cfg, num_gpus: int = 1):
+ """
+ Args:
+ cfg (CfgNode):
+ num_gpus (int): if 0, will run on CPU
+ """
+ num_workers = max(num_gpus, 1)
+ self.task_queue = mp.Queue(maxsize=num_workers * 3)
+ self.result_queue = mp.Queue(maxsize=num_workers * 3)
+ self.procs = []
+ for gpuid in range(max(num_gpus, 1)):
+ cfg = cfg.clone()
+ cfg.defrost()
+ cfg.MODEL.DEVICE = "cuda:{}".format(gpuid) if num_gpus > 0 else "cpu"
+ self.procs.append(
+ AsyncPredictor._PredictWorker(cfg, self.task_queue, self.result_queue)
+ )
+
+ self.put_idx = 0
+ self.get_idx = 0
+ self.result_rank = []
+ self.result_data = []
+
+ for p in self.procs:
+ p.start()
+ atexit.register(self.shutdown)
+
+ def put(self, image):
+ self.put_idx += 1
+ self.task_queue.put((self.put_idx, image))
+
+ def get(self):
+ self.get_idx += 1 # the index needed for this request
+ if len(self.result_rank) and self.result_rank[0] == self.get_idx:
+ res = self.result_data[0]
+ del self.result_data[0], self.result_rank[0]
+ return res
+
+ while True:
+ # make sure the results are returned in the correct order
+ idx, res = self.result_queue.get()
+ if idx == self.get_idx:
+ return res
+ insert = bisect.bisect(self.result_rank, idx)
+ self.result_rank.insert(insert, idx)
+ self.result_data.insert(insert, res)
+
+ def __len__(self):
+ return self.put_idx - self.get_idx
+
+ def __call__(self, image):
+ self.put(image)
+ return self.get()
+
+ def shutdown(self):
+ for _ in self.procs:
+ self.task_queue.put(AsyncPredictor._StopToken())
+
+ @property
+ def default_buffer_size(self):
+ return len(self.procs) * 5
diff --git a/demo/visualizer.py b/demo/visualizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..9903b7e49715d5df0ea861aec164e047abb09a80
--- /dev/null
+++ b/demo/visualizer.py
@@ -0,0 +1,219 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Copied from: https://github.com/facebookresearch/detectron2/blob/master/demo/predictor.py
+import atexit
+import bisect
+import multiprocessing as mp
+from collections import deque
+
+import cv2
+import torch
+
+from detectron2.data import MetadataCatalog
+from detectron2.engine.defaults import DefaultPredictor
+from detectron2.utils.video_visualizer import VideoVisualizer
+from detectron2.utils.visualizer import ColorMode, Visualizer
+
+
+class VisualizationGt(object):
+ def __init__(self, cfg, instance_mode=ColorMode.IMAGE, parallel=False):
+ """
+ Args:
+ cfg (CfgNode):
+ instance_mode (ColorMode):
+ parallel (bool): whether to run the model in different processes from visualization.
+ Useful since the visualization logic can be slow.
+ """
+ self.metadata = MetadataCatalog.get(
+ cfg.DATASETS.TEST[0] if len(cfg.DATASETS.TEST) else "__unused"
+ )
+ self.cpu_device = torch.device("cpu")
+ self.instance_mode = instance_mode
+
+ self.parallel = parallel
+ if parallel:
+ num_gpu = torch.cuda.device_count()
+ self.predictor = AsyncPredictor(cfg, num_gpus=num_gpu)
+ else:
+ self.predictor = DefaultPredictor(cfg)
+
+ def run_on_image(self, image, predictions):
+ """
+ Args:
+ image (np.ndarray): an image of shape (H, W, C) (in BGR order).
+ This is the format used by OpenCV.
+ Returns:
+ predictions (dict): the output of the model.
+ vis_output (VisImage): the visualized image output.
+ """
+ vis_output = None
+ # predictions = self.predictor(image)
+ # Convert image from OpenCV BGR format to Matplotlib RGB format.
+ image = image[:, :, ::-1]
+ visualizer = Visualizer(image, self.metadata, instance_mode=self.instance_mode)
+ if "panoptic_seg" in predictions:
+ panoptic_seg, segments_info = predictions["panoptic_seg"]
+ vis_output = visualizer.draw_panoptic_seg_predictions(
+ panoptic_seg.to(self.cpu_device), segments_info
+ )
+ else:
+ if "sem_seg" in predictions:
+ vis_output = visualizer.draw_sem_seg(
+ predictions["sem_seg"]
+ )
+ if "instances" in predictions:
+ instances = predictions["instances"].to(self.cpu_device)
+ vis_output = visualizer.draw_instance_predictions(predictions=instances)
+
+ return predictions, vis_output
+
+ def _frame_from_video(self, video):
+ while video.isOpened():
+ success, frame = video.read()
+ if success:
+ yield frame
+ else:
+ break
+
+ def run_on_video(self, video):
+ """
+ Visualizes predictions on frames of the input video.
+ Args:
+ video (cv2.VideoCapture): a :class:`VideoCapture` object, whose source can be
+ either a webcam or a video file.
+ Yields:
+ ndarray: BGR visualizations of each video frame.
+ """
+ video_visualizer = VideoVisualizer(self.metadata, self.instance_mode)
+
+ def process_predictions(frame, predictions):
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+ if "panoptic_seg" in predictions:
+ panoptic_seg, segments_info = predictions["panoptic_seg"]
+ vis_frame = video_visualizer.draw_panoptic_seg_predictions(
+ frame, panoptic_seg.to(self.cpu_device), segments_info
+ )
+ elif "instances" in predictions:
+ predictions = predictions["instances"].to(self.cpu_device)
+ vis_frame = video_visualizer.draw_instance_predictions(frame, predictions)
+ elif "sem_seg" in predictions:
+ vis_frame = video_visualizer.draw_sem_seg(
+ frame, predictions["sem_seg"].argmax(dim=0).to(self.cpu_device)
+ )
+
+ # Converts Matplotlib RGB format to OpenCV BGR format
+ vis_frame = cv2.cvtColor(vis_frame.get_image(), cv2.COLOR_RGB2BGR)
+ return vis_frame
+
+ frame_gen = self._frame_from_video(video)
+ if self.parallel:
+ buffer_size = self.predictor.default_buffer_size
+
+ frame_data = deque()
+
+ for cnt, frame in enumerate(frame_gen):
+ frame_data.append(frame)
+ self.predictor.put(frame)
+
+ if cnt >= buffer_size:
+ frame = frame_data.popleft()
+ predictions = self.predictor.get()
+ yield process_predictions(frame, predictions)
+
+ while len(frame_data):
+ frame = frame_data.popleft()
+ predictions = self.predictor.get()
+ yield process_predictions(frame, predictions)
+ else:
+ for frame in frame_gen:
+ yield process_predictions(frame, self.predictor(frame))
+
+
+class AsyncPredictor:
+ """
+ A predictor that runs the model asynchronously, possibly on >1 GPUs.
+ Because rendering the visualization takes considerably amount of time,
+ this helps improve throughput a little bit when rendering videos.
+ """
+
+ class _StopToken:
+ pass
+
+ class _PredictWorker(mp.Process):
+ def __init__(self, cfg, task_queue, result_queue):
+ self.cfg = cfg
+ self.task_queue = task_queue
+ self.result_queue = result_queue
+ super().__init__()
+
+ def run(self):
+ predictor = DefaultPredictor(self.cfg)
+
+ while True:
+ task = self.task_queue.get()
+ if isinstance(task, AsyncPredictor._StopToken):
+ break
+ idx, data = task
+ result = predictor(data)
+ self.result_queue.put((idx, result))
+
+ def __init__(self, cfg, num_gpus: int = 1):
+ """
+ Args:
+ cfg (CfgNode):
+ num_gpus (int): if 0, will run on CPU
+ """
+ num_workers = max(num_gpus, 1)
+ self.task_queue = mp.Queue(maxsize=num_workers * 3)
+ self.result_queue = mp.Queue(maxsize=num_workers * 3)
+ self.procs = []
+ for gpuid in range(max(num_gpus, 1)):
+ cfg = cfg.clone()
+ cfg.defrost()
+ cfg.MODEL.DEVICE = "cuda:{}".format(gpuid) if num_gpus > 0 else "cpu"
+ self.procs.append(
+ AsyncPredictor._PredictWorker(cfg, self.task_queue, self.result_queue)
+ )
+
+ self.put_idx = 0
+ self.get_idx = 0
+ self.result_rank = []
+ self.result_data = []
+
+ for p in self.procs:
+ p.start()
+ atexit.register(self.shutdown)
+
+ def put(self, image):
+ self.put_idx += 1
+ self.task_queue.put((self.put_idx, image))
+
+ def get(self):
+ self.get_idx += 1 # the index needed for this request
+ if len(self.result_rank) and self.result_rank[0] == self.get_idx:
+ res = self.result_data[0]
+ del self.result_data[0], self.result_rank[0]
+ return res
+
+ while True:
+ # make sure the results are returned in the correct order
+ idx, res = self.result_queue.get()
+ if idx == self.get_idx:
+ return res
+ insert = bisect.bisect(self.result_rank, idx)
+ self.result_rank.insert(insert, idx)
+ self.result_data.insert(insert, res)
+
+ def __len__(self):
+ return self.put_idx - self.get_idx
+
+ def __call__(self, image):
+ self.put(image)
+ return self.get()
+
+ def shutdown(self):
+ for _ in self.procs:
+ self.task_queue.put(AsyncPredictor._StopToken())
+
+ @property
+ def default_buffer_size(self):
+ return len(self.procs) * 5
diff --git a/eval.sh b/eval.sh
new file mode 100644
index 0000000000000000000000000000000000000000..450a72857d3e7eff0e81d933cc4e95378f90e086
--- /dev/null
+++ b/eval.sh
@@ -0,0 +1,100 @@
+#!/bin/sh
+
+gpus=4
+config=$1
+output=$2
+
+if [ -z $config ]
+then
+ echo "No config file found! Run with "sh run.sh [CONFIG_FILE] [OUTPUT_DIR] [OPTS]""
+ exit 0
+fi
+
+if [ -z $output ]
+then
+ echo "No output directory found! Run with "sh run.sh [CONFIG_FILE] [OUTPUT_DIR] [OPTS]""
+ exit 0
+fi
+
+shift 2
+opts=${@}
+
+#ADE20k-150
+python train_net.py --config $config \
+ --num-gpus $gpus \
+ --dist-url "auto" \
+ --eval-only \
+ OUTPUT_DIR $output/eval \
+ MODEL.SEM_SEG_HEAD.TEST_CLASS_JSON "datasets/ADE_20k/ADE20K_150_class.json" \
+ DATASETS.TEST \(\"ade20k_150_test_sem_seg\"\,\) \
+ TEST.SLIDING_WINDOW "True" \
+ MODEL.SEM_SEG_HEAD.POOLING_SIZES "[1,1]" \
+ MODEL.WEIGHTS $output/model_final.pth \
+ $opts
+
+#ADE20k-847
+python train_net.py --config $config \
+ --num-gpus $gpus \
+ --dist-url "auto" \
+ --eval-only \
+ OUTPUT_DIR $output/eval \
+ MODEL.SEM_SEG_HEAD.TEST_CLASS_JSON "datasets/ADE_20k/ADE20K_847_pure_class.json" \
+ DATASETS.TEST \(\"ade20k_full_sem_seg_freq_val_all\"\,\) \
+ TEST.SLIDING_WINDOW "True" \
+ MODEL.SEM_SEG_HEAD.POOLING_SIZES "[1,1]" \
+ MODEL.WEIGHTS $output/model_final.pth \
+ $opts
+
+#Pascal VOC
+python train_net.py --config $config \
+ --num-gpus $gpus \
+ --dist-url "auto" \
+ --eval-only \
+ OUTPUT_DIR $output/eval \
+ MODEL.SEM_SEG_HEAD.TEST_CLASS_JSON "datasets/pascal-voc20/VOC_20_class.json" \
+ DATASETS.TEST \(\"voc_2012_test_sem_seg\"\,\) \
+ TEST.SLIDING_WINDOW "True" \
+ MODEL.SEM_SEG_HEAD.POOLING_SIZES "[1,1]" \
+ MODEL.WEIGHTS $output/model_final.pth \
+ $opts
+
+#Pascal VOC-b
+python train_net.py --config $config \
+ --num-gpus $gpus \
+ --dist-url "auto" \
+ --eval-only \
+ OUTPUT_DIR $output/eval \
+ MODEL.SEM_SEG_HEAD.TEST_CLASS_JSON "datasets/pascal-voc20/VOC_20_class_59.json" \
+ DATASETS.TEST \(\"voc_2012_test_openseg_sem_seg\"\,\) \
+ TEST.SLIDING_WINDOW "True" \
+ MODEL.SEM_SEG_HEAD.POOLING_SIZES "[1,1]" \
+ MODEL.WEIGHTS $output/model_final.pth \
+ $opts
+
+#Pascal Context 59
+python train_net.py --config $config \
+ --num-gpus $gpus \
+ --dist-url "auto" \
+ --eval-only \
+ OUTPUT_DIR $output/eval \
+ MODEL.SEM_SEG_HEAD.TEST_CLASS_JSON "datasets/pascal-context/pas59.json" \
+ DATASETS.TEST \(\"context_59_test_sem_seg\"\,\) \
+ TEST.SLIDING_WINDOW "True" \
+ MODEL.SEM_SEG_HEAD.POOLING_SIZES "[1,1]" \
+ MODEL.WEIGHTS $output/model_final.pth \
+ $opts
+
+#Pascal Context 459
+python train_net.py --config $config \
+ --num-gpus $gpus \
+ --dist-url "auto" \
+ --eval-only \
+ OUTPUT_DIR $output/eval \
+ MODEL.SEM_SEG_HEAD.TEST_CLASS_JSON "datasets/pascal-context/pas459.json" \
+ DATASETS.TEST \(\"context_459_test_sem_seg\"\,\) \
+ TEST.SLIDING_WINDOW "True" \
+ MODEL.SEM_SEG_HEAD.POOLING_SIZES "[1,1]" \
+ MODEL.WEIGHTS $output/model_final.pth \
+ $opts
+
+cat $output/eval/log.txt | grep copypaste
\ No newline at end of file
diff --git a/model_final_cls.pth b/model_final_cls.pth
new file mode 100644
index 0000000000000000000000000000000000000000..6d9120a7d5e6a237f9b4fabacb3250133451f3f7
--- /dev/null
+++ b/model_final_cls.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:78db6c99b54b356cec8e7a661e3c4b75a49b6a89630a9e6aba387c4e47fc36c7
+size 2799748994
diff --git a/open_clip/.gitignore b/open_clip/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..960651015af97b2245b5620fa282931948e3cff3
--- /dev/null
+++ b/open_clip/.gitignore
@@ -0,0 +1,153 @@
+logs/
+wandb/
+models/
+features/
+results/
+
+tests/data/
+*.pt
+
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+pip-wheel-metadata/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+.python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+sync.sh
+gpu1sync.sh
+.idea
+*.pdf
+**/._*
+**/*DS_*
+**.jsonl
+src/sbatch
+src/misc
+.vscode
+src/debug
+core.*
+
+# Allow
+!src/evaluation/misc/results_dbs/*
\ No newline at end of file
diff --git a/open_clip/CITATION.cff b/open_clip/CITATION.cff
new file mode 100644
index 0000000000000000000000000000000000000000..1072ddd3a6065bbf88346c2c1d6ce7681363fab8
--- /dev/null
+++ b/open_clip/CITATION.cff
@@ -0,0 +1,33 @@
+cff-version: 1.1.0
+message: If you use this software, please cite it as below.
+authors:
+ - family-names: Ilharco
+ given-names: Gabriel
+ - family-names: Wortsman
+ given-names: Mitchell
+ - family-names: Wightman
+ given-names: Ross
+ - family-names: Gordon
+ given-names: Cade
+ - family-names: Carlini
+ given-names: Nicholas
+ - family-names: Taori
+ given-names: Rohan
+ - family-names: Dave
+ given-names: Achal
+ - family-names: Shankar
+ given-names: Vaishaal
+ - family-names: Namkoong
+ given-names: Hongseok
+ - family-names: Miller
+ given-names: John
+ - family-names: Hajishirzi
+ given-names: Hannaneh
+ - family-names: Farhadi
+ given-names: Ali
+ - family-names: Schmidt
+ given-names: Ludwig
+title: OpenCLIP
+version: v0.1
+doi: 10.5281/zenodo.5143773
+date-released: 2021-07-28
diff --git a/open_clip/HISTORY.md b/open_clip/HISTORY.md
new file mode 100644
index 0000000000000000000000000000000000000000..485bd346d0b55e876f637cc7359b401f54a90dbf
--- /dev/null
+++ b/open_clip/HISTORY.md
@@ -0,0 +1,110 @@
+## 2.10.1
+
+* `hf-hub:org/model_id` support for loading models w/ config and weights in Hugging Face Hub
+
+## 2.10.0
+
+* Added a ViT-bigG-14 model.
+* Added an up-to-date example slurm script for large training jobs.
+* Added a option to sync logs and checkpoints to S3 during training.
+* New options for LR schedulers, constant and constant with cooldown
+* Fix wandb autoresuming when resume is not set
+* ConvNeXt `base` & `base_w` pretrained models added
+* `timm-` model prefix removed from configs
+* `timm` augmentation + regularization (dropout / drop-path) supported
+
+## 2.9.3
+
+* Fix wandb collapsing multiple parallel runs into a single one
+
+## 2.9.2
+
+* Fix braceexpand memory explosion for complex webdataset urls
+
+## 2.9.1
+
+* Fix release
+
+## 2.9.0
+
+* Add training feature to auto-resume from the latest checkpoint on restart via `--resume latest`
+* Allow webp in webdataset
+* Fix logging for number of samples when using gradient accumulation
+* Add model configs for convnext xxlarge
+
+## 2.8.2
+
+* wrapped patchdropout in a torch.nn.Module
+
+## 2.8.1
+
+* relax protobuf dependency
+* override the default patch dropout value in 'vision_cfg'
+
+## 2.8.0
+
+* better support for HF models
+* add support for gradient accumulation
+* CI fixes
+* add support for patch dropout
+* add convnext configs
+
+
+## 2.7.0
+
+* add multilingual H/14 xlm roberta large
+
+## 2.6.1
+
+* fix setup.py _read_reqs
+
+## 2.6.0
+
+* Make openclip training usable from pypi.
+* Add xlm roberta large vit h 14 config.
+
+## 2.5.0
+
+* pretrained B/32 xlm roberta base: first multilingual clip trained on laion5B
+* pretrained B/32 roberta base: first clip trained using an HF text encoder
+
+## 2.4.1
+
+* Add missing hf_tokenizer_name in CLIPTextCfg.
+
+## 2.4.0
+
+* Fix #211, missing RN50x64 config. Fix type of dropout param for ResNet models
+* Bring back LayerNorm impl that casts to input for non bf16/fp16
+* zero_shot.py: set correct tokenizer based on args
+* training/params.py: remove hf params and get them from model config
+
+## 2.3.1
+
+* Implement grad checkpointing for hf model.
+* custom_text: True if hf_model_name is set
+* Disable hf tokenizer parallelism
+
+## 2.3.0
+
+* Generalizable Text Transformer with HuggingFace Models (@iejMac)
+
+## 2.2.0
+
+* Support for custom text tower
+* Add checksum verification for pretrained model weights
+
+## 2.1.0
+
+* lot including sota models, bfloat16 option, better loading, better metrics
+
+## 1.2.0
+
+* ViT-B/32 trained on Laion2B-en
+* add missing openai RN50x64 model
+
+## 1.1.1
+
+* ViT-B/16+
+* Add grad checkpointing support
+* more robust data loader
diff --git a/open_clip/LICENSE b/open_clip/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..5bfbf6c09daad743dbf9a98d303c0402e4099a27
--- /dev/null
+++ b/open_clip/LICENSE
@@ -0,0 +1,23 @@
+Copyright (c) 2012-2021 Gabriel Ilharco, Mitchell Wortsman,
+Nicholas Carlini, Rohan Taori, Achal Dave, Vaishaal Shankar,
+John Miller, Hongseok Namkoong, Hannaneh Hajishirzi, Ali Farhadi,
+Ludwig Schmidt
+
+Permission is hereby granted, free of charge, to any person obtaining
+a copy of this software and associated documentation files (the
+"Software"), to deal in the Software without restriction, including
+without limitation the rights to use, copy, modify, merge, publish,
+distribute, sublicense, and/or sell copies of the Software, and to
+permit persons to whom the Software is furnished to do so, subject to
+the following conditions:
+
+The above copyright notice and this permission notice shall be
+included in all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
+LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
+OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
+WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
diff --git a/open_clip/MANIFEST.in b/open_clip/MANIFEST.in
new file mode 100644
index 0000000000000000000000000000000000000000..c74de18e62cf8fe3b8fa777195f7d38c90b13380
--- /dev/null
+++ b/open_clip/MANIFEST.in
@@ -0,0 +1,3 @@
+include src/open_clip/bpe_simple_vocab_16e6.txt.gz
+include src/open_clip/model_configs/*.json
+
diff --git a/open_clip/Makefile b/open_clip/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..ff07eccefed3d959c77d007d2571e226a07ace60
--- /dev/null
+++ b/open_clip/Makefile
@@ -0,0 +1,12 @@
+install: ## [Local development] Upgrade pip, install requirements, install package.
+ python -m pip install -U pip
+ python -m pip install -e .
+
+install-training:
+ python -m pip install -r requirements-training.txt
+
+install-test: ## [Local development] Install test requirements
+ python -m pip install -r requirements-test.txt
+
+test: ## [Local development] Run unit tests
+ python -m pytest -x -s -v tests
diff --git a/open_clip/README.md b/open_clip/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..c69d288d57c670c27374cfbdbb0c34ef1a79df3f
--- /dev/null
+++ b/open_clip/README.md
@@ -0,0 +1,635 @@
+# OpenCLIP
+
+[[Paper]](https://arxiv.org/abs/2212.07143) [[Colab]](https://colab.research.google.com/github/mlfoundations/open_clip/blob/master/docs/Interacting_with_open_clip.ipynb)
+[![pypi](https://img.shields.io/pypi/v/open_clip_torch.svg)](https://pypi.python.org/pypi/open_clip_torch)
+
+Welcome to an open source implementation of OpenAI's [CLIP](https://arxiv.org/abs/2103.00020) (Contrastive Language-Image Pre-training).
+
+The goal of this repository is to enable training models with contrastive image-text supervision, and to investigate their properties such as robustness to distribution shift. Our starting point is an implementation of CLIP that matches the accuracy of the original CLIP models when trained on the same dataset.
+Specifically, a ResNet-50 model trained with our codebase on OpenAI's [15 million image subset of YFCC](https://github.com/openai/CLIP/blob/main/data/yfcc100m.md) achieves **32.7%** top-1 accuracy on ImageNet. OpenAI's CLIP model reaches **31.3%** when trained on the same subset of YFCC. For ease of experimentation, we also provide code for training on the 3 million images in the [Conceptual Captions](https://ai.google.com/research/ConceptualCaptions/download) dataset, where a ResNet-50x4 trained with our codebase reaches 22.2% top-1 ImageNet accuracy.
+
+We further this with a replication study on a dataset of comparable size to OpenAI's, [LAION-400M](https://arxiv.org/abs/2111.02114), and with the larger [LAION-2B](https://laion.ai/blog/laion-5b/) superset. In addition, we study scaling behavior in a paper on [reproducible scaling laws for contrastive language-image learning](https://arxiv.org/abs/2212.07143).
+
+We have trained:
+ * ViT-B/32 on LAION-400M with a accuracy of **62.9%**, comparable to OpenAI's **63.2%**, zero-shot top-1 on ImageNet1k
+ * ViT-B/32 on LAION-2B with a accuracy of **66.6%**.
+ * ViT-B/16 on LAION-400M achieving an accuracy of **67.1%**, lower than OpenAI's **68.3%** (as measured here, 68.6% in paper)
+ * ViT-B/16+ 240x240 (~50% more FLOPS than B/16 224x224) on LAION-400M achieving an accuracy of **69.2%**
+ * ViT-L/14 on LAION-400M with an accuracy of **72.77%**, vs OpenAI's **75.5%** (as measured here, 75.3% in paper)
+ * ViT-L/14 on LAION-2B with an accuracy of **75.3%**, vs OpenAI's **75.5%** (as measured here, 75.3% in paper)
+ * ViT-H/14 on LAION-2B with an accuracy of **78.0**. The second best in1k zero-shot for released, open-source weights thus far.
+ * ViT-g/14 on LAION-2B with an accuracy of **76.6**. This was trained on reduced schedule, same samples seen as 400M models.
+ * ViT-G/14 on LAION-2B with an accuracy of **80.1**. The best in1k zero-shot for released, open-source weights thus far.
+
+As we describe in more detail [below](#why-are-low-accuracy-clip-models-interesting), CLIP models in a medium accuracy regime already allow us to draw conclusions about the robustness of larger CLIP models since the models follow [reliable scaling laws](https://arxiv.org/abs/2107.04649).
+
+This codebase is work in progress, and we invite all to contribute in making it more accessible and useful. In the future, we plan to add support for TPU training and release larger models. We hope this codebase facilitates and promotes further research in contrastive image-text learning. Please submit an issue or send an email if you have any other requests or suggestions.
+
+Note that portions of `src/open_clip/` modelling and tokenizer code are adaptations of OpenAI's official [repository](https://github.com/openai/CLIP).
+
+## Approach
+
+| ![CLIP](https://raw.githubusercontent.com/mlfoundations/open_clip/main/docs/CLIP.png) |
+|:--:|
+| Image Credit: https://github.com/openai/CLIP |
+
+## Usage
+
+```
+pip install open_clip_torch
+```
+
+```python
+import torch
+from PIL import Image
+import open_clip
+
+model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32-quickgelu', pretrained='laion400m_e32')
+tokenizer = open_clip.get_tokenizer('ViT-B-32-quickgelu')
+
+image = preprocess(Image.open("CLIP.png")).unsqueeze(0)
+text = tokenizer(["a diagram", "a dog", "a cat"])
+
+with torch.no_grad(), torch.cuda.amp.autocast():
+ image_features = model.encode_image(image)
+ text_features = model.encode_text(text)
+ image_features /= image_features.norm(dim=-1, keepdim=True)
+ text_features /= text_features.norm(dim=-1, keepdim=True)
+
+ text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
+
+print("Label probs:", text_probs) # prints: [[1., 0., 0.]]
+```
+
+To compute billions of embeddings efficiently, you can use [clip-retrieval](https://github.com/rom1504/clip-retrieval) which has openclip support.
+
+## Fine-tuning on classification tasks
+
+This repository is focused on training CLIP models. To fine-tune a *trained* zero-shot model on a downstream classification task such as ImageNet, please see [our other repository: WiSE-FT](https://github.com/mlfoundations/wise-ft). The [WiSE-FT repository](https://github.com/mlfoundations/wise-ft) contains code for our paper on [Robust Fine-tuning of Zero-shot Models](https://arxiv.org/abs/2109.01903), in which we introduce a technique for fine-tuning zero-shot models while preserving robustness under distribution shift.
+
+## Data
+
+To download datasets as webdataset, we recommend [img2dataset](https://github.com/rom1504/img2dataset)
+
+### Conceptual Captions
+
+See [cc3m img2dataset example](https://github.com/rom1504/img2dataset/blob/main/dataset_examples/cc3m.md)
+
+### YFCC and other datasets
+
+In addition to specifying the training data via CSV files as mentioned above, our codebase also supports [webdataset](https://github.com/webdataset/webdataset), which is recommended for larger scale datasets. The expected format is a series of `.tar` files. Each of these `.tar` files should contain two files for each training example, one for the image and one for the corresponding text. Both files should have the same name but different extensions. For instance, `shard_001.tar` could contain files such as `abc.jpg` and `abc.txt`. You can learn more about `webdataset` at [https://github.com/webdataset/webdataset](https://github.com/webdataset/webdataset). We use `.tar` files with 1,000 data points each, which we create using [tarp](https://github.com/webdataset/tarp).
+
+You can download the YFCC dataset from [Multimedia Commons](http://mmcommons.org/).
+Similar to OpenAI, we used a subset of YFCC to reach the aforementioned accuracy numbers.
+The indices of images in this subset are in [OpenAI's CLIP repository](https://github.com/openai/CLIP/blob/main/data/yfcc100m.md).
+
+
+## Training CLIP
+
+### Install
+
+We advise you first create a virtual environment with:
+
+```
+python3 -m venv .env
+source .env/bin/activate
+pip install -U pip
+```
+
+You can then install openclip for training with `pip install 'open_clip_torch[training]'`.
+
+#### Development
+
+If you want to make changes to contribute code, you can close openclip then run `make install` in openclip folder (after creating a virtualenv)
+
+Install pip PyTorch as per https://pytorch.org/get-started/locally/
+
+You may run `make install-training` to install training deps
+
+#### Testing
+
+Test can be run with `make install-test` then `make test`
+
+`python -m pytest -x -s -v tests -k "training"` to run a specific test
+
+Running regression tests against a specific git revision or tag:
+1. Generate testing data
+ ```sh
+ python tests/util_test.py --model RN50 RN101 --save_model_list models.txt --git_revision 9d31b2ec4df6d8228f370ff20c8267ec6ba39383
+ ```
+ **_WARNING_: This will invoke git and modify your working tree, but will reset it to the current state after data has been generated! \
+ Don't modify your working tree while test data is being generated this way.**
+
+2. Run regression tests
+ ```sh
+ OPEN_CLIP_TEST_REG_MODELS=models.txt python -m pytest -x -s -v -m regression_test
+ ```
+
+### Sample single-process running code:
+
+```bash
+python -m training.main \
+ --save-frequency 1 \
+ --zeroshot-frequency 1 \
+ --report-to tensorboard \
+ --train-data="/path/to/train_data.csv" \
+ --val-data="/path/to/validation_data.csv" \
+ --csv-img-key filepath \
+ --csv-caption-key title \
+ --imagenet-val=/path/to/imagenet/root/val/ \
+ --warmup 10000 \
+ --batch-size=128 \
+ --lr=1e-3 \
+ --wd=0.1 \
+ --epochs=30 \
+ --workers=8 \
+ --model RN50
+```
+
+Note: `imagenet-val` is the path to the *validation* set of ImageNet for zero-shot evaluation, not the training set!
+You can remove this argument if you do not want to perform zero-shot evaluation on ImageNet throughout training. Note that the `val` folder should contain subfolders. If it doest not, please use [this script](https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh).
+
+### Multi-GPU and Beyond
+
+This code has been battle tested up to 1024 A100s and offers a variety of solutions
+for distributed training. We include native support for SLURM clusters.
+
+As the number of devices used to train increases, so does the space complexity of
+the the logit matrix. Using a naïve all-gather scheme, space complexity will be
+`O(n^2)`. Instead, complexity may become effectively linear if the flags
+`--gather-with-grad` and `--local-loss` are used. This alteration results in one-to-one
+numerical results as the naïve method.
+
+#### Epochs
+
+For larger datasets (eg Laion2B), we recommend setting --train-num-samples to a lower value than the full epoch, for example `--train-num-samples 135646078` to 1/16 of an epoch in conjunction with --dataset-resampled to do sampling with replacement. This allows having frequent checkpoints to evaluate more often.
+
+#### Patch Dropout
+
+Recent research has shown that one can dropout half to three-quarters of the visual tokens, leading to up to 2-3x training speeds without loss of accuracy.
+
+You can set this on your visual transformer config with the key `patch_dropout`.
+
+In the paper, they also finetuned without the patch dropout at the end. You can do this with the command-line argument `--force-patch-dropout 0.`
+
+#### Single-Node
+
+We make use of `torchrun` to launch distributed jobs. The following launches a
+a job on a node of 4 GPUs:
+
+```bash
+cd open_clip/src
+torchrun --nproc_per_node 4 -m training.main \
+ --train-data '/data/cc12m/cc12m-train-{0000..2175}.tar' \
+ --train-num-samples 10968539 \
+ --dataset-type webdataset \
+ --batch-size 320 \
+ --precision amp \
+ --workers 4 \
+ --imagenet-val /data/imagenet/validation/
+```
+
+#### Multi-Node
+
+The same script above works, so long as users include information about the number
+of nodes and host node.
+
+```bash
+cd open_clip/src
+torchrun --nproc_per_node=4 \
+ --rdzv_endpoint=$HOSTE_NODE_ADDR \
+ -m training.main \
+ --train-data '/data/cc12m/cc12m-train-{0000..2175}.tar' \
+ --train-num-samples 10968539 \
+ --dataset-type webdataset \
+ --batch-size 320 \
+ --precision amp \
+ --workers 4 \
+ --imagenet-val /data/imagenet/validation/
+```
+
+#### SLURM
+
+This is likely the easiest solution to utilize. The following script was used to
+train our largest models:
+
+```bash
+#!/bin/bash -x
+#SBATCH --nodes=32
+#SBATCH --gres=gpu:4
+#SBATCH --ntasks-per-node=4
+#SBATCH --cpus-per-task=6
+#SBATCH --wait-all-nodes=1
+#SBATCH --job-name=open_clip
+#SBATCH --account=ACCOUNT_NAME
+#SBATCH --partition PARTITION_NAME
+
+eval "$(/path/to/conda/bin/conda shell.bash hook)" # init conda
+conda activate open_clip
+export CUDA_VISIBLE_DEVICES=0,1,2,3
+export MASTER_PORT=12802
+
+master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
+export MASTER_ADDR=$master_addr
+
+cd /shared/open_clip
+export PYTHONPATH="$PYTHONPATH:$PWD/src"
+srun --cpu_bind=v --accel-bind=gn python -u src/training/main.py \
+ --save-frequency 1 \
+ --report-to tensorboard \
+ --train-data="/data/LAION-400M/{00000..41455}.tar" \
+ --warmup 2000 \
+ --batch-size=256 \
+ --epochs=32 \
+ --workers=8 \
+ --model ViT-B-32 \
+ --name "ViT-B-32-Vanilla" \
+ --seed 0 \
+ --local-loss \
+ --gather-with-grad
+```
+
+### Resuming from a checkpoint:
+
+```bash
+python -m training.main \
+ --train-data="/path/to/train_data.csv" \
+ --val-data="/path/to/validation_data.csv" \
+ --resume /path/to/checkpoints/epoch_K.pt
+```
+
+### Training with pre-trained language models as text encoder:
+
+If you wish to use different language models as the text encoder for CLIP you can do so by using one of the Hugging Face model configs in ```src/open_clip/model_configs``` and passing in it's tokenizer as the ```--model``` and ```--hf-tokenizer-name``` parameters respectively. Currently we only support RoBERTa ("test-roberta" config), however adding new models should be trivial. You can also determine how many layers, from the end, to leave unfrozen with the ```--lock-text-unlocked-layers``` parameter. Here's an example command to train CLIP with the RoBERTa LM that has it's last 10 layers unfrozen:
+```bash
+python -m training.main \
+ --train-data="pipe:aws s3 cp s3://s-mas/cc3m/{00000..00329}.tar -" \
+ --train-num-samples 3000000 \
+ --val-data="pipe:aws s3 cp s3://s-mas/cc3m/{00330..00331}.tar -" \
+ --val-num-samples 10000 \
+ --dataset-type webdataset \
+ --batch-size 256 \
+ --warmup 2000 \
+ --epochs 10 \
+ --lr 5e-4 \
+ --precision amp \
+ --workers 6 \
+ --model "roberta-ViT-B-32" \
+ --lock-text \
+ --lock-text-unlocked-layers 10 \
+ --name "10_unfrozen" \
+ --report-to "tensorboard" \
+```
+
+### Loss Curves
+
+When run on a machine with 8 GPUs the command should produce the following training curve for Conceptual Captions:
+
+![CLIP zero shot training curve](https://raw.githubusercontent.com/mlfoundations/open_clip/main/docs/clip_zeroshot.png)
+
+More detailed curves for Conceptual Captions are given at [/docs/clip_conceptual_captions.md](/docs/clip_conceptual_captions.md).
+
+When training a RN50 on YFCC the same hyperparameters as above are used, with the exception of `lr=5e-4` and `epochs=32`.
+
+Note that to use another model, like `ViT-B/32` or `RN50x4` or `RN50x16` or `ViT-B/16`, specify with `--model RN50x4`.
+
+### Launch tensorboard:
+```bash
+tensorboard --logdir=logs/tensorboard/ --port=7777
+```
+
+## Evaluation / Zero-Shot
+
+### Evaluating local checkpoint:
+
+```bash
+python -m training.main \
+ --val-data="/path/to/validation_data.csv" \
+ --model RN101 \
+ --pretrained /path/to/checkpoints/epoch_K.pt
+```
+
+### Evaluating hosted pretrained checkpoint on ImageNet zero-shot prediction:
+
+```bash
+python -m training.main \
+ --imagenet-val /path/to/imagenet/validation \
+ --model ViT-B-32-quickgelu \
+ --pretrained laion400m_e32
+```
+
+## Pretrained model details
+
+### LAION-400M - https://laion.ai/laion-400-open-dataset
+
+We are working on reproducing OpenAI's ViT results with the comparably sized (and open) LAION-400M dataset. Trained
+weights may be found in release [v0.2](https://github.com/mlfoundations/open_clip/releases/tag/v0.2-weights).
+
+The LAION400M weights have been trained on the JUWELS supercomputer (see acknowledgements section below).
+
+#### ViT-B/32 224x224
+
+We replicate OpenAI's results on ViT-B/32, reaching a top-1 ImageNet-1k zero-shot accuracy of 62.96%.
+
+
+
+__Zero-shot comparison (courtesy of Andreas Fürst)__
+
+
+ViT-B/32 was trained with 128 A100 (40 GB) GPUs for ~36 hours, 4600 GPU-hours. The per-GPU batch size was 256 for a global batch size of 32768. 256 is much lower than it could have been (~320-384) due to being sized initially before moving to 'local' contrastive loss.
+
+#### ViT-B/16 224x224
+
+The B/16 LAION400M training reached a top-1 ImageNet-1k zero-shot validation score of 67.07.
+
+
+
+This was the first major train session using the updated webdataset 0.2.x code. A bug was found that prevented shards from being shuffled properly between nodes/workers each epoch. This was fixed part way through training (epoch 26) but likely had an impact.
+
+ViT-B/16 was trained with 176 A100 (40 GB) GPUS for ~61 hours, 10700 GPU-hours. Batch size per GPU was 192 for a global batch size of 33792.
+
+#### ViT-B/16+ 240x240
+
+The B/16+ 240x240 LAION400M training reached a top-1 ImageNet-1k zero-shot validation score of 69.21.
+
+This model is the same depth as the B/16, but increases the
+ * vision width from 768 -> 896
+ * text width from 512 -> 640
+ * the resolution 224x224 -> 240x240 (196 -> 225 tokens)
+
+
+
+Unlike the B/16 run above, this model was a clean run with no dataset shuffling issues.
+
+ViT-B/16+ was trained with 224 A100 (40 GB) GPUS for ~61 hours, 13620 GPU-hours. Batch size per GPU was 160 for a global batch size of 35840.
+
+#### ViT-L/14 224x224
+
+The L/14 LAION-400M training reached a top-1 ImageNet-1k zero-shot validation score of 72.77.
+
+
+
+ViT-L/14 was trained with 400 A100 (40 GB) GPUS for ~127 hours, 50800 GPU-hours. Batch size per GPU was 96 for a global batch size of 38400. Grad checkpointing was enabled.
+
+### LAION-2B (en) - https://laion.ai/laion-5b-a-new-era-of-open-large-scale-multi-modal-datasets/
+
+A ~2B sample subset of LAION-5B with english captions (https://huggingface.co/datasets/laion/laion2B-en)
+
+#### ViT-B/32 224x224
+A ViT-B/32 trained on LAION-2B, reaching a top-1 ImageNet-1k zero-shot accuracy of 65.62%.
+
+
+
+ViT-B/32 was trained with 112 A100 (40 GB) GPUs. The per-GPU batch size was 416 for a global batch size of 46592. Compute generously provided by [stability.ai](https://stability.ai/).
+
+A second iteration of B/32 was trained on stability.ai cluster with a larger global batch size and learning rate, hitting 66.6% top-1. See https://huggingface.co/laion/CLIP-ViT-B-32-laion2B-s34B-b79K
+
+#### ViT-L/14 224x224
+
+A ViT-L/14 with a 75.3% top-1 ImageNet-1k zero-shot was trained on JUWELS Booster. See model details here https://huggingface.co/laion/CLIP-ViT-L-14-laion2B-s32B-b82K
+
+These weights use a different dataset mean and std than others. Instead of using the OpenAI mean & std, inception style normalization `[-1, 1]` is used via a mean and std of `[0.5, 0.5, 0.5]`. This is handled automatically if using `open_clip.create_model_and_transforms` from pretrained weights.
+
+#### ViT-H/14 224x224
+
+A ViT-H/14 with a 78.0% top-1 ImageNet-1k zero-shot was trained on JUWELS Booster. See model details here https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K
+
+#### ViT-g/14 224x224
+
+A ViT-g/14 with a 76.6% top-1 ImageNet-1k zero-shot was trained on JUWELS Booster. See model details here https://huggingface.co/laion/CLIP-ViT-g-14-laion2B-s12B-b42K
+
+This model was trained with a shorted schedule than other LAION-2B models with 12B samples seen instead of 32+B. It matches LAION-400M training in samples seen. Many zero-shot results are lower as a result, but despite this it performs very well in some OOD zero-shot and retrieval tasks.
+
+
+#### ViT-B/32 roberta base
+
+A ViT-B/32 with roberta base encoder with a 61.7% top-1 ImageNet-1k zero-shot was trained on stability. See model details here https://huggingface.co/laion/CLIP-ViT-B-32-roberta-base-laion2B-s12B-b32k
+This is the first openclip model using a HF text tower. It has better performance on a range of tasks compared to the standard text encoder, see [metrics](https://huggingface.co/laion/CLIP-ViT-B-32-roberta-base-laion2B-s12B-b32k/blob/main/unknown.png)
+
+#### ViT-B/32 xlm roberta base
+
+A ViT-B/32 with xlm roberta base encoder with a 62.33% top-1 ImageNet-1k zero-shot was trained on stability. See model details here https://huggingface.co/laion/CLIP-ViT-B-32-xlm-roberta-base-laion5B-s13B-b90k
+This is the first openclip model trained on the full laion5B dataset; hence the first multilingual clip trained with openclip. It has better performance on a range of tasks compared to the standard text encoder, see [metrics](https://huggingface.co/laion/CLIP-ViT-B-32-xlm-roberta-base-laion5B-s13B-b90k/blob/main/metrics.png)
+A preliminary multilingual evaluation was run: 43% on imagenet1k italian (vs 21% for english B/32), 37% for imagenet1k japanese (vs 1% for english B/32 and 50% for B/16 clip japanese). It shows the multilingual property is indeed there as expected. Larger models will get even better performance.
+
+#### ViT-H/14 xlm roberta large
+
+A ViT-H/14 with xlm roberta large encoder with a 77.0% (vs 78% for the english equivalent) top-1 ImageNet-1k zero-shot was trained on stability. See model details here https://huggingface.co/laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k
+
+This model was trained following the [LiT](https://arxiv.org/abs/2111.07991) methodology: the image tower was frozen (initialized from english openclip ViT-H/14), the text tower was initialized from [xlm roberta large](https://huggingface.co/xlm-roberta-large) and unfrozen. This reduced training cost by a 3x factor.
+
+See full english [metrics](https://huggingface.co/laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k/resolve/main/results_xlm_roberta_large.png)
+
+On zero shot classification on imagenet with translated prompts this model reaches:
+
+* 56% in italian (vs 21% for https://github.com/clip-italian/clip-italian)
+* 53% in japanese (vs 54.6% for https://github.com/rinnakk/japanese-clip)
+* 55.7% in chinese (to be compared with https://github.com/OFA-Sys/Chinese-CLIP)
+
+
+#### YFCC-15M
+
+Below are checkpoints of models trained on YFCC-15M, along with their zero-shot top-1 accuracies on ImageNet and ImageNetV2. These models were trained using 8 GPUs and the same hyperparameters described in the "Sample running code" section, with the exception of `lr=5e-4` and `epochs=32`.
+
+* [ResNet-50](https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt) (32.7% / 27.9%)
+* [ResNet-101](https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt) (34.8% / 30.0%)
+
+#### CC12M - https://github.com/google-research-datasets/conceptual-12m
+
+* [ResNet-50](https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt) (36.45%)
+
+### Pretrained Model Interface
+
+We offer a simple model interface to instantiate both pre-trained and untrained models.
+
+NOTE: Many existing checkpoints use the QuickGELU activation from the original OpenAI models. This activation is actually less efficient than native torch.nn.GELU in recent versions of PyTorch. The model defaults are now nn.GELU, so one should use model definitions with `-quickgelu` postfix for the OpenCLIP pretrained weights. All OpenAI pretrained weights will always default to QuickGELU. One can also use the non `-quickgelu` model definitions with pretrained weights using QuickGELU but there will be an accuracy drop, for fine-tune that will likely vanish for longer runs.
+
+Future trained models will use nn.GELU.
+
+```python
+>>> import open_clip
+>>> open_clip.list_pretrained()
+[('RN50', 'openai'),
+ ('RN50', 'yfcc15m'),
+ ('RN50', 'cc12m'),
+ ('RN50-quickgelu', 'openai'),
+ ('RN50-quickgelu', 'yfcc15m'),
+ ('RN50-quickgelu', 'cc12m'),
+ ('RN101', 'openai'),
+ ('RN101', 'yfcc15m'),
+ ('RN101-quickgelu', 'openai'),
+ ('RN101-quickgelu', 'yfcc15m'),
+ ('RN50x4', 'openai'),
+ ('RN50x16', 'openai'),
+ ('RN50x64', 'openai'),
+ ('ViT-B-32', 'openai'),
+ ('ViT-B-32', 'laion400m_e31'),
+ ('ViT-B-32', 'laion400m_e32'),
+ ('ViT-B-32', 'laion2b_e16'),
+ ('ViT-B-32', 'laion2b_s34b_b79k'),
+ ('ViT-B-32-quickgelu', 'openai'),
+ ('ViT-B-32-quickgelu', 'laion400m_e31'),
+ ('ViT-B-32-quickgelu', 'laion400m_e32'),
+ ('ViT-B-16', 'openai'),
+ ('ViT-B-16', 'laion400m_e31'),
+ ('ViT-B-16', 'laion400m_e32'),
+ ('ViT-B-16-plus-240', 'laion400m_e31'),
+ ('ViT-B-16-plus-240', 'laion400m_e32'),
+ ('ViT-L-14', 'openai'),
+ ('ViT-L-14', 'laion400m_e31'),
+ ('ViT-L-14', 'laion400m_e32'),
+ ('ViT-L-14', 'laion2b_s32b_b82k'),
+ ('ViT-L-14-336', 'openai'),
+ ('ViT-H-14', 'laion2b_s32b_b79k'),
+ ('ViT-g-14', 'laion2b_s12b_b42k'),
+ ('ViT-bigG-14', 'laion2b_s39b_b160k'),
+ ('roberta-ViT-B-32', 'laion2b_s12b_b32k'),
+ ('xlm-roberta-base-ViT-B-32', 'laion5b_s13b_b90k'),
+ ('xlm-roberta-large-ViT-H-14', 'frozen_laion5b_s13b_b90k'),]
+
+>>> model, train_transform, eval_transform = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
+```
+### Gradient accumulation
+
+To simulate larger batches use `--accum-freq k`. If per gpu batch size, `--batch-size`, is `m`, then the effective batch size will be `k * m * num_gpus`.
+
+When increasing `--accum-freq` from its default of 1, samples/s will remain approximately constant (batch size will double, as will time-per-batch). It is recommended to use other features to reduce batch size such as `--grad-checkpointing --local-loss --gather-with-grad` before increasing `--accum-freq`. `--accum-freq` can be used in addition to these features.
+
+Instead of 1 forward pass per example, there are now 2 forward passes per-example. However, the first is done with `torch.no_grad`.
+
+There is some additional GPU memory required --- the features and data from all `m` batches are stored in memory.
+
+There are also `m` loss computations instead of the usual 1.
+
+For more information see Cui et al. (https://arxiv.org/abs/2112.09331) or Pham et al. (https://arxiv.org/abs/2111.10050).
+
+### Support for remote loading/training
+
+It is always possible to resume directly from a remote file, e.g., a file in an s3 bucket. Just set `--resume s3:// `.
+This will work with any filesystem supported by `fsspec`.
+
+It is also possible to train `open_clip` models while continuously backing up to s3. This can help to avoid slow local file systems.
+
+Say that your node has a local ssd `/scratch`, an s3 bucket `s3://`.
+
+In that case, set `--logs /scratch` and `--remote-sync s3://`. Then, a background process will sync `/scratch/` to `s3:///`. After syncing, the background process will sleep for `--remote-sync-frequency` seconds, which defaults to 5 minutes.
+
+There is also experimental support for syncing to other remote file systems, not just s3. To do so, specify `--remote-sync-protocol fsspec`. However, this is currently very slow and not recommended.
+
+Also, to optionally avoid saving too many checkpoints locally when using these features, you can use `--delete-previous-checkpoint` which deletes the previous checkpoint after saving a new one.
+
+Note: if you are using this feature with `--resume latest`, there are a few warnings. First, use with `--save-most-recent` is not supported. Second, only `s3` is supported. Finally, since the sync happens in the background, it is possible that the most recent checkpoint may not be finished syncing to the remote.
+
+## Scaling trends
+
+The plot below shows how zero-shot performance of CLIP models varies as we scale the number of samples used for training. Zero-shot performance increases steadily for both ImageNet and [ImageNetV2](https://arxiv.org/abs/1902.10811), and is far from saturated at ~15M samples.
+
+
+
+## Why are low-accuracy CLIP models interesting?
+
+**TL;DR:** CLIP models have high effective robustness, even at small scales.
+
+CLIP models are particularly intriguing because they are more robust to natural distribution shifts (see Section 3.3 in the [CLIP paper](https://arxiv.org/abs/2103.00020)).
+This phenomena is illustrated by the figure below, with ImageNet accuracy on the x-axis
+and [ImageNetV2](https://arxiv.org/abs/1902.10811) (a reproduction of the ImageNet validation set with distribution shift) accuracy on the y-axis.
+Standard training denotes training on the ImageNet train set and the CLIP zero-shot models
+are shown as stars.
+
+![CLIP scatter plot](https://raw.githubusercontent.com/mlfoundations/open_clip/main/docs/effective_robustness.png)
+
+As observed by [Taori et al., 2020](https://arxiv.org/abs/2007.00644) and [Miller et al., 2021](https://arxiv.org/abs/2107.04649), the in-distribution
+and out-of-distribution accuracies of models trained on ImageNet follow a predictable linear trend (the red line in the above plot). *Effective robustness*
+quantifies robustness as accuracy beyond this baseline, i.e., how far a model lies above the red line. Ideally a model would not suffer from distribution shift and fall on the y = x line ([trained human labelers are within a percentage point of the y = x line](http://proceedings.mlr.press/v119/shankar20c.html)).
+
+Even though the CLIP models trained with
+this codebase achieve much lower accuracy than those trained by OpenAI, our models still lie on the same
+trend of improved effective robustness (the purple line). Therefore, we can study what makes
+CLIP robust without requiring industrial-scale compute.
+
+For more information on effective robustness, please see:
+
+- [Recht et al., 2019](https://arxiv.org/abs/1902.10811).
+- [Taori et al., 2020](https://arxiv.org/abs/2007.00644).
+- [Miller et al., 2021](https://arxiv.org/abs/2107.04649).
+
+To know more about the factors that contribute to CLIP's robustness refer to [Fang et al., 2022](https://arxiv.org/abs/2205.01397).
+
+## Acknowledgments
+
+We gratefully acknowledge the Gauss Centre for Supercomputing e.V. (www.gauss-centre.eu) for funding this part of work by providing computing time through the John von Neumann Institute for Computing (NIC) on the GCS Supercomputer JUWELS Booster at Jülich Supercomputing Centre (JSC).
+
+## The Team
+
+Current development of this repository is led by [Ross Wightman](https://rwightman.com/), [Cade Gordon](http://cadegordon.io/), and [Vaishaal Shankar](http://vaishaal.com/).
+
+The original version of this repository is from a group of researchers at UW, Google, Stanford, Amazon, Columbia, and Berkeley.
+
+[Gabriel Ilharco*](http://gabrielilharco.com/), [Mitchell Wortsman*](https://mitchellnw.github.io/), [Nicholas Carlini](https://nicholas.carlini.com/), [Rohan Taori](https://www.rohantaori.com/), [Achal Dave](http://www.achaldave.com/), [Vaishaal Shankar](http://vaishaal.com/), [John Miller](https://people.eecs.berkeley.edu/~miller_john/), [Hongseok Namkoong](https://hsnamkoong.github.io/), [Hannaneh Hajishirzi](https://homes.cs.washington.edu/~hannaneh/), [Ali Farhadi](https://homes.cs.washington.edu/~ali/), [Ludwig Schmidt](https://people.csail.mit.edu/ludwigs/)
+
+Special thanks to [Jong Wook Kim](https://jongwook.kim/) and [Alec Radford](https://github.com/Newmu) for help with reproducing CLIP!
+
+## Citing
+
+If you found this repository useful, please consider citing:
+```bibtex
+@software{ilharco_gabriel_2021_5143773,
+ author = {Ilharco, Gabriel and
+ Wortsman, Mitchell and
+ Wightman, Ross and
+ Gordon, Cade and
+ Carlini, Nicholas and
+ Taori, Rohan and
+ Dave, Achal and
+ Shankar, Vaishaal and
+ Namkoong, Hongseok and
+ Miller, John and
+ Hajishirzi, Hannaneh and
+ Farhadi, Ali and
+ Schmidt, Ludwig},
+ title = {OpenCLIP},
+ month = jul,
+ year = 2021,
+ note = {If you use this software, please cite it as below.},
+ publisher = {Zenodo},
+ version = {0.1},
+ doi = {10.5281/zenodo.5143773},
+ url = {https://doi.org/10.5281/zenodo.5143773}
+}
+```
+
+```bibtex
+@inproceedings{Radford2021LearningTV,
+ title={Learning Transferable Visual Models From Natural Language Supervision},
+ author={Alec Radford and Jong Wook Kim and Chris Hallacy and A. Ramesh and Gabriel Goh and Sandhini Agarwal and Girish Sastry and Amanda Askell and Pamela Mishkin and Jack Clark and Gretchen Krueger and Ilya Sutskever},
+ booktitle={ICML},
+ year={2021}
+}
+```
+
+```bibtex
+@inproceedings{schuhmann2022laionb,
+ title={{LAION}-5B: An open large-scale dataset for training next generation image-text models},
+ author={Christoph Schuhmann and
+ Romain Beaumont and
+ Richard Vencu and
+ Cade W Gordon and
+ Ross Wightman and
+ Mehdi Cherti and
+ Theo Coombes and
+ Aarush Katta and
+ Clayton Mullis and
+ Mitchell Wortsman and
+ Patrick Schramowski and
+ Srivatsa R Kundurthy and
+ Katherine Crowson and
+ Ludwig Schmidt and
+ Robert Kaczmarczyk and
+ Jenia Jitsev},
+ booktitle={Thirty-sixth Conference on Neural Information Processing Systems Datasets and Benchmarks Track},
+ year={2022},
+ url={https://openreview.net/forum?id=M3Y74vmsMcY}
+}
+```
+
+[![DOI](https://zenodo.org/badge/390536799.svg)](https://zenodo.org/badge/latestdoi/390536799)
diff --git a/open_clip/pytest.ini b/open_clip/pytest.ini
new file mode 100644
index 0000000000000000000000000000000000000000..9546b10ce86328ef21697b8d134a6d5865632f35
--- /dev/null
+++ b/open_clip/pytest.ini
@@ -0,0 +1,3 @@
+[pytest]
+markers =
+ regression_test
diff --git a/open_clip/requirements-test.txt b/open_clip/requirements-test.txt
new file mode 100644
index 0000000000000000000000000000000000000000..5d2e7e147ead413905b38e96d6e1f20e05a5816d
--- /dev/null
+++ b/open_clip/requirements-test.txt
@@ -0,0 +1,4 @@
+pytest-split==0.8.0
+pytest==7.2.0
+transformers
+timm==0.6.11
diff --git a/open_clip/requirements-training.txt b/open_clip/requirements-training.txt
new file mode 100644
index 0000000000000000000000000000000000000000..c44eb61d7c2276ccfa7594494663a38079447c83
--- /dev/null
+++ b/open_clip/requirements-training.txt
@@ -0,0 +1,12 @@
+torch>=1.9.0
+torchvision
+webdataset>=0.2.5
+regex
+ftfy
+tqdm
+pandas
+braceexpand
+huggingface_hub
+transformers
+timm
+fsspec
diff --git a/open_clip/requirements.txt b/open_clip/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..c4324e1f91632667f2655a5a2465c6062deec24e
--- /dev/null
+++ b/open_clip/requirements.txt
@@ -0,0 +1,9 @@
+torch>=1.9.0
+torchvision
+regex
+ftfy
+tqdm
+huggingface_hub
+sentencepiece
+protobuf==3.20.*
+timm
diff --git a/open_clip/setup.py b/open_clip/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..00ab400a6679904cc5009ee595738f2e21dfaa14
--- /dev/null
+++ b/open_clip/setup.py
@@ -0,0 +1,61 @@
+""" Setup
+"""
+from setuptools import setup, find_packages
+from codecs import open
+from os import path
+
+here = path.abspath(path.dirname(__file__))
+
+# Get the long description from the README file
+with open(path.join(here, 'README.md'), encoding='utf-8') as f:
+ long_description = f.read()
+
+def _read_reqs(relpath):
+ fullpath = path.join(path.dirname(__file__), relpath)
+ with open(fullpath) as f:
+ return [s.strip() for s in f.readlines() if (s.strip() and not s.startswith("#"))]
+
+REQUIREMENTS = _read_reqs("requirements.txt")
+TRAINING_REQUIREMENTS = _read_reqs("requirements-training.txt")
+
+exec(open('src/open_clip/version.py').read())
+setup(
+ name='open_clip_torch',
+ version=__version__,
+ description='OpenCLIP',
+ long_description=long_description,
+ long_description_content_type='text/markdown',
+ url='https://github.com/mlfoundations/open_clip',
+ author='',
+ author_email='',
+ classifiers=[
+ # How mature is this project? Common values are
+ # 3 - Alpha
+ # 4 - Beta
+ # 5 - Production/Stable
+ 'Development Status :: 3 - Alpha',
+ 'Intended Audience :: Education',
+ 'Intended Audience :: Science/Research',
+ 'License :: OSI Approved :: Apache Software License',
+ 'Programming Language :: Python :: 3.7',
+ 'Programming Language :: Python :: 3.8',
+ 'Programming Language :: Python :: 3.9',
+ 'Programming Language :: Python :: 3.10',
+ 'Topic :: Scientific/Engineering',
+ 'Topic :: Scientific/Engineering :: Artificial Intelligence',
+ 'Topic :: Software Development',
+ 'Topic :: Software Development :: Libraries',
+ 'Topic :: Software Development :: Libraries :: Python Modules',
+ ],
+
+ # Note that this is a string of words separated by whitespace, not a list.
+ keywords='CLIP pretrained',
+ package_dir={'': 'src'},
+ packages=find_packages(where='src'),
+ include_package_data=True,
+ install_requires=REQUIREMENTS,
+ extras_require={
+ "training": TRAINING_REQUIREMENTS,
+ },
+ python_requires='>=3.7',
+)
diff --git a/open_clip/src/open_clip/__init__.py b/open_clip/src/open_clip/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cf72e9280c90bdfeaced30750650ef0f9021c3d
--- /dev/null
+++ b/open_clip/src/open_clip/__init__.py
@@ -0,0 +1,11 @@
+from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
+from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer
+from .factory import list_models, add_model_config, get_model_config, load_checkpoint
+from .loss import ClipLoss
+from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg,\
+ convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype
+from .openai import load_openai_model, list_openai_models
+from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model,\
+ get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
+from .tokenizer import SimpleTokenizer, tokenize
+from .transform import image_transform, AugmentationCfg
diff --git a/open_clip/src/open_clip/bpe_simple_vocab_16e6.txt.gz b/open_clip/src/open_clip/bpe_simple_vocab_16e6.txt.gz
new file mode 100644
index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113
--- /dev/null
+++ b/open_clip/src/open_clip/bpe_simple_vocab_16e6.txt.gz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
+size 1356917
diff --git a/open_clip/src/open_clip/constants.py b/open_clip/src/open_clip/constants.py
new file mode 100644
index 0000000000000000000000000000000000000000..a670bb3fab442baeb9af53b91c312e6982af57ee
--- /dev/null
+++ b/open_clip/src/open_clip/constants.py
@@ -0,0 +1,2 @@
+OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
+OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
diff --git a/open_clip/src/open_clip/factory.py b/open_clip/src/open_clip/factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ac41f877ad9a6b969c2efd981a7453d1dd5142b
--- /dev/null
+++ b/open_clip/src/open_clip/factory.py
@@ -0,0 +1,313 @@
+import json
+import logging
+import os
+import pathlib
+import re
+from copy import deepcopy
+from pathlib import Path
+from typing import Any, Dict, Optional, Tuple, Union
+
+import torch
+
+from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
+from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
+ resize_pos_embed, get_cast_dtype
+from .openai import load_openai_model
+from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model, download_pretrained_from_hf
+from .transform import image_transform, AugmentationCfg
+from .tokenizer import HFTokenizer, tokenize
+
+
+HF_HUB_PREFIX = 'hf-hub:'
+_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
+_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
+
+
+def _natural_key(string_):
+ return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
+
+
+def _rescan_model_configs():
+ global _MODEL_CONFIGS
+
+ config_ext = ('.json',)
+ config_files = []
+ for config_path in _MODEL_CONFIG_PATHS:
+ if config_path.is_file() and config_path.suffix in config_ext:
+ config_files.append(config_path)
+ elif config_path.is_dir():
+ for ext in config_ext:
+ config_files.extend(config_path.glob(f'*{ext}'))
+
+ for cf in config_files:
+ with open(cf, 'r') as f:
+ model_cfg = json.load(f)
+ if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
+ _MODEL_CONFIGS[cf.stem] = model_cfg
+
+ _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))}
+
+
+_rescan_model_configs() # initial populate of model config registry
+
+
+def list_models():
+ """ enumerate available model architectures based on config files """
+ return list(_MODEL_CONFIGS.keys())
+
+
+def add_model_config(path):
+ """ add model config path or file and update registry """
+ if not isinstance(path, Path):
+ path = Path(path)
+ _MODEL_CONFIG_PATHS.append(path)
+ _rescan_model_configs()
+
+
+def get_model_config(model_name):
+ if model_name in _MODEL_CONFIGS:
+ return deepcopy(_MODEL_CONFIGS[model_name])
+ else:
+ return None
+
+
+def get_tokenizer(model_name):
+ if model_name.startswith(HF_HUB_PREFIX):
+ tokenizer = HFTokenizer(model_name[len(HF_HUB_PREFIX):])
+ else:
+ config = get_model_config(model_name)
+ tokenizer = HFTokenizer(config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize
+ return tokenizer
+
+
+def load_state_dict(checkpoint_path: str, map_location='cpu'):
+ checkpoint = torch.load(checkpoint_path, map_location=map_location)
+ if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
+ state_dict = checkpoint['state_dict']
+ else:
+ state_dict = checkpoint
+ if next(iter(state_dict.items()))[0].startswith('module'):
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
+ return state_dict
+
+
+def load_checkpoint(model, checkpoint_path, strict=True):
+ state_dict = load_state_dict(checkpoint_path)
+ # detect old format and make compatible with new format
+ if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
+ state_dict = convert_to_custom_text_state_dict(state_dict)
+ resize_pos_embed(state_dict, model)
+ incompatible_keys = model.load_state_dict(state_dict, strict=strict)
+ return incompatible_keys
+
+
+def create_model(
+ model_name: str,
+ pretrained: Optional[str] = None,
+ precision: str = 'fp32',
+ device: Union[str, torch.device] = 'cpu',
+ jit: bool = False,
+ force_quick_gelu: bool = False,
+ force_custom_text: bool = False,
+ force_patch_dropout: Optional[float] = None,
+ force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
+ pretrained_image: bool = False,
+ pretrained_hf: bool = True,
+ cache_dir: Optional[str] = None,
+):
+ has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX)
+ if has_hf_hub_prefix:
+ model_id = model_name[len(HF_HUB_PREFIX):]
+ checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
+ config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir)
+
+ with open(config_path, 'r', encoding='utf-8') as f:
+ config = json.load(f)
+ pretrained_cfg = config['preprocess_cfg']
+ model_cfg = config['model_cfg']
+ else:
+ model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
+ checkpoint_path = None
+ pretrained_cfg = {}
+ model_cfg = None
+
+ if isinstance(device, str):
+ device = torch.device(device)
+
+ if pretrained and pretrained.lower() == 'openai':
+ logging.info(f'Loading pretrained {model_name} from OpenAI.')
+ model = load_openai_model(
+ model_name,
+ precision=precision,
+ device=device,
+ jit=jit,
+ cache_dir=cache_dir,
+ )
+ else:
+ model_cfg = model_cfg or get_model_config(model_name)
+ if model_cfg is not None:
+ logging.info(f'Loaded {model_name} model config.')
+ else:
+ logging.error(f'Model config for {model_name} not found; available models {list_models()}.')
+ raise RuntimeError(f'Model config for {model_name} not found.')
+
+ if force_quick_gelu:
+ # override for use of QuickGELU on non-OpenAI transformer models
+ model_cfg["quick_gelu"] = True
+
+ if force_patch_dropout is not None:
+ # override the default patch dropout value
+ model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout
+
+ if force_image_size is not None:
+ # override model config's image size
+ model_cfg["vision_cfg"]["image_size"] = force_image_size
+
+ if pretrained_image:
+ if 'timm_model_name' in model_cfg.get('vision_cfg', {}):
+ # pretrained weight loading for timm models set via vision_cfg
+ model_cfg['vision_cfg']['timm_model_pretrained'] = True
+ else:
+ assert False, 'pretrained image towers currently only supported for timm models'
+
+ cast_dtype = get_cast_dtype(precision)
+ is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {})
+ custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model
+
+ if custom_text:
+ if is_hf_model:
+ model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf
+ model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
+ else:
+ model = CLIP(**model_cfg, cast_dtype=cast_dtype)
+
+ if pretrained:
+ checkpoint_path = ''
+ pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
+ if pretrained_cfg:
+ checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
+ elif os.path.exists(pretrained):
+ checkpoint_path = pretrained
+
+ if checkpoint_path:
+ logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
+ load_checkpoint(model, checkpoint_path)
+ else:
+ error_str = (
+ f'Pretrained weights ({pretrained}) not found for model {model_name}.'
+ f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
+ logging.warning(error_str)
+ raise RuntimeError(error_str)
+ elif has_hf_hub_prefix:
+ logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
+ load_checkpoint(model, checkpoint_path)
+
+ model.to(device=device)
+ if precision in ("fp16", "bf16"):
+ convert_weights_to_lp(model, dtype=torch.bfloat16 if precision == 'bf16' else torch.float16)
+
+ # set image / mean metadata from pretrained_cfg if available, or use default
+ model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN
+ model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD
+
+ if jit:
+ model = torch.jit.script(model)
+
+ return model
+
+
+def create_model_and_transforms(
+ model_name: str,
+ pretrained: Optional[str] = None,
+ precision: str = 'fp32',
+ device: Union[str, torch.device] = 'cpu',
+ jit: bool = False,
+ force_quick_gelu: bool = False,
+ force_custom_text: bool = False,
+ force_patch_dropout: Optional[float] = None,
+ force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
+ pretrained_image: bool = False,
+ pretrained_hf: bool = True,
+ image_mean: Optional[Tuple[float, ...]] = None,
+ image_std: Optional[Tuple[float, ...]] = None,
+ aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
+ cache_dir: Optional[str] = None,
+):
+ model = create_model(
+ model_name,
+ pretrained,
+ precision=precision,
+ device=device,
+ jit=jit,
+ force_quick_gelu=force_quick_gelu,
+ force_custom_text=force_custom_text,
+ force_patch_dropout=force_patch_dropout,
+ force_image_size=force_image_size,
+ pretrained_image=pretrained_image,
+ pretrained_hf=pretrained_hf,
+ cache_dir=cache_dir,
+ )
+
+ image_mean = image_mean or getattr(model.visual, 'image_mean', None)
+ image_std = image_std or getattr(model.visual, 'image_std', None)
+ preprocess_train = image_transform(
+ model.visual.image_size,
+ is_train=True,
+ mean=image_mean,
+ std=image_std,
+ aug_cfg=aug_cfg,
+ )
+ preprocess_val = image_transform(
+ model.visual.image_size,
+ is_train=False,
+ mean=image_mean,
+ std=image_std,
+ )
+
+ return model, preprocess_train, preprocess_val
+
+
+def create_model_from_pretrained(
+ model_name: str,
+ pretrained: str,
+ precision: str = 'fp32',
+ device: Union[str, torch.device] = 'cpu',
+ jit: bool = False,
+ force_quick_gelu: bool = False,
+ force_custom_text: bool = False,
+ force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
+ return_transform: bool = True,
+ image_mean: Optional[Tuple[float, ...]] = None,
+ image_std: Optional[Tuple[float, ...]] = None,
+ cache_dir: Optional[str] = None,
+):
+ if not is_pretrained_cfg(model_name, pretrained) and not os.path.exists(pretrained):
+ raise RuntimeError(
+ f'{pretrained} is not a valid pretrained cfg or checkpoint for {model_name}.'
+ f' Use open_clip.list_pretrained() to find one.')
+
+ model = create_model(
+ model_name,
+ pretrained,
+ precision=precision,
+ device=device,
+ jit=jit,
+ force_quick_gelu=force_quick_gelu,
+ force_custom_text=force_custom_text,
+ force_image_size=force_image_size,
+ cache_dir=cache_dir,
+ )
+
+ if not return_transform:
+ return model
+
+ image_mean = image_mean or getattr(model.visual, 'image_mean', None)
+ image_std = image_std or getattr(model.visual, 'image_std', None)
+ preprocess = image_transform(
+ model.visual.image_size,
+ is_train=False,
+ mean=image_mean,
+ std=image_std,
+ )
+
+ return model, preprocess
diff --git a/open_clip/src/open_clip/hf_configs.py b/open_clip/src/open_clip/hf_configs.py
new file mode 100644
index 0000000000000000000000000000000000000000..e236222bafce0358445ea16953ca0b2d5a84758a
--- /dev/null
+++ b/open_clip/src/open_clip/hf_configs.py
@@ -0,0 +1,45 @@
+# HF architecture dict:
+arch_dict = {
+ # https://huggingface.co/docs/transformers/model_doc/roberta#roberta
+ "roberta": {
+ "config_names": {
+ "context_length": "max_position_embeddings",
+ "vocab_size": "vocab_size",
+ "width": "hidden_size",
+ "heads": "num_attention_heads",
+ "layers": "num_hidden_layers",
+ "layer_attr": "layer",
+ "token_embeddings_attr": "embeddings"
+ },
+ "pooler": "mean_pooler",
+ },
+ # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
+ "xlm-roberta": {
+ "config_names": {
+ "context_length": "max_position_embeddings",
+ "vocab_size": "vocab_size",
+ "width": "hidden_size",
+ "heads": "num_attention_heads",
+ "layers": "num_hidden_layers",
+ "layer_attr": "layer",
+ "token_embeddings_attr": "embeddings"
+ },
+ "pooler": "mean_pooler",
+ },
+ # https://huggingface.co/docs/transformers/model_doc/mt5#mt5
+ "mt5": {
+ "config_names": {
+ # unlimited seqlen
+ # https://github.com/google-research/text-to-text-transfer-transformer/issues/273
+ # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
+ "context_length": "",
+ "vocab_size": "vocab_size",
+ "width": "d_model",
+ "heads": "num_heads",
+ "layers": "num_layers",
+ "layer_attr": "block",
+ "token_embeddings_attr": "embed_tokens"
+ },
+ "pooler": "mean_pooler",
+ },
+}
diff --git a/open_clip/src/open_clip/hf_model.py b/open_clip/src/open_clip/hf_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9f1103d6e4543c0951eefde0b41783495b3ed35
--- /dev/null
+++ b/open_clip/src/open_clip/hf_model.py
@@ -0,0 +1,164 @@
+""" huggingface model adapter
+
+Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
+"""
+
+import re
+
+import torch
+import torch.nn as nn
+from torch import TensorType
+
+try:
+ import transformers
+ from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \
+ BaseModelOutputWithPoolingAndCrossAttentions
+except ImportError as e:
+ transformers = None
+
+
+ class BaseModelOutput:
+ pass
+
+
+ class PretrainedConfig:
+ pass
+
+from .hf_configs import arch_dict
+
+
+# utils
+def _camel2snake(s):
+ return re.sub(r'(? TensorType:
+ attn_mask = (x != self.config.pad_token_id).long()
+ out = self.transformer(input_ids=x, attention_mask=attn_mask)
+ pooled_out = self.pooler(out, attn_mask)
+
+ return self.proj(pooled_out)
+
+ def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
+ if not unlocked_layers: # full freezing
+ for n, p in self.transformer.named_parameters():
+ p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
+ return
+
+ encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
+ layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
+ print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model")
+ embeddings = getattr(
+ self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"])
+ modules = [embeddings, *layer_list][:-unlocked_layers]
+ # freeze layers
+ for module in modules:
+ for n, p in module.named_parameters():
+ p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ self.transformer.gradient_checkpointing_enable()
+
+ def init_parameters(self):
+ pass
diff --git a/open_clip/src/open_clip/loss.py b/open_clip/src/open_clip/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..de31426dfa7ed40369b5461d6498008392d507e5
--- /dev/null
+++ b/open_clip/src/open_clip/loss.py
@@ -0,0 +1,121 @@
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+
+try:
+ import torch.distributed.nn
+ from torch import distributed as dist
+ has_distributed = True
+except ImportError:
+ has_distributed = False
+
+try:
+ import horovod.torch as hvd
+except ImportError:
+ hvd = None
+
+
+def gather_features(
+ image_features,
+ text_features,
+ local_loss=False,
+ gather_with_grad=False,
+ rank=0,
+ world_size=1,
+ use_horovod=False
+):
+ assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
+ if use_horovod:
+ assert hvd is not None, 'Please install horovod'
+ if gather_with_grad:
+ all_image_features = hvd.allgather(image_features)
+ all_text_features = hvd.allgather(text_features)
+ else:
+ with torch.no_grad():
+ all_image_features = hvd.allgather(image_features)
+ all_text_features = hvd.allgather(text_features)
+ if not local_loss:
+ # ensure grads for local rank when all_* features don't have a gradient
+ gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
+ gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
+ gathered_image_features[rank] = image_features
+ gathered_text_features[rank] = text_features
+ all_image_features = torch.cat(gathered_image_features, dim=0)
+ all_text_features = torch.cat(gathered_text_features, dim=0)
+ else:
+ # We gather tensors from all gpus
+ if gather_with_grad:
+ all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
+ all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
+ else:
+ gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
+ gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
+ dist.all_gather(gathered_image_features, image_features)
+ dist.all_gather(gathered_text_features, text_features)
+ if not local_loss:
+ # ensure grads for local rank when all_* features don't have a gradient
+ gathered_image_features[rank] = image_features
+ gathered_text_features[rank] = text_features
+ all_image_features = torch.cat(gathered_image_features, dim=0)
+ all_text_features = torch.cat(gathered_text_features, dim=0)
+
+ return all_image_features, all_text_features
+
+
+class ClipLoss(nn.Module):
+
+ def __init__(
+ self,
+ local_loss=False,
+ gather_with_grad=False,
+ cache_labels=False,
+ rank=0,
+ world_size=1,
+ use_horovod=False,
+ ):
+ super().__init__()
+ self.local_loss = local_loss
+ self.gather_with_grad = gather_with_grad
+ self.cache_labels = cache_labels
+ self.rank = rank
+ self.world_size = world_size
+ self.use_horovod = use_horovod
+
+ # cache state
+ self.prev_num_logits = 0
+ self.labels = {}
+
+ def forward(self, image_features, text_features, logit_scale):
+ device = image_features.device
+ if self.world_size > 1:
+ all_image_features, all_text_features = gather_features(
+ image_features, text_features,
+ self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
+
+ if self.local_loss:
+ logits_per_image = logit_scale * image_features @ all_text_features.T
+ logits_per_text = logit_scale * text_features @ all_image_features.T
+ else:
+ logits_per_image = logit_scale * all_image_features @ all_text_features.T
+ logits_per_text = logits_per_image.T
+ else:
+ logits_per_image = logit_scale * image_features @ text_features.T
+ logits_per_text = logit_scale * text_features @ image_features.T
+
+ # calculated ground-truth and cache if enabled
+ num_logits = logits_per_image.shape[0]
+ if self.prev_num_logits != num_logits or device not in self.labels:
+ labels = torch.arange(num_logits, device=device, dtype=torch.long)
+ if self.world_size > 1 and self.local_loss:
+ labels = labels + num_logits * self.rank
+ if self.cache_labels:
+ self.labels[device] = labels
+ self.prev_num_logits = num_logits
+ else:
+ labels = self.labels[device]
+
+ total_loss = (
+ F.cross_entropy(logits_per_image, labels) +
+ F.cross_entropy(logits_per_text, labels)
+ ) / 2
+ return total_loss
diff --git a/open_clip/src/open_clip/model.py b/open_clip/src/open_clip/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..038c4567d9bdb3850b1c2120d7b7f82ee22f89ae
--- /dev/null
+++ b/open_clip/src/open_clip/model.py
@@ -0,0 +1,408 @@
+""" CLIP Model
+
+Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
+"""
+from dataclasses import dataclass
+import logging
+import math
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+from torch.utils.checkpoint import checkpoint
+
+from .hf_model import HFTextEncoder
+from .modified_resnet import ModifiedResNet
+from .timm_model import TimmModel
+from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer
+from .utils import to_2tuple
+
+
+@dataclass
+class CLIPVisionCfg:
+ layers: Union[Tuple[int, int, int, int], int] = 12
+ width: int = 768
+ head_width: int = 64
+ mlp_ratio: float = 4.0
+ patch_size: int = 16
+ image_size: Union[Tuple[int, int], int] = 224
+ ls_init_value: Optional[float] = None # layer scale initial value
+ patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
+ global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
+ timm_model_name: str = None # a valid model name overrides layers, width, patch_size
+ timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
+ timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
+ timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
+ timm_proj_bias: bool = False # enable bias final projection
+ timm_drop: float = 0. # head dropout
+ timm_drop_path: Optional[float] = None # backbone stochastic depth
+
+
+@dataclass
+class CLIPTextCfg:
+ context_length: int = 77
+ vocab_size: int = 49408
+ width: int = 512
+ heads: int = 8
+ layers: int = 12
+ ls_init_value: Optional[float] = None # layer scale initial value
+ hf_model_name: str = None
+ hf_tokenizer_name: str = None
+ hf_model_pretrained: bool = True
+ proj: str = 'mlp'
+ pooler_type: str = 'mean_pooler'
+
+
+def get_cast_dtype(precision: str):
+ cast_dtype = None
+ if precision == 'bf16':
+ cast_dtype = torch.bfloat16
+ elif precision == 'fp16':
+ cast_dtype = torch.float16
+ return cast_dtype
+
+
+def _build_vision_tower(
+ embed_dim: int,
+ vision_cfg: CLIPVisionCfg,
+ quick_gelu: bool = False,
+ cast_dtype: Optional[torch.dtype] = None
+):
+ if isinstance(vision_cfg, dict):
+ vision_cfg = CLIPVisionCfg(**vision_cfg)
+
+ # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
+ # memory efficient in recent PyTorch releases (>= 1.10).
+ # NOTE: timm models always use native GELU regardless of quick_gelu flag.
+ act_layer = QuickGELU if quick_gelu else nn.GELU
+
+ if vision_cfg.timm_model_name:
+ visual = TimmModel(
+ vision_cfg.timm_model_name,
+ pretrained=vision_cfg.timm_model_pretrained,
+ pool=vision_cfg.timm_pool,
+ proj=vision_cfg.timm_proj,
+ proj_bias=vision_cfg.timm_proj_bias,
+ drop=vision_cfg.timm_drop,
+ drop_path=vision_cfg.timm_drop_path,
+ embed_dim=embed_dim,
+ image_size=vision_cfg.image_size
+ )
+ act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models
+ elif isinstance(vision_cfg.layers, (tuple, list)):
+ vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
+ visual = ModifiedResNet(
+ layers=vision_cfg.layers,
+ output_dim=embed_dim,
+ heads=vision_heads,
+ image_size=vision_cfg.image_size,
+ width=vision_cfg.width
+ )
+ else:
+ vision_heads = vision_cfg.width // vision_cfg.head_width
+ norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
+ visual = VisionTransformer(
+ image_size=vision_cfg.image_size,
+ patch_size=vision_cfg.patch_size,
+ width=vision_cfg.width,
+ layers=vision_cfg.layers,
+ heads=vision_heads,
+ mlp_ratio=vision_cfg.mlp_ratio,
+ ls_init_value=vision_cfg.ls_init_value,
+ patch_dropout=vision_cfg.patch_dropout,
+ global_average_pool=vision_cfg.global_average_pool,
+ output_dim=embed_dim,
+ act_layer=act_layer,
+ norm_layer=norm_layer,
+ )
+
+ return visual
+
+
+def _build_text_tower(
+ embed_dim: int,
+ text_cfg: CLIPTextCfg,
+ quick_gelu: bool = False,
+ cast_dtype: Optional[torch.dtype] = None,
+):
+ if isinstance(text_cfg, dict):
+ text_cfg = CLIPTextCfg(**text_cfg)
+
+ if text_cfg.hf_model_name:
+ text = HFTextEncoder(
+ text_cfg.hf_model_name,
+ output_dim=embed_dim,
+ proj=text_cfg.proj,
+ pooler_type=text_cfg.pooler_type,
+ pretrained=text_cfg.hf_model_pretrained
+ )
+ else:
+ act_layer = QuickGELU if quick_gelu else nn.GELU
+ norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
+
+ text = TextTransformer(
+ context_length=text_cfg.context_length,
+ vocab_size=text_cfg.vocab_size,
+ width=text_cfg.width,
+ heads=text_cfg.heads,
+ layers=text_cfg.layers,
+ ls_init_value=text_cfg.ls_init_value,
+ output_dim=embed_dim,
+ act_layer=act_layer,
+ norm_layer=norm_layer,
+ )
+ return text
+
+
+class CLIP(nn.Module):
+ def __init__(
+ self,
+ embed_dim: int,
+ vision_cfg: CLIPVisionCfg,
+ text_cfg: CLIPTextCfg,
+ quick_gelu: bool = False,
+ cast_dtype: Optional[torch.dtype] = None,
+ ):
+ super().__init__()
+ self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
+
+ text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
+ self.transformer = text.transformer
+ self.vocab_size = text.vocab_size
+ self.token_embedding = text.token_embedding
+ self.positional_embedding = text.positional_embedding
+ self.ln_final = text.ln_final
+ self.text_projection = text.text_projection
+ self.register_buffer('attn_mask', text.attn_mask, persistent=False)
+
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
+
+ def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
+ # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
+ self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ self.visual.set_grad_checkpointing(enable)
+ self.transformer.grad_checkpointing = enable
+
+ def encode_image(self, image, normalize: bool = False, dense=False):
+ features = self.visual(image, dense=dense)
+ return F.normalize(features, dim=-1) if normalize else features
+
+ def encode_text(self, text, normalize: bool = False):
+ cast_dtype = self.transformer.get_cast_dtype()
+
+ x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
+
+ x = x + self.positional_embedding.to(cast_dtype)
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.transformer(x, attn_mask=self.attn_mask)
+ x = x.permute(1, 0, 2) # LND -> NLD
+ x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
+ return F.normalize(x, dim=-1) if normalize else x
+
+ def forward(self, image, text):
+ image_features = self.encode_image(image, normalize=True)
+ text_features = self.encode_text(text, normalize=True)
+ return image_features, text_features, self.logit_scale.exp()
+
+
+class CustomTextCLIP(nn.Module):
+ def __init__(
+ self,
+ embed_dim: int,
+ vision_cfg: CLIPVisionCfg,
+ text_cfg: CLIPTextCfg,
+ quick_gelu: bool = False,
+ cast_dtype: Optional[torch.dtype] = None,
+ ):
+ super().__init__()
+ self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
+ self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
+
+ def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
+ # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
+ self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
+
+ def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
+ self.text.lock(unlocked_layers, freeze_layer_norm)
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ self.visual.set_grad_checkpointing(enable)
+ self.text.set_grad_checkpointing(enable)
+
+ def encode_image(self, image, normalize: bool = False):
+ features = self.visual(image)
+ return F.normalize(features, dim=-1) if normalize else features
+
+ def encode_text(self, text, normalize: bool = False):
+ features = self.text(text)
+ return F.normalize(features, dim=-1) if normalize else features
+
+ def forward(self, image, text):
+ image_features = self.encode_image(image, normalize=True)
+ text_features = self.encode_text(text, normalize=True)
+ return image_features, text_features, self.logit_scale.exp()
+
+
+def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
+ """Convert applicable model parameters to low-precision (bf16 or fp16)"""
+
+ def _convert_weights(l):
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
+ l.weight.data = l.weight.data.to(dtype)
+ if l.bias is not None:
+ l.bias.data = l.bias.data.to(dtype)
+
+ if isinstance(l, (nn.MultiheadAttention, Attention)):
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
+ tensor = getattr(l, attr)
+ if tensor is not None:
+ tensor.data = tensor.data.to(dtype)
+
+ for name in ["text_projection", "proj"]:
+ if hasattr(l, name):
+ attr = getattr(l, name)
+ if attr is not None:
+ attr.data = attr.data.to(dtype)
+
+ model.apply(_convert_weights)
+
+
+convert_weights_to_fp16 = convert_weights_to_lp # backwards compat
+
+
+# used to maintain checkpoint compatibility
+def convert_to_custom_text_state_dict(state_dict: dict):
+ if 'text_projection' in state_dict:
+ # old format state_dict, move text tower -> .text
+ new_state_dict = {}
+ for k, v in state_dict.items():
+ if any(k.startswith(p) for p in (
+ 'text_projection',
+ 'positional_embedding',
+ 'token_embedding',
+ 'transformer',
+ 'ln_final',
+ )):
+ k = 'text.' + k
+ new_state_dict[k] = v
+ return new_state_dict
+ return state_dict
+
+
+def build_model_from_openai_state_dict(
+ state_dict: dict,
+ quick_gelu=True,
+ cast_dtype=torch.float16,
+):
+ vit = "visual.proj" in state_dict
+
+ if vit:
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
+ vision_layers = len(
+ [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
+ image_size = vision_patch_size * grid_size
+ else:
+ counts: list = [
+ len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
+ vision_layers = tuple(counts)
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
+ vision_patch_size = None
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
+ image_size = output_width * 32
+
+ embed_dim = state_dict["text_projection"].shape[1]
+ context_length = state_dict["positional_embedding"].shape[0]
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
+ transformer_width = state_dict["ln_final.weight"].shape[0]
+ transformer_heads = transformer_width // 64
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
+
+ vision_cfg = CLIPVisionCfg(
+ layers=vision_layers,
+ width=vision_width,
+ patch_size=vision_patch_size,
+ image_size=image_size,
+ )
+ text_cfg = CLIPTextCfg(
+ context_length=context_length,
+ vocab_size=vocab_size,
+ width=transformer_width,
+ heads=transformer_heads,
+ layers=transformer_layers
+ )
+ model = CLIP(
+ embed_dim,
+ vision_cfg=vision_cfg,
+ text_cfg=text_cfg,
+ quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU
+ cast_dtype=cast_dtype,
+ )
+
+ for key in ["input_resolution", "context_length", "vocab_size"]:
+ state_dict.pop(key, None)
+
+ convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16
+ model.load_state_dict(state_dict)
+ return model.eval()
+
+
+def trace_model(model, batch_size=256, device=torch.device('cpu')):
+ model.eval()
+ image_size = model.visual.image_size
+ example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
+ example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
+ model = torch.jit.trace_module(
+ model,
+ inputs=dict(
+ forward=(example_images, example_text),
+ encode_text=(example_text,),
+ encode_image=(example_images,)
+ ))
+ model.visual.image_size = image_size
+ return model
+
+
+def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True):
+ # Rescale the grid of position embeddings when loading from state_dict
+ old_pos_embed = state_dict.get('visual.positional_embedding', None)
+ if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
+ return
+ grid_size = to_2tuple(model.visual.grid_size)
+ extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
+ new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
+ if new_seq_len == old_pos_embed.shape[0]:
+ return
+
+ if extra_tokens:
+ pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
+ else:
+ pos_emb_tok, pos_emb_img = None, old_pos_embed
+ old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
+
+ logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
+ pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
+ pos_emb_img = F.interpolate(
+ pos_emb_img,
+ size=grid_size,
+ mode=interpolation,
+ antialias=antialias,
+ align_corners=False,
+ )
+ pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
+ if pos_emb_tok is not None:
+ new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
+ else:
+ new_pos_embed = pos_emb_img
+ state_dict['visual.positional_embedding'] = new_pos_embed
diff --git a/open_clip/src/open_clip/model_configs/RN101-quickgelu.json b/open_clip/src/open_clip/model_configs/RN101-quickgelu.json
new file mode 100644
index 0000000000000000000000000000000000000000..d0db2c161d13138788c4609d373b023b8454d624
--- /dev/null
+++ b/open_clip/src/open_clip/model_configs/RN101-quickgelu.json
@@ -0,0 +1,22 @@
+{
+ "embed_dim": 512,
+ "quick_gelu": true,
+ "vision_cfg": {
+ "image_size": 224,
+ "layers": [
+ 3,
+ 4,
+ 23,
+ 3
+ ],
+ "width": 64,
+ "patch_size": null
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/open_clip/src/open_clip/model_configs/RN101.json b/open_clip/src/open_clip/model_configs/RN101.json
new file mode 100644
index 0000000000000000000000000000000000000000..b88b4d3acbaa701c614ab0ea65fc88fcfe289c32
--- /dev/null
+++ b/open_clip/src/open_clip/model_configs/RN101.json
@@ -0,0 +1,21 @@
+{
+ "embed_dim": 512,
+ "vision_cfg": {
+ "image_size": 224,
+ "layers": [
+ 3,
+ 4,
+ 23,
+ 3
+ ],
+ "width": 64,
+ "patch_size": null
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/open_clip/src/open_clip/model_configs/RN50-quickgelu.json b/open_clip/src/open_clip/model_configs/RN50-quickgelu.json
new file mode 100644
index 0000000000000000000000000000000000000000..8c2f91260cdeb043434dc1e893cce81d4ce7f0d1
--- /dev/null
+++ b/open_clip/src/open_clip/model_configs/RN50-quickgelu.json
@@ -0,0 +1,22 @@
+{
+ "embed_dim": 1024,
+ "quick_gelu": true,
+ "vision_cfg": {
+ "image_size": 224,
+ "layers": [
+ 3,
+ 4,
+ 6,
+ 3
+ ],
+ "width": 64,
+ "patch_size": null
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
diff --git a/open_clip/src/open_clip/model_configs/RN50.json b/open_clip/src/open_clip/model_configs/RN50.json
new file mode 100644
index 0000000000000000000000000000000000000000..33aa884d54fee0076c33676831e49d5e1ffcb8f2
--- /dev/null
+++ b/open_clip/src/open_clip/model_configs/RN50.json
@@ -0,0 +1,21 @@
+{
+ "embed_dim": 1024,
+ "vision_cfg": {
+ "image_size": 224,
+ "layers": [
+ 3,
+ 4,
+ 6,
+ 3
+ ],
+ "width": 64,
+ "patch_size": null
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/open_clip/src/open_clip/model_configs/RN50x16.json b/open_clip/src/open_clip/model_configs/RN50x16.json
new file mode 100644
index 0000000000000000000000000000000000000000..3161e1a2c9a839161e652a4d729c2cdc971161db
--- /dev/null
+++ b/open_clip/src/open_clip/model_configs/RN50x16.json
@@ -0,0 +1,21 @@
+{
+ "embed_dim": 768,
+ "vision_cfg": {
+ "image_size": 384,
+ "layers": [
+ 6,
+ 8,
+ 18,
+ 8
+ ],
+ "width": 96,
+ "patch_size": null
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 768,
+ "heads": 12,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/open_clip/src/open_clip/model_configs/RN50x4.json b/open_clip/src/open_clip/model_configs/RN50x4.json
new file mode 100644
index 0000000000000000000000000000000000000000..e155237f8ce1026aaaeecc80751eabe6f329f0bb
--- /dev/null
+++ b/open_clip/src/open_clip/model_configs/RN50x4.json
@@ -0,0 +1,21 @@
+{
+ "embed_dim": 640,
+ "vision_cfg": {
+ "image_size": 288,
+ "layers": [
+ 4,
+ 6,
+ 10,
+ 6
+ ],
+ "width": 80,
+ "patch_size": null
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 640,
+ "heads": 10,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/open_clip/src/open_clip/model_configs/RN50x64.json b/open_clip/src/open_clip/model_configs/RN50x64.json
new file mode 100644
index 0000000000000000000000000000000000000000..f5aaa2ee3de21ddb03cbd12766a3419bf34898c7
--- /dev/null
+++ b/open_clip/src/open_clip/model_configs/RN50x64.json
@@ -0,0 +1,21 @@
+{
+ "embed_dim": 1024,
+ "vision_cfg": {
+ "image_size": 448,
+ "layers": [
+ 3,
+ 15,
+ 36,
+ 10
+ ],
+ "width": 128,
+ "patch_size": null
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 1024,
+ "heads": 16,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/open_clip/src/open_clip/model_configs/ViT-B-16-plus-240.json b/open_clip/src/open_clip/model_configs/ViT-B-16-plus-240.json
new file mode 100644
index 0000000000000000000000000000000000000000..5bbd12bcd01f64d6d0a0aa8316b129327a0d169a
--- /dev/null
+++ b/open_clip/src/open_clip/model_configs/ViT-B-16-plus-240.json
@@ -0,0 +1,16 @@
+{
+ "embed_dim": 640,
+ "vision_cfg": {
+ "image_size": 240,
+ "layers": 12,
+ "width": 896,
+ "patch_size": 16
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 640,
+ "heads": 10,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/open_clip/src/open_clip/model_configs/ViT-B-16-plus.json b/open_clip/src/open_clip/model_configs/ViT-B-16-plus.json
new file mode 100644
index 0000000000000000000000000000000000000000..5dc1e09baccef2b15055c1bffeb9903e760101c6
--- /dev/null
+++ b/open_clip/src/open_clip/model_configs/ViT-B-16-plus.json
@@ -0,0 +1,16 @@
+{
+ "embed_dim": 640,
+ "vision_cfg": {
+ "image_size": 224,
+ "layers": 12,
+ "width": 896,
+ "patch_size": 16
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 640,
+ "heads": 10,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/open_clip/src/open_clip/model_configs/ViT-B-16.json b/open_clip/src/open_clip/model_configs/ViT-B-16.json
new file mode 100644
index 0000000000000000000000000000000000000000..395eea77ec3907c0611531aba63459b193e67b9c
--- /dev/null
+++ b/open_clip/src/open_clip/model_configs/ViT-B-16.json
@@ -0,0 +1,16 @@
+{
+ "embed_dim": 512,
+ "vision_cfg": {
+ "image_size": 224,
+ "layers": 12,
+ "width": 768,
+ "patch_size": 16
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/open_clip/src/open_clip/model_configs/ViT-B-32-plus-256.json b/open_clip/src/open_clip/model_configs/ViT-B-32-plus-256.json
new file mode 100644
index 0000000000000000000000000000000000000000..2f09c857de9a4c01ae51297a7e2451984879f9de
--- /dev/null
+++ b/open_clip/src/open_clip/model_configs/ViT-B-32-plus-256.json
@@ -0,0 +1,16 @@
+{
+ "embed_dim": 640,
+ "vision_cfg": {
+ "image_size": 256,
+ "layers": 12,
+ "width": 896,
+ "patch_size": 32
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 640,
+ "heads": 10,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/open_clip/src/open_clip/model_configs/ViT-B-32-quickgelu.json b/open_clip/src/open_clip/model_configs/ViT-B-32-quickgelu.json
new file mode 100644
index 0000000000000000000000000000000000000000..ce6bd923593293ed50dfcfb28b73ca7403bcf3c5
--- /dev/null
+++ b/open_clip/src/open_clip/model_configs/ViT-B-32-quickgelu.json
@@ -0,0 +1,17 @@
+{
+ "embed_dim": 512,
+ "quick_gelu": true,
+ "vision_cfg": {
+ "image_size": 224,
+ "layers": 12,
+ "width": 768,
+ "patch_size": 32
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/open_clip/src/open_clip/model_configs/ViT-B-32.json b/open_clip/src/open_clip/model_configs/ViT-B-32.json
new file mode 100644
index 0000000000000000000000000000000000000000..07c8e28eb06fa1813ba932fe4eec668262d1c47f
--- /dev/null
+++ b/open_clip/src/open_clip/model_configs/ViT-B-32.json
@@ -0,0 +1,16 @@
+{
+ "embed_dim": 512,
+ "vision_cfg": {
+ "image_size": 224,
+ "layers": 12,
+ "width": 768,
+ "patch_size": 32
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/open_clip/src/open_clip/model_configs/ViT-H-14.json b/open_clip/src/open_clip/model_configs/ViT-H-14.json
new file mode 100644
index 0000000000000000000000000000000000000000..3e3a7e934e7f02e41f4829996c4950e05f015a74
--- /dev/null
+++ b/open_clip/src/open_clip/model_configs/ViT-H-14.json
@@ -0,0 +1,17 @@
+{
+ "embed_dim": 1024,
+ "vision_cfg": {
+ "image_size": 224,
+ "layers": 32,
+ "width": 1280,
+ "head_width": 80,
+ "patch_size": 14
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 1024,
+ "heads": 16,
+ "layers": 24
+ }
+}
\ No newline at end of file
diff --git a/open_clip/src/open_clip/model_configs/ViT-H-16.json b/open_clip/src/open_clip/model_configs/ViT-H-16.json
new file mode 100644
index 0000000000000000000000000000000000000000..588485455fdf8193ec16474450b94e31c91ea93c
--- /dev/null
+++ b/open_clip/src/open_clip/model_configs/ViT-H-16.json
@@ -0,0 +1,17 @@
+{
+ "embed_dim": 1024,
+ "vision_cfg": {
+ "image_size": 224,
+ "layers": 32,
+ "width": 1280,
+ "head_width": 80,
+ "patch_size": 16
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 1024,
+ "heads": 16,
+ "layers": 24
+ }
+}
\ No newline at end of file
diff --git a/open_clip/src/open_clip/model_configs/ViT-L-14-280.json b/open_clip/src/open_clip/model_configs/ViT-L-14-280.json
new file mode 100644
index 0000000000000000000000000000000000000000..2262deaefa82792d35d73c0d7c8e620525092581
--- /dev/null
+++ b/open_clip/src/open_clip/model_configs/ViT-L-14-280.json
@@ -0,0 +1,16 @@
+{
+ "embed_dim": 768,
+ "vision_cfg": {
+ "image_size": 280,
+ "layers": 24,
+ "width": 1024,
+ "patch_size": 14
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 768,
+ "heads": 12,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/open_clip/src/open_clip/model_configs/ViT-L-14-336.json b/open_clip/src/open_clip/model_configs/ViT-L-14-336.json
new file mode 100644
index 0000000000000000000000000000000000000000..8d1f74c2639c3a3705df9865b9c08215675ddc97
--- /dev/null
+++ b/open_clip/src/open_clip/model_configs/ViT-L-14-336.json
@@ -0,0 +1,16 @@
+{
+ "embed_dim": 768,
+ "vision_cfg": {
+ "image_size": 336,
+ "layers": 24,
+ "width": 1024,
+ "patch_size": 14
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 768,
+ "heads": 12,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/open_clip/src/open_clip/model_configs/ViT-L-14.json b/open_clip/src/open_clip/model_configs/ViT-L-14.json
new file mode 100644
index 0000000000000000000000000000000000000000..d4a4bbb1dd4ed4edb317d3ace4f3ad13b211c241
--- /dev/null
+++ b/open_clip/src/open_clip/model_configs/ViT-L-14.json
@@ -0,0 +1,16 @@
+{
+ "embed_dim": 768,
+ "vision_cfg": {
+ "image_size": 224,
+ "layers": 24,
+ "width": 1024,
+ "patch_size": 14
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 768,
+ "heads": 12,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/open_clip/src/open_clip/model_configs/ViT-L-16-320.json b/open_clip/src/open_clip/model_configs/ViT-L-16-320.json
new file mode 100644
index 0000000000000000000000000000000000000000..fc2d13ca9ec7f0b56a886ddaf66c4a7ba7a442ba
--- /dev/null
+++ b/open_clip/src/open_clip/model_configs/ViT-L-16-320.json
@@ -0,0 +1,16 @@
+{
+ "embed_dim": 768,
+ "vision_cfg": {
+ "image_size": 320,
+ "layers": 24,
+ "width": 1024,
+ "patch_size": 16
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 768,
+ "heads": 12,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/open_clip/src/open_clip/model_configs/ViT-L-16.json b/open_clip/src/open_clip/model_configs/ViT-L-16.json
new file mode 100644
index 0000000000000000000000000000000000000000..82a1cedfa290adacbbdc02bc5d589734c22d41d3
--- /dev/null
+++ b/open_clip/src/open_clip/model_configs/ViT-L-16.json
@@ -0,0 +1,16 @@
+{
+ "embed_dim": 768,
+ "vision_cfg": {
+ "image_size": 224,
+ "layers": 24,
+ "width": 1024,
+ "patch_size": 16
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 768,
+ "heads": 12,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/open_clip/src/open_clip/model_configs/ViT-M-16-alt.json b/open_clip/src/open_clip/model_configs/ViT-M-16-alt.json
new file mode 100644
index 0000000000000000000000000000000000000000..1a317aad8e02d9c26d2decc7cc49a18dfdf9e0d8
--- /dev/null
+++ b/open_clip/src/open_clip/model_configs/ViT-M-16-alt.json
@@ -0,0 +1,17 @@
+{
+ "embed_dim": 384,
+ "vision_cfg": {
+ "image_size": 224,
+ "layers": 12,
+ "width": 512,
+ "patch_size": 16,
+ "ls_init_value": 1e-4
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 384,
+ "heads": 6,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/open_clip/src/open_clip/model_configs/ViT-M-16.json b/open_clip/src/open_clip/model_configs/ViT-M-16.json
new file mode 100644
index 0000000000000000000000000000000000000000..f2f3225a46e09237730a151d161f70c86b985172
--- /dev/null
+++ b/open_clip/src/open_clip/model_configs/ViT-M-16.json
@@ -0,0 +1,16 @@
+{
+ "embed_dim": 512,
+ "vision_cfg": {
+ "image_size": 224,
+ "layers": 12,
+ "width": 512,
+ "patch_size": 16
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/open_clip/src/open_clip/model_configs/ViT-M-32-alt.json b/open_clip/src/open_clip/model_configs/ViT-M-32-alt.json
new file mode 100644
index 0000000000000000000000000000000000000000..fd222aeac0f582ef6a1a33f1b3fec70a5b386ac0
--- /dev/null
+++ b/open_clip/src/open_clip/model_configs/ViT-M-32-alt.json
@@ -0,0 +1,16 @@
+{
+ "embed_dim": 384,
+ "vision_cfg": {
+ "image_size": 224,
+ "layers": 12,
+ "width": 512,
+ "patch_size": 32
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 384,
+ "heads": 6,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/open_clip/src/open_clip/model_configs/ViT-M-32.json b/open_clip/src/open_clip/model_configs/ViT-M-32.json
new file mode 100644
index 0000000000000000000000000000000000000000..4f718642821035d9776d1e006817d65ede074366
--- /dev/null
+++ b/open_clip/src/open_clip/model_configs/ViT-M-32.json
@@ -0,0 +1,16 @@
+{
+ "embed_dim": 512,
+ "vision_cfg": {
+ "image_size": 224,
+ "layers": 12,
+ "width": 512,
+ "patch_size": 32
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/open_clip/src/open_clip/model_configs/ViT-S-16-alt.json b/open_clip/src/open_clip/model_configs/ViT-S-16-alt.json
new file mode 100644
index 0000000000000000000000000000000000000000..a8c056555e4da3ba0d1475a61fc316362ecce76f
--- /dev/null
+++ b/open_clip/src/open_clip/model_configs/ViT-S-16-alt.json
@@ -0,0 +1,16 @@
+{
+ "embed_dim": 256,
+ "vision_cfg": {
+ "image_size": 224,
+ "layers": 12,
+ "width": 384,
+ "patch_size": 16
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 256,
+ "heads": 4,
+ "layers": 10
+ }
+}
\ No newline at end of file
diff --git a/open_clip/src/open_clip/model_configs/ViT-S-16.json b/open_clip/src/open_clip/model_configs/ViT-S-16.json
new file mode 100644
index 0000000000000000000000000000000000000000..1d8504e59658803f3093e5b05de45f30a09b8185
--- /dev/null
+++ b/open_clip/src/open_clip/model_configs/ViT-S-16.json
@@ -0,0 +1,16 @@
+{
+ "embed_dim": 384,
+ "vision_cfg": {
+ "image_size": 224,
+ "layers": 12,
+ "width": 384,
+ "patch_size": 16
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 384,
+ "heads": 6,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/open_clip/src/open_clip/model_configs/ViT-S-32-alt.json b/open_clip/src/open_clip/model_configs/ViT-S-32-alt.json
new file mode 100644
index 0000000000000000000000000000000000000000..e1dfdec9824df09a2010e991ccfa1d9ee2f45807
--- /dev/null
+++ b/open_clip/src/open_clip/model_configs/ViT-S-32-alt.json
@@ -0,0 +1,16 @@
+{
+ "embed_dim": 256,
+ "vision_cfg": {
+ "image_size": 224,
+ "layers": 12,
+ "width": 384,
+ "patch_size": 32
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 256,
+ "heads": 4,
+ "layers": 10
+ }
+}
\ No newline at end of file
diff --git a/open_clip/src/open_clip/model_configs/ViT-S-32.json b/open_clip/src/open_clip/model_configs/ViT-S-32.json
new file mode 100644
index 0000000000000000000000000000000000000000..9b8b4191b268de267268cfcb90fc01c6b9df07d8
--- /dev/null
+++ b/open_clip/src/open_clip/model_configs/ViT-S-32.json
@@ -0,0 +1,16 @@
+{
+ "embed_dim": 384,
+ "vision_cfg": {
+ "image_size": 224,
+ "layers": 12,
+ "width": 384,
+ "patch_size": 32
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 384,
+ "heads": 6,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/open_clip/src/open_clip/model_configs/ViT-bigG-14.json b/open_clip/src/open_clip/model_configs/ViT-bigG-14.json
new file mode 100644
index 0000000000000000000000000000000000000000..2cfba479a2e8f3737e71ce240732bf3bc743d8b7
--- /dev/null
+++ b/open_clip/src/open_clip/model_configs/ViT-bigG-14.json
@@ -0,0 +1,18 @@
+{
+ "embed_dim": 1280,
+ "vision_cfg": {
+ "image_size": 224,
+ "layers": 48,
+ "width": 1664,
+ "head_width": 104,
+ "mlp_ratio": 4.9231,
+ "patch_size": 14
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 1280,
+ "heads": 20,
+ "layers": 32
+ }
+}
\ No newline at end of file
diff --git a/open_clip/src/open_clip/model_configs/ViT-e-14.json b/open_clip/src/open_clip/model_configs/ViT-e-14.json
new file mode 100644
index 0000000000000000000000000000000000000000..91a0fe14d25a107fb8ec48dd7faae313fd26ed7b
--- /dev/null
+++ b/open_clip/src/open_clip/model_configs/ViT-e-14.json
@@ -0,0 +1,18 @@
+{
+ "embed_dim": 1280,
+ "vision_cfg": {
+ "image_size": 224,
+ "layers": 56,
+ "width": 1792,
+ "head_width": 112,
+ "mlp_ratio": 8.5715,
+ "patch_size": 14
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 1280,
+ "heads": 20,
+ "layers": 36
+ }
+}
\ No newline at end of file
diff --git a/open_clip/src/open_clip/model_configs/ViT-g-14.json b/open_clip/src/open_clip/model_configs/ViT-g-14.json
new file mode 100644
index 0000000000000000000000000000000000000000..8c4b7325cc75b6112be7107d36ae2cb5762d9091
--- /dev/null
+++ b/open_clip/src/open_clip/model_configs/ViT-g-14.json
@@ -0,0 +1,18 @@
+{
+ "embed_dim": 1024,
+ "vision_cfg": {
+ "image_size": 224,
+ "layers": 40,
+ "width": 1408,
+ "head_width": 88,
+ "mlp_ratio": 4.3637,
+ "patch_size": 14
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 1024,
+ "heads": 16,
+ "layers": 24
+ }
+}
\ No newline at end of file
diff --git a/open_clip/src/open_clip/model_configs/convnext_base.json b/open_clip/src/open_clip/model_configs/convnext_base.json
new file mode 100644
index 0000000000000000000000000000000000000000..4de9aa8a320f426ebc6e1b24edcf61b2b6a318c9
--- /dev/null
+++ b/open_clip/src/open_clip/model_configs/convnext_base.json
@@ -0,0 +1,17 @@
+{
+ "embed_dim": 512,
+ "vision_cfg": {
+ "timm_model_name": "convnext_base",
+ "timm_model_pretrained": false,
+ "timm_pool": "",
+ "timm_proj": "linear",
+ "image_size": 224
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/open_clip/src/open_clip/model_configs/convnext_base_w.json b/open_clip/src/open_clip/model_configs/convnext_base_w.json
new file mode 100644
index 0000000000000000000000000000000000000000..68e74e783d4cf82e8bfd9eb04cd423498ced92fd
--- /dev/null
+++ b/open_clip/src/open_clip/model_configs/convnext_base_w.json
@@ -0,0 +1,17 @@
+{
+ "embed_dim": 640,
+ "vision_cfg": {
+ "timm_model_name": "convnext_base",
+ "timm_model_pretrained": false,
+ "timm_pool": "",
+ "timm_proj": "linear",
+ "image_size": 256
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 640,
+ "heads": 10,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/open_clip/src/open_clip/model_configs/convnext_base_w_320.json b/open_clip/src/open_clip/model_configs/convnext_base_w_320.json
new file mode 100644
index 0000000000000000000000000000000000000000..3b1f7f0c1e3168cf43f496b2a7c2ddba68b6832c
--- /dev/null
+++ b/open_clip/src/open_clip/model_configs/convnext_base_w_320.json
@@ -0,0 +1,17 @@
+{
+ "embed_dim": 640,
+ "vision_cfg": {
+ "timm_model_name": "convnext_base",
+ "timm_model_pretrained": false,
+ "timm_pool": "",
+ "timm_proj": "linear",
+ "image_size": 320
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 640,
+ "heads": 10,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/open_clip/src/open_clip/model_configs/convnext_large.json b/open_clip/src/open_clip/model_configs/convnext_large.json
new file mode 100644
index 0000000000000000000000000000000000000000..72341b9a719114dca5f19d7ac8ce874ea5ba273e
--- /dev/null
+++ b/open_clip/src/open_clip/model_configs/convnext_large.json
@@ -0,0 +1,17 @@
+{
+ "embed_dim": 768,
+ "vision_cfg": {
+ "timm_model_name": "convnext_large",
+ "timm_model_pretrained": false,
+ "timm_pool": "",
+ "timm_proj": "linear",
+ "image_size": 224
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 768,
+ "heads": 12,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/open_clip/src/open_clip/model_configs/convnext_large_d.json b/open_clip/src/open_clip/model_configs/convnext_large_d.json
new file mode 100644
index 0000000000000000000000000000000000000000..8d171f7d996e1bbbad6f2f987f7087c5c122fd23
--- /dev/null
+++ b/open_clip/src/open_clip/model_configs/convnext_large_d.json
@@ -0,0 +1,19 @@
+{
+ "embed_dim": 768,
+ "vision_cfg": {
+ "timm_model_name": "convnext_large",
+ "timm_model_pretrained": false,
+ "timm_pool": "",
+ "timm_proj": "mlp",
+ "timm_drop": 0.1,
+ "timm_drop_path": 0.1,
+ "image_size": 256
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 768,
+ "heads": 12,
+ "layers": 16
+ }
+}
\ No newline at end of file
diff --git a/open_clip/src/open_clip/model_configs/convnext_small.json b/open_clip/src/open_clip/model_configs/convnext_small.json
new file mode 100644
index 0000000000000000000000000000000000000000..d158c569370fa60add8b4a9031c26aae79ee105b
--- /dev/null
+++ b/open_clip/src/open_clip/model_configs/convnext_small.json
@@ -0,0 +1,17 @@
+{
+ "embed_dim": 512,
+ "vision_cfg": {
+ "timm_model_name": "convnext_small",
+ "timm_model_pretrained": false,
+ "timm_pool": "",
+ "timm_proj": "linear",
+ "image_size": 224
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/open_clip/src/open_clip/model_configs/convnext_tiny.json b/open_clip/src/open_clip/model_configs/convnext_tiny.json
new file mode 100644
index 0000000000000000000000000000000000000000..48d89584f98d068baec724263d900c47f1621d55
--- /dev/null
+++ b/open_clip/src/open_clip/model_configs/convnext_tiny.json
@@ -0,0 +1,17 @@
+{
+ "embed_dim": 1024,
+ "vision_cfg": {
+ "timm_model_name": "convnext_tiny",
+ "timm_model_pretrained": false,
+ "timm_pool": "",
+ "timm_proj": "linear",
+ "image_size": 224
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/open_clip/src/open_clip/model_configs/convnext_xlarge.json b/open_clip/src/open_clip/model_configs/convnext_xlarge.json
new file mode 100644
index 0000000000000000000000000000000000000000..5186dca08b170d17d8bd19f6c0572fc5923b1391
--- /dev/null
+++ b/open_clip/src/open_clip/model_configs/convnext_xlarge.json
@@ -0,0 +1,17 @@
+{
+ "embed_dim": 1024,
+ "vision_cfg": {
+ "timm_model_name": "convnext_xlarge",
+ "timm_model_pretrained": false,
+ "timm_pool": "",
+ "timm_proj": "linear",
+ "image_size": 224
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 1024,
+ "heads": 16,
+ "layers": 16
+ }
+}
\ No newline at end of file
diff --git a/open_clip/src/open_clip/model_configs/convnext_xxlarge.json b/open_clip/src/open_clip/model_configs/convnext_xxlarge.json
new file mode 100644
index 0000000000000000000000000000000000000000..5943a654b3f54badb1b077ea083ea7eec9d6f7a3
--- /dev/null
+++ b/open_clip/src/open_clip/model_configs/convnext_xxlarge.json
@@ -0,0 +1,17 @@
+{
+ "embed_dim": 1024,
+ "vision_cfg": {
+ "timm_model_name": "convnext_xxlarge",
+ "timm_model_pretrained": false,
+ "timm_pool": "",
+ "timm_proj": "linear",
+ "image_size": 256
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 1024,
+ "heads": 16,
+ "layers": 24
+ }
+}
\ No newline at end of file
diff --git a/open_clip/src/open_clip/model_configs/convnext_xxlarge_320.json b/open_clip/src/open_clip/model_configs/convnext_xxlarge_320.json
new file mode 100644
index 0000000000000000000000000000000000000000..bdc2287dcf414f9626f0841419e233af7e18d71b
--- /dev/null
+++ b/open_clip/src/open_clip/model_configs/convnext_xxlarge_320.json
@@ -0,0 +1,17 @@
+{
+ "embed_dim": 1024,
+ "vision_cfg": {
+ "timm_model_name": "convnext_xxlarge",
+ "timm_model_pretrained": false,
+ "timm_pool": "",
+ "timm_proj": "linear",
+ "image_size": 320
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 1024,
+ "heads": 16,
+ "layers": 24
+ }
+}
\ No newline at end of file
diff --git a/open_clip/src/open_clip/model_configs/mt5-base-ViT-B-32.json b/open_clip/src/open_clip/model_configs/mt5-base-ViT-B-32.json
new file mode 100644
index 0000000000000000000000000000000000000000..58cad89cf0f446bbe15e4e25b1ac43424a828017
--- /dev/null
+++ b/open_clip/src/open_clip/model_configs/mt5-base-ViT-B-32.json
@@ -0,0 +1,15 @@
+{
+ "embed_dim": 512,
+ "vision_cfg": {
+ "image_size": 224,
+ "layers": 12,
+ "width": 768,
+ "patch_size": 32
+ },
+ "text_cfg": {
+ "hf_model_name": "google/mt5-base",
+ "hf_tokenizer_name": "google/mt5-base",
+ "proj": "mlp",
+ "pooler_type": "mean_pooler"
+ }
+}
diff --git a/open_clip/src/open_clip/model_configs/mt5-xl-ViT-H-14.json b/open_clip/src/open_clip/model_configs/mt5-xl-ViT-H-14.json
new file mode 100644
index 0000000000000000000000000000000000000000..b432810777ba7269dbb0e89edfe65cdd27e7d255
--- /dev/null
+++ b/open_clip/src/open_clip/model_configs/mt5-xl-ViT-H-14.json
@@ -0,0 +1,16 @@
+{
+ "embed_dim": 1024,
+ "vision_cfg": {
+ "image_size": 224,
+ "layers": 32,
+ "width": 1280,
+ "head_width": 80,
+ "patch_size": 14
+ },
+ "text_cfg": {
+ "hf_model_name": "google/mt5-xl",
+ "hf_tokenizer_name": "google/mt5-xl",
+ "proj": "mlp",
+ "pooler_type": "mean_pooler"
+ }
+}
diff --git a/open_clip/src/open_clip/model_configs/roberta-ViT-B-32.json b/open_clip/src/open_clip/model_configs/roberta-ViT-B-32.json
new file mode 100644
index 0000000000000000000000000000000000000000..ed687d472a73bb2ac96025f355f80437ab14c260
--- /dev/null
+++ b/open_clip/src/open_clip/model_configs/roberta-ViT-B-32.json
@@ -0,0 +1,16 @@
+{
+ "embed_dim": 512,
+ "quick_gelu": true,
+ "vision_cfg": {
+ "image_size": 224,
+ "layers": 12,
+ "width": 768,
+ "patch_size": 32
+ },
+ "text_cfg": {
+ "hf_model_name": "roberta-base",
+ "hf_tokenizer_name": "roberta-base",
+ "proj": "mlp",
+ "pooler_type": "mean_pooler"
+ }
+}
diff --git a/open_clip/src/open_clip/model_configs/swin_base_patch4_window7_224.json b/open_clip/src/open_clip/model_configs/swin_base_patch4_window7_224.json
new file mode 100644
index 0000000000000000000000000000000000000000..bd6820f0cf2aa655e0a2723287f4b78895a58e6a
--- /dev/null
+++ b/open_clip/src/open_clip/model_configs/swin_base_patch4_window7_224.json
@@ -0,0 +1,17 @@
+{
+ "embed_dim": 640,
+ "vision_cfg": {
+ "timm_model_name": "swin_base_patch4_window7_224",
+ "timm_model_pretrained": false,
+ "timm_pool": "",
+ "timm_proj": "linear",
+ "image_size": 224
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 640,
+ "heads": 10,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/open_clip/src/open_clip/model_configs/vit_medium_patch16_gap_256.json b/open_clip/src/open_clip/model_configs/vit_medium_patch16_gap_256.json
new file mode 100644
index 0000000000000000000000000000000000000000..8843eaf08cad16c3e7b5f496fd650715c9573f65
--- /dev/null
+++ b/open_clip/src/open_clip/model_configs/vit_medium_patch16_gap_256.json
@@ -0,0 +1,17 @@
+{
+ "embed_dim": 512,
+ "vision_cfg": {
+ "timm_model_name": "vit_medium_patch16_gap_256",
+ "timm_model_pretrained": false,
+ "timm_pool": "",
+ "timm_proj": "linear",
+ "image_size": 256
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/open_clip/src/open_clip/model_configs/vit_relpos_medium_patch16_cls_224.json b/open_clip/src/open_clip/model_configs/vit_relpos_medium_patch16_cls_224.json
new file mode 100644
index 0000000000000000000000000000000000000000..ed217b202d5e6071c5307f4547c97ff4cfe2abd1
--- /dev/null
+++ b/open_clip/src/open_clip/model_configs/vit_relpos_medium_patch16_cls_224.json
@@ -0,0 +1,17 @@
+{
+ "embed_dim": 512,
+ "vision_cfg": {
+ "timm_model_name": "vit_relpos_medium_patch16_cls_224",
+ "timm_model_pretrained": false,
+ "timm_pool": "",
+ "timm_proj": "linear",
+ "image_size": 224
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/open_clip/src/open_clip/model_configs/xlm-roberta-base-ViT-B-32.json b/open_clip/src/open_clip/model_configs/xlm-roberta-base-ViT-B-32.json
new file mode 100644
index 0000000000000000000000000000000000000000..751bccc2c6fc41bc4ff20182de88d86739d518d9
--- /dev/null
+++ b/open_clip/src/open_clip/model_configs/xlm-roberta-base-ViT-B-32.json
@@ -0,0 +1,15 @@
+{
+ "embed_dim": 512,
+ "vision_cfg": {
+ "image_size": 224,
+ "layers": 12,
+ "width": 768,
+ "patch_size": 32
+ },
+ "text_cfg": {
+ "hf_model_name": "xlm-roberta-base",
+ "hf_tokenizer_name": "xlm-roberta-base",
+ "proj": "mlp",
+ "pooler_type": "mean_pooler"
+ }
+}
diff --git a/open_clip/src/open_clip/model_configs/xlm-roberta-large-ViT-H-14.json b/open_clip/src/open_clip/model_configs/xlm-roberta-large-ViT-H-14.json
new file mode 100644
index 0000000000000000000000000000000000000000..31f271faa9bbb7a9da53900b483a4c00a16f3c4a
--- /dev/null
+++ b/open_clip/src/open_clip/model_configs/xlm-roberta-large-ViT-H-14.json
@@ -0,0 +1,16 @@
+{
+ "embed_dim": 1024,
+ "vision_cfg": {
+ "image_size": 224,
+ "layers": 32,
+ "width": 1280,
+ "head_width": 80,
+ "patch_size": 14
+ },
+ "text_cfg": {
+ "hf_model_name": "xlm-roberta-large",
+ "hf_tokenizer_name": "xlm-roberta-large",
+ "proj": "mlp",
+ "pooler_type": "mean_pooler"
+ }
+}
diff --git a/open_clip/src/open_clip/modified_resnet.py b/open_clip/src/open_clip/modified_resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7c0b033a80e7d08a20a367050c5b1bc5d5292e7
--- /dev/null
+++ b/open_clip/src/open_clip/modified_resnet.py
@@ -0,0 +1,181 @@
+from collections import OrderedDict
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from open_clip.utils import freeze_batch_norm_2d
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1):
+ super().__init__()
+
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.act1 = nn.ReLU(inplace=True)
+
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.act2 = nn.ReLU(inplace=True)
+
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
+
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+ self.act3 = nn.ReLU(inplace=True)
+
+ self.downsample = None
+ self.stride = stride
+
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
+ self.downsample = nn.Sequential(OrderedDict([
+ ("-1", nn.AvgPool2d(stride)),
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
+ ("1", nn.BatchNorm2d(planes * self.expansion))
+ ]))
+
+ def forward(self, x: torch.Tensor):
+ identity = x
+
+ out = self.act1(self.bn1(self.conv1(x)))
+ out = self.act2(self.bn2(self.conv2(out)))
+ out = self.avgpool(out)
+ out = self.bn3(self.conv3(out))
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.act3(out)
+ return out
+
+
+class AttentionPool2d(nn.Module):
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
+ super().__init__()
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
+ self.num_heads = num_heads
+
+ def forward(self, x):
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
+ x, _ = F.multi_head_attention_forward(
+ query=x, key=x, value=x,
+ embed_dim_to_check=x.shape[-1],
+ num_heads=self.num_heads,
+ q_proj_weight=self.q_proj.weight,
+ k_proj_weight=self.k_proj.weight,
+ v_proj_weight=self.v_proj.weight,
+ in_proj_weight=None,
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
+ bias_k=None,
+ bias_v=None,
+ add_zero_attn=False,
+ dropout_p=0.,
+ out_proj_weight=self.c_proj.weight,
+ out_proj_bias=self.c_proj.bias,
+ use_separate_proj_weight=True,
+ training=self.training,
+ need_weights=False
+ )
+
+ return x[0]
+
+
+class ModifiedResNet(nn.Module):
+ """
+ A ResNet class that is similar to torchvision's but contains the following changes:
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
+ - The final pooling layer is a QKV attention instead of an average pool
+ """
+
+ def __init__(self, layers, output_dim, heads, image_size=224, width=64):
+ super().__init__()
+ self.output_dim = output_dim
+ self.image_size = image_size
+
+ # the 3-layer stem
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(width // 2)
+ self.act1 = nn.ReLU(inplace=True)
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(width // 2)
+ self.act2 = nn.ReLU(inplace=True)
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(width)
+ self.act3 = nn.ReLU(inplace=True)
+ self.avgpool = nn.AvgPool2d(2)
+
+ # residual layers
+ self._inplanes = width # this is a *mutable* variable used during construction
+ self.layer1 = self._make_layer(width, layers[0])
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
+
+ embed_dim = width * 32 # the ResNet feature dimension
+ self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
+
+ self.init_parameters()
+
+ def _make_layer(self, planes, blocks, stride=1):
+ layers = [Bottleneck(self._inplanes, planes, stride)]
+
+ self._inplanes = planes * Bottleneck.expansion
+ for _ in range(1, blocks):
+ layers.append(Bottleneck(self._inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def init_parameters(self):
+ if self.attnpool is not None:
+ std = self.attnpool.c_proj.in_features ** -0.5
+ nn.init.normal_(self.attnpool.q_proj.weight, std=std)
+ nn.init.normal_(self.attnpool.k_proj.weight, std=std)
+ nn.init.normal_(self.attnpool.v_proj.weight, std=std)
+ nn.init.normal_(self.attnpool.c_proj.weight, std=std)
+
+ for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
+ for name, param in resnet_block.named_parameters():
+ if name.endswith("bn3.weight"):
+ nn.init.zeros_(param)
+
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
+ assert unlocked_groups == 0, 'partial locking not currently supported for this model'
+ for param in self.parameters():
+ param.requires_grad = False
+ if freeze_bn_stats:
+ freeze_batch_norm_2d(self)
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ # FIXME support for non-transformer
+ pass
+
+ def stem(self, x):
+ x = self.act1(self.bn1(self.conv1(x)))
+ x = self.act2(self.bn2(self.conv2(x)))
+ x = self.act3(self.bn3(self.conv3(x)))
+ x = self.avgpool(x)
+ return x
+
+ def forward(self, x):
+ x = self.stem(x)
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+ x = self.attnpool(x)
+
+ return x
diff --git a/open_clip/src/open_clip/openai.py b/open_clip/src/open_clip/openai.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc4e13e876d6a7a3463b457e62c517cb063b1356
--- /dev/null
+++ b/open_clip/src/open_clip/openai.py
@@ -0,0 +1,144 @@
+""" OpenAI pretrained model functions
+
+Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
+"""
+
+import os
+import warnings
+from typing import List, Optional, Union
+
+import torch
+
+from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype
+from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url
+
+__all__ = ["list_openai_models", "load_openai_model"]
+
+
+def list_openai_models() -> List[str]:
+ """Returns the names of available CLIP models"""
+ return list_pretrained_models_by_tag('openai')
+
+
+def load_openai_model(
+ name: str,
+ precision: Optional[str] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ jit: bool = True,
+ cache_dir: Optional[str] = None,
+):
+ """Load a CLIP model
+
+ Parameters
+ ----------
+ name : str
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
+ precision: str
+ Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'.
+ device : Union[str, torch.device]
+ The device to put the loaded model
+ jit : bool
+ Whether to load the optimized JIT model (default) or more hackable non-JIT model.
+ cache_dir : Optional[str]
+ The directory to cache the downloaded model weights
+
+ Returns
+ -------
+ model : torch.nn.Module
+ The CLIP model
+ preprocess : Callable[[PIL.Image], torch.Tensor]
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
+ """
+ if device is None:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ if precision is None:
+ precision = 'fp32' if device == 'cpu' else 'fp16'
+
+ if get_pretrained_url(name, 'openai'):
+ model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir)
+ elif os.path.isfile(name):
+ model_path = name
+ else:
+ raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}")
+
+ try:
+ # loading JIT archive
+ model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
+ state_dict = None
+ except RuntimeError:
+ # loading saved state dict
+ if jit:
+ warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
+ jit = False
+ state_dict = torch.load(model_path, map_location="cpu")
+
+ if not jit:
+ # Build a non-jit model from the OpenAI jitted model state dict
+ cast_dtype = get_cast_dtype(precision)
+ try:
+ model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype)
+ except KeyError:
+ sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
+ model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype)
+
+ # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use
+ model = model.to(device)
+ if precision.startswith('amp') or precision == 'fp32':
+ model.float()
+ elif precision == 'bf16':
+ convert_weights_to_lp(model, dtype=torch.bfloat16)
+
+ return model
+
+ # patch the device names
+ device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
+ device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
+
+ def patch_device(module):
+ try:
+ graphs = [module.graph] if hasattr(module, "graph") else []
+ except RuntimeError:
+ graphs = []
+
+ if hasattr(module, "forward1"):
+ graphs.append(module.forward1.graph)
+
+ for graph in graphs:
+ for node in graph.findAllNodes("prim::Constant"):
+ if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
+ node.copyAttributes(device_node)
+
+ model.apply(patch_device)
+ patch_device(model.encode_image)
+ patch_device(model.encode_text)
+
+ # patch dtype to float32 (typically for CPU)
+ if precision == 'fp32':
+ float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
+ float_node = float_input.node()
+
+ def patch_float(module):
+ try:
+ graphs = [module.graph] if hasattr(module, "graph") else []
+ except RuntimeError:
+ graphs = []
+
+ if hasattr(module, "forward1"):
+ graphs.append(module.forward1.graph)
+
+ for graph in graphs:
+ for node in graph.findAllNodes("aten::to"):
+ inputs = list(node.inputs())
+ for i in [1, 2]: # dtype can be the second or third argument to aten::to()
+ if inputs[i].node()["value"] == 5:
+ inputs[i].node().copyAttributes(float_node)
+
+ model.apply(patch_float)
+ patch_float(model.encode_image)
+ patch_float(model.encode_text)
+ model.float()
+
+ # ensure image_size attr available at consistent location for both jit and non-jit
+ model.visual.image_size = model.input_resolution.item()
+ return model
diff --git a/open_clip/src/open_clip/pretrained.py b/open_clip/src/open_clip/pretrained.py
new file mode 100644
index 0000000000000000000000000000000000000000..73643f95dced25c0a3c82d439bbea47f495aafd1
--- /dev/null
+++ b/open_clip/src/open_clip/pretrained.py
@@ -0,0 +1,345 @@
+import hashlib
+import os
+import urllib
+import warnings
+from functools import partial
+from typing import Dict, Union
+
+from tqdm import tqdm
+
+from .version import __version__
+
+try:
+ from huggingface_hub import hf_hub_download
+ hf_hub_download = partial(hf_hub_download, library_name="open_clip", library_version=__version__)
+ _has_hf_hub = True
+except ImportError:
+ hf_hub_download = None
+ _has_hf_hub = False
+
+
+def _pcfg(url='', hf_hub='', mean=None, std=None):
+ return dict(
+ url=url,
+ hf_hub=hf_hub,
+ mean=mean,
+ std=std,
+ )
+
+
+_RN50 = dict(
+ openai=_pcfg(
+ "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"),
+ yfcc15m=_pcfg(
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"),
+ cc12m=_pcfg(
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"),
+)
+
+_RN50_quickgelu = dict(
+ openai=_pcfg(
+ "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"),
+ yfcc15m=_pcfg(
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"),
+ cc12m=_pcfg(
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"),
+)
+
+_RN101 = dict(
+ openai=_pcfg(
+ "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"),
+ yfcc15m=_pcfg(
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"),
+)
+
+_RN101_quickgelu = dict(
+ openai=_pcfg(
+ "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"),
+ yfcc15m=_pcfg(
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"),
+)
+
+_RN50x4 = dict(
+ openai=_pcfg(
+ "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt"),
+)
+
+_RN50x16 = dict(
+ openai=_pcfg(
+ "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt"),
+)
+
+_RN50x64 = dict(
+ openai=_pcfg(
+ "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt"),
+)
+
+_VITB32 = dict(
+ openai=_pcfg(
+ "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
+ laion400m_e31=_pcfg(
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
+ laion400m_e32=_pcfg(
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
+ laion2b_e16=_pcfg(
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"),
+ laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/')
+)
+
+_VITB32_quickgelu = dict(
+ openai=_pcfg(
+ "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
+ laion400m_e31=_pcfg(
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
+ laion400m_e32=_pcfg(
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
+)
+
+_VITB16 = dict(
+ openai=_pcfg(
+ "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"),
+ laion400m_e31=_pcfg(
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"),
+ laion400m_e32=_pcfg(
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"),
+ # laion400m_32k=_pcfg(
+ # url="",
+ # mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
+ # laion400m_64k=_pcfg(
+ # url="",
+ # mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
+ laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'),
+)
+
+_VITB16_PLUS_240 = dict(
+ laion400m_e31=_pcfg(
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"),
+ laion400m_e32=_pcfg(
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"),
+)
+
+_VITL14 = dict(
+ openai=_pcfg(
+ "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"),
+ laion400m_e31=_pcfg(
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"),
+ laion400m_e32=_pcfg(
+ "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"),
+ laion2b_s32b_b82k=_pcfg(
+ hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/',
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
+)
+
+_VITL14_336 = dict(
+ openai=_pcfg(
+ "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"),
+)
+
+_VITH14 = dict(
+ laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'),
+)
+
+_VITg14 = dict(
+ laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'),
+)
+
+_VITbigG14 = dict(
+ laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'),
+)
+
+_robertaViTB32 = dict(
+ laion2b_s12b_b32k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-roberta-base-laion2B-s12B-b32k/'),
+)
+
+_xlmRobertaBaseViTB32 = dict(
+ laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-xlm-roberta-base-laion5B-s13B-b90k/'),
+)
+
+_xlmRobertaLargeFrozenViTH14 = dict(
+ frozen_laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k/'),
+)
+
+_convnext_base = dict(
+ laion400m_s13b_b51k=_pcfg(hf_hub='convnext_base-laion400M-s13B-b51K'),
+)
+
+_convnext_base_w = dict(
+ laion2b_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K/'),
+ laion2b_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg/'),
+ laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K/'),
+)
+
+_convnext_base_w_320 = dict(
+ laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K/'),
+ laion_aesthetic_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg/'),
+)
+
+
+_PRETRAINED = {
+ "RN50": _RN50,
+ "RN50-quickgelu": _RN50_quickgelu,
+ "RN101": _RN101,
+ "RN101-quickgelu": _RN101_quickgelu,
+ "RN50x4": _RN50x4,
+ "RN50x16": _RN50x16,
+ "RN50x64": _RN50x64,
+ "ViT-B-32": _VITB32,
+ "ViT-B-32-quickgelu": _VITB32_quickgelu,
+ "ViT-B-16": _VITB16,
+ "ViT-B-16-plus-240": _VITB16_PLUS_240,
+ "ViT-L-14": _VITL14,
+ "ViT-L-14-336": _VITL14_336,
+ "ViT-H-14": _VITH14,
+ "ViT-g-14": _VITg14,
+ "ViT-bigG-14": _VITbigG14,
+ "roberta-ViT-B-32": _robertaViTB32,
+ "xlm-roberta-base-ViT-B-32": _xlmRobertaBaseViTB32,
+ "xlm-roberta-large-ViT-H-14": _xlmRobertaLargeFrozenViTH14,
+ "convnext_base": _convnext_base,
+ "convnext_base_w": _convnext_base_w,
+ "convnext_base_w_320": _convnext_base_w_320,
+}
+
+
+def _clean_tag(tag: str):
+ # normalize pretrained tags
+ return tag.lower().replace('-', '_')
+
+
+def list_pretrained(as_str: bool = False):
+ """ returns list of pretrained models
+ Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True
+ """
+ return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()]
+
+
+def list_pretrained_models_by_tag(tag: str):
+ """ return all models having the specified pretrain tag """
+ models = []
+ tag = _clean_tag(tag)
+ for k in _PRETRAINED.keys():
+ if tag in _PRETRAINED[k]:
+ models.append(k)
+ return models
+
+
+def list_pretrained_tags_by_model(model: str):
+ """ return all pretrain tags for the specified model architecture """
+ tags = []
+ if model in _PRETRAINED:
+ tags.extend(_PRETRAINED[model].keys())
+ return tags
+
+
+def is_pretrained_cfg(model: str, tag: str):
+ if model not in _PRETRAINED:
+ return False
+ return _clean_tag(tag) in _PRETRAINED[model]
+
+
+def get_pretrained_cfg(model: str, tag: str):
+ if model not in _PRETRAINED:
+ return {}
+ model_pretrained = _PRETRAINED[model]
+ return model_pretrained.get(_clean_tag(tag), {})
+
+
+def get_pretrained_url(model: str, tag: str):
+ cfg = get_pretrained_cfg(model, _clean_tag(tag))
+ return cfg.get('url', '')
+
+
+def download_pretrained_from_url(
+ url: str,
+ cache_dir: Union[str, None] = None,
+):
+ if not cache_dir:
+ cache_dir = os.path.expanduser("~/.cache/clip")
+ os.makedirs(cache_dir, exist_ok=True)
+ filename = os.path.basename(url)
+
+ if 'openaipublic' in url:
+ expected_sha256 = url.split("/")[-2]
+ elif 'mlfoundations' in url:
+ expected_sha256 = os.path.splitext(filename)[0].split("-")[-1]
+ else:
+ expected_sha256 = ''
+
+ download_target = os.path.join(cache_dir, filename)
+
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
+
+ if os.path.isfile(download_target):
+ if expected_sha256:
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
+ return download_target
+ else:
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
+ else:
+ return download_target
+
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
+ with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
+ while True:
+ buffer = source.read(8192)
+ if not buffer:
+ break
+
+ output.write(buffer)
+ loop.update(len(buffer))
+
+ if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
+ raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
+
+ return download_target
+
+
+def has_hf_hub(necessary=False):
+ if not _has_hf_hub and necessary:
+ # if no HF Hub module installed, and it is necessary to continue, raise error
+ raise RuntimeError(
+ 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
+ return _has_hf_hub
+
+
+def download_pretrained_from_hf(
+ model_id: str,
+ filename: str = 'open_clip_pytorch_model.bin',
+ revision=None,
+ cache_dir: Union[str, None] = None,
+):
+ has_hf_hub(True)
+ cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir)
+ return cached_file
+
+
+def download_pretrained(
+ cfg: Dict,
+ force_hf_hub: bool = False,
+ cache_dir: Union[str, None] = None,
+):
+ target = ''
+ if not cfg:
+ return target
+
+ download_url = cfg.get('url', '')
+ download_hf_hub = cfg.get('hf_hub', '')
+ if download_hf_hub and force_hf_hub:
+ # use HF hub even if url exists
+ download_url = ''
+
+ if download_url:
+ target = download_pretrained_from_url(download_url, cache_dir=cache_dir)
+ elif download_hf_hub:
+ has_hf_hub(True)
+ # we assume the hf_hub entries in pretrained config combine model_id + filename in
+ # 'org/model_name/filename.pt' form. To specify just the model id w/o filename and
+ # use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'.
+ model_id, filename = os.path.split(download_hf_hub)
+ if filename:
+ target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir)
+ else:
+ target = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
+
+ return target
diff --git a/open_clip/src/open_clip/timm_model.py b/open_clip/src/open_clip/timm_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc71a693f9a42ec01fd88d307661bc382b4d05bc
--- /dev/null
+++ b/open_clip/src/open_clip/timm_model.py
@@ -0,0 +1,127 @@
+""" timm model adapter
+
+Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
+"""
+import logging
+from collections import OrderedDict
+
+import torch
+import torch.nn as nn
+
+try:
+ import timm
+ from timm.models.layers import Mlp, to_2tuple
+ try:
+ # old timm imports < 0.8.1
+ from timm.models.layers.attention_pool2d import RotAttentionPool2d
+ from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d
+ except ImportError:
+ # new timm imports >= 0.8.1
+ from timm.layers import RotAttentionPool2d
+ from timm.layers import AttentionPool2d as AbsAttentionPool2d
+except ImportError:
+ timm = None
+
+from .utils import freeze_batch_norm_2d
+
+
+class TimmModel(nn.Module):
+ """ timm model adapter
+ # FIXME this adapter is a work in progress, may change in ways that break weight compat
+ """
+
+ def __init__(
+ self,
+ model_name,
+ embed_dim,
+ image_size=224,
+ pool='avg',
+ proj='linear',
+ proj_bias=False,
+ drop=0.,
+ drop_path=None,
+ pretrained=False,
+ ):
+ super().__init__()
+ if timm is None:
+ raise RuntimeError("Please `pip install timm` to use timm models.")
+
+ self.image_size = to_2tuple(image_size)
+ timm_kwargs = {}
+ if drop_path is not None:
+ timm_kwargs['drop_path_rate'] = drop_path
+ self.trunk = timm.create_model(model_name, pretrained=pretrained, **timm_kwargs)
+ feat_size = self.trunk.default_cfg.get('pool_size', None)
+ feature_ndim = 1 if not feat_size else 2
+ if pool in ('abs_attn', 'rot_attn'):
+ assert feature_ndim == 2
+ # if attn pooling used, remove both classifier and default pool
+ self.trunk.reset_classifier(0, global_pool='')
+ else:
+ # reset global pool if pool config set, otherwise leave as network default
+ reset_kwargs = dict(global_pool=pool) if pool else {}
+ self.trunk.reset_classifier(0, **reset_kwargs)
+ prev_chs = self.trunk.num_features
+
+ head_layers = OrderedDict()
+ if pool == 'abs_attn':
+ head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim)
+ prev_chs = embed_dim
+ elif pool == 'rot_attn':
+ head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
+ prev_chs = embed_dim
+ else:
+ assert proj, 'projection layer needed if non-attention pooling is used.'
+
+ # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
+ if proj == 'linear':
+ head_layers['drop'] = nn.Dropout(drop)
+ head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias)
+ elif proj == 'mlp':
+ head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias))
+
+ self.head = nn.Sequential(head_layers)
+
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
+ """ lock modules
+ Args:
+ unlocked_groups (int): leave last n layer groups unlocked (default: 0)
+ """
+ if not unlocked_groups:
+ # lock full model
+ for param in self.trunk.parameters():
+ param.requires_grad = False
+ if freeze_bn_stats:
+ freeze_batch_norm_2d(self.trunk)
+ else:
+ # NOTE: partial freeze requires latest timm (master) branch and is subject to change
+ try:
+ # FIXME import here until API stable and in an official release
+ from timm.models.helpers import group_parameters, group_modules
+ except ImportError:
+ raise RuntimeError(
+ 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`')
+ matcher = self.trunk.group_matcher()
+ gparams = group_parameters(self.trunk, matcher)
+ max_layer_id = max(gparams.keys())
+ max_layer_id = max_layer_id - unlocked_groups
+ for group_idx in range(max_layer_id + 1):
+ group = gparams[group_idx]
+ for param in group:
+ self.trunk.get_parameter(param).requires_grad = False
+ if freeze_bn_stats:
+ gmodules = group_modules(self.trunk, matcher, reverse=True)
+ gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
+ freeze_batch_norm_2d(self.trunk, gmodules)
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ try:
+ self.trunk.set_grad_checkpointing(enable)
+ except Exception as e:
+ logging.warning('grad checkpointing not supported for this timm image tower, continuing without...')
+
+ def forward(self, x):
+ x = self.trunk(x)
+ x = self.head(x)
+ return x
diff --git a/open_clip/src/open_clip/tokenizer.py b/open_clip/src/open_clip/tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..01e9f9d25574cfe757bc43a0ff0d982f5a4efad3
--- /dev/null
+++ b/open_clip/src/open_clip/tokenizer.py
@@ -0,0 +1,201 @@
+""" CLIP tokenizer
+
+Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
+"""
+import gzip
+import html
+import os
+from functools import lru_cache
+from typing import Union, List
+
+import ftfy
+import regex as re
+import torch
+
+# https://stackoverflow.com/q/62691279
+import os
+os.environ["TOKENIZERS_PARALLELISM"] = "false"
+
+
+@lru_cache()
+def default_bpe():
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
+
+
+@lru_cache()
+def bytes_to_unicode():
+ """
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
+ The reversible bpe codes work on unicode strings.
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
+ This is a significant percentage of your normal, say, 32K bpe vocab.
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
+ """
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
+ cs = bs[:]
+ n = 0
+ for b in range(2**8):
+ if b not in bs:
+ bs.append(b)
+ cs.append(2**8+n)
+ n += 1
+ cs = [chr(n) for n in cs]
+ return dict(zip(bs, cs))
+
+
+def get_pairs(word):
+ """Return set of symbol pairs in a word.
+ Word is represented as tuple of symbols (symbols being variable-length strings).
+ """
+ pairs = set()
+ prev_char = word[0]
+ for char in word[1:]:
+ pairs.add((prev_char, char))
+ prev_char = char
+ return pairs
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r'\s+', ' ', text)
+ text = text.strip()
+ return text
+
+
+class SimpleTokenizer(object):
+ def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
+ self.byte_encoder = bytes_to_unicode()
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
+ merges = merges[1:49152-256-2+1]
+ merges = [tuple(merge.split()) for merge in merges]
+ vocab = list(bytes_to_unicode().values())
+ vocab = vocab + [v+'' for v in vocab]
+ for merge in merges:
+ vocab.append(''.join(merge))
+ if not special_tokens:
+ special_tokens = ['', '']
+ else:
+ special_tokens = ['', ''] + special_tokens
+ vocab.extend(special_tokens)
+ self.encoder = dict(zip(vocab, range(len(vocab))))
+ self.decoder = {v: k for k, v in self.encoder.items()}
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
+ self.cache = {t:t for t in special_tokens}
+ special = "|".join(special_tokens)
+ self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
+
+ self.vocab_size = len(self.encoder)
+ self.all_special_ids = [self.encoder[t] for t in special_tokens]
+
+ def bpe(self, token):
+ if token in self.cache:
+ return self.cache[token]
+ word = tuple(token[:-1]) + ( token[-1] + '',)
+ pairs = get_pairs(word)
+
+ if not pairs:
+ return token+''
+
+ while True:
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
+ if bigram not in self.bpe_ranks:
+ break
+ first, second = bigram
+ new_word = []
+ i = 0
+ while i < len(word):
+ try:
+ j = word.index(first, i)
+ new_word.extend(word[i:j])
+ i = j
+ except:
+ new_word.extend(word[i:])
+ break
+
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
+ new_word.append(first+second)
+ i += 2
+ else:
+ new_word.append(word[i])
+ i += 1
+ new_word = tuple(new_word)
+ word = new_word
+ if len(word) == 1:
+ break
+ else:
+ pairs = get_pairs(word)
+ word = ' '.join(word)
+ self.cache[token] = word
+ return word
+
+ def encode(self, text):
+ bpe_tokens = []
+ text = whitespace_clean(basic_clean(text)).lower()
+ for token in re.findall(self.pat, text):
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
+ return bpe_tokens
+
+ def decode(self, tokens):
+ text = ''.join([self.decoder[token] for token in tokens])
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
+ return text
+
+
+_tokenizer = SimpleTokenizer()
+
+
+def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
+ """
+ Returns the tokenized representation of given input string(s)
+
+ Parameters
+ ----------
+ texts : Union[str, List[str]]
+ An input string or a list of input strings to tokenize
+ context_length : int
+ The context length to use; all CLIP models use 77 as the context length
+
+ Returns
+ -------
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
+ """
+ if isinstance(texts, str):
+ texts = [texts]
+
+ sot_token = _tokenizer.encoder[""]
+ eot_token = _tokenizer.encoder[""]
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
+
+ for i, tokens in enumerate(all_tokens):
+ if len(tokens) > context_length:
+ tokens = tokens[:context_length] # Truncate
+ tokens[-1] = eot_token
+ result[i, :len(tokens)] = torch.tensor(tokens)
+
+ return result
+
+
+class HFTokenizer:
+ "HuggingFace tokenizer wrapper"
+ def __init__(self, tokenizer_name:str):
+ from transformers import AutoTokenizer
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
+
+ def __call__(self, texts:Union[str, List[str]], context_length:int=77) -> torch.Tensor:
+ # same cleaning as for default tokenizer, except lowercasing
+ # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance
+ if isinstance(texts, str):
+ texts = [texts]
+ texts = [whitespace_clean(basic_clean(text)) for text in texts]
+ input_ids = self.tokenizer(texts, return_tensors='pt', max_length=context_length, padding='max_length', truncation=True).input_ids
+ return input_ids
diff --git a/open_clip/src/open_clip/transform.py b/open_clip/src/open_clip/transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..0224a0dae4a89fd4bf46c5d27a2bb9377dfa06d4
--- /dev/null
+++ b/open_clip/src/open_clip/transform.py
@@ -0,0 +1,133 @@
+import warnings
+from dataclasses import dataclass, asdict
+from typing import Any, Dict, Optional, Sequence, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torchvision.transforms.functional as F
+
+from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
+ CenterCrop
+
+from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
+
+
+@dataclass
+class AugmentationCfg:
+ scale: Tuple[float, float] = (0.9, 1.0)
+ ratio: Optional[Tuple[float, float]] = None
+ color_jitter: Optional[Union[float, Tuple[float, float, float]]] = None
+ interpolation: Optional[str] = None
+ re_prob: Optional[float] = None
+ re_count: Optional[int] = None
+ use_timm: bool = False
+
+
+class ResizeMaxSize(nn.Module):
+
+ def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0):
+ super().__init__()
+ if not isinstance(max_size, int):
+ raise TypeError(f"Size should be int. Got {type(max_size)}")
+ self.max_size = max_size
+ self.interpolation = interpolation
+ self.fn = min if fn == 'min' else min
+ self.fill = fill
+
+ def forward(self, img):
+ if isinstance(img, torch.Tensor):
+ height, width = img.shape[:2]
+ else:
+ width, height = img.size
+ scale = self.max_size / float(max(height, width))
+ if scale != 1.0:
+ new_size = tuple(round(dim * scale) for dim in (height, width))
+ img = F.resize(img, new_size, self.interpolation)
+ pad_h = self.max_size - new_size[0]
+ pad_w = self.max_size - new_size[1]
+ img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill)
+ return img
+
+
+def _convert_to_rgb(image):
+ return image.convert('RGB')
+
+
+def image_transform(
+ image_size: int,
+ is_train: bool,
+ mean: Optional[Tuple[float, ...]] = None,
+ std: Optional[Tuple[float, ...]] = None,
+ resize_longest_max: bool = False,
+ fill_color: int = 0,
+ aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
+):
+ mean = mean or OPENAI_DATASET_MEAN
+ if not isinstance(mean, (list, tuple)):
+ mean = (mean,) * 3
+
+ std = std or OPENAI_DATASET_STD
+ if not isinstance(std, (list, tuple)):
+ std = (std,) * 3
+
+ if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
+ # for square size, pass size as int so that Resize() uses aspect preserving shortest edge
+ image_size = image_size[0]
+
+ if isinstance(aug_cfg, dict):
+ aug_cfg = AugmentationCfg(**aug_cfg)
+ else:
+ aug_cfg = aug_cfg or AugmentationCfg()
+ normalize = Normalize(mean=mean, std=std)
+ if is_train:
+ aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None}
+ use_timm = aug_cfg_dict.pop('use_timm', False)
+ if use_timm:
+ from timm.data import create_transform # timm can still be optional
+ if isinstance(model.visual.image_size, (tuple, list)):
+ assert len(model.visual.image_size) >= 2
+ input_size = (3,) + model.visual.image_size[-2:]
+ else:
+ input_size = (3, model.visual.image_size, model.visual.image_size)
+ # by default, timm aug randomly alternates bicubic & bilinear for better robustness at inference time
+ aug_cfg_dict.setdefault('interpolation', 'random')
+ aug_cfg_dict.setdefault('color_jitter', None) # disable by default
+ train_transform = create_transform(
+ input_size=input_size,
+ is_training=True,
+ hflip=0.,
+ mean=image_mean,
+ std=image_std,
+ re_mode='pixel',
+ **aug_cfg_dict,
+ )
+ else:
+ train_transform = Compose([
+ RandomResizedCrop(
+ image_size,
+ scale=aug_cfg_dict.pop('scale'),
+ interpolation=InterpolationMode.BICUBIC,
+ ),
+ _convert_to_rgb,
+ ToTensor(),
+ normalize,
+ ])
+ if aug_cfg_dict:
+ warnings.warn(f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).')
+ return train_transform
+ else:
+ if resize_longest_max:
+ transforms = [
+ ResizeMaxSize(image_size, fill=fill_color)
+ ]
+ else:
+ transforms = [
+ Resize(image_size, interpolation=InterpolationMode.BICUBIC),
+ CenterCrop(image_size),
+ ]
+ transforms.extend([
+ _convert_to_rgb,
+ ToTensor(),
+ normalize,
+ ])
+ return Compose(transforms)
diff --git a/open_clip/src/open_clip/transformer.py b/open_clip/src/open_clip/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb5416ecf3e8297f7d890832cea5b10293bd86d2
--- /dev/null
+++ b/open_clip/src/open_clip/transformer.py
@@ -0,0 +1,508 @@
+from collections import OrderedDict
+import math
+from typing import Callable, Optional, Sequence
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torch.utils.checkpoint import checkpoint
+
+from .utils import to_2tuple
+
+
+class LayerNormFp32(nn.LayerNorm):
+ """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back)."""
+
+ def forward(self, x: torch.Tensor):
+ orig_type = x.dtype
+ x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps)
+ return x.to(orig_type)
+
+
+class LayerNorm(nn.LayerNorm):
+ """Subclass torch's LayerNorm (with cast back to input dtype)."""
+
+ def forward(self, x: torch.Tensor):
+ orig_type = x.dtype
+ x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
+ return x.to(orig_type)
+
+
+class QuickGELU(nn.Module):
+ # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
+ def forward(self, x: torch.Tensor):
+ return x * torch.sigmoid(1.702 * x)
+
+
+class LayerScale(nn.Module):
+ def __init__(self, dim, init_values=1e-5, inplace=False):
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x):
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
+
+
+class PatchDropout(nn.Module):
+ """
+ https://arxiv.org/abs/2212.00794
+ """
+
+ def __init__(self, prob, exclude_first_token=True):
+ super().__init__()
+ assert 0 <= prob < 1.
+ self.prob = prob
+ self.exclude_first_token = exclude_first_token # exclude CLS token
+
+ def forward(self, x):
+ if not self.training or self.prob == 0.:
+ return x
+
+ if self.exclude_first_token:
+ cls_tokens, x = x[:, :1], x[:, 1:]
+ else:
+ cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
+
+ batch = x.size()[0]
+ num_tokens = x.size()[1]
+
+ batch_indices = torch.arange(batch)
+ batch_indices = batch_indices[..., None]
+
+ keep_prob = 1 - self.prob
+ num_patches_keep = max(1, int(num_tokens * keep_prob))
+
+ rand = torch.randn(batch, num_tokens)
+ patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
+
+ x = x[batch_indices, patch_indices_keep]
+
+ if self.exclude_first_token:
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ return x
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim,
+ num_heads=8,
+ qkv_bias=True,
+ scaled_cosine=False,
+ scale_heads=False,
+ logit_scale_max=math.log(1. / 0.01),
+ attn_drop=0.,
+ proj_drop=0.
+ ):
+ super().__init__()
+ self.scaled_cosine = scaled_cosine
+ self.scale_heads = scale_heads
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.scale = self.head_dim ** -0.5
+ self.logit_scale_max = logit_scale_max
+
+ # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
+ self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
+ if qkv_bias:
+ self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
+ else:
+ self.in_proj_bias = None
+
+ if self.scaled_cosine:
+ self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
+ else:
+ self.logit_scale = None
+ self.attn_drop = nn.Dropout(attn_drop)
+ if self.scale_heads:
+ self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
+ else:
+ self.head_scale = None
+ self.out_proj = nn.Linear(dim, dim)
+ self.out_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
+ L, N, C = x.shape
+ q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)
+ q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
+ k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
+ v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
+
+ if self.logit_scale is not None:
+ attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
+ logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
+ attn = attn.view(N, self.num_heads, L, L) * logit_scale
+ attn = attn.view(-1, L, L)
+ else:
+ q = q * self.scale
+ attn = torch.bmm(q, k.transpose(-1, -2))
+
+ if attn_mask is not None:
+ if attn_mask.dtype == torch.bool:
+ new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
+ new_attn_mask.masked_fill_(attn_mask, float("-inf"))
+ attn_mask = new_attn_mask
+ attn += attn_mask
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = torch.bmm(attn, v)
+ if self.head_scale is not None:
+ x = x.view(N, self.num_heads, L, C) * self.head_scale
+ x = x.view(-1, L, C)
+ x = x.transpose(0, 1).reshape(L, N, C)
+ x = self.out_proj(x)
+ x = self.out_drop(x)
+ return x
+
+
+class ResidualAttentionBlock(nn.Module):
+ def __init__(
+ self,
+ d_model: int,
+ n_head: int,
+ mlp_ratio: float = 4.0,
+ ls_init_value: float = None,
+ act_layer: Callable = nn.GELU,
+ norm_layer: Callable = LayerNorm,
+ ):
+ super().__init__()
+
+ self.ln_1 = norm_layer(d_model)
+ self.attn = nn.MultiheadAttention(d_model, n_head)
+ self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
+
+ self.ln_2 = norm_layer(d_model)
+ mlp_width = int(d_model * mlp_ratio)
+ self.mlp = nn.Sequential(OrderedDict([
+ ("c_fc", nn.Linear(d_model, mlp_width)),
+ ("gelu", act_layer()),
+ ("c_proj", nn.Linear(mlp_width, d_model))
+ ]))
+ self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
+
+ def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
+ attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None
+ return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
+
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
+ x = x + self.ls_1(self.attention(self.ln_1(x), attn_mask=attn_mask))
+ x = x + self.ls_2(self.mlp(self.ln_2(x)))
+ return x
+
+ def forward_dense(self, x):
+ y = self.ln_1(x)
+ y = F.linear(y, self.attn.in_proj_weight, self.attn.in_proj_bias)
+ L, N, D = y.shape # L N 3D
+
+ y = y.reshape(L, N, 3, D // 3).permute(2, 1, 0, 3).reshape(3 * N, L, D // 3)
+ y = F.linear(y, self.attn.out_proj.weight, self.attn.out_proj.bias)
+
+ q, k, v = y.tensor_split(3, dim=0)
+ #v = v.transpose(1, 0) + x # L N D
+ v = v.transpose(1, 0) + x[:1] # L N D
+
+ v = v + self.mlp(self.ln_2(v))
+
+ return v
+
+
+class CustomResidualAttentionBlock(nn.Module):
+ def __init__(
+ self,
+ d_model: int,
+ n_head: int,
+ mlp_ratio: float = 4.0,
+ ls_init_value: float = None,
+ act_layer: Callable = nn.GELU,
+ norm_layer: Callable = LayerNorm,
+ scale_cosine_attn: bool = False,
+ scale_heads: bool = False,
+ scale_attn: bool = False,
+ scale_fc: bool = False,
+ ):
+ super().__init__()
+
+ self.ln_1 = norm_layer(d_model)
+ self.attn = Attention(
+ d_model, n_head,
+ scaled_cosine=scale_cosine_attn,
+ scale_heads=scale_heads,
+ )
+ self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity()
+ self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
+
+ self.ln_2 = norm_layer(d_model)
+ mlp_width = int(d_model * mlp_ratio)
+ self.mlp = nn.Sequential(OrderedDict([
+ ("c_fc", nn.Linear(d_model, mlp_width)),
+ ('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()),
+ ("gelu", act_layer()),
+ ("c_proj", nn.Linear(mlp_width, d_model))
+ ]))
+ self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
+
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
+ x = x + self.ls_1(self.ln_attn(self.attn(self.ln_1(x), attn_mask=attn_mask)))
+ x = x + self.ls_2(self.mlp(self.ln_2(x)))
+ return x
+
+
+class Transformer(nn.Module):
+ def __init__(
+ self,
+ width: int,
+ layers: int,
+ heads: int,
+ mlp_ratio: float = 4.0,
+ ls_init_value: float = None,
+ act_layer: Callable = nn.GELU,
+ norm_layer: Callable = LayerNorm,
+ ):
+ super().__init__()
+ self.width = width
+ self.layers = layers
+ self.grad_checkpointing = False
+
+ self.resblocks = nn.ModuleList([
+ ResidualAttentionBlock(
+ width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer)
+ for _ in range(layers)
+ ])
+
+ def get_cast_dtype(self) -> torch.dtype:
+ return self.resblocks[0].mlp.c_fc.weight.dtype
+
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, dense=False):
+ for i, r in enumerate(self.resblocks):
+ if self.grad_checkpointing and not torch.jit.is_scripting():
+ x = checkpoint(r, x, attn_mask)
+ else:
+ if dense and i == self.layers - 1:
+ x = r.forward_dense(x)
+ else:
+ x = r(x, attn_mask=attn_mask)
+ return x
+
+
+class VisionTransformer(nn.Module):
+ def __init__(
+ self,
+ image_size: int,
+ patch_size: int,
+ width: int,
+ layers: int,
+ heads: int,
+ mlp_ratio: float,
+ ls_init_value: float = None,
+ global_average_pool: bool = False,
+ output_dim: int = 512,
+ patch_dropout: float = 0.,
+ act_layer: Callable = nn.GELU,
+ norm_layer: Callable = LayerNorm,
+ ):
+ super().__init__()
+ self.image_size = to_2tuple(image_size)
+ self.patch_size = to_2tuple(patch_size)
+ self.grid_size = (self.image_size[0] // self.patch_size[0], self.image_size[1] // self.patch_size[1])
+ self.output_dim = output_dim
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
+
+ scale = width ** -0.5
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
+ self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width))
+
+ # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
+ self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
+
+ self.ln_pre = norm_layer(width)
+ self.transformer = Transformer(
+ width,
+ layers,
+ heads,
+ mlp_ratio,
+ ls_init_value=ls_init_value,
+ act_layer=act_layer,
+ norm_layer=norm_layer,
+ )
+
+ self.global_average_pool = global_average_pool
+ self.ln_post = norm_layer(width)
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
+
+ self.init_parameters()
+
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
+ for param in self.parameters():
+ param.requires_grad = False
+
+ if unlocked_groups != 0:
+ groups = [
+ [
+ self.conv1,
+ self.class_embedding,
+ self.positional_embedding,
+ self.ln_pre,
+ ],
+ *self.transformer.resblocks[:-1],
+ [
+ self.transformer.resblocks[-1],
+ self.ln_post,
+ ],
+ self.proj,
+ ]
+
+ def _unlock(x):
+ if isinstance(x, Sequence):
+ for g in x:
+ _unlock(g)
+ else:
+ if isinstance(x, torch.nn.Parameter):
+ x.requires_grad = True
+ else:
+ for p in x.parameters():
+ p.requires_grad = True
+
+ _unlock(groups[-unlocked_groups:])
+
+ def init_parameters(self):
+ # FIXME OpenAI CLIP did not define an init for the VisualTransformer
+ # TODO experiment if default PyTorch init, below, or alternate init is best.
+
+ # nn.init.normal_(self.class_embedding, std=self.scale)
+ # nn.init.normal_(self.positional_embedding, std=self.scale)
+ #
+ # proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
+ # attn_std = self.transformer.width ** -0.5
+ # fc_std = (2 * self.transformer.width) ** -0.5
+ # for block in self.transformer.resblocks:
+ # nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
+ # nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
+ # nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
+ # nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
+ #
+ # if self.text_projection is not None:
+ # nn.init.normal_(self.text_projection, std=self.scale)
+ pass
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ self.transformer.grad_checkpointing = enable
+
+ def forward(self, x: torch.Tensor, dense=False):
+ x = self.conv1(x) # shape = [*, width, grid, grid]
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
+ x = torch.cat(
+ [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
+ x], dim=1) # shape = [*, grid ** 2 + 1, width]
+ x = x + self.positional_embedding.to(x.dtype)
+
+ # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
+ x = self.patch_dropout(x)
+ x = self.ln_pre(x)
+
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.transformer(x, dense=dense)
+ x = x.permute(1, 0, 2) # LND -> NLD
+
+ if self.global_average_pool:
+ x = x.mean(dim=1)
+ elif dense:
+ x = x
+ else:
+ x = x[:, 0]
+
+ x = self.ln_post(x)
+
+ if self.proj is not None:
+ x = x @ self.proj
+
+ return x
+
+
+class TextTransformer(nn.Module):
+
+ def __init__(
+ self,
+ context_length: int = 77,
+ vocab_size: int = 49408,
+ width: int = 512,
+ heads: int = 8,
+ layers: int = 12,
+ ls_init_value: float = None,
+ output_dim: int = 512,
+ act_layer: Callable = nn.GELU,
+ norm_layer: Callable = LayerNorm,
+ ):
+ super().__init__()
+ self.context_length = context_length
+ self.vocab_size = vocab_size
+ self.width = width
+ self.output_dim = output_dim
+
+ self.token_embedding = nn.Embedding(vocab_size, width)
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, width))
+ self.transformer = Transformer(
+ width=width,
+ layers=layers,
+ heads=heads,
+ ls_init_value=ls_init_value,
+ act_layer=act_layer,
+ norm_layer=norm_layer,
+ )
+ self.ln_final = norm_layer(width)
+ self.text_projection = nn.Parameter(torch.empty(width, output_dim))
+
+ self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)
+
+ self.init_parameters()
+
+ def init_parameters(self):
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
+ nn.init.normal_(self.positional_embedding, std=0.01)
+
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
+ attn_std = self.transformer.width ** -0.5
+ fc_std = (2 * self.transformer.width) ** -0.5
+ for block in self.transformer.resblocks:
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
+
+ if self.text_projection is not None:
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ self.transformer.grad_checkpointing = enable
+
+ def build_attention_mask(self):
+ # lazily create causal attention mask, with full attention between the vision tokens
+ # pytorch uses additive attention mask; fill with -inf
+ mask = torch.empty(self.context_length, self.context_length)
+ mask.fill_(float("-inf"))
+ mask.triu_(1) # zero out the lower diagonal
+ return mask
+
+ def forward(self, text):
+ cast_dtype = self.transformer.get_cast_dtype()
+
+ x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
+
+ x = x + self.positional_embedding.to(cast_dtype)
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.transformer(x, attn_mask=self.attn_mask)
+ x = x.permute(1, 0, 2) # LND -> NLD
+ x = self.ln_final(x)
+
+ # x.shape = [batch_size, n_ctx, transformer.width]
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
+
+ return x
diff --git a/open_clip/src/open_clip/utils.py b/open_clip/src/open_clip/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..51e80c5e296b24cae130ab0459baf268e0db7673
--- /dev/null
+++ b/open_clip/src/open_clip/utils.py
@@ -0,0 +1,60 @@
+from itertools import repeat
+import collections.abc
+
+from torch import nn as nn
+from torchvision.ops.misc import FrozenBatchNorm2d
+
+
+def freeze_batch_norm_2d(module, module_match={}, name=''):
+ """
+ Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
+ itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
+ returned. Otherwise, the module is walked recursively and submodules are converted in place.
+
+ Args:
+ module (torch.nn.Module): Any PyTorch module.
+ module_match (dict): Dictionary of full module names to freeze (all if empty)
+ name (str): Full module name (prefix)
+
+ Returns:
+ torch.nn.Module: Resulting module
+
+ Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
+ """
+ res = module
+ is_match = True
+ if module_match:
+ is_match = name in module_match
+ if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):
+ res = FrozenBatchNorm2d(module.num_features)
+ res.num_features = module.num_features
+ res.affine = module.affine
+ if module.affine:
+ res.weight.data = module.weight.data.clone().detach()
+ res.bias.data = module.bias.data.clone().detach()
+ res.running_mean.data = module.running_mean.data
+ res.running_var.data = module.running_var.data
+ res.eps = module.eps
+ else:
+ for child_name, child in module.named_children():
+ full_child_name = '.'.join([name, child_name]) if name else child_name
+ new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
+ if new_child is not child:
+ res.add_module(child_name, new_child)
+ return res
+
+
+# From PyTorch internals
+def _ntuple(n):
+ def parse(x):
+ if isinstance(x, collections.abc.Iterable):
+ return x
+ return tuple(repeat(x, n))
+ return parse
+
+
+to_1tuple = _ntuple(1)
+to_2tuple = _ntuple(2)
+to_3tuple = _ntuple(3)
+to_4tuple = _ntuple(4)
+to_ntuple = lambda n, x: _ntuple(n)(x)
diff --git a/open_clip/src/open_clip/version.py b/open_clip/src/open_clip/version.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1c6124423a7be38d4625a6989acf5b0dd9dbf07
--- /dev/null
+++ b/open_clip/src/open_clip/version.py
@@ -0,0 +1 @@
+__version__ = '2.10.1'
diff --git a/open_clip/src/training/.gitignore b/open_clip/src/training/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..333c1e910a3e2bef1b9d0d4587392627d8388974
--- /dev/null
+++ b/open_clip/src/training/.gitignore
@@ -0,0 +1 @@
+logs/
diff --git a/open_clip/src/training/__init__.py b/open_clip/src/training/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/open_clip/src/training/data.py b/open_clip/src/training/data.py
new file mode 100644
index 0000000000000000000000000000000000000000..863528a12879f85f2bba0e41d3a68da4e16a90fb
--- /dev/null
+++ b/open_clip/src/training/data.py
@@ -0,0 +1,514 @@
+import ast
+import json
+import logging
+import math
+import os
+import random
+import sys
+import time
+from dataclasses import dataclass
+from multiprocessing import Value
+
+import numpy as np
+import pandas as pd
+import torch
+import torchvision.datasets as datasets
+import webdataset as wds
+from PIL import Image
+from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, IterableDataset, get_worker_info
+from torch.utils.data.distributed import DistributedSampler
+from webdataset.filters import _shuffle
+from webdataset.tariterators import base_plus_ext, url_opener, tar_file_expander, valid_sample
+
+try:
+ import horovod.torch as hvd
+except ImportError:
+ hvd = None
+
+
+class CsvDataset(Dataset):
+ def __init__(self, input_filename, transforms, img_key, caption_key, sep="\t", tokenizer=None):
+ logging.debug(f'Loading csv data from {input_filename}.')
+ df = pd.read_csv(input_filename, sep=sep)
+
+ self.images = df[img_key].tolist()
+ self.captions = df[caption_key].tolist()
+ self.transforms = transforms
+ logging.debug('Done loading data.')
+
+ self.tokenize = tokenizer
+
+ def __len__(self):
+ return len(self.captions)
+
+ def __getitem__(self, idx):
+ images = self.transforms(Image.open(str(self.images[idx])))
+ texts = self.tokenize([str(self.captions[idx])])[0]
+ return images, texts
+
+
+class SharedEpoch:
+ def __init__(self, epoch: int = 0):
+ self.shared_epoch = Value('i', epoch)
+
+ def set_value(self, epoch):
+ self.shared_epoch.value = epoch
+
+ def get_value(self):
+ return self.shared_epoch.value
+
+
+@dataclass
+class DataInfo:
+ dataloader: DataLoader
+ sampler: DistributedSampler = None
+ shared_epoch: SharedEpoch = None
+
+ def set_epoch(self, epoch):
+ if self.shared_epoch is not None:
+ self.shared_epoch.set_value(epoch)
+ if self.sampler is not None and isinstance(self.sampler, DistributedSampler):
+ self.sampler.set_epoch(epoch)
+
+
+def get_dataset_size(shards):
+ shards_list = wds.shardlists.expand_urls(shards)
+ dir_path = os.path.dirname(shards_list[0])
+ sizes_filename = os.path.join(dir_path, 'sizes.json')
+ len_filename = os.path.join(dir_path, '__len__')
+ if os.path.exists(sizes_filename):
+ sizes = json.load(open(sizes_filename, 'r'))
+ total_size = sum([int(sizes[os.path.basename(shard)]) for shard in shards_list])
+ elif os.path.exists(len_filename):
+ # FIXME this used to be eval(open(...)) but that seemed rather unsafe
+ total_size = ast.literal_eval(open(len_filename, 'r').read())
+ else:
+ total_size = None # num samples undefined
+ # some common dataset sizes (at time of authors last download)
+ # CC3M (train): 2905954
+ # CC12M: 10968539
+ # LAION-400M: 407332084
+ # LAION-2B (english): 2170337258
+ num_shards = len(shards_list)
+ return total_size, num_shards
+
+
+def get_imagenet(args, preprocess_fns, split):
+ assert split in ["train", "val", "v2"]
+ is_train = split == "train"
+ preprocess_train, preprocess_val = preprocess_fns
+
+ if split == "v2":
+ from imagenetv2_pytorch import ImageNetV2Dataset
+ dataset = ImageNetV2Dataset(location=args.imagenet_v2, transform=preprocess_val)
+ else:
+ if is_train:
+ data_path = args.imagenet_train
+ preprocess_fn = preprocess_train
+ else:
+ data_path = args.imagenet_val
+ preprocess_fn = preprocess_val
+ assert data_path
+
+ dataset = datasets.ImageFolder(data_path, transform=preprocess_fn)
+
+ if is_train:
+ idxs = np.zeros(len(dataset.targets))
+ target_array = np.array(dataset.targets)
+ k = 50
+ for c in range(1000):
+ m = target_array == c
+ n = len(idxs[m])
+ arr = np.zeros(n)
+ arr[:k] = 1
+ np.random.shuffle(arr)
+ idxs[m] = arr
+
+ idxs = idxs.astype('int')
+ sampler = SubsetRandomSampler(np.where(idxs)[0])
+ else:
+ sampler = None
+
+ dataloader = torch.utils.data.DataLoader(
+ dataset,
+ batch_size=args.batch_size,
+ num_workers=args.workers,
+ sampler=sampler,
+ )
+
+ return DataInfo(dataloader=dataloader, sampler=sampler)
+
+
+def count_samples(dataloader):
+ os.environ["WDS_EPOCH"] = "0"
+ n_elements, n_batches = 0, 0
+ for images, texts in dataloader:
+ n_batches += 1
+ n_elements += len(images)
+ assert len(images) == len(texts)
+ return n_elements, n_batches
+
+
+def filter_no_caption_or_no_image(sample):
+ has_caption = ('txt' in sample)
+ has_image = ('png' in sample or 'jpg' in sample or 'jpeg' in sample or 'webp' in sample)
+ return has_caption and has_image
+
+
+def log_and_continue(exn):
+ """Call in an exception handler to ignore any exception, issue a warning, and continue."""
+ logging.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.')
+ return True
+
+
+def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None):
+ """Return function over iterator that groups key, value pairs into samples.
+
+ :param keys: function that splits the key into key and extension (base_plus_ext)
+ :param lcase: convert suffixes to lower case (Default value = True)
+ """
+ current_sample = None
+ for filesample in data:
+ assert isinstance(filesample, dict)
+ fname, value = filesample["fname"], filesample["data"]
+ prefix, suffix = keys(fname)
+ if prefix is None:
+ continue
+ if lcase:
+ suffix = suffix.lower()
+ # FIXME webdataset version throws if suffix in current_sample, but we have a potential for
+ # this happening in the current LAION400m dataset if a tar ends with same prefix as the next
+ # begins, rare, but can happen since prefix aren't unique across tar files in that dataset
+ if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample:
+ if valid_sample(current_sample):
+ yield current_sample
+ current_sample = dict(__key__=prefix, __url__=filesample["__url__"])
+ if suffixes is None or suffix in suffixes:
+ current_sample[suffix] = value
+ if valid_sample(current_sample):
+ yield current_sample
+
+
+def tarfile_to_samples_nothrow(src, handler=log_and_continue):
+ # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw
+ streams = url_opener(src, handler=handler)
+ files = tar_file_expander(streams, handler=handler)
+ samples = group_by_keys_nothrow(files, handler=handler)
+ return samples
+
+
+def pytorch_worker_seed(increment=0):
+ """get dataloader worker seed from pytorch"""
+ worker_info = get_worker_info()
+ if worker_info is not None:
+ # favour using the seed already created for pytorch dataloader workers if it exists
+ seed = worker_info.seed
+ if increment:
+ # space out seed increments so they can't overlap across workers in different iterations
+ seed += increment * max(1, worker_info.num_workers)
+ return seed
+ # fallback to wds rank based seed
+ return wds.utils.pytorch_worker_seed()
+
+
+_SHARD_SHUFFLE_SIZE = 2000
+_SHARD_SHUFFLE_INITIAL = 500
+_SAMPLE_SHUFFLE_SIZE = 5000
+_SAMPLE_SHUFFLE_INITIAL = 1000
+
+
+class detshuffle2(wds.PipelineStage):
+ def __init__(
+ self,
+ bufsize=1000,
+ initial=100,
+ seed=0,
+ epoch=-1,
+ ):
+ self.bufsize = bufsize
+ self.initial = initial
+ self.seed = seed
+ self.epoch = epoch
+
+ def run(self, src):
+ if isinstance(self.epoch, SharedEpoch):
+ epoch = self.epoch.get_value()
+ else:
+ # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train)
+ # situation as different workers may wrap at different times (or not at all).
+ self.epoch += 1
+ epoch = self.epoch
+ rng = random.Random()
+ if self.seed < 0:
+ # If seed is negative, we use the worker's seed, this will be different across all nodes/workers
+ seed = pytorch_worker_seed(epoch)
+ else:
+ # This seed to be deterministic AND the same across all nodes/workers in each epoch
+ seed = self.seed + epoch
+ rng.seed(seed)
+ return _shuffle(src, self.bufsize, self.initial, rng)
+
+
+class ResampledShards2(IterableDataset):
+ """An iterable dataset yielding a list of urls."""
+
+ def __init__(
+ self,
+ urls,
+ nshards=sys.maxsize,
+ worker_seed=None,
+ deterministic=False,
+ epoch=-1,
+ ):
+ """Sample shards from the shard list with replacement.
+
+ :param urls: a list of URLs as a Python list or brace notation string
+ """
+ super().__init__()
+ urls = wds.shardlists.expand_urls(urls)
+ self.urls = urls
+ assert isinstance(self.urls[0], str)
+ self.nshards = nshards
+ self.rng = random.Random()
+ self.worker_seed = worker_seed
+ self.deterministic = deterministic
+ self.epoch = epoch
+
+ def __iter__(self):
+ """Return an iterator over the shards."""
+ if isinstance(self.epoch, SharedEpoch):
+ epoch = self.epoch.get_value()
+ else:
+ # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train)
+ # situation as different workers may wrap at different times (or not at all).
+ self.epoch += 1
+ epoch = self.epoch
+ if self.deterministic:
+ # reset seed w/ epoch if deterministic
+ if self.worker_seed is None:
+ # pytorch worker seed should be deterministic due to being init by arg.seed + rank + worker id
+ seed = pytorch_worker_seed(epoch)
+ else:
+ seed = self.worker_seed() + epoch
+ self.rng.seed(seed)
+ for _ in range(self.nshards):
+ yield dict(url=self.rng.choice(self.urls))
+
+
+def get_wds_dataset(args, preprocess_img, is_train, epoch=0, floor=False, tokenizer=None):
+ input_shards = args.train_data if is_train else args.val_data
+ assert input_shards is not None
+ resampled = getattr(args, 'dataset_resampled', False) and is_train
+
+ num_samples, num_shards = get_dataset_size(input_shards)
+ if not num_samples:
+ if is_train:
+ num_samples = args.train_num_samples
+ if not num_samples:
+ raise RuntimeError(
+ 'Currently, number of dataset samples must be specified for training dataset. '
+ 'Please specify via `--train-num-samples` if no dataset length info present.')
+ else:
+ num_samples = args.val_num_samples or 0 # eval will just exhaust the iterator if not specified
+
+ shared_epoch = SharedEpoch(epoch=epoch) # create a shared epoch store to sync epoch to dataloader worker proc
+
+ if resampled:
+ pipeline = [ResampledShards2(input_shards, deterministic=True, epoch=shared_epoch)]
+ else:
+ pipeline = [wds.SimpleShardList(input_shards)]
+
+ # at this point we have an iterator over all the shards
+ if is_train:
+ if not resampled:
+ pipeline.extend([
+ detshuffle2(
+ bufsize=_SHARD_SHUFFLE_SIZE,
+ initial=_SHARD_SHUFFLE_INITIAL,
+ seed=args.seed,
+ epoch=shared_epoch,
+ ),
+ wds.split_by_node,
+ wds.split_by_worker,
+ ])
+ pipeline.extend([
+ # at this point, we have an iterator over the shards assigned to each worker at each node
+ tarfile_to_samples_nothrow, # wds.tarfile_to_samples(handler=log_and_continue),
+ wds.shuffle(
+ bufsize=_SAMPLE_SHUFFLE_SIZE,
+ initial=_SAMPLE_SHUFFLE_INITIAL,
+ ),
+ ])
+ else:
+ pipeline.extend([
+ wds.split_by_worker,
+ # at this point, we have an iterator over the shards assigned to each worker
+ wds.tarfile_to_samples(handler=log_and_continue),
+ ])
+ pipeline.extend([
+ wds.select(filter_no_caption_or_no_image),
+ wds.decode("pilrgb", handler=log_and_continue),
+ wds.rename(image="jpg;png;jpeg;webp", text="txt"),
+ wds.map_dict(image=preprocess_img, text=lambda text: tokenizer(text)[0]),
+ wds.to_tuple("image", "text"),
+ wds.batched(args.batch_size, partial=not is_train),
+ ])
+
+ dataset = wds.DataPipeline(*pipeline)
+ if is_train:
+ if not resampled:
+ assert num_shards >= args.workers * args.world_size, 'number of shards must be >= total workers'
+ # roll over and repeat a few samples to get same number of full batches on each node
+ round_fn = math.floor if floor else math.ceil
+ global_batch_size = args.batch_size * args.world_size
+ num_batches = round_fn(num_samples / global_batch_size)
+ num_workers = max(1, args.workers)
+ num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker
+ num_batches = num_worker_batches * num_workers
+ num_samples = num_batches * global_batch_size
+ dataset = dataset.with_epoch(num_worker_batches) # each worker is iterating over this
+ else:
+ # last batches are partial, eval is done on single (master) node
+ num_batches = math.ceil(num_samples / args.batch_size)
+
+ dataloader = wds.WebLoader(
+ dataset,
+ batch_size=None,
+ shuffle=False,
+ num_workers=args.workers,
+ persistent_workers=True,
+ )
+
+ # FIXME not clear which approach is better, with_epoch before vs after dataloader?
+ # hoping to resolve via https://github.com/webdataset/webdataset/issues/169
+ # if is_train:
+ # # roll over and repeat a few samples to get same number of full batches on each node
+ # global_batch_size = args.batch_size * args.world_size
+ # num_batches = math.ceil(num_samples / global_batch_size)
+ # num_workers = max(1, args.workers)
+ # num_batches = math.ceil(num_batches / num_workers) * num_workers
+ # num_samples = num_batches * global_batch_size
+ # dataloader = dataloader.with_epoch(num_batches)
+ # else:
+ # # last batches are partial, eval is done on single (master) node
+ # num_batches = math.ceil(num_samples / args.batch_size)
+
+ # add meta-data to dataloader instance for convenience
+ dataloader.num_batches = num_batches
+ dataloader.num_samples = num_samples
+
+ return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch)
+
+
+def get_csv_dataset(args, preprocess_fn, is_train, epoch=0, tokenizer=None):
+ input_filename = args.train_data if is_train else args.val_data
+ assert input_filename
+ dataset = CsvDataset(
+ input_filename,
+ preprocess_fn,
+ img_key=args.csv_img_key,
+ caption_key=args.csv_caption_key,
+ sep=args.csv_separator,
+ tokenizer=tokenizer
+ )
+ num_samples = len(dataset)
+ sampler = DistributedSampler(dataset) if args.distributed and is_train else None
+ shuffle = is_train and sampler is None
+
+ dataloader = DataLoader(
+ dataset,
+ batch_size=args.batch_size,
+ shuffle=shuffle,
+ num_workers=args.workers,
+ pin_memory=True,
+ sampler=sampler,
+ drop_last=is_train,
+ )
+ dataloader.num_samples = num_samples
+ dataloader.num_batches = len(dataloader)
+
+ return DataInfo(dataloader, sampler)
+
+
+class SyntheticDataset(Dataset):
+
+ def __init__(self, transform=None, image_size=(224, 224), caption="Dummy caption", dataset_size=100, tokenizer=None):
+ self.transform = transform
+ self.image_size = image_size
+ self.caption = caption
+ self.image = Image.new('RGB', image_size)
+ self.dataset_size = dataset_size
+
+ self.preprocess_txt = lambda text: tokenizer(text)[0]
+
+ def __len__(self):
+ return self.dataset_size
+
+ def __getitem__(self, idx):
+ if self.transform is not None:
+ image = self.transform(self.image)
+ return image, self.preprocess_txt(self.caption)
+
+
+def get_synthetic_dataset(args, preprocess_fn, is_train, epoch=0, tokenizer=None):
+ image_size = preprocess_fn.transforms[0].size
+ dataset = SyntheticDataset(
+ transform=preprocess_fn, image_size=image_size, dataset_size=args.train_num_samples, tokenizer=tokenizer)
+ num_samples = len(dataset)
+ sampler = DistributedSampler(dataset) if args.distributed and is_train else None
+ shuffle = is_train and sampler is None
+
+ dataloader = DataLoader(
+ dataset,
+ batch_size=args.batch_size,
+ shuffle=shuffle,
+ num_workers=args.workers,
+ pin_memory=True,
+ sampler=sampler,
+ drop_last=is_train,
+ )
+ dataloader.num_samples = num_samples
+ dataloader.num_batches = len(dataloader)
+
+ return DataInfo(dataloader, sampler)
+
+
+def get_dataset_fn(data_path, dataset_type):
+ if dataset_type == "webdataset":
+ return get_wds_dataset
+ elif dataset_type == "csv":
+ return get_csv_dataset
+ elif dataset_type == "synthetic":
+ return get_synthetic_dataset
+ elif dataset_type == "auto":
+ ext = data_path.split('.')[-1]
+ if ext in ['csv', 'tsv']:
+ return get_csv_dataset
+ elif ext in ['tar']:
+ return get_wds_dataset
+ else:
+ raise ValueError(
+ f"Tried to figure out dataset type, but failed for extension {ext}.")
+ else:
+ raise ValueError(f"Unsupported dataset type: {dataset_type}")
+
+
+def get_data(args, preprocess_fns, epoch=0, tokenizer=None):
+ preprocess_train, preprocess_val = preprocess_fns
+ data = {}
+
+ if args.train_data or args.dataset_type == "synthetic":
+ data["train"] = get_dataset_fn(args.train_data, args.dataset_type)(
+ args, preprocess_train, is_train=True, epoch=epoch, tokenizer=tokenizer)
+
+ if args.val_data:
+ data["val"] = get_dataset_fn(args.val_data, args.dataset_type)(
+ args, preprocess_val, is_train=False, tokenizer=tokenizer)
+
+ if args.imagenet_val is not None:
+ data["imagenet-val"] = get_imagenet(args, preprocess_fns, "val")
+
+ if args.imagenet_v2 is not None:
+ data["imagenet-v2"] = get_imagenet(args, preprocess_fns, "v2")
+
+ return data
diff --git a/open_clip/src/training/distributed.py b/open_clip/src/training/distributed.py
new file mode 100644
index 0000000000000000000000000000000000000000..268a6c7ad75a9ef29c72801dbf59d606f3318a59
--- /dev/null
+++ b/open_clip/src/training/distributed.py
@@ -0,0 +1,137 @@
+import os
+
+import torch
+import torch.distributed as dist
+
+try:
+ import horovod.torch as hvd
+except ImportError:
+ hvd = None
+
+
+def is_global_master(args):
+ return args.rank == 0
+
+
+def is_local_master(args):
+ return args.local_rank == 0
+
+
+def is_master(args, local=False):
+ return is_local_master(args) if local else is_global_master(args)
+
+
+def is_using_horovod():
+ # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set
+ # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required...
+ ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"]
+ pmi_vars = ["PMI_RANK", "PMI_SIZE"]
+ if all([var in os.environ for var in ompi_vars]) or all([var in os.environ for var in pmi_vars]):
+ return True
+ else:
+ return False
+
+
+def is_using_distributed():
+ if 'WORLD_SIZE' in os.environ:
+ return int(os.environ['WORLD_SIZE']) > 1
+ if 'SLURM_NTASKS' in os.environ:
+ return int(os.environ['SLURM_NTASKS']) > 1
+ return False
+
+
+def world_info_from_env():
+ local_rank = 0
+ for v in ('LOCAL_RANK', 'MPI_LOCALRANKID', 'SLURM_LOCALID', 'OMPI_COMM_WORLD_LOCAL_RANK'):
+ if v in os.environ:
+ local_rank = int(os.environ[v])
+ break
+ global_rank = 0
+ for v in ('RANK', 'PMI_RANK', 'SLURM_PROCID', 'OMPI_COMM_WORLD_RANK'):
+ if v in os.environ:
+ global_rank = int(os.environ[v])
+ break
+ world_size = 1
+ for v in ('WORLD_SIZE', 'PMI_SIZE', 'SLURM_NTASKS', 'OMPI_COMM_WORLD_SIZE'):
+ if v in os.environ:
+ world_size = int(os.environ[v])
+ break
+
+ return local_rank, global_rank, world_size
+
+
+def init_distributed_device(args):
+ # Distributed training = training on more than one GPU.
+ # Works in both single and multi-node scenarios.
+ args.distributed = False
+ args.world_size = 1
+ args.rank = 0 # global rank
+ args.local_rank = 0
+ if args.horovod:
+ assert hvd is not None, "Horovod is not installed"
+ hvd.init()
+ args.local_rank = int(hvd.local_rank())
+ args.rank = hvd.rank()
+ args.world_size = hvd.size()
+ args.distributed = True
+ os.environ['LOCAL_RANK'] = str(args.local_rank)
+ os.environ['RANK'] = str(args.rank)
+ os.environ['WORLD_SIZE'] = str(args.world_size)
+ elif is_using_distributed():
+ if 'SLURM_PROCID' in os.environ:
+ # DDP via SLURM
+ args.local_rank, args.rank, args.world_size = world_info_from_env()
+ # SLURM var -> torch.distributed vars in case needed
+ os.environ['LOCAL_RANK'] = str(args.local_rank)
+ os.environ['RANK'] = str(args.rank)
+ os.environ['WORLD_SIZE'] = str(args.world_size)
+ torch.distributed.init_process_group(
+ backend=args.dist_backend,
+ init_method=args.dist_url,
+ world_size=args.world_size,
+ rank=args.rank,
+ )
+ else:
+ # DDP via torchrun, torch.distributed.launch
+ args.local_rank, _, _ = world_info_from_env()
+ torch.distributed.init_process_group(
+ backend=args.dist_backend,
+ init_method=args.dist_url)
+ args.world_size = torch.distributed.get_world_size()
+ args.rank = torch.distributed.get_rank()
+ args.distributed = True
+
+ if torch.cuda.is_available():
+ if args.distributed and not args.no_set_device_rank:
+ device = 'cuda:%d' % args.local_rank
+ else:
+ device = 'cuda:0'
+ torch.cuda.set_device(device)
+ else:
+ device = 'cpu'
+ args.device = device
+ device = torch.device(device)
+ return device
+
+
+def broadcast_object(args, obj, src=0):
+ # broadcast a pickle-able python object from rank-0 to all ranks
+ if args.horovod:
+ return hvd.broadcast_object(obj, root_rank=src)
+ else:
+ if args.rank == src:
+ objects = [obj]
+ else:
+ objects = [None]
+ dist.broadcast_object_list(objects, src=src)
+ return objects[0]
+
+
+def all_gather_object(args, obj, dst=0):
+ # gather a pickle-able python object across all ranks
+ if args.horovod:
+ return hvd.allgather_object(obj)
+ else:
+ objects = [None for _ in range(args.world_size)]
+ dist.all_gather_object(objects, obj)
+ return objects
diff --git a/open_clip/src/training/file_utils.py b/open_clip/src/training/file_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ec933b7d718bf4efb59962fdd97e18357303865
--- /dev/null
+++ b/open_clip/src/training/file_utils.py
@@ -0,0 +1,83 @@
+import logging
+import os
+import multiprocessing
+import subprocess
+import time
+import fsspec
+import torch
+from tqdm import tqdm
+
+def remote_sync_s3(local_dir, remote_dir):
+ # skip epoch_latest which can change during sync.
+ result = subprocess.run(["aws", "s3", "sync", local_dir, remote_dir, '--exclude', '*epoch_latest.pt'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ if result.returncode != 0:
+ logging.error(f"Error: Failed to sync with S3 bucket {result.stderr.decode('utf-8')}")
+ return False
+
+ logging.info(f"Successfully synced with S3 bucket")
+ return True
+
+def remote_sync_fsspec(local_dir, remote_dir):
+ # FIXME currently this is slow and not recommended. Look into speeding up.
+ a = fsspec.get_mapper(local_dir)
+ b = fsspec.get_mapper(remote_dir)
+
+ for k in a:
+ # skip epoch_latest which can change during sync.
+ if 'epoch_latest.pt' in k:
+ continue
+
+ logging.info(f'Attempting to sync {k}')
+ if k in b and len(a[k]) == len(b[k]):
+ logging.debug(f'Skipping remote sync for {k}.')
+ continue
+
+ try:
+ logging.info(f'Successful sync for {k}.')
+ b[k] = a[k]
+ except Exception as e:
+ logging.info(f'Error during remote sync for {k}: {e}')
+ return False
+
+ return True
+
+def remote_sync(local_dir, remote_dir, protocol):
+ logging.info('Starting remote sync.')
+ if protocol == 's3':
+ return remote_sync_s3(local_dir, remote_dir)
+ elif protocol == 'fsspec':
+ return remote_sync_fsspec(local_dir, remote_dir)
+ else:
+ logging.error('Remote protocol not known')
+ return False
+
+def keep_running_remote_sync(sync_every, local_dir, remote_dir, protocol):
+ while True:
+ time.sleep(sync_every)
+ remote_sync(local_dir, remote_dir, protocol)
+
+def start_sync_process(sync_every, local_dir, remote_dir, protocol):
+ p = multiprocessing.Process(target=keep_running_remote_sync, args=(sync_every, local_dir, remote_dir, protocol))
+ return p
+
+# Note: we are not currently using this save function.
+def pt_save(pt_obj, file_path):
+ of = fsspec.open(file_path, "wb")
+ with of as f:
+ torch.save(pt_obj, file_path)
+
+def pt_load(file_path, map_location=None):
+ if not file_path.startswith('/'):
+ logging.info('Loading remote checkpoint, which may take a bit.')
+ of = fsspec.open(file_path, "rb")
+ with of as f:
+ out = torch.load(f, map_location=map_location)
+ return out
+
+def check_exists(file_path):
+ try:
+ with fsspec.open(file_path):
+ pass
+ except FileNotFoundError:
+ return False
+ return True
diff --git a/open_clip/src/training/imagenet_zeroshot_data.py b/open_clip/src/training/imagenet_zeroshot_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..27abd8bf24ebe077a73e8496576d949d8bb16f69
--- /dev/null
+++ b/open_clip/src/training/imagenet_zeroshot_data.py
@@ -0,0 +1,254 @@
+
+
+imagenet_classnames = ["tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray",
+ "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco",
+ "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper",
+ "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander",
+ "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog",
+ "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin",
+ "box turtle", "banded gecko", "green iguana", "Carolina anole",
+ "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard",
+ "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile",
+ "American alligator", "triceratops", "worm snake", "ring-necked snake",
+ "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake",
+ "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra",
+ "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake",
+ "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider",
+ "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider",
+ "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl",
+ "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet",
+ "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck",
+ "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby",
+ "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch",
+ "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab",
+ "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab",
+ "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron",
+ "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot",
+ "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher",
+ "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion",
+ "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel",
+ "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle",
+ "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound",
+ "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound",
+ "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound",
+ "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier",
+ "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier",
+ "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier",
+ "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier",
+ "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer",
+ "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier",
+ "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier",
+ "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever",
+ "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla",
+ "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel",
+ "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel",
+ "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard",
+ "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie",
+ "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann",
+ "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog",
+ "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff",
+ "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky",
+ "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog",
+ "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon",
+ "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle",
+ "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf",
+ "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox",
+ "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat",
+ "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger",
+ "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose",
+ "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle",
+ "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper",
+ "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper",
+ "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly",
+ "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly",
+ "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit",
+ "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse",
+ "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison",
+ "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)",
+ "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat",
+ "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan",
+ "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque",
+ "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin",
+ "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey",
+ "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda",
+ "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish",
+ "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown",
+ "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance",
+ "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle",
+ "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo",
+ "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel",
+ "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel",
+ "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)",
+ "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini",
+ "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet",
+ "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra",
+ "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest",
+ "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe",
+ "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton",
+ "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran",
+ "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw",
+ "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking",
+ "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker",
+ "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard",
+ "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot",
+ "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed",
+ "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer",
+ "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table",
+ "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig",
+ "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar",
+ "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder",
+ "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute",
+ "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed",
+ "freight car", "French horn", "frying pan", "fur coat", "garbage truck",
+ "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola",
+ "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine",
+ "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer",
+ "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet",
+ "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar",
+ "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep",
+ "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat",
+ "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library",
+ "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion",
+ "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag",
+ "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask",
+ "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone",
+ "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile",
+ "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor",
+ "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa",
+ "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail",
+ "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina",
+ "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart",
+ "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush",
+ "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench",
+ "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case",
+ "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube",
+ "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball",
+ "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag",
+ "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho",
+ "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug",
+ "printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill",
+ "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel",
+ "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator",
+ "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser",
+ "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal",
+ "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard",
+ "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store",
+ "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap",
+ "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door",
+ "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock",
+ "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater",
+ "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight",
+ "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf",
+ "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa",
+ "submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge",
+ "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe",
+ "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball",
+ "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof",
+ "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store",
+ "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod",
+ "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard",
+ "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling",
+ "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball",
+ "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink",
+ "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle",
+ "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing",
+ "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website",
+ "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu",
+ "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette",
+ "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli",
+ "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber",
+ "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange",
+ "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate",
+ "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito",
+ "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef",
+ "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player",
+ "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn",
+ "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom",
+ "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"]
+
+
+
+
+
+openai_imagenet_template = [
+ lambda c: f'a bad photo of a {c}.',
+ lambda c: f'a photo of many {c}.',
+ lambda c: f'a sculpture of a {c}.',
+ lambda c: f'a photo of the hard to see {c}.',
+ lambda c: f'a low resolution photo of the {c}.',
+ lambda c: f'a rendering of a {c}.',
+ lambda c: f'graffiti of a {c}.',
+ lambda c: f'a bad photo of the {c}.',
+ lambda c: f'a cropped photo of the {c}.',
+ lambda c: f'a tattoo of a {c}.',
+ lambda c: f'the embroidered {c}.',
+ lambda c: f'a photo of a hard to see {c}.',
+ lambda c: f'a bright photo of a {c}.',
+ lambda c: f'a photo of a clean {c}.',
+ lambda c: f'a photo of a dirty {c}.',
+ lambda c: f'a dark photo of the {c}.',
+ lambda c: f'a drawing of a {c}.',
+ lambda c: f'a photo of my {c}.',
+ lambda c: f'the plastic {c}.',
+ lambda c: f'a photo of the cool {c}.',
+ lambda c: f'a close-up photo of a {c}.',
+ lambda c: f'a black and white photo of the {c}.',
+ lambda c: f'a painting of the {c}.',
+ lambda c: f'a painting of a {c}.',
+ lambda c: f'a pixelated photo of the {c}.',
+ lambda c: f'a sculpture of the {c}.',
+ lambda c: f'a bright photo of the {c}.',
+ lambda c: f'a cropped photo of a {c}.',
+ lambda c: f'a plastic {c}.',
+ lambda c: f'a photo of the dirty {c}.',
+ lambda c: f'a jpeg corrupted photo of a {c}.',
+ lambda c: f'a blurry photo of the {c}.',
+ lambda c: f'a photo of the {c}.',
+ lambda c: f'a good photo of the {c}.',
+ lambda c: f'a rendering of the {c}.',
+ lambda c: f'a {c} in a video game.',
+ lambda c: f'a photo of one {c}.',
+ lambda c: f'a doodle of a {c}.',
+ lambda c: f'a close-up photo of the {c}.',
+ lambda c: f'a photo of a {c}.',
+ lambda c: f'the origami {c}.',
+ lambda c: f'the {c} in a video game.',
+ lambda c: f'a sketch of a {c}.',
+ lambda c: f'a doodle of the {c}.',
+ lambda c: f'a origami {c}.',
+ lambda c: f'a low resolution photo of a {c}.',
+ lambda c: f'the toy {c}.',
+ lambda c: f'a rendition of the {c}.',
+ lambda c: f'a photo of the clean {c}.',
+ lambda c: f'a photo of a large {c}.',
+ lambda c: f'a rendition of a {c}.',
+ lambda c: f'a photo of a nice {c}.',
+ lambda c: f'a photo of a weird {c}.',
+ lambda c: f'a blurry photo of a {c}.',
+ lambda c: f'a cartoon {c}.',
+ lambda c: f'art of a {c}.',
+ lambda c: f'a sketch of the {c}.',
+ lambda c: f'a embroidered {c}.',
+ lambda c: f'a pixelated photo of a {c}.',
+ lambda c: f'itap of the {c}.',
+ lambda c: f'a jpeg corrupted photo of the {c}.',
+ lambda c: f'a good photo of a {c}.',
+ lambda c: f'a plushie {c}.',
+ lambda c: f'a photo of the nice {c}.',
+ lambda c: f'a photo of the small {c}.',
+ lambda c: f'a photo of the weird {c}.',
+ lambda c: f'the cartoon {c}.',
+ lambda c: f'art of the {c}.',
+ lambda c: f'a drawing of the {c}.',
+ lambda c: f'a photo of the large {c}.',
+ lambda c: f'a black and white photo of a {c}.',
+ lambda c: f'the plushie {c}.',
+ lambda c: f'a dark photo of a {c}.',
+ lambda c: f'itap of a {c}.',
+ lambda c: f'graffiti of the {c}.',
+ lambda c: f'a toy {c}.',
+ lambda c: f'itap of my {c}.',
+ lambda c: f'a photo of a cool {c}.',
+ lambda c: f'a photo of a small {c}.',
+ lambda c: f'a tattoo of the {c}.',
+]
diff --git a/open_clip/src/training/logger.py b/open_clip/src/training/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d9abed92568d459cbc8d6094ae3901935d89621
--- /dev/null
+++ b/open_clip/src/training/logger.py
@@ -0,0 +1,26 @@
+import logging
+
+
+def setup_logging(log_file, level, include_host=False):
+ if include_host:
+ import socket
+ hostname = socket.gethostname()
+ formatter = logging.Formatter(
+ f'%(asctime)s | {hostname} | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S')
+ else:
+ formatter = logging.Formatter('%(asctime)s | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S')
+
+ logging.root.setLevel(level)
+ loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict]
+ for logger in loggers:
+ logger.setLevel(level)
+
+ stream_handler = logging.StreamHandler()
+ stream_handler.setFormatter(formatter)
+ logging.root.addHandler(stream_handler)
+
+ if log_file:
+ file_handler = logging.FileHandler(filename=log_file)
+ file_handler.setFormatter(formatter)
+ logging.root.addHandler(file_handler)
+
diff --git a/open_clip/src/training/main.py b/open_clip/src/training/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..e648099c0209bc01451ea8ff3bb110e9a336e357
--- /dev/null
+++ b/open_clip/src/training/main.py
@@ -0,0 +1,446 @@
+import glob
+import logging
+import os
+import re
+import subprocess
+import sys
+import random
+from datetime import datetime
+
+import numpy as np
+import torch
+from torch import optim
+from torch.cuda.amp import GradScaler
+
+try:
+ import wandb
+except ImportError:
+ wandb = None
+
+try:
+ import torch.utils.tensorboard as tensorboard
+except ImportError:
+ tensorboard = None
+
+try:
+ import horovod.torch as hvd
+except ImportError:
+ hvd = None
+
+from open_clip import create_model_and_transforms, trace_model, get_tokenizer
+from training.data import get_data
+from training.distributed import is_master, init_distributed_device, broadcast_object
+from training.logger import setup_logging
+from training.params import parse_args
+from training.scheduler import cosine_lr, const_lr, const_lr_cooldown
+from training.train import train_one_epoch, evaluate
+from training.file_utils import pt_load, check_exists, start_sync_process, remote_sync
+
+
+LATEST_CHECKPOINT_NAME = "epoch_latest.pt"
+
+
+def random_seed(seed=42, rank=0):
+ torch.manual_seed(seed + rank)
+ np.random.seed(seed + rank)
+ random.seed(seed + rank)
+
+
+def natural_key(string_):
+ """See http://www.codinghorror.com/blog/archives/001018.html"""
+ return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
+
+
+def get_latest_checkpoint(path: str, remote : bool):
+ # as writen, this glob recurses, so can pick up checkpoints across multiple sub-folders
+ if remote:
+ result = subprocess.run(["aws", "s3", "ls", path + "/"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ print(result)
+ if result.returncode == 1:
+ return None
+ checkpoints = [os.path.join(path, x.split(' ')[-1]) for x in result.stdout.decode().split('\n')[:-1]]
+ else:
+ checkpoints = glob.glob(path + '**/*.pt', recursive=True)
+ if checkpoints:
+ checkpoints = sorted(checkpoints, key=natural_key)
+ return checkpoints[-1]
+ return None
+
+
+def main(args):
+ args = parse_args(args)
+
+ if torch.cuda.is_available():
+ # This enables tf32 on Ampere GPUs which is only 8% slower than
+ # float16 and almost as accurate as float32
+ # This was a default in pytorch until 1.12
+ torch.backends.cuda.matmul.allow_tf32 = True
+ torch.backends.cudnn.benchmark = True
+ torch.backends.cudnn.deterministic = False
+
+ # fully initialize distributed device environment
+ device = init_distributed_device(args)
+
+ # get the name of the experiments
+ if args.name is None:
+ # sanitize model name for filesystem / uri use, easier if we don't use / in name as a rule?
+ model_name_safe = args.model.replace('/', '-')
+ date_str = datetime.now().strftime("%Y_%m_%d`-%H_%M_%S")
+ if args.distributed:
+ # sync date_str from master to all ranks
+ date_str = broadcast_object(args, date_str)
+ args.name = '-'.join([
+ date_str,
+ f"model_{model_name_safe}",
+ f"lr_{args.lr}",
+ f"b_{args.batch_size}",
+ f"j_{args.workers}",
+ f"p_{args.precision}",
+ ])
+
+ resume_latest = args.resume == 'latest'
+ log_base_path = os.path.join(args.logs, args.name)
+ args.log_path = None
+ if is_master(args, local=args.log_local):
+ os.makedirs(log_base_path, exist_ok=True)
+ log_filename = f'out-{args.rank}' if args.log_local else 'out.log'
+ args.log_path = os.path.join(log_base_path, log_filename)
+ if os.path.exists(args.log_path) and not resume_latest:
+ print(
+ "Error. Experiment already exists. Use --name {} to specify a new experiment."
+ )
+ return -1
+
+ # Setup text logger
+ args.log_level = logging.DEBUG if args.debug else logging.INFO
+ setup_logging(args.log_path, args.log_level)
+
+ # Setup wandb, tensorboard, checkpoint logging
+ args.wandb = 'wandb' in args.report_to or 'all' in args.report_to
+ args.tensorboard = 'tensorboard' in args.report_to or 'all' in args.report_to
+ args.checkpoint_path = os.path.join(log_base_path, "checkpoints")
+ if is_master(args):
+ args.tensorboard_path = os.path.join(log_base_path, "tensorboard") if args.tensorboard else ''
+ for dirname in [args.tensorboard_path, args.checkpoint_path]:
+ if dirname:
+ os.makedirs(dirname, exist_ok=True)
+ else:
+ args.tensorboard_path = ''
+
+ if resume_latest:
+ resume_from = None
+ checkpoint_path = args.checkpoint_path
+ # If using remote_sync, need to check the remote instead of the local checkpoints folder.
+ if args.remote_sync is not None:
+ checkpoint_path = os.path.join(args.remote_sync, args.name, "checkpoints")
+ if args.save_most_recent:
+ print('Error. Cannot use save-most-recent with remote_sync and resume latest.')
+ return -1
+ if args.remote_sync_protocol != 's3':
+ print('Error. Sync protocol not supported when using resume latest.')
+ return -1
+ if is_master(args):
+ # Checking for existing checkpoint via master rank only. It is possible for
+ # different rank processes to see different files if a shared file-system is under
+ # stress, however it's very difficult to fully work around such situations.
+ if args.save_most_recent:
+ # if --save-most-recent flag is set, look for latest at a fixed filename
+ resume_from = os.path.join(checkpoint_path, LATEST_CHECKPOINT_NAME)
+ if not os.path.exists(resume_from):
+ # If no latest checkpoint has been saved yet, don't try to resume
+ resume_from = None
+ else:
+ # otherwise, list checkpoint dir contents and pick the newest checkpoint
+ resume_from = get_latest_checkpoint(checkpoint_path, remote=args.remote_sync is not None)
+ if resume_from:
+ logging.info(f'Found latest resume checkpoint at {resume_from}.')
+ else:
+ logging.info(f'No latest resume checkpoint found in {checkpoint_path}.')
+ if args.distributed:
+ # sync found checkpoint path to all ranks
+ resume_from = broadcast_object(args, resume_from)
+ args.resume = resume_from
+
+ if args.copy_codebase:
+ copy_codebase(args)
+
+ # start the sync proces if remote-sync is not None
+ remote_sync_process = None
+ if is_master(args) and args.remote_sync is not None:
+ # first make sure it works
+ result = remote_sync(
+ os.path.join(args.logs, args.name),
+ os.path.join(args.remote_sync, args.name),
+ args.remote_sync_protocol
+ )
+ if result:
+ logging.info('remote sync successful.')
+ else:
+ logging.info('Error: remote sync failed. Exiting.')
+ return -1
+ # if all looks good, start a process to do this every args.remote_sync_frequency seconds
+ remote_sync_process = start_sync_process(
+ args.remote_sync_frequency,
+ os.path.join(args.logs, args.name),
+ os.path.join(args.remote_sync, args.name),
+ args.remote_sync_protocol
+ )
+ remote_sync_process.start()
+
+ if args.precision == 'fp16':
+ logging.warning(
+ 'It is recommended to use AMP mixed-precision instead of FP16. '
+ 'FP16 support needs further verification and tuning, especially for train.')
+
+ if args.horovod:
+ logging.info(
+ f'Running in horovod mode with multiple processes / nodes. Device: {args.device}.'
+ f'Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}.')
+ elif args.distributed:
+ logging.info(
+ f'Running in distributed mode with multiple processes. Device: {args.device}.'
+ f'Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}.')
+ else:
+ logging.info(f'Running with a single process. Device {args.device}.')
+
+ if isinstance(args.force_image_size, (tuple, list)) and len(args.force_image_size) == 1:
+ # arg is nargs, single (square) image size list -> int
+ args.force_image_size = args.force_image_size[0]
+ random_seed(args.seed, 0)
+ model, preprocess_train, preprocess_val = create_model_and_transforms(
+ args.model,
+ args.pretrained,
+ precision=args.precision,
+ device=device,
+ jit=args.torchscript,
+ force_quick_gelu=args.force_quick_gelu,
+ force_custom_text=args.force_custom_text,
+ force_patch_dropout=args.force_patch_dropout,
+ force_image_size=args.force_image_size,
+ pretrained_image=args.pretrained_image,
+ image_mean=args.image_mean,
+ image_std=args.image_std,
+ aug_cfg=args.aug_cfg,
+ )
+ random_seed(args.seed, args.rank)
+
+ if args.trace:
+ model = trace_model(model, batch_size=args.batch_size, device=device)
+
+ if args.lock_image:
+ # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
+ model.lock_image_tower(
+ unlocked_groups=args.lock_image_unlocked_groups,
+ freeze_bn_stats=args.lock_image_freeze_bn_stats)
+ if args.lock_text:
+ model.lock_text_tower(
+ unlocked_layers=args.lock_text_unlocked_layers,
+ freeze_layer_norm=args.lock_text_freeze_layer_norm)
+
+ if args.grad_checkpointing:
+ model.set_grad_checkpointing()
+
+ if is_master(args):
+ logging.info("Model:")
+ logging.info(f"{str(model)}")
+ logging.info("Params:")
+ params_file = os.path.join(args.logs, args.name, "params.txt")
+ with open(params_file, "w") as f:
+ for name in sorted(vars(args)):
+ val = getattr(args, name)
+ logging.info(f" {name}: {val}")
+ f.write(f"{name}: {val}\n")
+
+ if args.distributed and not args.horovod:
+ if args.use_bn_sync:
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
+ ddp_args = {}
+ if args.ddp_static_graph:
+ # this doesn't exist in older PyTorch, arg only added if enabled
+ ddp_args['static_graph'] = True
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device], **ddp_args)
+
+ # create optimizer and scaler
+ optimizer = None
+ scaler = None
+
+ if args.train_data or args.dataset_type == "synthetic":
+ assert not args.trace, 'Cannot train with traced model'
+
+ exclude = lambda n, p: p.ndim < 2 or "bn" in n or "ln" in n or "bias" in n or 'logit_scale' in n
+ include = lambda n, p: not exclude(n, p)
+
+ named_parameters = list(model.named_parameters())
+ gain_or_bias_params = [p for n, p in named_parameters if exclude(n, p) and p.requires_grad]
+ rest_params = [p for n, p in named_parameters if include(n, p) and p.requires_grad]
+
+ optimizer = optim.AdamW(
+ [
+ {"params": gain_or_bias_params, "weight_decay": 0.},
+ {"params": rest_params, "weight_decay": args.wd},
+ ],
+ lr=args.lr,
+ betas=(args.beta1, args.beta2),
+ eps=args.eps,
+ )
+ if args.horovod:
+ optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters())
+ hvd.broadcast_parameters(model.state_dict(), root_rank=0)
+ hvd.broadcast_optimizer_state(optimizer, root_rank=0)
+
+ scaler = GradScaler() if args.precision == "amp" else None
+
+ # optionally resume from a checkpoint
+ start_epoch = 0
+ if args.resume is not None:
+ checkpoint = pt_load(args.resume, map_location='cpu')
+ if 'epoch' in checkpoint:
+ # resuming a train checkpoint w/ epoch and optimizer state
+ start_epoch = checkpoint["epoch"]
+ sd = checkpoint["state_dict"]
+ if not args.distributed and next(iter(sd.items()))[0].startswith('module'):
+ sd = {k[len('module.'):]: v for k, v in sd.items()}
+ model.load_state_dict(sd)
+ if optimizer is not None:
+ optimizer.load_state_dict(checkpoint["optimizer"])
+ if scaler is not None and 'scaler' in checkpoint:
+ scaler.load_state_dict(checkpoint['scaler'])
+ logging.info(f"=> resuming checkpoint '{args.resume}' (epoch {start_epoch})")
+ else:
+ # loading a bare (model only) checkpoint for fine-tune or evaluation
+ model.load_state_dict(checkpoint)
+ logging.info(f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})")
+
+ # initialize datasets
+ data = get_data(args, (preprocess_train, preprocess_val), epoch=start_epoch, tokenizer=get_tokenizer(args.model))
+ assert len(data), 'At least one train or eval dataset must be specified.'
+
+ # create scheduler if train
+ scheduler = None
+ if 'train' in data and optimizer is not None:
+ total_steps = (data["train"].dataloader.num_batches // args.accum_freq) * args.epochs
+ if args.lr_scheduler == "cosine":
+ scheduler = cosine_lr(optimizer, args.lr, args.warmup, total_steps)
+ elif args.lr_scheduler == "const":
+ scheduler = const_lr(optimizer, args.lr, args.warmup, total_steps)
+ elif args.lr_scheduler == "const-cooldown":
+ assert args.epochs_cooldown is not None,\
+ "Please specify the number of cooldown epochs for this lr schedule."
+ cooldown_steps = (data["train"].dataloader.num_batches // args.accum_freq) * args.epochs_cooldown
+ scheduler = const_lr_cooldown(
+ optimizer, args.lr, args.warmup, total_steps,
+ cooldown_steps, args.lr_cooldown_power, args.lr_cooldown_end)
+ else:
+ logging.error(
+ f'Unknown scheduler, {args.lr_scheduler}. Available options are: cosine, const, const-cooldown.')
+ exit(1)
+
+ # determine if this worker should save logs and checkpoints. only do so if it is rank == 0
+ args.save_logs = args.logs and args.logs.lower() != 'none' and is_master(args)
+ writer = None
+ if args.save_logs and args.tensorboard:
+ assert tensorboard is not None, "Please install tensorboard."
+ writer = tensorboard.SummaryWriter(args.tensorboard_path)
+
+ if args.wandb and is_master(args):
+ assert wandb is not None, 'Please install wandb.'
+ logging.debug('Starting wandb.')
+ args.train_sz = data["train"].dataloader.num_samples
+ if args.val_data is not None:
+ args.val_sz = data["val"].dataloader.num_samples
+ # you will have to configure this for your project!
+ wandb.init(
+ project=args.wandb_project_name,
+ name=args.name,
+ id=args.name,
+ notes=args.wandb_notes,
+ tags=[],
+ resume='auto' if args.resume == "latest" else None,
+ config=vars(args),
+ )
+ if args.debug:
+ wandb.watch(model, log='all')
+ wandb.save(params_file)
+ logging.debug('Finished loading wandb.')
+
+ if 'train' not in data:
+ evaluate(model, data, start_epoch, args, writer)
+ return
+
+ for epoch in range(start_epoch, args.epochs):
+ if is_master(args):
+ logging.info(f'Start epoch {epoch}')
+
+ train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, writer)
+ completed_epoch = epoch + 1
+
+ if any(v in data for v in ('val', 'imagenet-val', 'imagenet-v2')):
+ evaluate(model, data, completed_epoch, args, writer)
+
+ # Saving checkpoints.
+ if args.save_logs:
+ checkpoint_dict = {
+ "epoch": completed_epoch,
+ "name": args.name,
+ "state_dict": model.state_dict(),
+ "optimizer": optimizer.state_dict(),
+ }
+ if scaler is not None:
+ checkpoint_dict["scaler"] = scaler.state_dict()
+
+ if completed_epoch == args.epochs or (
+ args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0
+ ):
+ torch.save(
+ checkpoint_dict,
+ os.path.join(args.checkpoint_path, f"epoch_{completed_epoch}.pt"),
+ )
+ if args.delete_previous_checkpoint:
+ previous_checkpoint = os.path.join(args.checkpoint_path, f"epoch_{completed_epoch - 1}.pt")
+ if os.path.exists(previous_checkpoint):
+ os.remove(previous_checkpoint)
+
+ if args.save_most_recent:
+ # try not to corrupt the latest checkpoint if save fails
+ tmp_save_path = os.path.join(args.checkpoint_path, "tmp.pt")
+ latest_save_path = os.path.join(args.checkpoint_path, LATEST_CHECKPOINT_NAME)
+ torch.save(checkpoint_dict, tmp_save_path)
+ os.replace(tmp_save_path, latest_save_path)
+
+ if args.wandb and is_master(args):
+ wandb.finish()
+
+ # run a final sync.
+ if remote_sync_process is not None:
+ logging.info('Final remote sync.')
+ remote_sync_process.terminate()
+ result = remote_sync(
+ os.path.join(args.logs, args.name),
+ os.path.join(args.remote_sync, args.name),
+ args.remote_sync_protocol
+ )
+ if result:
+ logging.info('Final remote sync successful.')
+ else:
+ logging.info('Final remote sync failed.')
+
+
+def copy_codebase(args):
+ from shutil import copytree, ignore_patterns
+ new_code_path = os.path.join(args.logs, args.name, "code")
+ if os.path.exists(new_code_path):
+ print(
+ f"Error. Experiment already exists at {new_code_path}. Use --name to specify a new experiment."
+ )
+ return -1
+ print(f"Copying codebase to {new_code_path}")
+ current_code_path = os.path.realpath(__file__)
+ for _ in range(3):
+ current_code_path = os.path.dirname(current_code_path)
+ copytree(current_code_path, new_code_path, ignore=ignore_patterns('log', 'logs', 'wandb'))
+ print("Done copying code.")
+ return 1
+
+
+if __name__ == "__main__":
+ main(sys.argv[1:])
diff --git a/open_clip/src/training/params.py b/open_clip/src/training/params.py
new file mode 100644
index 0000000000000000000000000000000000000000..44db413a5440ed1d6b151851fb95507177b4f3d2
--- /dev/null
+++ b/open_clip/src/training/params.py
@@ -0,0 +1,403 @@
+import argparse
+import ast
+
+
+def get_default_params(model_name):
+ # Params from paper (https://arxiv.org/pdf/2103.00020.pdf)
+ model_name = model_name.lower()
+ if "vit" in model_name:
+ return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.98, "eps": 1.0e-6}
+ else:
+ return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.999, "eps": 1.0e-8}
+
+
+class ParseKwargs(argparse.Action):
+ def __call__(self, parser, namespace, values, option_string=None):
+ kw = {}
+ for value in values:
+ key, value = value.split('=')
+ try:
+ kw[key] = ast.literal_eval(value)
+ except ValueError:
+ kw[key] = str(value) # fallback to string (avoid need to escape on command line)
+ setattr(namespace, self.dest, kw)
+
+
+def parse_args(args):
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--train-data",
+ type=str,
+ default=None,
+ help="Path to file(s) with training data",
+ )
+ parser.add_argument(
+ "--val-data",
+ type=str,
+ default=None,
+ help="Path to file(s) with validation data",
+ )
+ parser.add_argument(
+ "--train-num-samples",
+ type=int,
+ default=None,
+ help="Number of samples in dataset. Required for webdataset if not available in info file.",
+ )
+ parser.add_argument(
+ "--val-num-samples",
+ type=int,
+ default=None,
+ help="Number of samples in dataset. Useful for webdataset if not available in info file.",
+ )
+ parser.add_argument(
+ "--dataset-type",
+ choices=["webdataset", "csv", "synthetic", "auto"],
+ default="auto",
+ help="Which type of dataset to process."
+ )
+ parser.add_argument(
+ "--dataset-resampled",
+ default=False,
+ action="store_true",
+ help="Whether to use sampling with replacement for webdataset shard selection."
+ )
+ parser.add_argument(
+ "--csv-separator",
+ type=str,
+ default="\t",
+ help="For csv-like datasets, which separator to use."
+ )
+ parser.add_argument(
+ "--csv-img-key",
+ type=str,
+ default="filepath",
+ help="For csv-like datasets, the name of the key for the image paths."
+ )
+ parser.add_argument(
+ "--csv-caption-key",
+ type=str,
+ default="title",
+ help="For csv-like datasets, the name of the key for the captions."
+ )
+ parser.add_argument(
+ "--imagenet-val",
+ type=str,
+ default=None,
+ help="Path to imagenet val set for conducting zero shot evaluation.",
+ )
+ parser.add_argument(
+ "--imagenet-v2",
+ type=str,
+ default=None,
+ help="Path to imagenet v2 for conducting zero shot evaluation.",
+ )
+ parser.add_argument(
+ "--logs",
+ type=str,
+ default="./logs/",
+ help="Where to store tensorboard logs. Use None to avoid storing logs.",
+ )
+ parser.add_argument(
+ "--log-local",
+ action="store_true",
+ default=False,
+ help="log files on local master, otherwise global master only.",
+ )
+ parser.add_argument(
+ "--name",
+ type=str,
+ default=None,
+ help="Optional identifier for the experiment when storing logs. Otherwise use current time.",
+ )
+ parser.add_argument(
+ "--workers", type=int, default=1, help="Number of dataloader workers per GPU."
+ )
+ parser.add_argument(
+ "--batch-size", type=int, default=64, help="Batch size per GPU."
+ )
+ parser.add_argument(
+ "--epochs", type=int, default=32, help="Number of epochs to train for."
+ )
+ parser.add_argument(
+ "--epochs-cooldown", type=int, default=None,
+ help="When scheduler w/ cooldown used, perform cooldown from total_epochs - cooldown_epochs onwards."
+ )
+ parser.add_argument("--lr", type=float, default=None, help="Learning rate.")
+ parser.add_argument("--beta1", type=float, default=None, help="Adam beta 1.")
+ parser.add_argument("--beta2", type=float, default=None, help="Adam beta 2.")
+ parser.add_argument("--eps", type=float, default=None, help="Adam epsilon.")
+ parser.add_argument("--wd", type=float, default=0.2, help="Weight decay.")
+ parser.add_argument(
+ "--warmup", type=int, default=10000, help="Number of steps to warmup for."
+ )
+ parser.add_argument(
+ "--use-bn-sync",
+ default=False,
+ action="store_true",
+ help="Whether to use batch norm sync.")
+ parser.add_argument(
+ "--skip-scheduler",
+ action="store_true",
+ default=False,
+ help="Use this flag to skip the learning rate decay.",
+ )
+ parser.add_argument(
+ "--lr-scheduler",
+ type=str,
+ default='cosine',
+ help="LR scheduler. One of: 'cosine', 'const' (constant), 'const-cooldown' (constant w/ cooldown). Default: cosine",
+ )
+ parser.add_argument(
+ "--lr-cooldown-end", type=float, default=0.0,
+ help="End learning rate for cooldown schedule. Default: 0"
+ )
+ parser.add_argument(
+ "--lr-cooldown-power", type=float, default=1.0,
+ help="Power for polynomial cooldown schedule. Default: 1.0 (linear decay)"
+ )
+ parser.add_argument(
+ "--save-frequency", type=int, default=1, help="How often to save checkpoints."
+ )
+ parser.add_argument(
+ "--save-most-recent",
+ action="store_true",
+ default=False,
+ help="Always save the most recent model trained to epoch_latest.pt.",
+ )
+ parser.add_argument(
+ "--zeroshot-frequency", type=int, default=2, help="How often to run zero shot."
+ )
+ parser.add_argument(
+ "--val-frequency", type=int, default=1, help="How often to run evaluation with val data."
+ )
+ parser.add_argument(
+ "--resume",
+ default=None,
+ type=str,
+ help="path to latest checkpoint (default: none)",
+ )
+ parser.add_argument(
+ "--precision",
+ choices=["amp", "amp_bf16", "amp_bfloat16", "bf16", "fp16", "fp32"],
+ default="amp",
+ help="Floating point precision."
+ )
+ parser.add_argument(
+ "--model",
+ type=str,
+ default="RN50",
+ help="Name of the vision backbone to use.",
+ )
+ parser.add_argument(
+ "--pretrained",
+ default='',
+ type=str,
+ help="Use a pretrained CLIP model weights with the specified tag or file path.",
+ )
+ parser.add_argument(
+ "--pretrained-image",
+ default=False,
+ action='store_true',
+ help="Load imagenet pretrained weights for image tower backbone if available.",
+ )
+ parser.add_argument(
+ "--lock-image",
+ default=False,
+ action='store_true',
+ help="Lock full image tower by disabling gradients.",
+ )
+ parser.add_argument(
+ "--lock-image-unlocked-groups",
+ type=int,
+ default=0,
+ help="Leave last n image tower layer groups unlocked.",
+ )
+ parser.add_argument(
+ "--lock-image-freeze-bn-stats",
+ default=False,
+ action='store_true',
+ help="Freeze BatchNorm running stats in image tower for any locked layers.",
+ )
+ parser.add_argument(
+ '--image-mean', type=float, nargs='+', default=None, metavar='MEAN',
+ help='Override default image mean value of dataset')
+ parser.add_argument(
+ '--image-std', type=float, nargs='+', default=None, metavar='STD',
+ help='Override default image std deviation of of dataset')
+ parser.add_argument('--aug-cfg', nargs='*', default={}, action=ParseKwargs)
+ parser.add_argument(
+ "--grad-checkpointing",
+ default=False,
+ action='store_true',
+ help="Enable gradient checkpointing.",
+ )
+ parser.add_argument(
+ "--local-loss",
+ default=False,
+ action="store_true",
+ help="calculate loss w/ local features @ global (instead of realizing full global @ global matrix)"
+ )
+ parser.add_argument(
+ "--gather-with-grad",
+ default=False,
+ action="store_true",
+ help="enable full distributed gradient for feature gather"
+ )
+ parser.add_argument(
+ '--force-image-size', type=int, nargs='+', default=None,
+ help='Override default image size'
+ )
+ parser.add_argument(
+ "--force-quick-gelu",
+ default=False,
+ action='store_true',
+ help="Force use of QuickGELU activation for non-OpenAI transformer models.",
+ )
+ parser.add_argument(
+ "--force-patch-dropout",
+ default=None,
+ type=float,
+ help="Override the patch dropout during training, for fine tuning with no dropout near the end as in the paper",
+ )
+ parser.add_argument(
+ "--force-custom-text",
+ default=False,
+ action='store_true',
+ help="Force use of CustomTextCLIP model (separate text-tower).",
+ )
+ parser.add_argument(
+ "--torchscript",
+ default=False,
+ action='store_true',
+ help="torch.jit.script the model, also uses jit version of OpenAI models if pretrained=='openai'",
+ )
+ parser.add_argument(
+ "--trace",
+ default=False,
+ action='store_true',
+ help="torch.jit.trace the model for inference / eval only",
+ )
+ parser.add_argument(
+ "--accum-freq", type=int, default=1, help="Update the model every --acum-freq steps."
+ )
+ # arguments for distributed training
+ parser.add_argument(
+ "--dist-url",
+ default="env://",
+ type=str,
+ help="url used to set up distributed training",
+ )
+ parser.add_argument(
+ "--dist-backend", default="nccl", type=str, help="distributed backend"
+ )
+ parser.add_argument(
+ "--report-to",
+ default='',
+ type=str,
+ help="Options are ['wandb', 'tensorboard', 'wandb,tensorboard']"
+ )
+ parser.add_argument(
+ "--wandb-notes",
+ default='',
+ type=str,
+ help="Notes if logging with wandb"
+ )
+ parser.add_argument(
+ "--wandb-project-name",
+ type=str,
+ default='open-clip',
+ help="Name of the project if logging with wandb.",
+ )
+ parser.add_argument(
+ "--debug",
+ default=False,
+ action="store_true",
+ help="If true, more information is logged."
+ )
+ parser.add_argument(
+ "--copy-codebase",
+ default=False,
+ action="store_true",
+ help="If true, we copy the entire base on the log directory, and execute from there."
+ )
+ parser.add_argument(
+ "--horovod",
+ default=False,
+ action="store_true",
+ help="Use horovod for distributed training."
+ )
+ parser.add_argument(
+ "--ddp-static-graph",
+ default=False,
+ action='store_true',
+ help="Enable static graph optimization for DDP in PyTorch >= 1.11.",
+ )
+ parser.add_argument(
+ "--no-set-device-rank",
+ default=False,
+ action="store_true",
+ help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc)."
+ )
+ parser.add_argument(
+ "--seed", type=int, default=0, help="Default random seed."
+ )
+ parser.add_argument(
+ "--grad-clip-norm", type=float, default=None, help="Gradient clip."
+ )
+ parser.add_argument(
+ "--lock-text",
+ default=False,
+ action='store_true',
+ help="Lock full text tower by disabling gradients.",
+ )
+ parser.add_argument(
+ "--lock-text-unlocked-layers",
+ type=int,
+ default=0,
+ help="Leave last n image tower layer groups unlocked.",
+ )
+ parser.add_argument(
+ "--lock-text-freeze-layer-norm",
+ default=False,
+ action='store_true',
+ help="Freeze BatchNorm running stats in image tower for any locked layers.",
+ )
+ parser.add_argument(
+ "--log-every-n-steps",
+ type=int,
+ default=100,
+ help="Log every n steps to tensorboard/console/wandb.",
+ )
+ parser.add_argument(
+ "--remote-sync",
+ type=str,
+ default=None,
+ help="Optinoally sync with a remote path specified by this arg",
+ )
+ parser.add_argument(
+ "--remote-sync-frequency",
+ type=int,
+ default=300,
+ help="How frequently to sync to a remote directly if --remote-sync is not None.",
+ )
+ parser.add_argument(
+ "--remote-sync-protocol",
+ choices=["s3", "fsspec"],
+ default="s3",
+ help="How to do the remote sync backup if --remote-sync is not None.",
+ )
+ parser.add_argument(
+ "--delete-previous-checkpoint",
+ default=False,
+ action="store_true",
+ help="If true, delete previous checkpoint after storing a new one."
+ )
+ args = parser.parse_args(args)
+
+ # If some params are not passed, we use the default values based on model name.
+ default_params = get_default_params(args.model)
+ for name, val in default_params.items():
+ if getattr(args, name) is None:
+ setattr(args, name, val)
+
+ return args
diff --git a/open_clip/src/training/precision.py b/open_clip/src/training/precision.py
new file mode 100644
index 0000000000000000000000000000000000000000..a63b92256518d13afd57261df1568e26b1622201
--- /dev/null
+++ b/open_clip/src/training/precision.py
@@ -0,0 +1,12 @@
+import torch
+from contextlib import suppress
+
+
+def get_autocast(precision):
+ if precision == 'amp':
+ return torch.cuda.amp.autocast
+ elif precision == 'amp_bfloat16' or precision == 'amp_bf16':
+ # amp_bfloat16 is more stable than amp float16 for clip training
+ return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16)
+ else:
+ return suppress
diff --git a/open_clip/src/training/profile.py b/open_clip/src/training/profile.py
new file mode 100644
index 0000000000000000000000000000000000000000..f10372cdef306e5e199db432b23062df1c098cf9
--- /dev/null
+++ b/open_clip/src/training/profile.py
@@ -0,0 +1,158 @@
+import argparse
+
+import torch
+import open_clip
+import pandas as pd
+from fvcore.nn import FlopCountAnalysis, flop_count_str, ActivationCountAnalysis
+
+
+parser = argparse.ArgumentParser(description='OpenCLIP Profiler')
+
+# benchmark specific args
+parser.add_argument('--model', metavar='NAME', default='',
+ help='model(s) to profile')
+parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
+ help='Output csv file for results')
+
+
+def profile_fvcore(
+ model,
+ image_input_size=(3, 224, 224),
+ text_input_size=(77,),
+ batch_size=1,
+ detailed=False,
+ force_cpu=False
+):
+ if force_cpu:
+ model = model.to('cpu')
+ device, dtype = next(model.parameters()).device, next(model.parameters()).dtype
+ example_image_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype)
+ example_text_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64)
+ fca = FlopCountAnalysis(model, (example_image_input, example_text_input))
+ aca = ActivationCountAnalysis(model, (example_image_input, example_text_input))
+ if detailed:
+ fcs = flop_count_str(fca)
+ print(fcs)
+ return fca.total(), aca.total()
+
+
+def profile_fvcore_text(
+ model,
+ text_input_size=(77,),
+ batch_size=1,
+ detailed=False,
+ force_cpu=False
+):
+ if force_cpu:
+ model = model.to('cpu')
+ device = next(model.parameters()).device
+ example_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64)
+ fca = FlopCountAnalysis(model, example_input)
+ aca = ActivationCountAnalysis(model, example_input)
+ if detailed:
+ fcs = flop_count_str(fca)
+ print(fcs)
+ return fca.total(), aca.total()
+
+
+def profile_fvcore_image(
+ model,
+ image_input_size=(3, 224, 224),
+ batch_size=1,
+ detailed=False,
+ force_cpu=False
+):
+ if force_cpu:
+ model = model.to('cpu')
+ device, dtype = next(model.parameters()).device, next(model.parameters()).dtype
+ example_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype)
+ fca = FlopCountAnalysis(model, example_input)
+ aca = ActivationCountAnalysis(model, example_input)
+ if detailed:
+ fcs = flop_count_str(fca)
+ print(fcs)
+ return fca.total(), aca.total()
+
+
+def count_params(model):
+ return sum([m.numel() for m in model.parameters()])
+
+
+def profile_model(model_name):
+ model = open_clip.create_model(model_name, force_custom_text=True, pretrained_hf=False)
+ model.eval()
+ if torch.cuda.is_available():
+ model = model.cuda()
+
+ if isinstance(model.visual.image_size, (tuple, list)):
+ image_input_size = (3,) + tuple(model.visual.image_size[-2:])
+ else:
+ image_input_size = (3, model.visual.image_size, model.visual.image_size)
+ text_input_size = (77,)
+
+ results = {}
+ results['model'] = model_name
+ results['image_size'] = image_input_size[1]
+
+ model_cfg = open_clip.get_model_config(model_name)
+ if model_cfg:
+ vision_cfg = open_clip.CLIPVisionCfg(**model_cfg['vision_cfg'])
+ text_cfg = open_clip.CLIPTextCfg(**model_cfg['text_cfg'])
+ results['image_width'] = int(vision_cfg.width)
+ results['text_width'] = int(text_cfg.width)
+ results['embed_dim'] = int(model_cfg['embed_dim'])
+ else:
+ results['image_width'] = 0
+ results['text_width'] = 0
+ results['embed_dim'] = 0
+
+ retries = 2
+ while retries:
+ retries -= 1
+ try:
+ macs, acts = profile_fvcore(
+ model, image_input_size=image_input_size, text_input_size=text_input_size, force_cpu=not retries)
+
+ image_macs, image_acts = profile_fvcore_image(
+ model.visual, image_input_size=image_input_size, force_cpu=not retries)
+
+ text_macs, text_acts = profile_fvcore_text(
+ model.text, text_input_size=text_input_size, force_cpu=not retries)
+
+ results['gmacs'] = round(macs / 1e9, 2)
+ results['macts'] = round(acts / 1e6, 2)
+ results['mparams'] = round(count_params(model) / 1e6, 2)
+ results['image_gmacs'] = round(image_macs / 1e9, 2)
+ results['image_macts'] = round(image_acts / 1e6, 2)
+ results['image_mparams'] = round(count_params(model.visual) / 1e6, 2)
+ results['text_gmacs'] = round(text_macs / 1e9, 2)
+ results['text_macts'] = round(text_acts / 1e6, 2)
+ results['text_mparams'] = round(count_params(model.text) / 1e6, 2)
+ except RuntimeError as e:
+ pass
+ return results
+
+
+def main():
+ args = parser.parse_args()
+
+ # FIXME accept a text file name to allow lists of models in txt/csv
+ if args.model == 'all':
+ parsed_model = open_clip.list_models()
+ else:
+ parsed_model = args.model.split(',')
+
+ results = []
+ for m in parsed_model:
+ row = profile_model(m)
+ results.append(row)
+
+ df = pd.DataFrame(results, columns=results[0].keys())
+ df = df.sort_values('gmacs')
+ print(df)
+ if args.results_file:
+ df.to_csv(args.results_file, index=False)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/open_clip/src/training/scheduler.py b/open_clip/src/training/scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..fba76fcf1720b11d136a5ab6d3a58ab2fbe42f74
--- /dev/null
+++ b/open_clip/src/training/scheduler.py
@@ -0,0 +1,53 @@
+import numpy as np
+
+
+def assign_learning_rate(optimizer, new_lr):
+ for param_group in optimizer.param_groups:
+ param_group["lr"] = new_lr
+
+
+def _warmup_lr(base_lr, warmup_length, step):
+ return base_lr * (step + 1) / warmup_length
+
+
+def const_lr(optimizer, base_lr, warmup_length, steps):
+ def _lr_adjuster(step):
+ if step < warmup_length:
+ lr = _warmup_lr(base_lr, warmup_length, step)
+ else:
+ lr = base_lr
+ assign_learning_rate(optimizer, lr)
+ return lr
+ return _lr_adjuster
+
+
+def const_lr_cooldown(optimizer, base_lr, warmup_length, steps, cooldown_steps, cooldown_power=1.0, cooldown_end_lr=0.):
+ def _lr_adjuster(step):
+ start_cooldown_step = steps - cooldown_steps
+ if step < warmup_length:
+ lr = _warmup_lr(base_lr, warmup_length, step)
+ else:
+ if step < start_cooldown_step:
+ lr = base_lr
+ else:
+ e = step - start_cooldown_step
+ es = steps - start_cooldown_step
+ # linear decay if power == 1; polynomial decay otherwise;
+ decay = (1 - (e/es)) ** cooldown_power
+ lr = decay * (base_lr - cooldown_end_lr) + cooldown_end_lr
+ assign_learning_rate(optimizer, lr)
+ return lr
+ return _lr_adjuster
+
+
+def cosine_lr(optimizer, base_lr, warmup_length, steps):
+ def _lr_adjuster(step):
+ if step < warmup_length:
+ lr = _warmup_lr(base_lr, warmup_length, step)
+ else:
+ e = step - warmup_length
+ es = steps - warmup_length
+ lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
+ assign_learning_rate(optimizer, lr)
+ return lr
+ return _lr_adjuster
diff --git a/open_clip/src/training/train.py b/open_clip/src/training/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf42f147592e1a1745b067a255688aaf913e9401
--- /dev/null
+++ b/open_clip/src/training/train.py
@@ -0,0 +1,308 @@
+import json
+import logging
+import math
+import os
+import time
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+try:
+ import wandb
+except ImportError:
+ wandb = None
+
+from open_clip import ClipLoss, get_cast_dtype
+from .distributed import is_master
+from .zero_shot import zero_shot_eval
+from .precision import get_autocast
+
+
+class AverageMeter(object):
+ """Computes and stores the average and current value"""
+
+ def __init__(self):
+ self.reset()
+
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
+
+
+def unwrap_model(model):
+ if hasattr(model, 'module'):
+ return model.module
+ else:
+ return model
+
+
+def backward(total_loss, scaler):
+ if scaler is not None:
+ scaler.scale(total_loss).backward()
+ else:
+ total_loss.backward()
+
+
+def train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, tb_writer=None):
+ device = torch.device(args.device)
+ autocast = get_autocast(args.precision)
+ cast_dtype = get_cast_dtype(args.precision)
+
+ model.train()
+ loss = ClipLoss(
+ local_loss=args.local_loss,
+ gather_with_grad=args.gather_with_grad,
+ cache_labels=True,
+ rank=args.rank,
+ world_size=args.world_size,
+ use_horovod=args.horovod)
+
+ data['train'].set_epoch(epoch) # set epoch in process safe manner via sampler or shared_epoch
+ dataloader = data['train'].dataloader
+ num_batches_per_epoch = dataloader.num_batches // args.accum_freq
+ sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10))
+
+ if args.accum_freq > 1:
+ accum_images, accum_texts, accum_image_features, accum_text_features = [], [], [], []
+
+ loss_m = AverageMeter()
+ batch_time_m = AverageMeter()
+ data_time_m = AverageMeter()
+ end = time.time()
+ for i, batch in enumerate(dataloader):
+ i_accum = i // args.accum_freq
+ step = num_batches_per_epoch * epoch + i_accum
+
+ if not args.skip_scheduler:
+ scheduler(step)
+
+ images, texts = batch
+ images = images.to(device=device, dtype=cast_dtype, non_blocking=True)
+ texts = texts.to(device=device, non_blocking=True)
+
+ data_time_m.update(time.time() - end)
+ optimizer.zero_grad()
+
+ if args.accum_freq == 1:
+ with autocast():
+ image_features, text_features, logit_scale = model(images, texts)
+ total_loss = loss(image_features, text_features, logit_scale)
+
+ backward(total_loss, scaler)
+ else:
+ # First, cache the features without any gradient tracking.
+ with torch.no_grad():
+ with autocast():
+ chunk_image_features, chunk_text_features, _ = model(images, texts)
+ accum_image_features.append(chunk_image_features)
+ accum_text_features.append(chunk_text_features)
+
+ accum_images.append(images)
+ accum_texts.append(texts)
+
+ # If (i + 1) % accum_freq is not zero, move on to the next batch.
+ if ((i + 1) % args.accum_freq) > 0:
+ # FIXME this makes data time logging unreliable when accumulating
+ continue
+
+ # Now, ready to take gradients for the last accum_freq batches.
+ # Re-do the forward pass for those batches, and use the cached features from the other batches as negatives.
+ # Call backwards each time, but only step optimizer at the end.
+ optimizer.zero_grad()
+ for j in range(args.accum_freq):
+ images = accum_images[j]
+ texts = accum_texts[j]
+ with autocast():
+ chunk_image_features, chunk_text_features, logit_scale = model(images, texts)
+ image_features = torch.cat(
+ accum_image_features[:j] + [chunk_image_features] + accum_image_features[j + 1:])
+ text_features = torch.cat(
+ accum_text_features[:j] + [chunk_text_features] + accum_text_features[j + 1:])
+ total_loss = loss(image_features, text_features, logit_scale)
+ backward(total_loss, scaler)
+
+ if scaler is not None:
+ if args.horovod:
+ optimizer.synchronize()
+ scaler.unscale_(optimizer)
+ if args.grad_clip_norm is not None:
+ torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0)
+ with optimizer.skip_synchronize():
+ scaler.step(optimizer)
+ else:
+ if args.grad_clip_norm is not None:
+ scaler.unscale_(optimizer)
+ torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0)
+ scaler.step(optimizer)
+ scaler.update()
+ else:
+ if args.grad_clip_norm is not None:
+ torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0)
+ optimizer.step()
+
+ # reset gradient accum, if enabled
+ if args.accum_freq > 1:
+ accum_images, accum_texts, accum_image_features, accum_text_features = [], [], [], []
+
+ # Note: we clamp to 4.6052 = ln(100), as in the original paper.
+ with torch.no_grad():
+ unwrap_model(model).logit_scale.clamp_(0, math.log(100))
+
+ batch_time_m.update(time.time() - end)
+ end = time.time()
+ batch_count = i_accum + 1
+ if is_master(args) and (i_accum % args.log_every_n_steps == 0 or batch_count == num_batches_per_epoch):
+ batch_size = len(images)
+ num_samples = batch_count * batch_size * args.accum_freq * args.world_size
+ samples_per_epoch = dataloader.num_samples
+ percent_complete = 100.0 * batch_count / num_batches_per_epoch
+
+ # NOTE loss is coarsely sampled, just master node and per log update
+ loss_m.update(total_loss.item(), batch_size)
+ logit_scale_scalar = logit_scale.item()
+ logging.info(
+ f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] "
+ f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) "
+ f"Data (t): {data_time_m.avg:.3f} "
+ f"Batch (t): {batch_time_m.avg:.3f}, {args.accum_freq * args.batch_size * args.world_size / batch_time_m.val:#g}/s "
+ f"LR: {optimizer.param_groups[0]['lr']:5f} "
+ f"Logit Scale: {logit_scale_scalar:.3f}"
+ )
+
+ # Save train loss / etc. Using non avg meter values as loggers have their own smoothing
+ log_data = {
+ "loss": loss_m.val,
+ "data_time": data_time_m.val,
+ "batch_time": batch_time_m.val,
+ "samples_per_second": args.accum_freq * args.batch_size * args.world_size / batch_time_m.val,
+ "scale": logit_scale_scalar,
+ "lr": optimizer.param_groups[0]["lr"]
+ }
+ for name, val in log_data.items():
+ name = "train/" + name
+ if tb_writer is not None:
+ tb_writer.add_scalar(name, val, step)
+ if args.wandb:
+ assert wandb is not None, 'Please install wandb.'
+ wandb.log({name: val, 'step': step})
+
+ # resetting batch / data time meters per log window
+ batch_time_m.reset()
+ data_time_m.reset()
+ # end for
+
+
+def evaluate(model, data, epoch, args, tb_writer=None):
+ metrics = {}
+ if not is_master(args):
+ return metrics
+ device = torch.device(args.device)
+ model.eval()
+
+ zero_shot_metrics = zero_shot_eval(model, data, epoch, args)
+ metrics.update(zero_shot_metrics)
+
+ autocast = get_autocast(args.precision)
+ cast_dtype = get_cast_dtype(args.precision)
+
+ if 'val' in data and (args.val_frequency and ((epoch % args.val_frequency) == 0 or epoch == args.epochs)):
+ dataloader = data['val'].dataloader
+ num_samples = 0
+ samples_per_val = dataloader.num_samples
+
+ # FIXME this does not scale past small eval datasets
+ # all_image_features @ all_text_features will blow up memory and compute very quickly
+ cumulative_loss = 0.0
+ all_image_features, all_text_features = [], []
+ with torch.no_grad():
+ for i, batch in enumerate(dataloader):
+ images, texts = batch
+ images = images.to(device=device, dtype=cast_dtype, non_blocking=True)
+ texts = texts.to(device=device, non_blocking=True)
+
+ with autocast():
+ image_features, text_features, logit_scale = model(images, texts)
+ # features are accumulated in CPU tensors, otherwise GPU memory exhausted quickly
+ # however, system RAM is easily exceeded and compute time becomes problematic
+ all_image_features.append(image_features.cpu())
+ all_text_features.append(text_features.cpu())
+ logit_scale = logit_scale.mean()
+ logits_per_image = logit_scale * image_features @ text_features.t()
+ logits_per_text = logits_per_image.t()
+
+ batch_size = images.shape[0]
+ labels = torch.arange(batch_size, device=device).long()
+ total_loss = (
+ F.cross_entropy(logits_per_image, labels) +
+ F.cross_entropy(logits_per_text, labels)
+ ) / 2
+
+ cumulative_loss += total_loss * batch_size
+ num_samples += batch_size
+ if is_master(args) and (i % 100) == 0:
+ logging.info(
+ f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]\t"
+ f"Loss: {cumulative_loss / num_samples:.6f}\t")
+
+ val_metrics = get_metrics(
+ image_features=torch.cat(all_image_features),
+ text_features=torch.cat(all_text_features),
+ logit_scale=logit_scale.cpu(),
+ )
+ loss = cumulative_loss / num_samples
+ metrics.update(
+ {**val_metrics, "val_loss": loss.item(), "epoch": epoch, "num_samples": num_samples}
+ )
+
+ if not metrics:
+ return metrics
+
+ logging.info(
+ f"Eval Epoch: {epoch} "
+ + "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()])
+ )
+
+ if args.save_logs:
+ for name, val in metrics.items():
+ if tb_writer is not None:
+ tb_writer.add_scalar(f"val/{name}", val, epoch)
+
+ with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f:
+ f.write(json.dumps(metrics))
+ f.write("\n")
+
+ if args.wandb:
+ assert wandb is not None, 'Please install wandb.'
+ for name, val in metrics.items():
+ wandb.log({f"val/{name}": val, 'epoch': epoch})
+
+ return metrics
+
+
+def get_metrics(image_features, text_features, logit_scale):
+ metrics = {}
+ logits_per_image = (logit_scale * image_features @ text_features.t()).detach().cpu()
+ logits_per_text = logits_per_image.t().detach().cpu()
+
+ logits = {"image_to_text": logits_per_image, "text_to_image": logits_per_text}
+ ground_truth = torch.arange(len(text_features)).view(-1, 1)
+
+ for name, logit in logits.items():
+ ranking = torch.argsort(logit, descending=True)
+ preds = torch.where(ranking == ground_truth)[1]
+ preds = preds.detach().cpu().numpy()
+ metrics[f"{name}_mean_rank"] = preds.mean() + 1
+ metrics[f"{name}_median_rank"] = np.floor(np.median(preds)) + 1
+ for k in [1, 5, 10]:
+ metrics[f"{name}_R@{k}"] = np.mean(preds < k)
+
+ return metrics
diff --git a/open_clip/src/training/zero_shot.py b/open_clip/src/training/zero_shot.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5768b4a3ce26f0a9a12d8ee3a6d9490e778a78a
--- /dev/null
+++ b/open_clip/src/training/zero_shot.py
@@ -0,0 +1,93 @@
+import logging
+
+import torch
+import torch.nn.functional as F
+from tqdm import tqdm
+
+from open_clip import get_cast_dtype, get_tokenizer
+from .precision import get_autocast
+from .imagenet_zeroshot_data import imagenet_classnames, openai_imagenet_template
+
+
+def zero_shot_classifier(model, classnames, templates, args):
+ tokenizer = get_tokenizer(args.model)
+ with torch.no_grad():
+ zeroshot_weights = []
+ for classname in tqdm(classnames):
+ texts = [template(classname) for template in templates] # format with class
+ texts = tokenizer(texts).to(args.device) # tokenize
+ if args.distributed and not args.horovod:
+ class_embeddings = model.module.encode_text(texts)
+ else:
+ class_embeddings = model.encode_text(texts)
+ class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0)
+ class_embedding /= class_embedding.norm()
+ zeroshot_weights.append(class_embedding)
+ zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(args.device)
+ return zeroshot_weights
+
+
+def accuracy(output, target, topk=(1,)):
+ pred = output.topk(max(topk), 1, True, True)[1].t()
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
+ return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk]
+
+
+def run(model, classifier, dataloader, args):
+ autocast = get_autocast(args.precision)
+ cast_dtype = get_cast_dtype(args.precision)
+ with torch.no_grad():
+ top1, top5, n = 0., 0., 0.
+ for images, target in tqdm(dataloader, unit_scale=args.batch_size):
+ images = images.to(args.device)
+ if cast_dtype is not None:
+ images = images.to(dtype=cast_dtype)
+ target = target.to(args.device)
+
+ with autocast():
+ # predict
+ if args.distributed and not args.horovod:
+ image_features = model.module.encode_image(images)
+ else:
+ image_features = model.encode_image(images)
+ image_features = F.normalize(image_features, dim=-1)
+ logits = 100. * image_features @ classifier
+
+ # measure accuracy
+ acc1, acc5 = accuracy(logits, target, topk=(1, 5))
+ top1 += acc1
+ top5 += acc5
+ n += images.size(0)
+
+ top1 = (top1 / n)
+ top5 = (top5 / n)
+ return top1, top5
+
+
+def zero_shot_eval(model, data, epoch, args):
+ if 'imagenet-val' not in data and 'imagenet-v2' not in data:
+ return {}
+ if args.zeroshot_frequency == 0:
+ return {}
+ if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs:
+ return {}
+
+ logging.info('Starting zero-shot imagenet.')
+
+ logging.info('Building zero-shot classifier')
+ classifier = zero_shot_classifier(model, imagenet_classnames, openai_imagenet_template, args)
+
+ logging.info('Using classifier')
+ results = {}
+ if 'imagenet-val' in data:
+ top1, top5 = run(model, classifier, data['imagenet-val'].dataloader, args)
+ results['imagenet-zeroshot-val-top1'] = top1
+ results['imagenet-zeroshot-val-top5'] = top5
+ if 'imagenet-v2' in data:
+ top1, top5 = run(model, classifier, data['imagenet-v2'].dataloader, args)
+ results['imagenetv2-zeroshot-val-top1'] = top1
+ results['imagenetv2-zeroshot-val-top5'] = top5
+
+ logging.info('Finished zero-shot imagenet.')
+
+ return results
diff --git a/open_clip/tests/test_download_pretrained.py b/open_clip/tests/test_download_pretrained.py
new file mode 100644
index 0000000000000000000000000000000000000000..6340918ed5b7c56fdbbfb84e2bcb26ccf662c8b5
--- /dev/null
+++ b/open_clip/tests/test_download_pretrained.py
@@ -0,0 +1,111 @@
+import requests
+import torch
+from PIL import Image
+import hashlib
+import tempfile
+import unittest
+from io import BytesIO
+from pathlib import Path
+from unittest.mock import patch
+
+from urllib3 import HTTPResponse
+from urllib3._collections import HTTPHeaderDict
+
+import open_clip
+from open_clip.pretrained import download_pretrained_from_url
+
+
+class DownloadPretrainedTests(unittest.TestCase):
+
+ def create_response(self, data, status_code=200, content_type='application/octet-stream'):
+ fp = BytesIO(data)
+ headers = HTTPHeaderDict({
+ 'Content-Type': content_type,
+ 'Content-Length': str(len(data))
+ })
+ raw = HTTPResponse(fp, preload_content=False, headers=headers, status=status_code)
+ return raw
+
+ @patch('open_clip.pretrained.urllib')
+ def test_download_pretrained_from_url_from_openaipublic(self, urllib):
+ file_contents = b'pretrained model weights'
+ expected_hash = hashlib.sha256(file_contents).hexdigest()
+ urllib.request.urlopen.return_value = self.create_response(file_contents)
+ with tempfile.TemporaryDirectory() as root:
+ url = f'https://openaipublic.azureedge.net/clip/models/{expected_hash}/RN50.pt'
+ download_pretrained_from_url(url, root)
+ urllib.request.urlopen.assert_called_once()
+
+ @patch('open_clip.pretrained.urllib')
+ def test_download_pretrained_from_url_from_openaipublic_corrupted(self, urllib):
+ file_contents = b'pretrained model weights'
+ expected_hash = hashlib.sha256(file_contents).hexdigest()
+ urllib.request.urlopen.return_value = self.create_response(b'corrupted pretrained model')
+ with tempfile.TemporaryDirectory() as root:
+ url = f'https://openaipublic.azureedge.net/clip/models/{expected_hash}/RN50.pt'
+ with self.assertRaisesRegex(RuntimeError, r'checksum does not not match'):
+ download_pretrained_from_url(url, root)
+ urllib.request.urlopen.assert_called_once()
+
+ @patch('open_clip.pretrained.urllib')
+ def test_download_pretrained_from_url_from_openaipublic_valid_cache(self, urllib):
+ file_contents = b'pretrained model weights'
+ expected_hash = hashlib.sha256(file_contents).hexdigest()
+ urllib.request.urlopen.return_value = self.create_response(file_contents)
+ with tempfile.TemporaryDirectory() as root:
+ local_file = Path(root) / 'RN50.pt'
+ local_file.write_bytes(file_contents)
+ url = f'https://openaipublic.azureedge.net/clip/models/{expected_hash}/RN50.pt'
+ download_pretrained_from_url(url, root)
+ urllib.request.urlopen.assert_not_called()
+
+ @patch('open_clip.pretrained.urllib')
+ def test_download_pretrained_from_url_from_openaipublic_corrupted_cache(self, urllib):
+ file_contents = b'pretrained model weights'
+ expected_hash = hashlib.sha256(file_contents).hexdigest()
+ urllib.request.urlopen.return_value = self.create_response(file_contents)
+ with tempfile.TemporaryDirectory() as root:
+ local_file = Path(root) / 'RN50.pt'
+ local_file.write_bytes(b'corrupted pretrained model')
+ url = f'https://openaipublic.azureedge.net/clip/models/{expected_hash}/RN50.pt'
+ download_pretrained_from_url(url, root)
+ urllib.request.urlopen.assert_called_once()
+
+ @patch('open_clip.pretrained.urllib')
+ def test_download_pretrained_from_url_from_mlfoundations(self, urllib):
+ file_contents = b'pretrained model weights'
+ expected_hash = hashlib.sha256(file_contents).hexdigest()[:8]
+ urllib.request.urlopen.return_value = self.create_response(file_contents)
+ with tempfile.TemporaryDirectory() as root:
+ url = f'https://github.com/mlfoundations/download/v0.2-weights/rn50-quickgelu-{expected_hash}.pt'
+ download_pretrained_from_url(url, root)
+ urllib.request.urlopen.assert_called_once()
+
+ @patch('open_clip.pretrained.urllib')
+ def test_download_pretrained_from_url_from_mlfoundations_corrupted(self, urllib):
+ file_contents = b'pretrained model weights'
+ expected_hash = hashlib.sha256(file_contents).hexdigest()[:8]
+ urllib.request.urlopen.return_value = self.create_response(b'corrupted pretrained model')
+ with tempfile.TemporaryDirectory() as root:
+ url = f'https://github.com/mlfoundations/download/v0.2-weights/rn50-quickgelu-{expected_hash}.pt'
+ with self.assertRaisesRegex(RuntimeError, r'checksum does not not match'):
+ download_pretrained_from_url(url, root)
+ urllib.request.urlopen.assert_called_once()
+
+ @patch('open_clip.pretrained.urllib')
+ def test_download_pretrained_from_hfh(self, urllib):
+ model, _, preprocess = open_clip.create_model_and_transforms('hf-hub:hf-internal-testing/tiny-open-clip-model')
+ tokenizer = open_clip.get_tokenizer('hf-hub:hf-internal-testing/tiny-open-clip-model')
+ img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/coco_sample.png"
+ image = preprocess(Image.open(requests.get(img_url, stream=True).raw)).unsqueeze(0)
+ text = tokenizer(["a diagram", "a dog", "a cat"])
+
+ with torch.no_grad():
+ image_features = model.encode_image(image)
+ text_features = model.encode_text(text)
+ image_features /= image_features.norm(dim=-1, keepdim=True)
+ text_features /= text_features.norm(dim=-1, keepdim=True)
+
+ text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
+
+ self.assertTrue(torch.allclose(text_probs, torch.tensor([[0.0597, 0.6349, 0.3053]]), 1e-3))
diff --git a/open_clip/tests/test_hf_model.py b/open_clip/tests/test_hf_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..79df2f2cf3655b1299ed236791d606c5215dae2c
--- /dev/null
+++ b/open_clip/tests/test_hf_model.py
@@ -0,0 +1,29 @@
+import pytest
+
+import torch
+from open_clip.hf_model import _POOLERS, HFTextEncoder
+from transformers import AutoConfig
+from transformers.modeling_outputs import BaseModelOutput
+# test poolers
+def test_poolers():
+ bs, sl, d = 2, 10, 5
+ h = torch.arange(sl).repeat(bs).reshape(bs, sl)[..., None] * torch.linspace(0.2, 1., d)
+ mask = torch.ones(bs, sl, dtype=torch.long)
+ mask[:2, 6:] = 0
+ x = BaseModelOutput(h)
+ for name, cls in _POOLERS.items():
+ pooler = cls()
+ res = pooler(x, mask)
+ assert res.shape == (bs, d), f"{name} returned wrong shape"
+
+# test HFTextEncoder
+@pytest.mark.parametrize("model_id", ["arampacha/roberta-tiny", "roberta-base", "xlm-roberta-base", "google/mt5-base"])
+def test_pretrained_text_encoder(model_id):
+ bs, sl, d = 2, 10, 64
+ cfg = AutoConfig.from_pretrained(model_id)
+ model = HFTextEncoder(model_id, d, proj='linear')
+ x = torch.randint(0, cfg.vocab_size, (bs, sl))
+ with torch.no_grad():
+ emb = model(x)
+
+ assert emb.shape == (bs, d)
diff --git a/open_clip/tests/test_inference.py b/open_clip/tests/test_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..ecd46d07271893e209a0050377129129fc75ef6b
--- /dev/null
+++ b/open_clip/tests/test_inference.py
@@ -0,0 +1,82 @@
+
+import os
+import pytest
+import torch
+import open_clip
+import util_test
+
+os.environ['CUDA_VISIBLE_DEVICES'] = ''
+
+models_to_test = set(open_clip.list_models())
+
+# testing excemptions
+models_to_test = models_to_test.difference({
+ # not available with timm yet
+ # see https://github.com/mlfoundations/open_clip/issues/219
+ 'convnext_xlarge',
+ 'convnext_xxlarge',
+ 'convnext_xxlarge_320',
+ 'vit_medium_patch16_gap_256',
+ # exceeds GH runner memory limit
+ 'ViT-bigG-14',
+ 'ViT-e-14',
+ 'mt5-xl-ViT-H-14',
+})
+
+if 'OPEN_CLIP_TEST_REG_MODELS' in os.environ:
+ external_model_list = os.environ['OPEN_CLIP_TEST_REG_MODELS']
+ with open(external_model_list, 'r') as f:
+ models_to_test = set(f.read().splitlines()).intersection(models_to_test)
+ print(f"Selected models from {external_model_list}: {models_to_test}")
+
+models_to_test = list(models_to_test)
+models_to_test.sort()
+
+@pytest.mark.regression_test
+@pytest.mark.parametrize('model_name', models_to_test)
+def test_inference_with_data(
+ model_name,
+ pretrained = None,
+ pretrained_hf = False,
+ precision = 'fp32',
+ jit = False,
+ force_quick_gelu = False,
+):
+ util_test.seed_all()
+ model, _, preprocess_val = open_clip.create_model_and_transforms(
+ model_name,
+ pretrained = pretrained,
+ precision = precision,
+ jit = jit,
+ force_quick_gelu = force_quick_gelu,
+ pretrained_hf = pretrained_hf
+ )
+ model_id = f'{model_name}_{pretrained or pretrained_hf}_{precision}'
+ input_dir, output_dir = util_test.get_data_dirs()
+ # text
+ input_text_path = os.path.join(input_dir, 'random_text.pt')
+ gt_text_path = os.path.join(output_dir, f'{model_id}_random_text.pt')
+ if not os.path.isfile(input_text_path):
+ pytest.skip(reason = f"missing test data, expected at {input_text_path}")
+ if not os.path.isfile(gt_text_path):
+ pytest.skip(reason = f"missing test data, expected at {gt_text_path}")
+ input_text = torch.load(input_text_path)
+ gt_text = torch.load(gt_text_path)
+ y_text = util_test.inference_text(model, model_name, input_text)
+ assert (y_text == gt_text).all(), f"text output differs @ {input_text_path}"
+ # image
+ image_size = model.visual.image_size
+ if not isinstance(image_size, tuple):
+ image_size = (image_size, image_size)
+ input_image_path = os.path.join(input_dir, f'random_image_{image_size[0]}_{image_size[1]}.pt')
+ gt_image_path = os.path.join(output_dir, f'{model_id}_random_image.pt')
+ if not os.path.isfile(input_image_path):
+ pytest.skip(reason = f"missing test data, expected at {input_image_path}")
+ if not os.path.isfile(gt_image_path):
+ pytest.skip(reason = f"missing test data, expected at {gt_image_path}")
+ input_image = torch.load(input_image_path)
+ gt_image = torch.load(gt_image_path)
+ y_image = util_test.inference_image(model, preprocess_val, input_image)
+ assert (y_image == gt_image).all(), f"image output differs @ {input_image_path}"
+
+
diff --git a/open_clip/tests/test_inference_simple.py b/open_clip/tests/test_inference_simple.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb6bb49584e8e3005942493b6ed8f2449d323073
--- /dev/null
+++ b/open_clip/tests/test_inference_simple.py
@@ -0,0 +1,26 @@
+
+import torch
+from PIL import Image
+from open_clip.factory import get_tokenizer
+import pytest
+import open_clip
+import os
+os.environ["CUDA_VISIBLE_DEVICES"] = ""
+
+@pytest.mark.parametrize("model_type,pretrained", [("ViT-B-32-quickgelu", "laion400m_e32"), ("roberta-ViT-B-32", "laion2b_s12b_b32k")])
+def test_inference_simple(model_type, pretrained):
+ model, _, preprocess = open_clip.create_model_and_transforms(model_type, pretrained=pretrained, jit=False)
+ tokenizer = get_tokenizer(model_type)
+
+ current_dir = os.path.dirname(os.path.realpath(__file__))
+
+ image = preprocess(Image.open(current_dir + "/../docs/CLIP.png")).unsqueeze(0)
+ text = tokenizer(["a diagram", "a dog", "a cat"])
+
+ with torch.no_grad():
+ image_features = model.encode_image(image)
+ text_features = model.encode_text(text)
+
+ text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
+
+ assert text_probs.cpu().numpy()[0].tolist() == [1.0, 0.0, 0.0]
diff --git a/open_clip/tests/test_num_shards.py b/open_clip/tests/test_num_shards.py
new file mode 100644
index 0000000000000000000000000000000000000000..70ca8feccd6ff5be4b04a5d9da7b47ab99e36fa3
--- /dev/null
+++ b/open_clip/tests/test_num_shards.py
@@ -0,0 +1,20 @@
+import pytest
+
+from training.data import get_dataset_size
+
+@pytest.mark.parametrize(
+ "shards,expected_size",
+ [
+ ('/path/to/shard.tar', 1),
+ ('/path/to/shard_{000..000}.tar', 1),
+ ('/path/to/shard_{000..009}.tar', 10),
+ ('/path/to/shard_{000..009}_{000..009}.tar', 100),
+ ('/path/to/shard.tar::/path/to/other_shard_{000..009}.tar', 11),
+ ('/path/to/shard_{000..009}.tar::/path/to/other_shard_{000..009}.tar', 20),
+ (['/path/to/shard.tar'], 1),
+ (['/path/to/shard.tar', '/path/to/other_shard.tar'], 2),
+ ]
+)
+def test_num_shards(shards, expected_size):
+ _, size = get_dataset_size(shards)
+ assert size == expected_size, f'Expected {expected_size} for {shards} but found {size} instead.'
diff --git a/open_clip/tests/test_training_simple.py b/open_clip/tests/test_training_simple.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe55b33286215ada9e6494272ad7bfc8ef5b8aea
--- /dev/null
+++ b/open_clip/tests/test_training_simple.py
@@ -0,0 +1,63 @@
+
+import os
+import sys
+import pytest
+from PIL import Image
+import torch
+from training.main import main
+
+os.environ["CUDA_VISIBLE_DEVICES"] = ""
+
+@pytest.mark.skipif(sys.platform.startswith('darwin'), reason="macos pickle bug with locals")
+def test_training():
+ main([
+ '--save-frequency', '1',
+ '--zeroshot-frequency', '1',
+ '--dataset-type', "synthetic",
+ '--train-num-samples', '16',
+ '--warmup', '1',
+ '--batch-size', '4',
+ '--lr', '1e-3',
+ '--wd', '0.1',
+ '--epochs', '1',
+ '--workers', '2',
+ '--model', 'RN50'
+ ])
+
+@pytest.mark.skipif(sys.platform.startswith('darwin'), reason="macos pickle bug with locals")
+def test_training_mt5():
+ main([
+ '--save-frequency', '1',
+ '--zeroshot-frequency', '1',
+ '--dataset-type', "synthetic",
+ '--train-num-samples', '16',
+ '--warmup', '1',
+ '--batch-size', '4',
+ '--lr', '1e-3',
+ '--wd', '0.1',
+ '--epochs', '1',
+ '--workers', '2',
+ '--model', 'mt5-base-ViT-B-32',
+ '--lock-text',
+ '--lock-text-unlocked-layers', '2'
+ ])
+
+
+
+@pytest.mark.skipif(sys.platform.startswith('darwin'), reason="macos pickle bug with locals")
+def test_training_unfreezing_vit():
+ main([
+ '--save-frequency', '1',
+ '--zeroshot-frequency', '1',
+ '--dataset-type', "synthetic",
+ '--train-num-samples', '16',
+ '--warmup', '1',
+ '--batch-size', '4',
+ '--lr', '1e-3',
+ '--wd', '0.1',
+ '--epochs', '1',
+ '--workers', '2',
+ '--model', 'ViT-B-32',
+ '--lock-image',
+ '--lock-image-unlocked-groups', '5'
+ ])
\ No newline at end of file
diff --git a/open_clip/tests/util_test.py b/open_clip/tests/util_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2a2c9c3d4726672ca421d71b768de225c4d34b5
--- /dev/null
+++ b/open_clip/tests/util_test.py
@@ -0,0 +1,287 @@
+import os
+import random
+import numpy as np
+from PIL import Image
+import torch
+
+if __name__ != '__main__':
+ import open_clip
+
+os.environ['CUDA_VISIBLE_DEVICES'] = ''
+
+def seed_all(seed = 0):
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+ torch.use_deterministic_algorithms(True, warn_only=False)
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+
+def inference_text(model, model_name, batches):
+ y = []
+ tokenizer = open_clip.get_tokenizer(model_name)
+ with torch.no_grad():
+ for x in batches:
+ x = tokenizer(x)
+ y.append(model.encode_text(x))
+ return torch.stack(y)
+
+def inference_image(model, preprocess_val, batches):
+ y = []
+ with torch.no_grad():
+ for x in batches:
+ x = torch.stack([preprocess_val(img) for img in x])
+ y.append(model.encode_image(x))
+ return torch.stack(y)
+
+def random_image_batch(batch_size, size):
+ h, w = size
+ data = np.random.randint(255, size = (batch_size, h, w, 3), dtype = np.uint8)
+ return [ Image.fromarray(d) for d in data ]
+
+def random_text_batch(batch_size, min_length = 75, max_length = 75):
+ t = open_clip.tokenizer.SimpleTokenizer()
+ # every token decoded as string, exclude SOT and EOT, replace EOW with space
+ token_words = [
+ x[1].replace('', ' ')
+ for x in t.decoder.items()
+ if x[0] not in t.all_special_ids
+ ]
+ # strings of randomly chosen tokens
+ return [
+ ''.join(random.choices(
+ token_words,
+ k = random.randint(min_length, max_length)
+ ))
+ for _ in range(batch_size)
+ ]
+
+def create_random_text_data(
+ path,
+ min_length = 75,
+ max_length = 75,
+ batches = 1,
+ batch_size = 1
+):
+ text_batches = [
+ random_text_batch(batch_size, min_length, max_length)
+ for _ in range(batches)
+ ]
+ print(f"{path}")
+ torch.save(text_batches, path)
+
+def create_random_image_data(path, size, batches = 1, batch_size = 1):
+ image_batches = [
+ random_image_batch(batch_size, size)
+ for _ in range(batches)
+ ]
+ print(f"{path}")
+ torch.save(image_batches, path)
+
+def get_data_dirs(make_dir = True):
+ data_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data')
+ input_dir = os.path.join(data_dir, 'input')
+ output_dir = os.path.join(data_dir, 'output')
+ if make_dir:
+ os.makedirs(input_dir, exist_ok = True)
+ os.makedirs(output_dir, exist_ok = True)
+ assert os.path.isdir(data_dir), f"data directory missing, expected at {input_dir}"
+ assert os.path.isdir(data_dir), f"data directory missing, expected at {output_dir}"
+ return input_dir, output_dir
+
+def create_test_data_for_model(
+ model_name,
+ pretrained = None,
+ precision = 'fp32',
+ jit = False,
+ pretrained_hf = False,
+ force_quick_gelu = False,
+ create_missing_input_data = True,
+ batches = 1,
+ batch_size = 1,
+ overwrite = False
+):
+ model_id = f'{model_name}_{pretrained or pretrained_hf}_{precision}'
+ input_dir, output_dir = get_data_dirs()
+ output_file_text = os.path.join(output_dir, f'{model_id}_random_text.pt')
+ output_file_image = os.path.join(output_dir, f'{model_id}_random_image.pt')
+ text_exists = os.path.exists(output_file_text)
+ image_exists = os.path.exists(output_file_image)
+ if not overwrite and text_exists and image_exists:
+ return
+ seed_all()
+ model, _, preprocess_val = open_clip.create_model_and_transforms(
+ model_name,
+ pretrained = pretrained,
+ precision = precision,
+ jit = jit,
+ force_quick_gelu = force_quick_gelu,
+ pretrained_hf = pretrained_hf
+ )
+ # text
+ if overwrite or not text_exists:
+ input_file_text = os.path.join(input_dir, 'random_text.pt')
+ if create_missing_input_data and not os.path.exists(input_file_text):
+ create_random_text_data(
+ input_file_text,
+ batches = batches,
+ batch_size = batch_size
+ )
+ assert os.path.isfile(input_file_text), f"missing input data, expected at {input_file_text}"
+ input_data_text = torch.load(input_file_text)
+ output_data_text = inference_text(model, model_name, input_data_text)
+ print(f"{output_file_text}")
+ torch.save(output_data_text, output_file_text)
+ # image
+ if overwrite or not image_exists:
+ size = model.visual.image_size
+ if not isinstance(size, tuple):
+ size = (size, size)
+ input_file_image = os.path.join(input_dir, f'random_image_{size[0]}_{size[1]}.pt')
+ if create_missing_input_data and not os.path.exists(input_file_image):
+ create_random_image_data(
+ input_file_image,
+ size,
+ batches = batches,
+ batch_size = batch_size
+ )
+ assert os.path.isfile(input_file_image), f"missing input data, expected at {input_file_image}"
+ input_data_image = torch.load(input_file_image)
+ output_data_image = inference_image(model, preprocess_val, input_data_image)
+ print(f"{output_file_image}")
+ torch.save(output_data_image, output_file_image)
+
+def create_test_data(
+ models,
+ batches = 1,
+ batch_size = 1,
+ overwrite = False
+):
+ models = list(set(models).difference({
+ # not available with timm
+ # see https://github.com/mlfoundations/open_clip/issues/219
+ 'timm-convnext_xlarge',
+ 'timm-vit_medium_patch16_gap_256'
+ }).intersection(open_clip.list_models()))
+ models.sort()
+ print(f"generating test data for:\n{models}")
+ for model_name in models:
+ print(model_name)
+ create_test_data_for_model(
+ model_name,
+ batches = batches,
+ batch_size = batch_size,
+ overwrite = overwrite
+ )
+ return models
+
+def _sytem_assert(string):
+ assert os.system(string) == 0
+
+def main(args):
+ global open_clip
+ import importlib
+ import shutil
+ import subprocess
+ import argparse
+ parser = argparse.ArgumentParser(description = "Populate test data directory")
+ parser.add_argument(
+ '-a', '--all',
+ action = 'store_true',
+ help = "create test data for all models"
+ )
+ parser.add_argument(
+ '-m', '--model',
+ type = str,
+ default = [],
+ nargs = '+',
+ help = "model(s) to create test data for"
+ )
+ parser.add_argument(
+ '-f', '--model_list',
+ type = str,
+ help = "path to a text file containing a list of model names, one model per line"
+ )
+ parser.add_argument(
+ '-s', '--save_model_list',
+ type = str,
+ help = "path to save the list of models that data was generated for"
+ )
+ parser.add_argument(
+ '-g', '--git_revision',
+ type = str,
+ help = "git revision to generate test data for"
+ )
+ parser.add_argument(
+ '--overwrite',
+ action = 'store_true',
+ help = "overwrite existing output data"
+ )
+ parser.add_argument(
+ '-n', '--num_batches',
+ default = 1,
+ type = int,
+ help = "amount of data batches to create (default: 1)"
+ )
+ parser.add_argument(
+ '-b', '--batch_size',
+ default = 1,
+ type = int,
+ help = "test data batch size (default: 1)"
+ )
+ args = parser.parse_args(args)
+ model_list = []
+ if args.model_list is not None:
+ with open(args.model_list, 'r') as f:
+ model_list = f.read().splitlines()
+ if not args.all and len(args.model) < 1 and len(model_list) < 1:
+ print("error: at least one model name is required")
+ parser.print_help()
+ parser.exit(1)
+ if args.git_revision is not None:
+ stash_output = subprocess.check_output(['git', 'stash']).decode().splitlines()
+ has_stash = len(stash_output) > 0 and stash_output[0] != 'No local changes to save'
+ current_branch = subprocess.check_output(['git', 'branch', '--show-current'])
+ if len(current_branch) < 1:
+ # not on a branch -> detached head
+ current_branch = subprocess.check_output(['git', 'rev-parse', 'HEAD'])
+ current_branch = current_branch.splitlines()[0].decode()
+ try:
+ _sytem_assert(f'git checkout {args.git_revision}')
+ except AssertionError as e:
+ _sytem_assert(f'git checkout -f {current_branch}')
+ if has_stash:
+ os.system(f'git stash pop')
+ raise e
+ open_clip = importlib.import_module('open_clip')
+ models = open_clip.list_models() if args.all else args.model + model_list
+ try:
+ models = create_test_data(
+ models,
+ batches = args.num_batches,
+ batch_size = args.batch_size,
+ overwrite = args.overwrite
+ )
+ finally:
+ if args.git_revision is not None:
+ test_dir = os.path.join(os.path.dirname(__file__), 'data')
+ test_dir_ref = os.path.join(os.path.dirname(__file__), 'data_ref')
+ if os.path.exists(test_dir_ref):
+ shutil.rmtree(test_dir_ref, ignore_errors = True)
+ if os.path.exists(test_dir):
+ os.rename(test_dir, test_dir_ref)
+ _sytem_assert(f'git checkout {current_branch}')
+ if has_stash:
+ os.system(f'git stash pop')
+ os.rename(test_dir_ref, test_dir)
+ if args.save_model_list is not None:
+ print(f"Saving model list as {args.save_model_list}")
+ with open(args.save_model_list, 'w') as f:
+ for m in models:
+ print(m, file=f)
+
+
+if __name__ == '__main__':
+ import sys
+ main(sys.argv[1:])
+
diff --git a/plain_train_net.py b/plain_train_net.py
new file mode 100644
index 0000000000000000000000000000000000000000..5412d18d729b6ec37d01d5232c62d5116e9992b0
--- /dev/null
+++ b/plain_train_net.py
@@ -0,0 +1,536 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+"""
+MaskFormer Training Script.
+
+This script is a simplified version of the training script in detectron2/tools.
+"""
+import copy
+import itertools
+import logging
+import os
+from collections import OrderedDict
+from typing import Any, Dict, List, Set
+
+import torch
+
+import detectron2.utils.comm as comm
+from detectron2.checkpoint import DetectionCheckpointer
+from detectron2.config import get_cfg
+from detectron2.data import MetadataCatalog, build_detection_train_loader
+from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, launch
+from detectron2.evaluation import CityscapesInstanceEvaluator, CityscapesSemSegEvaluator, \
+ COCOEvaluator, COCOPanopticEvaluator, DatasetEvaluators, SemSegEvaluator, verify_results, \
+ DatasetEvaluator
+
+from detectron2.projects.deeplab import add_deeplab_config, build_lr_scheduler
+from detectron2.solver.build import maybe_add_gradient_clipping
+from detectron2.utils.logger import setup_logger
+
+from detectron2.utils.file_io import PathManager
+import numpy as np
+from PIL import Image
+import glob
+
+import pycocotools.mask as mask_util
+
+from detectron2.data import DatasetCatalog, MetadataCatalog
+from detectron2.utils.comm import all_gather, is_main_process, synchronize
+import json
+from torch.nn.parallel import DistributedDataParallel
+from detectron2.engine.train_loop import AMPTrainer, SimpleTrainer, TrainerBase, HookBase
+import weakref
+from detectron2.utils.events import EventStorage
+from detectron2.utils.logger import _log_api_usage
+
+# from detectron2.evaluation import SemSegGzeroEvaluator
+# from mask_former.evaluation.sem_seg_evaluation_gzero import SemSegGzeroEvaluator
+
+class SemSegGzeroEvaluator(DatasetEvaluator):
+ """
+ Evaluate semantic segmentation metrics.
+ """
+
+ def __init__(
+ self, dataset_name, distributed, output_dir=None, *, num_classes=None, ignore_label=None
+ ):
+ """
+ Args:
+ dataset_name (str): name of the dataset to be evaluated.
+ distributed (True): if True, will collect results from all ranks for evaluation.
+ Otherwise, will evaluate the results in the current process.
+ output_dir (str): an output directory to dump results.
+ num_classes, ignore_label: deprecated argument
+ """
+ self._logger = logging.getLogger(__name__)
+ if num_classes is not None:
+ self._logger.warn(
+ "SemSegEvaluator(num_classes) is deprecated! It should be obtained from metadata."
+ )
+ if ignore_label is not None:
+ self._logger.warn(
+ "SemSegEvaluator(ignore_label) is deprecated! It should be obtained from metadata."
+ )
+ self._dataset_name = dataset_name
+ self._distributed = distributed
+ self._output_dir = output_dir
+
+ self._cpu_device = torch.device("cpu")
+
+ self.input_file_to_gt_file = {
+ dataset_record["file_name"]: dataset_record["sem_seg_file_name"]
+ for dataset_record in DatasetCatalog.get(dataset_name)
+ }
+
+ meta = MetadataCatalog.get(dataset_name)
+ # Dict that maps contiguous training ids to COCO category ids
+ try:
+ c2d = meta.stuff_dataset_id_to_contiguous_id
+ self._contiguous_id_to_dataset_id = {v: k for k, v in c2d.items()}
+ except AttributeError:
+ self._contiguous_id_to_dataset_id = None
+ self._class_names = meta.stuff_classes
+ self._val_extra_classes = meta.val_extra_classes
+ self._num_classes = len(meta.stuff_classes)
+ if num_classes is not None:
+ assert self._num_classes == num_classes, f"{self._num_classes} != {num_classes}"
+ self._ignore_label = ignore_label if ignore_label is not None else meta.ignore_label
+
+ def reset(self):
+ self._conf_matrix = np.zeros((self._num_classes + 1, self._num_classes + 1), dtype=np.int64)
+ self._predictions = []
+
+ def process(self, inputs, outputs):
+ """
+ Args:
+ inputs: the inputs to a model.
+ It is a list of dicts. Each dict corresponds to an image and
+ contains keys like "height", "width", "file_name".
+ outputs: the outputs of a model. It is either list of semantic segmentation predictions
+ (Tensor [H, W]) or list of dicts with key "sem_seg" that contains semantic
+ segmentation prediction in the same format.
+ """
+ for input, output in zip(inputs, outputs):
+ output = output["sem_seg"].argmax(dim=0).to(self._cpu_device)
+ pred = np.array(output, dtype=np.int)
+ with PathManager.open(self.input_file_to_gt_file[input["file_name"]], "rb") as f:
+ gt = np.array(Image.open(f), dtype=np.int)
+
+ gt[gt == self._ignore_label] = self._num_classes
+
+ self._conf_matrix += np.bincount(
+ (self._num_classes + 1) * pred.reshape(-1) + gt.reshape(-1),
+ minlength=self._conf_matrix.size,
+ ).reshape(self._conf_matrix.shape)
+
+ self._predictions.extend(self.encode_json_sem_seg(pred, input["file_name"]))
+
+ def evaluate(self):
+ """
+ Evaluates standard semantic segmentation metrics (http://cocodataset.org/#stuff-eval):
+
+ * Mean intersection-over-union averaged across classes (mIoU)
+ * Frequency Weighted IoU (fwIoU)
+ * Mean pixel accuracy averaged across classes (mACC)
+ * Pixel Accuracy (pACC)
+ """
+ if self._distributed:
+ synchronize()
+ conf_matrix_list = all_gather(self._conf_matrix)
+ self._predictions = all_gather(self._predictions)
+ self._predictions = list(itertools.chain(*self._predictions))
+ if not is_main_process():
+ return
+
+ self._conf_matrix = np.zeros_like(self._conf_matrix)
+ for conf_matrix in conf_matrix_list:
+ self._conf_matrix += conf_matrix
+
+ if self._output_dir:
+ PathManager.mkdirs(self._output_dir)
+ file_path = os.path.join(self._output_dir, "sem_seg_predictions.json")
+ with PathManager.open(file_path, "w") as f:
+ f.write(json.dumps(self._predictions))
+
+ acc = np.full(self._num_classes, np.nan, dtype=np.float)
+ iou = np.full(self._num_classes, np.nan, dtype=np.float)
+ tp = self._conf_matrix.diagonal()[:-1].astype(np.float)
+ pos_gt = np.sum(self._conf_matrix[:-1, :-1], axis=0).astype(np.float)
+ class_weights = pos_gt / np.sum(pos_gt)
+ pos_pred = np.sum(self._conf_matrix[:-1, :-1], axis=1).astype(np.float)
+ acc_valid = pos_gt > 0
+ acc[acc_valid] = tp[acc_valid] / pos_gt[acc_valid]
+ iou_valid = (pos_gt + pos_pred) > 0
+ union = pos_gt + pos_pred - tp
+ iou[acc_valid] = tp[acc_valid] / union[acc_valid]
+ macc = np.sum(acc[acc_valid]) / np.sum(acc_valid)
+ miou = np.sum(iou[acc_valid]) / np.sum(iou_valid)
+ fiou = np.sum(iou[acc_valid] * class_weights[acc_valid])
+ pacc = np.sum(tp) / np.sum(pos_gt)
+ seen_IoU = 0
+ unseen_IoU = 0
+ seen_acc = 0
+ unseen_acc = 0
+ res = {}
+ res["mIoU"] = 100 * miou
+ res["fwIoU"] = 100 * fiou
+ for i, name in enumerate(self._class_names):
+ res["IoU-{}".format(name)] = 100 * iou[i]
+ if name in self._val_extra_classes:
+ unseen_IoU = unseen_IoU + 100 * iou[i]
+ else:
+ seen_IoU = seen_IoU + 100 * iou[i]
+ unseen_IoU = unseen_IoU / len(self._val_extra_classes)
+ seen_IoU = seen_IoU / (self._num_classes - len(self._val_extra_classes))
+ res["mACC"] = 100 * macc
+ res["pACC"] = 100 * pacc
+ for i, name in enumerate(self._class_names):
+ res["ACC-{}".format(name)] = 100 * acc[i]
+ if name in self._val_extra_classes:
+ unseen_acc = unseen_acc + 100 * iou[i]
+ else:
+ seen_acc = seen_acc + 100 * iou[i]
+ unseen_acc = unseen_acc / len(self._val_extra_classes)
+ seen_acc = seen_acc / (self._num_classes - len(self._val_extra_classes))
+ res["seen_IoU"] = seen_IoU
+ res["unseen_IoU"] = unseen_IoU
+ res["harmonic mean"] = 2 * seen_IoU * unseen_IoU / (seen_IoU + unseen_IoU)
+ # res["unseen_acc"] = unseen_acc
+ # res["seen_acc"] = seen_acc
+ if self._output_dir:
+ file_path = os.path.join(self._output_dir, "sem_seg_evaluation.pth")
+ with PathManager.open(file_path, "wb") as f:
+ torch.save(res, f)
+ results = OrderedDict({"sem_seg": res})
+ self._logger.info(results)
+ return results
+
+ def encode_json_sem_seg(self, sem_seg, input_file_name):
+ """
+ Convert semantic segmentation to COCO stuff format with segments encoded as RLEs.
+ See http://cocodataset.org/#format-results
+ """
+ json_list = []
+ for label in np.unique(sem_seg):
+ if self._contiguous_id_to_dataset_id is not None:
+ # import ipdb; ipdb.set_trace()
+ assert (
+ label in self._contiguous_id_to_dataset_id
+ ), "Label {} is not in the metadata info for {}".format(label, self._dataset_name)
+ dataset_id = self._contiguous_id_to_dataset_id[label]
+ else:
+ dataset_id = int(label)
+ mask = (sem_seg == label).astype(np.uint8)
+ mask_rle = mask_util.encode(np.array(mask[:, :, None], order="F"))[0]
+ mask_rle["counts"] = mask_rle["counts"].decode("utf-8")
+ json_list.append(
+ {"file_name": input_file_name, "category_id": dataset_id, "segmentation": mask_rle}
+ )
+ return json_list
+
+
+# MaskFormer
+from cat_seg import (
+ DETRPanopticDatasetMapper,
+ MaskFormerPanopticDatasetMapper,
+ MaskFormerSemanticDatasetMapper,
+ SemanticSegmentorWithTTA,
+ add_mask_former_config,
+)
+
+
+def create_ddp_model(model, *, fp16_compression=False, **kwargs):
+ """
+ Create a DistributedDataParallel model if there are >1 processes.
+
+ Args:
+ model: a torch.nn.Module
+ fp16_compression: add fp16 compression hooks to the ddp object.
+ See more at https://pytorch.org/docs/stable/ddp_comm_hooks.html#torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_hook
+ kwargs: other arguments of :module:`torch.nn.parallel.DistributedDataParallel`.
+ """ # noqa
+ if comm.get_world_size() == 1:
+ return model
+ if "device_ids" not in kwargs:
+ kwargs["device_ids"] = [comm.get_local_rank()]
+ ddp = DistributedDataParallel(model, **kwargs)
+ if fp16_compression:
+ from torch.distributed.algorithms.ddp_comm_hooks import default as comm_hooks
+
+ ddp.register_comm_hook(state=None, hook=comm_hooks.fp16_compress_hook)
+ return ddp
+
+class Trainer(DefaultTrainer):
+ """
+ Extension of the Trainer class adapted to DETR.
+ """
+
+ def __init__(self, cfg):
+ # super().__init__(cfg)
+ self._hooks: List[HookBase] = []
+ self.iter: int = 0
+ self.start_iter: int = 0
+ self.max_iter: int
+ self.storage: EventStorage
+ _log_api_usage("trainer." + self.__class__.__name__)
+
+ logger = logging.getLogger("detectron2")
+ if not logger.isEnabledFor(logging.INFO): # setup_logger is not called for d2
+ setup_logger()
+ cfg = DefaultTrainer.auto_scale_workers(cfg, comm.get_world_size())
+
+ # Assume these objects must be constructed in this order.
+ model = self.build_model(cfg)
+ optimizer = self.build_optimizer(cfg, model)
+ data_loader = self.build_train_loader(cfg)
+
+ model = create_ddp_model(model, broadcast_buffers=False, find_unused_parameters=True)
+ self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(
+ model, data_loader, optimizer
+ )
+
+ self.scheduler = self.build_lr_scheduler(cfg, optimizer)
+ self.checkpointer = DetectionCheckpointer(
+ # Assume you want to save checkpoints together with logs/statistics
+ model,
+ cfg.OUTPUT_DIR,
+ trainer=weakref.proxy(self),
+ )
+ self.start_iter = 0
+ self.max_iter = cfg.SOLVER.MAX_ITER
+ self.cfg = cfg
+
+ self.register_hooks(self.build_hooks())
+
+ @classmethod
+ def build_evaluator(cls, cfg, dataset_name, output_folder=None):
+ """
+ Create evaluator(s) for a given dataset.
+ This uses the special metadata "evaluator_type" associated with each
+ builtin dataset. For your own dataset, you can simply create an
+ evaluator manually in your script and do not have to worry about the
+ hacky if-else logic here.
+ """
+ if output_folder is None:
+ output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
+ evaluator_list = []
+ evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type
+ if evaluator_type in ["sem_seg", "ade20k_panoptic_seg"]:
+ evaluator_list.append(
+ SemSegEvaluator(
+ dataset_name,
+ distributed=True,
+ output_dir=output_folder,
+ )
+ )
+ # import pdb; pdb.set_trace()
+ if evaluator_type == "sem_seg_gzero":
+
+ evaluator_list.append(
+ SemSegGzeroEvaluator(
+ dataset_name,
+ distributed=True,
+ output_dir=output_folder,
+ )
+ )
+ if evaluator_type == "coco":
+ evaluator_list.append(COCOEvaluator(dataset_name, output_dir=output_folder))
+ if evaluator_type in [
+ "coco_panoptic_seg",
+ "ade20k_panoptic_seg",
+ "cityscapes_panoptic_seg",
+ ]:
+ evaluator_list.append(COCOPanopticEvaluator(dataset_name, output_folder))
+ if evaluator_type == "cityscapes_instance":
+ assert (
+ torch.cuda.device_count() >= comm.get_rank()
+ ), "CityscapesEvaluator currently do not work with multiple machines."
+ return CityscapesInstanceEvaluator(dataset_name)
+ if evaluator_type == "cityscapes_sem_seg":
+ assert (
+ torch.cuda.device_count() >= comm.get_rank()
+ ), "CityscapesEvaluator currently do not work with multiple machines."
+ return CityscapesSemSegEvaluator(dataset_name)
+ if evaluator_type == "cityscapes_panoptic_seg":
+ assert (
+ torch.cuda.device_count() >= comm.get_rank()
+ ), "CityscapesEvaluator currently do not work with multiple machines."
+ evaluator_list.append(CityscapesSemSegEvaluator(dataset_name))
+ if len(evaluator_list) == 0:
+ raise NotImplementedError(
+ "no Evaluator for the dataset {} with the type {}".format(
+ dataset_name, evaluator_type
+ )
+ )
+ elif len(evaluator_list) == 1:
+ return evaluator_list[0]
+ return DatasetEvaluators(evaluator_list)
+
+ @classmethod
+ def build_train_loader(cls, cfg):
+ # Semantic segmentation dataset mapper
+ if cfg.INPUT.DATASET_MAPPER_NAME == "mask_former_semantic":
+ mapper = MaskFormerSemanticDatasetMapper(cfg, True)
+ # Panoptic segmentation dataset mapper
+ elif cfg.INPUT.DATASET_MAPPER_NAME == "mask_former_panoptic":
+ mapper = MaskFormerPanopticDatasetMapper(cfg, True)
+ # DETR-style dataset mapper for COCO panoptic segmentation
+ elif cfg.INPUT.DATASET_MAPPER_NAME == "detr_panoptic":
+ mapper = DETRPanopticDatasetMapper(cfg, True)
+ else:
+ mapper = None
+ return build_detection_train_loader(cfg, mapper=mapper)
+
+ @classmethod
+ def build_lr_scheduler(cls, cfg, optimizer):
+ """
+ It now calls :func:`detectron2.solver.build_lr_scheduler`.
+ Overwrite it if you'd like a different scheduler.
+ """
+ return build_lr_scheduler(cfg, optimizer)
+
+ @classmethod
+ def build_optimizer(cls, cfg, model):
+ weight_decay_norm = cfg.SOLVER.WEIGHT_DECAY_NORM
+ weight_decay_embed = cfg.SOLVER.WEIGHT_DECAY_EMBED
+
+ defaults = {}
+ defaults["lr"] = cfg.SOLVER.BASE_LR
+ defaults["weight_decay"] = cfg.SOLVER.WEIGHT_DECAY
+
+ norm_module_types = (
+ torch.nn.BatchNorm1d,
+ torch.nn.BatchNorm2d,
+ torch.nn.BatchNorm3d,
+ torch.nn.SyncBatchNorm,
+ # NaiveSyncBatchNorm inherits from BatchNorm2d
+ torch.nn.GroupNorm,
+ torch.nn.InstanceNorm1d,
+ torch.nn.InstanceNorm2d,
+ torch.nn.InstanceNorm3d,
+ torch.nn.LayerNorm,
+ torch.nn.LocalResponseNorm,
+ )
+
+ params: List[Dict[str, Any]] = []
+ memo: Set[torch.nn.parameter.Parameter] = set()
+ for module_name, module in model.named_modules():
+ for module_param_name, value in module.named_parameters(recurse=False):
+ if not value.requires_grad:
+ continue
+ # Avoid duplicating parameters
+ if value in memo:
+ continue
+ memo.add(value)
+
+ hyperparams = copy.copy(defaults)
+ if "backbone" in module_name:
+ hyperparams["lr"] = hyperparams["lr"] * cfg.SOLVER.BACKBONE_MULTIPLIER
+ if (
+ "relative_position_bias_table" in module_param_name
+ or "absolute_pos_embed" in module_param_name
+ ):
+ print(module_param_name)
+ hyperparams["weight_decay"] = 0.0
+ if isinstance(module, norm_module_types):
+ hyperparams["weight_decay"] = weight_decay_norm
+ if isinstance(module, torch.nn.Embedding):
+ hyperparams["weight_decay"] = weight_decay_embed
+ params.append({"params": [value], **hyperparams})
+
+ def maybe_add_full_model_gradient_clipping(optim):
+ # detectron2 doesn't have full model gradient clipping now
+ clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE
+ enable = (
+ cfg.SOLVER.CLIP_GRADIENTS.ENABLED
+ and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model"
+ and clip_norm_val > 0.0
+ )
+
+ class FullModelGradientClippingOptimizer(optim):
+ def step(self, closure=None):
+ all_params = itertools.chain(*[x["params"] for x in self.param_groups])
+ torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val)
+ super().step(closure=closure)
+
+ return FullModelGradientClippingOptimizer if enable else optim
+
+ optimizer_type = cfg.SOLVER.OPTIMIZER
+ if optimizer_type == "SGD":
+ optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)(
+ params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM
+ )
+ elif optimizer_type == "ADAMW":
+ optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)(
+ params, cfg.SOLVER.BASE_LR
+ )
+ else:
+ raise NotImplementedError(f"no optimizer type {optimizer_type}")
+ if not cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model":
+ optimizer = maybe_add_gradient_clipping(cfg, optimizer)
+ return optimizer
+
+ @classmethod
+ def test_with_TTA(cls, cfg, model):
+ logger = logging.getLogger("detectron2.trainer")
+ # In the end of training, run an evaluation with TTA.
+ logger.info("Running inference with test-time augmentation ...")
+ model = SemanticSegmentorWithTTA(cfg, model)
+ evaluators = [
+ cls.build_evaluator(
+ cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA")
+ )
+ for name in cfg.DATASETS.TEST
+ ]
+ res = cls.test(cfg, model, evaluators)
+ res = OrderedDict({k + "_TTA": v for k, v in res.items()})
+ return res
+
+
+def setup(args):
+ """
+ Create configs and perform basic setups.
+ """
+ cfg = get_cfg()
+ # for poly lr schedule
+ add_deeplab_config(cfg)
+ add_mask_former_config(cfg)
+ cfg.merge_from_file(args.config_file)
+ cfg.merge_from_list(args.opts)
+ cfg.freeze()
+ default_setup(cfg, args)
+ # Setup logger for "mask_former" module
+ setup_logger(output=cfg.OUTPUT_DIR, distributed_rank=comm.get_rank(), name="mask_former")
+ return cfg
+
+
+def main(args):
+ cfg = setup(args)
+
+ if args.eval_only:
+ model = Trainer.build_model(cfg)
+ DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
+ cfg.MODEL.WEIGHTS, resume=args.resume
+ )
+ res = Trainer.test(cfg, model)
+ if cfg.TEST.AUG.ENABLED:
+ res.update(Trainer.test_with_TTA(cfg, model))
+ if comm.is_main_process():
+ verify_results(cfg, res)
+ return res
+
+ trainer = Trainer(cfg)
+ trainer.resume_or_load(resume=args.resume)
+ return trainer.train()
+
+
+if __name__ == "__main__":
+ args = default_argument_parser().parse_args()
+ print("Command Line Args:", args)
+ launch(
+ main,
+ args.num_gpus,
+ num_machines=args.num_machines,
+ machine_rank=args.machine_rank,
+ dist_url=args.dist_url,
+ args=(args,),
+ )
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4f7b181f6e11ca4646934c2a812611f0561a6232
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,14 @@
+scipy==1.7.0
+ftfy==6.0.1
+opencv-python==4.5.1.48
+setuptools==59.5.0
+pillow==8.2.0
+imageio==2.4.1
+timm==0.8.3.dev0
+regex
+einops
+torch==1.13.1+cu117
+torchvision==0.14.1+cu117
+torchaudio==0.13.1
+gradio
+--extra-index-url https://download.pytorch.org/whl/cu117
diff --git a/run.sh b/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..9dc719114b97b429c4eb7713a19a7c939aca0333
--- /dev/null
+++ b/run.sh
@@ -0,0 +1,29 @@
+#!/bin/sh
+
+gpus=4
+config=$1
+output=$2
+
+if [ -z $config ]
+then
+ echo "No config file found! Run with "sh run.sh [CONFIG_FILE] [OUTPUT_DIR] [OPTS]""
+ exit 0
+fi
+
+if [ -z $output ]
+then
+ echo "No output directory found! Run with "sh run.sh [CONFIG_FILE] [OUTPUT_DIR] [OPTS]""
+ exit 0
+fi
+
+shift 2
+opts=${@}
+
+python train_net.py --config $config \
+ --num-gpus $gpus \
+ --dist-url "auto" \
+ --resume \
+ OUTPUT_DIR $output \
+ $opts
+
+sh eval.sh $config $output $opts
\ No newline at end of file
diff --git a/sam_vit_h_4b8939.pth b/sam_vit_h_4b8939.pth
new file mode 100644
index 0000000000000000000000000000000000000000..8523acce9ddab1cf7e355628a08b1aab8ce08a72
--- /dev/null
+++ b/sam_vit_h_4b8939.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e
+size 2564550879
diff --git a/script.sh b/script.sh
new file mode 100644
index 0000000000000000000000000000000000000000..0887030933c04cd8754cc10fe3bfcfbe274202df
--- /dev/null
+++ b/script.sh
@@ -0,0 +1,3 @@
+#!/bin/sh
+sh eval.sh configs/catsam_ade150.yaml 1 output/iou_focal MODEL.WEIGHTS model_final_cls.pth
+sh eval.sh configs/catsam_ade150.yaml 1 output/iou_focal MODEL.WEIGHTS model_final_cls.pth
\ No newline at end of file
diff --git a/train_net.py b/train_net.py
new file mode 100644
index 0000000000000000000000000000000000000000..cab55aad19003957fd12b5a650f75d2f360a2d36
--- /dev/null
+++ b/train_net.py
@@ -0,0 +1,324 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+"""
+MaskFormer Training Script.
+
+This script is a simplified version of the training script in detectron2/tools.
+"""
+import copy
+import itertools
+import logging
+import os
+from collections import OrderedDict
+from typing import Any, Dict, List, Set
+
+import torch
+
+import detectron2.utils.comm as comm
+from detectron2.checkpoint import DetectionCheckpointer
+from detectron2.config import get_cfg
+from detectron2.data import MetadataCatalog, build_detection_train_loader
+from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, launch
+from detectron2.evaluation import CityscapesInstanceEvaluator, CityscapesSemSegEvaluator, \
+ COCOEvaluator, COCOPanopticEvaluator, DatasetEvaluators, SemSegEvaluator, verify_results, \
+ DatasetEvaluator
+
+from detectron2.projects.deeplab import add_deeplab_config, build_lr_scheduler
+from detectron2.solver.build import maybe_add_gradient_clipping
+from detectron2.utils.logger import setup_logger
+
+from detectron2.utils.file_io import PathManager
+import numpy as np
+from PIL import Image
+import glob
+
+import pycocotools.mask as mask_util
+
+from detectron2.data import DatasetCatalog, MetadataCatalog
+from detectron2.utils.comm import all_gather, is_main_process, synchronize
+import json
+
+# from detectron2.evaluation import SemSegGzeroEvaluator
+# from mask_former.evaluation.sem_seg_evaluation_gzero import SemSegGzeroEvaluator
+
+class VOCbEvaluator(SemSegEvaluator):
+ """
+ Evaluate semantic segmentation metrics.
+ """
+ def process(self, inputs, outputs):
+ """
+ Args:
+ inputs: the inputs to a model.
+ It is a list of dicts. Each dict corresponds to an image and
+ contains keys like "height", "width", "file_name".
+ outputs: the outputs of a model. It is either list of semantic segmentation predictions
+ (Tensor [H, W]) or list of dicts with key "sem_seg" that contains semantic
+ segmentation prediction in the same format.
+ """
+ for input, output in zip(inputs, outputs):
+ output = output["sem_seg"].argmax(dim=0).to(self._cpu_device)
+ pred = np.array(output, dtype=np.int)
+ pred[pred >= 20] = 20
+ with PathManager.open(self.input_file_to_gt_file[input["file_name"]], "rb") as f:
+ gt = np.array(Image.open(f), dtype=np.int)
+
+ gt[gt == self._ignore_label] = self._num_classes
+
+ self._conf_matrix += np.bincount(
+ (self._num_classes + 1) * pred.reshape(-1) + gt.reshape(-1),
+ minlength=self._conf_matrix.size,
+ ).reshape(self._conf_matrix.shape)
+
+ self._predictions.extend(self.encode_json_sem_seg(pred, input["file_name"]))
+
+# MaskFormer
+from cat_seg import (
+ DETRPanopticDatasetMapper,
+ MaskFormerPanopticDatasetMapper,
+ MaskFormerSemanticDatasetMapper,
+ SemanticSegmentorWithTTA,
+ add_cat_seg_config,
+)
+
+
+class Trainer(DefaultTrainer):
+ """
+ Extension of the Trainer class adapted to DETR.
+ """
+
+ @classmethod
+ def build_evaluator(cls, cfg, dataset_name, output_folder=None):
+ """
+ Create evaluator(s) for a given dataset.
+ This uses the special metadata "evaluator_type" associated with each
+ builtin dataset. For your own dataset, you can simply create an
+ evaluator manually in your script and do not have to worry about the
+ hacky if-else logic here.
+ """
+ if output_folder is None:
+ output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
+ evaluator_list = []
+ evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type
+ if evaluator_type in ["sem_seg", "ade20k_panoptic_seg"]:
+ evaluator_list.append(
+ SemSegEvaluator(
+ dataset_name,
+ distributed=True,
+ output_dir=output_folder,
+ )
+ )
+
+ if evaluator_type == "sem_seg_background":
+ evaluator_list.append(
+ VOCbEvaluator(
+ dataset_name,
+ distributed=True,
+ output_dir=output_folder,
+ )
+ )
+ if evaluator_type == "coco":
+ evaluator_list.append(COCOEvaluator(dataset_name, output_dir=output_folder))
+ if evaluator_type in [
+ "coco_panoptic_seg",
+ "ade20k_panoptic_seg",
+ "cityscapes_panoptic_seg",
+ ]:
+ evaluator_list.append(COCOPanopticEvaluator(dataset_name, output_folder))
+ if evaluator_type == "cityscapes_instance":
+ assert (
+ torch.cuda.device_count() >= comm.get_rank()
+ ), "CityscapesEvaluator currently do not work with multiple machines."
+ return CityscapesInstanceEvaluator(dataset_name)
+ if evaluator_type == "cityscapes_sem_seg":
+ assert (
+ torch.cuda.device_count() >= comm.get_rank()
+ ), "CityscapesEvaluator currently do not work with multiple machines."
+ return CityscapesSemSegEvaluator(dataset_name)
+ if evaluator_type == "cityscapes_panoptic_seg":
+ assert (
+ torch.cuda.device_count() >= comm.get_rank()
+ ), "CityscapesEvaluator currently do not work with multiple machines."
+ evaluator_list.append(CityscapesSemSegEvaluator(dataset_name))
+ if len(evaluator_list) == 0:
+ raise NotImplementedError(
+ "no Evaluator for the dataset {} with the type {}".format(
+ dataset_name, evaluator_type
+ )
+ )
+ elif len(evaluator_list) == 1:
+ return evaluator_list[0]
+ return DatasetEvaluators(evaluator_list)
+
+ @classmethod
+ def build_train_loader(cls, cfg):
+ # Semantic segmentation dataset mapper
+ if cfg.INPUT.DATASET_MAPPER_NAME == "mask_former_semantic":
+ mapper = MaskFormerSemanticDatasetMapper(cfg, True)
+ # Panoptic segmentation dataset mapper
+ elif cfg.INPUT.DATASET_MAPPER_NAME == "mask_former_panoptic":
+ mapper = MaskFormerPanopticDatasetMapper(cfg, True)
+ # DETR-style dataset mapper for COCO panoptic segmentation
+ elif cfg.INPUT.DATASET_MAPPER_NAME == "detr_panoptic":
+ mapper = DETRPanopticDatasetMapper(cfg, True)
+ else:
+ mapper = None
+ return build_detection_train_loader(cfg, mapper=mapper)
+
+ @classmethod
+ def build_lr_scheduler(cls, cfg, optimizer):
+ """
+ It now calls :func:`detectron2.solver.build_lr_scheduler`.
+ Overwrite it if you'd like a different scheduler.
+ """
+ return build_lr_scheduler(cfg, optimizer)
+
+ @classmethod
+ def build_optimizer(cls, cfg, model):
+ weight_decay_norm = cfg.SOLVER.WEIGHT_DECAY_NORM
+ weight_decay_embed = cfg.SOLVER.WEIGHT_DECAY_EMBED
+
+ defaults = {}
+ defaults["lr"] = cfg.SOLVER.BASE_LR
+ defaults["weight_decay"] = cfg.SOLVER.WEIGHT_DECAY
+
+ norm_module_types = (
+ torch.nn.BatchNorm1d,
+ torch.nn.BatchNorm2d,
+ torch.nn.BatchNorm3d,
+ torch.nn.SyncBatchNorm,
+ # NaiveSyncBatchNorm inherits from BatchNorm2d
+ torch.nn.GroupNorm,
+ torch.nn.InstanceNorm1d,
+ torch.nn.InstanceNorm2d,
+ torch.nn.InstanceNorm3d,
+ torch.nn.LayerNorm,
+ torch.nn.LocalResponseNorm,
+ )
+
+ params: List[Dict[str, Any]] = []
+ memo: Set[torch.nn.parameter.Parameter] = set()
+ # import ipdb;
+ # ipdb.set_trace()
+ for module_name, module in model.named_modules():
+ for module_param_name, value in module.named_parameters(recurse=False):
+ if not value.requires_grad:
+ continue
+ # Avoid duplicating parameters
+ if value in memo:
+ continue
+ memo.add(value)
+ hyperparams = copy.copy(defaults)
+ if "backbone" in module_name:
+ hyperparams["lr"] = hyperparams["lr"] * cfg.SOLVER.BACKBONE_MULTIPLIER
+ if "clip_model" in module_name:
+ hyperparams["lr"] = hyperparams["lr"] * cfg.SOLVER.CLIP_MULTIPLIER
+ # for deformable detr
+
+ if (
+ "relative_position_bias_table" in module_param_name
+ or "absolute_pos_embed" in module_param_name
+ ):
+ print(module_param_name)
+ hyperparams["weight_decay"] = 0.0
+ if isinstance(module, norm_module_types):
+ hyperparams["weight_decay"] = weight_decay_norm
+ if isinstance(module, torch.nn.Embedding):
+ hyperparams["weight_decay"] = weight_decay_embed
+ params.append({"params": [value], **hyperparams})
+
+ def maybe_add_full_model_gradient_clipping(optim):
+ # detectron2 doesn't have full model gradient clipping now
+ clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE
+ enable = (
+ cfg.SOLVER.CLIP_GRADIENTS.ENABLED
+ and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model"
+ and clip_norm_val > 0.0
+ )
+
+ class FullModelGradientClippingOptimizer(optim):
+ def step(self, closure=None):
+ all_params = itertools.chain(*[x["params"] for x in self.param_groups])
+ torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val)
+ super().step(closure=closure)
+
+ return FullModelGradientClippingOptimizer if enable else optim
+
+ optimizer_type = cfg.SOLVER.OPTIMIZER
+ if optimizer_type == "SGD":
+ optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)(
+ params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM
+ )
+ elif optimizer_type == "ADAMW":
+ optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)(
+ params, cfg.SOLVER.BASE_LR
+ )
+ else:
+ raise NotImplementedError(f"no optimizer type {optimizer_type}")
+ if not cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model":
+ optimizer = maybe_add_gradient_clipping(cfg, optimizer)
+ return optimizer
+
+ @classmethod
+ def test_with_TTA(cls, cfg, model):
+ logger = logging.getLogger("detectron2.trainer")
+ # In the end of training, run an evaluation with TTA.
+ logger.info("Running inference with test-time augmentation ...")
+ model = SemanticSegmentorWithTTA(cfg, model)
+ evaluators = [
+ cls.build_evaluator(
+ cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA")
+ )
+ for name in cfg.DATASETS.TEST
+ ]
+ res = cls.test(cfg, model, evaluators)
+ res = OrderedDict({k + "_TTA": v for k, v in res.items()})
+ return res
+
+
+def setup(args):
+ """
+ Create configs and perform basic setups.
+ """
+ cfg = get_cfg()
+ # for poly lr schedule
+ add_deeplab_config(cfg)
+ add_cat_seg_config(cfg)
+ cfg.merge_from_file(args.config_file)
+ cfg.merge_from_list(args.opts)
+ cfg.freeze()
+ default_setup(cfg, args)
+ # Setup logger for "mask_former" module
+ setup_logger(output=cfg.OUTPUT_DIR, distributed_rank=comm.get_rank(), name="mask_former")
+ return cfg
+
+
+def main(args):
+ cfg = setup(args)
+ torch.set_float32_matmul_precision("high")
+ if args.eval_only:
+ model = Trainer.build_model(cfg)
+ DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
+ cfg.MODEL.WEIGHTS, resume=args.resume
+ )
+ res = Trainer.test(cfg, model)
+ if cfg.TEST.AUG.ENABLED:
+ res.update(Trainer.test_with_TTA(cfg, model))
+ if comm.is_main_process():
+ verify_results(cfg, res)
+ return res
+
+ trainer = Trainer(cfg)
+ trainer.resume_or_load(resume=args.resume)
+ return trainer.train()
+
+
+if __name__ == "__main__":
+ args = default_argument_parser().parse_args()
+ print("Command Line Args:", args)
+ launch(
+ main,
+ args.num_gpus,
+ num_machines=args.num_machines,
+ machine_rank=args.machine_rank,
+ dist_url=args.dist_url,
+ args=(args,),
+ )