diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..ecc0896bf1784f869b355fae4b9eab26550a4fef
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2024 OpenMMLab
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/annotator/canny/__init__.py b/annotator/canny/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e651dd69088cb1c68f6039804227d5fb8fee9327
--- /dev/null
+++ b/annotator/canny/__init__.py
@@ -0,0 +1,6 @@
+import cv2
+
+
+class CannyDetector:
+    def __call__(self, img, low_threshold=100, high_threshold=200):
+        return cv2.Canny(img, low_threshold, high_threshold)
diff --git a/annotator/cielab/__init__.py b/annotator/cielab/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e5016db2b6c0e248af810c4123a6eaed7b29ca2
--- /dev/null
+++ b/annotator/cielab/__init__.py
@@ -0,0 +1,47 @@
+import os
+import sys
+sys.path.append(os.getcwd())
+sys.path.append(os.path.join(os.getcwd(), 'rayleigh'))
+
+import numpy as np
+from skimage.color import rgb2lab
+from .rayleigh import Palette
+from .rayleigh.util import histogram_colors_strict, smooth_histogram, color_hist_to_palette_image
+
+
+class CIELabDetector:
+
+    MAX_DIMENSION = 240 + 1
+
+    def __init__(self, sigma=10, num_hues=11, num_light=5, num_sat=5):
+        self.sigma = sigma
+        self.palette = Palette(num_hues=num_hues, light_range=num_light, sat_range=num_sat)
+
+    def __call__(self, img):
+        # Handle grayscale and RGBA images.
+        # TODO: Should be smarter here in the future, but for now simply remove
+        # the alpha channel if present.
+        if img.ndim == 2:
+            img = np.tile(img[:, :, np.newaxis], (1, 1, 3))
+        elif img.ndim == 4:
+            img = img[:, :, :3]
+        img = img[:,:,:3]
+
+        h, w, d = tuple(img.shape)
+        h_stride = int(h / self.MAX_DIMENSION + 1)
+        w_stride = int(w / self.MAX_DIMENSION + 1)
+        img = img[::h_stride, ::w_stride, :]
+
+        # Convert to L*a*b colors.
+        h, w, d = img.shape
+        lab_array = rgb2lab(img).reshape((h * w, d))
+
+        # compute hist
+        hist = histogram_colors_strict(lab_array, self.palette)
+        hist = smooth_histogram(hist, self.palette, self.sigma)
+        return hist
+
+    def hist_to_palette(self, hist):
+        # hist to image
+        plt = color_hist_to_palette_image(hist, self.palette)
+        return (plt * 255).astype(np.uint8)
diff --git a/annotator/cielab/rayleigh/__init__.py b/annotator/cielab/rayleigh/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..677c1aa575a10ea34c97f182da610999851a31ea
--- /dev/null
+++ b/annotator/cielab/rayleigh/__init__.py
@@ -0,0 +1,8 @@
+"""
+Rayleigh is an open-source system for quickly searching medium-sized image
+collections by multiple colors given as a palette or derived from a query image.
+"""
+
+
+from .palette import Palette
+from .util import *
diff --git a/annotator/cielab/rayleigh/palette.py b/annotator/cielab/rayleigh/palette.py
new file mode 100644
index 0000000000000000000000000000000000000000..059f8e43b3b8b1d3964576a1562774de1555900a
--- /dev/null
+++ b/annotator/cielab/rayleigh/palette.py
@@ -0,0 +1,132 @@
+"""
+Encapsulate the list of hex colors and array of Lab values representations
+of a palette (codebook) of colors.
+
+Provide methods to work with color conversion and the Palette class.
+
+Provide a parametrized method to generate a palette that covers the range
+of colors.
+"""
+
+import os
+import numpy as np
+from skimage.color import hsv2rgb, rgb2lab
+from skimage.io import imsave
+from sklearn.metrics import euclidean_distances
+
+from .util import rgb2hex
+
+
+class Palette(object):
+    """
+    Create a color palette (codebook) in the form of a 2D grid of colors,
+    as described in the parameters list below.
+    Further, the rightmost column has num_hues gradations from black to white.
+
+    Parameters
+    ----------
+    num_hues : int
+        number of colors with full lightness and saturation, in the middle
+    sat_range : int
+        number of rows above middle row that show
+        the same hues with decreasing saturation.
+    light_range : int
+        number of rows below middle row that show
+        the same hues with decreasing lightness.
+
+    Returns
+    -------
+    palette: rayleigh.Palette
+    """
+
+    def __init__(self, num_hues=8, sat_range=2, light_range=2):
+        height = 1 + sat_range + (2 * light_range - 1)
+        # generate num_hues+1 hues, but don't take the last one:
+        # hues are on a circle, and we would be oversampling the origin
+        hues = np.tile(np.linspace(0, 1, num_hues + 1)[:-1], (height, 1))
+        if num_hues == 8:
+            hues = np.tile(np.array(
+                [0.,  0.10,  0.15,  0.28, 0.51, 0.58, 0.77,  0.85]), (height, 1))
+        if num_hues == 9:
+            hues = np.tile(np.array(
+                [0.,  0.10,  0.15,  0.28, 0.49, 0.54, 0.60, 0.7, 0.87]), (height, 1))
+        if num_hues == 10:
+            hues = np.tile(np.array(
+                [0.,  0.10,  0.15,  0.28, 0.49, 0.54, 0.60, 0.66, 0.76, 0.87]), (height, 1))
+        elif num_hues == 11:
+            hues = np.tile(np.array(
+                [0.0, 0.0833, 0.166, 0.25,
+                 0.333, 0.5, 0.56333,
+                 0.666, 0.73, 0.803,
+                 0.916]), (height, 1))
+        
+        sats = np.hstack((
+            np.linspace(0, 1, sat_range + 2)[1:-1],
+            1,
+            [1] * (light_range),
+            [.4] * (light_range - 1),
+        ))
+        lights = np.hstack((
+            [1] * sat_range,
+            1,
+            np.linspace(1, 0.2, light_range + 2)[1:-1],
+            np.linspace(1, 0.2, light_range + 2)[1:-2],
+        ))
+
+        sats = np.tile(np.atleast_2d(sats).T, (1, num_hues))
+        lights = np.tile(np.atleast_2d(lights).T, (1, num_hues))
+        colors = hsv2rgb(np.dstack((hues, sats, lights)))
+        grays = np.tile(
+            np.linspace(1, 0, height)[:, np.newaxis, np.newaxis], (1, 1, 3))
+
+        self.rgb_image = np.hstack((colors, grays))
+
+        # Make a nice histogram ordering of the hues and grays
+        h, w, d = colors.shape
+        color_array = colors.T.reshape((d, w * h)).T
+        h, w, d = grays.shape
+        gray_array = grays.T.reshape((d, w * h)).T
+        
+        self.rgb_array = np.vstack((color_array, gray_array))
+        self.lab_array = rgb2lab(self.rgb_array[None, :, :]).squeeze()
+        self.hex_list = [rgb2hex(row) for row in self.rgb_array]
+        #assert(np.all(self.rgb_array == self.rgb_array[None, :, :].squeeze()))
+
+        self.distances = euclidean_distances(self.lab_array, squared=True)
+
+    def output(self, dirname, html=False):
+        """
+        Output an image of the palette, josn list of the hex
+        colors, and an HTML color picker for it.
+
+        Parameters
+        ----------
+        dirname : string
+            directory for the files to be output
+        """
+        def get_palette_html():
+            """
+            Return HTML for a color picker using the given palette.
+            """
+            html = """
+            <style>
+                span {
+                    width: 20px;
+                    height: 20px;
+                    margin: 2px;
+                    padding: 0px;
+                    display: inline-block;
+                }
+            </style>
+            """
+            for row in self.rgb_image:
+                for rgb_color in row:
+                    s = '<a id="{0}"><span style="background-color: {0}" /></a>\n'
+                    html += s.format(rgb2hex(rgb_color))
+                html += "<br />\n"
+            return html
+
+        imsave(os.path.join(dirname, 'palette.png'), (self.rgb_image*255).astype(np.uint8))
+        if html:
+            with open(os.path.join(dirname, 'palette.html'), 'w') as f:
+                f.write(get_palette_html())
diff --git a/annotator/cielab/rayleigh/util.py b/annotator/cielab/rayleigh/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c46bb33f1f77bbb1daa6c17888e20e0b2bfba01
--- /dev/null
+++ b/annotator/cielab/rayleigh/util.py
@@ -0,0 +1,270 @@
+import os
+import numpy as np
+import tempfile
+import matplotlib.pyplot as plt
+from sklearn.metrics import euclidean_distances
+from skimage.io import imsave
+
+
+def rgb2hex(rgb_number):
+    """
+    Args:
+        - rgb_number (sequence of float)
+
+    Returns:
+        - hex_number (string)
+    """
+    return '#%02x%02x%02x' % tuple([int(np.round(val * 255)) for val in rgb_number])
+
+
+def hex2rgb(hexcolor_str):
+    """
+    Args:
+        - hexcolor_str (string): e.g. '#ffffff' or '33cc00'
+
+    Returns:
+        - rgb_color (sequence of floats): e.g. (0.2, 0.3, 0)
+    """
+    color = hexcolor_str.strip('#')
+    rgb = lambda x: round(int(x, 16) / 255., 5)
+    return (rgb(color[:2]), rgb(color[2:4]), rgb(color[4:6]))
+
+
+def color_hist_to_palette_image(color_hist, palette, percentile=90,
+                                width=200, height=50, filename=None):
+    """
+    Output the main colors in the histogram to a "palette image."
+
+    Parameters
+    ----------
+    color_hist : (K,) ndarray
+    palette : rayleigh.Palette
+    percentile : int, optional:
+        Output only colors above this percentile of prevalence in the histogram.
+    filename : string, optional:
+        If given, save the resulting image to file.
+
+    Returns
+    -------
+    rgb_image : ndarray
+    """
+    ind = np.argsort(-color_hist)
+    ind = ind[color_hist[ind] > np.percentile(color_hist, percentile)]
+    hex_list = np.take(palette.hex_list, ind)
+    values = color_hist[ind]
+    rgb_image = palette_query_to_rgb_image(dict(zip(hex_list, values)))
+    if filename:
+        imsave(filename, rgb_image)
+    return rgb_image
+
+
+def palette_query_to_rgb_image(palette_query, width=200, height=50):
+    """
+    Convert a list of hex colors and their values to an RGB image of given
+    width and height.
+
+    Args:
+        - palette_query (dict):
+            a dictionary of hex colors to unnormalized values,
+            e.g. {'#ffffff': 20, '#33cc00': 0.4}.
+    """
+    hex_list, values = zip(*palette_query.items())
+    values = np.array(values)
+    values /= values.sum()
+    nums = np.array(values * width, dtype=int)
+    rgb_arrays = (np.tile(np.array(hex2rgb(x)), (num, 1))
+                  for x, num in zip(hex_list, nums))
+    rgb_array = np.vstack(list(rgb_arrays))
+    rgb_image = rgb_array[np.newaxis, :, :]
+    rgb_image = np.tile(rgb_image, (height, 1, 1))
+    return rgb_image
+
+
+def plot_histogram(color_hist, palette, plot_filename=None):
+    """
+    Return Figure containing the color palette histogram.
+
+    Args:
+        - color_hist (K, ndarray)
+
+        - palette (Palette)
+
+        - plot_filename (string) [default=None]:
+                Save histogram to this file, if given.
+
+    Returns:
+        - fig (Figure)
+    """
+    fig = plt.figure(figsize=(5, 3), dpi=150)
+    ax = fig.add_subplot(111)
+    ax.bar(
+        range(len(color_hist)), color_hist,
+        color=palette.hex_list, edgecolor='black')
+    ax.set_ylim((0, 0.3))
+    ax.xaxis.set_ticks([])
+    ax.set_xlim((0, len(palette.hex_list)))
+    if plot_filename:
+        fig.savefig(plot_filename, dpi=150, facecolor='none')
+    return fig
+
+
+def output_histogram_base64(color_hist, palette):
+    """
+    Return base64-encoded image containing the color palette histogram.
+
+    Args:
+        - color_hist (K, ndarray)
+
+        - palette (Palette)
+
+    Returns:
+        - data_uri (base64 encoded string)
+    """
+    _, tfname = tempfile.mkstemp('.png')
+    plot_histogram(color_hist, palette, tfname)
+    data_uri = open(tfname, 'rb').read().encode('base64').replace('\n', '')
+    os.remove(tfname)
+    return data_uri
+
+
+def histogram_colors_strict(lab_array, palette, plot_filename=None):
+    """
+    Return a palette histogram of colors in the image.
+
+    Parameters
+    ----------
+    lab_array : (N,3) ndarray
+        The L*a*b color of each of N pixels.
+    palette : rayleigh.Palette
+        Containing K colors.
+    plot_filename : string, optional
+        If given, save histogram to this filename.
+
+    Returns
+    -------
+    color_hist : (K,) ndarray
+    """
+    # This is the fastest way that I've found.
+    # >>> %%timeit -n 200 from sklearn.metrics import euclidean_distances
+    # >>> euclidean_distances(palette, lab_array, squared=True)
+    dist = euclidean_distances(palette.lab_array, lab_array, squared=True).T
+    min_ind = np.argmin(dist, axis=1)
+    num_colors = palette.lab_array.shape[0]
+    num_pixels = lab_array.shape[0]
+    color_hist = 1. * np.bincount(min_ind, minlength=num_colors) / num_pixels
+    if plot_filename is not None:
+        plot_histogram(color_hist, palette, plot_filename)
+    return color_hist
+
+
+def histogram_colors_smoothed(lab_array, palette, sigma=10,
+                              plot_filename=None, direct=True):
+    """
+    Returns a palette histogram of colors in the image, smoothed with
+    a Gaussian. Can smooth directly per-pixel, or after computing a strict
+    histogram.
+
+    Parameters
+    ----------
+    lab_array : (N,3) ndarray
+        The L*a*b color of each of N pixels.
+    palette : rayleigh.Palette
+        Containing K colors.
+    sigma : float
+        Variance of the smoothing Gaussian.
+    direct : bool, optional
+        If True, constructs a smoothed histogram directly from pixels.
+        If False, constructs a nearest-color histogram and then smoothes it.
+
+    Returns
+    -------
+    color_hist : (K,) ndarray
+    """
+    if direct:
+        color_hist_smooth = histogram_colors_with_smoothing(
+            lab_array, palette, sigma)
+    else:
+        color_hist_strict = histogram_colors_strict(lab_array, palette)
+        color_hist_smooth = smooth_histogram(color_hist_strict, palette, sigma)
+    if plot_filename is not None:
+        plot_histogram(color_hist_smooth, palette, plot_filename)
+    return color_hist_smooth
+
+
+def smooth_histogram(color_hist, palette, sigma=10):
+    """
+    Smooth the given palette histogram with a Gaussian of variance sigma.
+
+    Parameters
+    ----------
+    color_hist : (K,) ndarray
+    palette : rayleigh.Palette
+        containing K colors.
+
+    Returns
+    -------
+    color_hist_smooth : (K,) ndarray
+    """
+    n = 2. * sigma ** 2
+    weights = np.exp(-palette.distances / n)
+    norm_weights = weights / weights.sum(1)[:, np.newaxis]
+    color_hist_smooth = (norm_weights * color_hist).sum(1)
+    color_hist_smooth[color_hist_smooth < 1e-5] = 0
+    return color_hist_smooth
+
+
+def histogram_colors_with_smoothing(lab_array, palette, sigma=10):
+    """
+    Assign colors in the image to nearby colors in the palette, weighted by
+    distance in Lab color space.
+
+    Parameters
+    ----------
+    lab_array (N,3) ndarray:
+        N is the number of data points, columns are L, a, b values.
+    palette : rayleigh.Palette
+        containing K colors.
+    sigma : float
+        (0,1] value to control the steepness of exponential falloff.
+        To see the effect:
+
+    >>> from pylab import *
+    >>> ds = linspace(0,5000) # squared distance
+    >>> sigma=10; plot(ds, exp(-ds/(2*sigma**2)), label='$\sigma=%.1f$'%sigma)
+    >>> sigma=20; plot(ds, exp(-ds/(2*sigma**2)), label='$\sigma=%.1f$'%sigma)
+    >>> sigma=40; plot(ds, exp(-ds/(2*sigma**2)), label='$\sigma=%.1f$'%sigma)
+    >>> ylim([0,1]); legend();
+    >>> xlabel('Squared distance'); ylabel('Weight');
+    >>> title('Exponential smoothing')
+    >>> #plt.savefig('exponential_smoothing.png', dpi=300)
+
+        sigma=20 seems reasonable: hits 0 around squared distance of 4000.
+
+    Returns:
+    color_hist : (K,) ndarray
+        the normalized, smooth histogram of colors.
+    """
+    dist = euclidean_distances(palette.lab_array, lab_array, squared=True).T
+    n = 2. * sigma ** 2
+    weights = np.exp(-dist / n)
+    
+    # normalize by sum: if a color is equally well represented by several colors
+    # it should not contribute much to the overall histogram
+    normalizing = weights.sum(1)
+    normalizing[normalizing == 0] = 1e16
+    normalized_weights = weights / normalizing[:, np.newaxis]
+
+    color_hist = normalized_weights.sum(0)
+    color_hist /= lab_array.shape[0]
+    color_hist[color_hist < 1e-5] = 0
+    return color_hist
+
+
+def makedirs(dirname):
+    "Does what mkdir -p does, and returns dirname."
+    if not os.path.exists(dirname):
+        try:
+            os.makedirs(dirname)
+        except:
+            print("Exception on os.makedirs")
+    return dirname
diff --git a/annotator/content/__init__.py b/annotator/content/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a16a467131f114841ef57d780326699fb098cee
--- /dev/null
+++ b/annotator/content/__init__.py
@@ -0,0 +1,23 @@
+import cv2
+import numpy as np
+from PIL import Image
+
+import torch
+from transformers import AutoProcessor, CLIPModel
+
+from annotator.util import annotator_ckpts_path
+
+
+class ContentDetector:
+    def __init__(self, model_name="openai/clip-vit-large-patch14"):
+
+        self.model = CLIPModel.from_pretrained(model_name, cache_dir=annotator_ckpts_path).cuda().eval()
+        self.processor = AutoProcessor.from_pretrained(model_name, cache_dir=annotator_ckpts_path)
+
+    def __call__(self, img):
+        with torch.no_grad():
+            img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
+            inputs = self.processor(images=[img], return_tensors="pt").to('cuda')
+            image_features = self.model.get_image_features(**inputs)
+            content_emb = image_features[0].detach().cpu().numpy()
+        return content_emb
diff --git a/annotator/entityseg/__init__.py b/annotator/entityseg/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..61473ece64d0ad7b096cb12fd73b5ca1a247000f
--- /dev/null
+++ b/annotator/entityseg/__init__.py
@@ -0,0 +1,93 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Modified by Bowen Cheng from: https://github.com/facebookresearch/detectron2/blob/master/demo/demo.py
+import argparse
+import glob
+import multiprocessing as mp
+import os
+import sys
+sys.path.insert(1, os.getcwd())
+
+import tempfile
+import time
+import warnings
+
+import cv2
+import numpy as np
+import tqdm
+import torch
+
+from detectron2.config import get_cfg
+from detectron2.data.detection_utils import read_image
+from detectron2.projects.deeplab import add_deeplab_config
+from detectron2.utils.logger import setup_logger
+
+from mask2former import add_maskformer2_config
+from predictor import VisualizationDemo
+
+from annotator.util import annotator_ckpts_path
+
+
+model_url = "https://huggingface.co/datasets/qqlu1992/Adobe_EntitySeg/resolve/main/CropFormer_model/Entity_Segmentation/CropFormer_hornet_3x.pth"
+
+
+def make_colors():
+    from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES
+    colors = []
+    for cate in COCO_CATEGORIES:
+        colors.append(cate["color"])
+    return colors
+
+
+class EntitysegDetector:
+
+    def __init__(self, confidence_threshold=0.5):
+        cfg = get_cfg()
+        add_deeplab_config(cfg)
+        add_maskformer2_config(cfg)
+
+        workdir = os.getcwd()
+        config_file = f"{workdir}/annotator/entityseg/configs/cropformer_hornet_3x.yaml"
+        model_path = f'{annotator_ckpts_path}/CropFormer_hornet_3x_03823a.pth'
+        # Authentication required
+        # if not os.path.exists(model_path):
+        #     from basicsr.utils.download_util import load_file_from_url
+        #     load_file_from_url(model_url, model_dir=annotator_ckpts_path)
+
+        cfg.merge_from_file(config_file)
+        opts = ['MODEL.WEIGHTS', model_path]
+        cfg.merge_from_list(opts)
+        cfg.freeze()
+
+        self.confidence_threshold = confidence_threshold
+
+        self.colors = make_colors()
+        self.demo = VisualizationDemo(cfg)
+
+
+    def __call__(self, image): 
+        predictions = self.demo.run_on_image(image)
+        ##### color_mask
+        pred_masks = predictions["instances"].pred_masks
+        pred_scores = predictions["instances"].scores
+        
+        # select by confidence threshold
+        selected_indexes = (pred_scores >= self.confidence_threshold)
+        selected_scores = pred_scores[selected_indexes]
+        selected_masks  = pred_masks[selected_indexes]
+        _, m_H, m_W = selected_masks.shape
+        mask_id = np.zeros((m_H, m_W), dtype=np.uint8)
+
+        # rank
+        selected_scores, ranks = torch.sort(selected_scores)
+        ranks = ranks + 1
+        for index in ranks:
+            mask_id[(selected_masks[index-1]==1).cpu().numpy()] = int(index)
+        unique_mask_id = np.unique(mask_id)
+
+        color_mask = np.zeros(image.shape, dtype=np.uint8)
+        for count in unique_mask_id:
+            if count == 0:
+                continue
+            color_mask[mask_id==count] = self.colors[count % len(self.colors)]
+        
+        return color_mask
diff --git a/annotator/entityseg/configs/Base-Mask2Former.yaml b/annotator/entityseg/configs/Base-Mask2Former.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c298c610f002fe3d8d7635ede808070f4b5bffed
--- /dev/null
+++ b/annotator/entityseg/configs/Base-Mask2Former.yaml
@@ -0,0 +1,49 @@
+ENTITY:
+  ENABLE: True
+MODEL:
+  BACKBONE:
+    FREEZE_AT: 0
+    NAME: "build_resnet_backbone"
+  WEIGHTS: "R-50.pkl"
+  PIXEL_MEAN: [123.675, 116.280, 103.530]
+  PIXEL_STD: [58.395, 57.120, 57.375]
+  RESNETS:
+    DEPTH: 50
+    STEM_TYPE: "basic"  # not used
+    STEM_OUT_CHANNELS: 64
+    STRIDE_IN_1X1: False
+    OUT_FEATURES: ["res2", "res3", "res4", "res5"]
+    # NORM: "SyncBN"
+    RES5_MULTI_GRID: [1, 1, 1]  # not used
+DATASETS:
+  TRAIN: ("entityv2_entity_train_01",)
+  TEST: ("entityv2_entity_val_01",)
+SOLVER:
+  STEPS: (30525, 33138)
+  MAX_ITER: 34375
+  IMS_PER_BATCH: 16
+  BASE_LR: 0.0001
+  WARMUP_FACTOR: 1.0
+  WARMUP_ITERS: 0
+  WEIGHT_DECAY: 0.05
+  OPTIMIZER: "ADAMW"
+  LR_SCHEDULER_NAME: "WarmupPolyLR"
+  BACKBONE_MULTIPLIER: 0.1
+  CLIP_GRADIENTS:
+    ENABLED: True
+    CLIP_TYPE: "full_model"
+    CLIP_VALUE: 0.01
+    NORM_TYPE: 2.0
+  AMP:
+    ENABLED: True
+INPUT:
+  MASK_FORMAT: "bitmask"
+  FORMAT: "RGB"
+  MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800)
+  DATASET_MAPPER_NAME: "entity_crop"
+TEST:
+  EVAL_PERIOD: 400000
+DATALOADER:
+  FILTER_EMPTY_ANNOTATIONS: True
+  NUM_WORKERS: 32
+VERSION: 2
\ No newline at end of file
diff --git a/annotator/entityseg/configs/cropformer_hornet_3x.yaml b/annotator/entityseg/configs/cropformer_hornet_3x.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0462eac3c2646d72460e982888d92b259f9588b7
--- /dev/null
+++ b/annotator/entityseg/configs/cropformer_hornet_3x.yaml
@@ -0,0 +1,70 @@
+_BASE_: Base-Mask2Former.yaml
+DATALOADER:
+  NUM_WORKERS: 32
+DATASETS:
+  TRAIN: ("entityv2_entity_train_01","entityv2_entity_train_02","entityv2_entity_train_03",)
+  TEST: ("entityv2_entity_val_all",)
+  # TEST: ("entityv2_entity_val_all_lr",)
+SOLVER:
+  # STEPS: (91575, 99414)
+  # MAX_ITER: 103125
+  IMS_PER_BATCH: 8
+  STEPS: (183150, 198828)
+  MAX_ITER: 206250
+MODEL:
+  BACKBONE:
+    NAME: "D2HorNet"
+  PIXEL_MEAN: [123.675, 116.28, 103.53]
+  PIXEL_STD: [58.395, 57.120, 57.375]
+  SWIN:
+    EMBED_DIM: 192
+    DEPTHS: [2, 2, 18, 2]
+    NUM_HEADS: [6, 12, 24, 48]
+    WINDOW_SIZE: 7
+    APE: False
+    DROP_PATH_RATE: 0.3
+    PATCH_NORM: True
+    PRETRAIN_IMG_SIZE: 384
+  WEIGHTS: "hornet_l_pretrained.pth"
+  META_ARCHITECTURE: "CropFormer"
+  SEM_SEG_HEAD:
+    NAME: "MaskFormerHead"
+    IGNORE_VALUE: 255
+    NUM_CLASSES: 1
+    LOSS_WEIGHT: 1.0
+    CONVS_DIM: 256
+    MASK_DIM: 256
+    NORM: "GN"
+    # pixel decoder
+    PIXEL_DECODER_NAME: "MSDeformAttnPixelDecoder"
+    IN_FEATURES: ["res2", "res3", "res4", "res5"]
+    DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES: ["res3", "res4", "res5"]
+    COMMON_STRIDE: 4
+    TRANSFORMER_ENC_LAYERS: 6
+  MASK_FORMER:
+    TRANSFORMER_DECODER_NAME: "CropSharedMultiScaleMaskedTransformerDecoder"
+    TRANSFORMER_IN_FEATURE: "multi_scale_pixel_decoder"
+    DEEP_SUPERVISION: True
+    NO_OBJECT_WEIGHT: 0.1
+    CLASS_WEIGHT: 2.0
+    MASK_WEIGHT: 5.0
+    DICE_WEIGHT: 5.0
+    HIDDEN_DIM: 256
+    NUM_OBJECT_QUERIES: 200
+    NHEADS: 8
+    DROPOUT: 0.0
+    DIM_FEEDFORWARD: 2048
+    ENC_LAYERS: 0
+    PRE_NORM: False
+    ENFORCE_INPUT_PROJ: False
+    SIZE_DIVISIBILITY: 32
+    DEC_LAYERS: 10  # 9 decoder layers, add one for the loss on learnable query
+    TRAIN_NUM_POINTS: 12544
+    OVERSAMPLE_RATIO: 3.0
+    IMPORTANCE_SAMPLE_RATIO: 0.75
+    TEST:
+      SEMANTIC_ON: False
+      INSTANCE_ON: True
+      PANOPTIC_ON: False
+      OVERLAP_THRESHOLD: 0.8
+      OBJECT_MASK_THRESHOLD: 0.8
\ No newline at end of file
diff --git a/annotator/entityseg/mask2former/__init__.py b/annotator/entityseg/mask2former/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..480340621ed6906adb9cf9c6d08a0eabc267483b
--- /dev/null
+++ b/annotator/entityseg/mask2former/__init__.py
@@ -0,0 +1,11 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+from . import data  # register all new datasets
+from . import modeling
+
+# config
+from .config import add_maskformer2_config
+
+# models
+from .maskformer_model import MaskFormer
+from .cropformer_model import CropFormer
+from .test_time_augmentation import SemanticSegmentorWithTTA
diff --git a/annotator/entityseg/mask2former/config.py b/annotator/entityseg/mask2former/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc47f240c03ebbf010204f6fad1fcd8e7c1d4267
--- /dev/null
+++ b/annotator/entityseg/mask2former/config.py
@@ -0,0 +1,139 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) Facebook, Inc. and its affiliates.
+from detectron2.config import CfgNode as CN
+
+
+def add_maskformer2_config(cfg):
+    """
+    Add config for MASK_FORMER.
+    """
+    # NOTE: configs from original maskformer
+    # data config
+    # select the dataset mapper
+    cfg.INPUT.DATASET_MAPPER_NAME = "mask_former_semantic"
+    # Color augmentation
+    cfg.INPUT.COLOR_AUG_SSD = False
+    # We retry random cropping until no single category in semantic segmentation GT occupies more
+    # than `SINGLE_CATEGORY_MAX_AREA` part of the crop.
+    cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA = 1.0
+    # Pad image and segmentation GT in dataset mapper.
+    cfg.INPUT.SIZE_DIVISIBILITY = -1
+
+    # solver config
+    # weight decay on embedding
+    cfg.SOLVER.WEIGHT_DECAY_EMBED = 0.0
+    # optimizer
+    cfg.SOLVER.OPTIMIZER = "ADAMW"
+    cfg.SOLVER.BACKBONE_MULTIPLIER = 0.1
+
+    # mask_former model config
+    cfg.MODEL.MASK_FORMER = CN()
+
+    # loss
+    cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION = True
+    cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT = 0.1
+    cfg.MODEL.MASK_FORMER.CLASS_WEIGHT = 1.0
+    cfg.MODEL.MASK_FORMER.DICE_WEIGHT = 1.0
+    cfg.MODEL.MASK_FORMER.MASK_WEIGHT = 20.0
+
+    # transformer config
+    cfg.MODEL.MASK_FORMER.NHEADS = 8
+    cfg.MODEL.MASK_FORMER.DROPOUT = 0.1
+    cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD = 2048
+    cfg.MODEL.MASK_FORMER.ENC_LAYERS = 0
+    cfg.MODEL.MASK_FORMER.DEC_LAYERS = 6
+    cfg.MODEL.MASK_FORMER.PRE_NORM = False
+
+    cfg.MODEL.MASK_FORMER.HIDDEN_DIM = 256
+    cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES = 100
+
+    cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE = "res5"
+    cfg.MODEL.MASK_FORMER.ENFORCE_INPUT_PROJ = False
+
+    # mask_former inference config
+    cfg.MODEL.MASK_FORMER.TEST = CN()
+    cfg.MODEL.MASK_FORMER.TEST.SEMANTIC_ON = True
+    cfg.MODEL.MASK_FORMER.TEST.INSTANCE_ON = False
+    cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON = False
+    cfg.MODEL.MASK_FORMER.TEST.OBJECT_MASK_THRESHOLD = 0.0
+    cfg.MODEL.MASK_FORMER.TEST.OVERLAP_THRESHOLD = 0.0
+    cfg.MODEL.MASK_FORMER.TEST.SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE = False
+
+    # Sometimes `backbone.size_divisibility` is set to 0 for some backbone (e.g. ResNet)
+    # you can use this config to override
+    cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY = 32
+
+    # pixel decoder config
+    cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256
+    # adding transformer in pixel decoder
+    cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = 0
+    # pixel decoder
+    cfg.MODEL.SEM_SEG_HEAD.PIXEL_DECODER_NAME = "BasePixelDecoder"
+
+    # swin transformer backbone
+    cfg.MODEL.SWIN = CN()
+    cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE = 224
+    cfg.MODEL.SWIN.PATCH_SIZE = 4
+    cfg.MODEL.SWIN.EMBED_DIM = 96
+    cfg.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]
+    cfg.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]
+    cfg.MODEL.SWIN.WINDOW_SIZE = 7
+    cfg.MODEL.SWIN.MLP_RATIO = 4.0
+    cfg.MODEL.SWIN.QKV_BIAS = True
+    cfg.MODEL.SWIN.QK_SCALE = None
+    cfg.MODEL.SWIN.DROP_RATE = 0.0
+    cfg.MODEL.SWIN.ATTN_DROP_RATE = 0.0
+    cfg.MODEL.SWIN.DROP_PATH_RATE = 0.3
+    cfg.MODEL.SWIN.APE = False
+    cfg.MODEL.SWIN.PATCH_NORM = True
+    cfg.MODEL.SWIN.OUT_FEATURES = ["res2", "res3", "res4", "res5"]
+    cfg.MODEL.SWIN.USE_CHECKPOINT = False
+
+    # NOTE: maskformer2 extra configs
+    # transformer module
+    cfg.MODEL.MASK_FORMER.TRANSFORMER_DECODER_NAME = "MultiScaleMaskedTransformerDecoder"
+
+    # LSJ aug
+    cfg.INPUT.IMAGE_SIZE = 1024
+    cfg.INPUT.MIN_SCALE = 0.1
+    cfg.INPUT.MAX_SCALE = 2.0
+
+    # MSDeformAttn encoder configs
+    cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES = ["res3", "res4", "res5"]
+    cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_N_POINTS = 4
+    cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_N_HEADS = 8
+
+    # point loss configs
+    # Number of points sampled during training for a mask point head.
+    cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS = 112 * 112
+    # Oversampling parameter for PointRend point sampling during training. Parameter `k` in the
+    # original paper.
+    cfg.MODEL.MASK_FORMER.OVERSAMPLE_RATIO = 3.0
+    # Importance sampling parameter for PointRend point sampling during training. Parametr `beta` in
+    # the original paper.
+    cfg.MODEL.MASK_FORMER.IMPORTANCE_SAMPLE_RATIO = 0.75
+
+    ## For Entity
+    cfg.ENTITY = CN()
+    cfg.ENTITY.ENABLE = False
+    cfg.ENTITY.CROP_AREA_RATIO = 0.7
+    cfg.ENTITY.CROP_STRIDE_RATIO = 0.6
+    cfg.ENTITY.CROP_SAMPLE_NUM_TRAIN = 1
+    cfg.ENTITY.CROP_SAMPLE_NUM_TEST = 4
+
+    ## fuse frame embeddings to batch embedding
+    cfg.ENTITY.FUSE_NUM_LAYERS = 1
+    cfg.ENTITY.FUSE_ENC_HIDDIEN_DIM = 256
+    cfg.ENTITY.FUSE_ENC_NHEADS = 8
+    cfg.ENTITY.FUSE_ENC_PRE_NORM = False
+    cfg.ENTITY.FUSE_ENC_DIM_FEEDFORWARD = 2048
+    cfg.ENTITY.FUSE_ENC_LAST_LAYERS = 1
+    cfg.ENTITY.FUSE_DEC_NUM_LAYERS = 3
+
+    ## Hornet backbone
+    cfg.MODEL.HORNET = CN()
+    cfg.MODEL.HORNET.DEPTHS = [2, 3, 18, 2]
+    cfg.MODEL.HORNET.BASE_DIM = 192
+    cfg.MODEL.HORNET.GCONV = ['partial(gnconv, order=2, s=1/3)', 'partial(gnconv, order=3, s=1/3)', 'partial(gnconv, order=4, s=1/3, h=24, w=13, gflayer=GlobalLocalFilter)', 'partial(gnconv, order=5, s=1/3, h=12, w=7, gflayer=GlobalLocalFilter)']
+    cfg.MODEL.HORNET.DROP_PATH_RATE=0.6
+    cfg.MODEL.HORNET.OUT_FEATURES = ["res2", "res3", "res4", "res5"]
diff --git a/annotator/entityseg/mask2former/cropformer_model.py b/annotator/entityseg/mask2former/cropformer_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..601594246643e8780d9901f5c18d0fdc5ebd8910
--- /dev/null
+++ b/annotator/entityseg/mask2former/cropformer_model.py
@@ -0,0 +1,678 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+from typing import Tuple
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+import pdb
+import numpy as np
+import cv2
+import os
+
+from detectron2.config import configurable
+from detectron2.data import MetadataCatalog
+from detectron2.modeling import META_ARCH_REGISTRY, build_backbone, build_sem_seg_head
+from detectron2.modeling.backbone import Backbone
+from detectron2.modeling.postprocessing import sem_seg_postprocess
+from detectron2.structures import Boxes, ImageList, Instances, BitMasks
+from detectron2.utils.memory import retry_if_cuda_oom
+from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES
+
+from .modeling.criterion import SetCriterion
+from .modeling.matcher import HungarianMatcher
+from .modeling.criterion_view import ViewSetCriterion
+from .modeling.matcher_view import ViewHungarianMatcher
+import pdb
+import copy
+
+@META_ARCH_REGISTRY.register()
+class CropFormer(nn.Module):
+    """
+    Main class for mask classification semantic segmentation architectures.
+    """
+    @configurable
+    def __init__(
+        self,
+        *,
+        cfg,
+        backbone: Backbone,
+        sem_seg_head: nn.Module,
+        criterion_2d: nn.Module,
+        criterion_3d: nn.Module,
+        num_queries: int,
+        object_mask_threshold: float,
+        overlap_threshold: float,
+        metadata,
+        size_divisibility: int,
+        sem_seg_postprocess_before_inference: bool,
+        pixel_mean: Tuple[float],
+        pixel_std: Tuple[float],
+        # inference
+        semantic_on: bool,
+        panoptic_on: bool,
+        instance_on: bool,
+        test_topk_per_image: int,
+    ):
+        """
+        Args:
+            backbone: a backbone module, must follow detectron2's backbone interface
+            sem_seg_head: a module that predicts semantic segmentation from backbone features
+            criterion: a module that defines the loss
+            num_queries: int, number of queries
+            object_mask_threshold: float, threshold to filter query based on classification score
+                for panoptic segmentation inference
+            overlap_threshold: overlap threshold used in general inference for panoptic segmentation
+            metadata: dataset meta, get `thing` and `stuff` category names for panoptic
+                segmentation inference
+            size_divisibility: Some backbones require the input height and width to be divisible by a
+                specific integer. We can use this to override such requirement.
+            sem_seg_postprocess_before_inference: whether to resize the prediction back
+                to original input size before semantic segmentation inference or after.
+                For high-resolution dataset like Mapillary, resizing predictions before
+                inference will cause OOM error.
+            pixel_mean, pixel_std: list or tuple with #channels element, representing
+                the per-channel mean and std to be used to normalize the input image
+            semantic_on: bool, whether to output semantic segmentation prediction
+            instance_on: bool, whether to output instance segmentation prediction
+            panoptic_on: bool, whether to output panoptic segmentation prediction
+            test_topk_per_image: int, instance segmentation parameter, keep topk instances per image
+        """
+        super().__init__()
+        self.cfg = cfg
+        self.backbone = backbone
+        self.sem_seg_head = sem_seg_head
+        self.criterion_2d = criterion_2d
+        self.criterion_3d = criterion_3d
+        ## colors
+        self.colors = [info["color"] for info in COCO_CATEGORIES]
+
+        self.num_queries = num_queries
+        self.overlap_threshold = overlap_threshold
+        self.object_mask_threshold = object_mask_threshold
+        self.metadata = metadata
+        if size_divisibility < 0:
+            # use backbone size_divisibility if not set
+            size_divisibility = self.backbone.size_divisibility
+        self.size_divisibility = size_divisibility
+        self.sem_seg_postprocess_before_inference = sem_seg_postprocess_before_inference
+        self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
+        self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
+
+        ## colors
+        self.colors = [info["color"] for info in COCO_CATEGORIES]
+
+        # additional args
+        self.semantic_on = semantic_on
+        self.instance_on = instance_on
+        self.panoptic_on = panoptic_on
+        self.test_topk_per_image = test_topk_per_image
+
+        if not self.semantic_on:
+            assert self.sem_seg_postprocess_before_inference
+
+    @classmethod
+    def from_config(cls, cfg):
+        backbone = build_backbone(cfg)
+        sem_seg_head = build_sem_seg_head(cfg, backbone.output_shape())
+
+        # Loss parameters:
+        deep_supervision = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION
+        no_object_weight = cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT
+
+        # loss weights
+        class_weight = cfg.MODEL.MASK_FORMER.CLASS_WEIGHT
+        dice_weight = cfg.MODEL.MASK_FORMER.DICE_WEIGHT
+        mask_weight = cfg.MODEL.MASK_FORMER.MASK_WEIGHT
+
+        # building criterion
+        matcher_2d = HungarianMatcher(
+            cost_class=class_weight,
+            cost_mask=mask_weight,
+            cost_dice=dice_weight,
+            num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS,
+        )
+
+        matcher_3d = ViewHungarianMatcher(
+            cost_class=class_weight,
+            cost_mask=mask_weight,
+            cost_dice=dice_weight,
+            num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS,
+        )
+
+        weight_dict = {"loss_ce": class_weight, "loss_mask": mask_weight, "loss_dice": dice_weight}
+
+        if deep_supervision:
+            dec_layers = cfg.MODEL.MASK_FORMER.DEC_LAYERS
+            aux_weight_dict = {}
+            for i in range(dec_layers - 1):
+                aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
+            weight_dict.update(aux_weight_dict)
+
+        losses = ["labels", "masks"]
+
+        criterion_2d = SetCriterion(
+            sem_seg_head.num_classes,
+            matcher=matcher_2d,
+            weight_dict=weight_dict,
+            eos_coef=no_object_weight,
+            losses=losses,
+            num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS,
+            oversample_ratio=cfg.MODEL.MASK_FORMER.OVERSAMPLE_RATIO,
+            importance_sample_ratio=cfg.MODEL.MASK_FORMER.IMPORTANCE_SAMPLE_RATIO,
+        )
+
+        criterion_3d = ViewSetCriterion(
+            sem_seg_head.num_classes,
+            matcher=matcher_3d,
+            weight_dict=weight_dict,
+            eos_coef=no_object_weight,
+            losses=losses,
+            num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS,
+            oversample_ratio=cfg.MODEL.MASK_FORMER.OVERSAMPLE_RATIO,
+            importance_sample_ratio=cfg.MODEL.MASK_FORMER.IMPORTANCE_SAMPLE_RATIO,
+        )
+
+        return {
+            "cfg": cfg,
+            "backbone": backbone,
+            "sem_seg_head": sem_seg_head,
+            "criterion_2d": criterion_2d,
+            "criterion_3d": criterion_3d,
+            "num_queries": cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES,
+            "object_mask_threshold": cfg.MODEL.MASK_FORMER.TEST.OBJECT_MASK_THRESHOLD,
+            "overlap_threshold": cfg.MODEL.MASK_FORMER.TEST.OVERLAP_THRESHOLD,
+            "metadata": MetadataCatalog.get(cfg.DATASETS.TRAIN[0]),
+            "size_divisibility": cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY,
+            "sem_seg_postprocess_before_inference": (
+                cfg.MODEL.MASK_FORMER.TEST.SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE
+                or cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON
+                or cfg.MODEL.MASK_FORMER.TEST.INSTANCE_ON
+            ),
+            "pixel_mean": cfg.MODEL.PIXEL_MEAN,
+            "pixel_std": cfg.MODEL.PIXEL_STD,
+            # inference
+            "semantic_on": cfg.MODEL.MASK_FORMER.TEST.SEMANTIC_ON,
+            "instance_on": cfg.MODEL.MASK_FORMER.TEST.INSTANCE_ON,
+            "panoptic_on": cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON,
+            "test_topk_per_image": cfg.TEST.DETECTIONS_PER_IMAGE,
+        }
+
+    @property
+    def device(self):
+        return self.pixel_mean.device
+
+    def forward(self, batched_inputs):
+        """
+        Args:
+            batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
+                Each item in the list contains the inputs for one image.
+                For now, each item in the list is a dict that contains:
+                   * "image": Tensor, image in (C, H, W) format.
+                   * "instances": per-region ground truth
+                   * Other information that's included in the original dicts, such as:
+                     "height", "width" (int): the output resolution of the model (may be different
+                     from input resolution), used in inference.
+        Returns:
+            list[dict]:
+                each dict has the results for one image. The dict contains the following keys:
+
+                * "sem_seg":
+                    A Tensor that represents the
+                    per-pixel segmentation prediced by the head.
+                    The prediction has shape KxHxW that represents the logits of
+                    each class for each pixel.
+                * "panoptic_seg":
+                    A tuple that represent panoptic output
+                    panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment.
+                    segments_info (list[dict]): Describe each segment in `panoptic_seg`.
+                        Each dict contains keys "id", "category_id", "isthing".
+        """
+        ## make new images
+        batched_inputs_new = []
+        for batched_input in batched_inputs:
+            ori_infos = {"height": batched_input["height"],
+                        "width": batched_input["width"], 
+                        "image": batched_input["image"],
+                        # "file_name": batched_input["file_name"],
+                        }
+            if "instances" in batched_input.keys():
+                ori_instances = batched_input["instances"]
+                ori_instances.original_indices = torch.arange(0, len(ori_instances)).long()
+                ori_infos["instances"] = ori_instances
+            batched_inputs_new.append(ori_infos)
+            ## cropped patches
+            # pdb.set_trace()
+            crop_region = batched_input["crop_region"]
+            crop_images = batched_input["image_crop"]
+            crop_o_width  = int(crop_region[0][2]-crop_region[0][0])
+            crop_o_height = int(crop_region[0][3]-crop_region[0][1])
+
+            if "instances_crop" in batched_input.keys():
+                crop_instances = batched_input["instances_crop"]
+            else:
+                crop_instances = None
+
+            for crop_index, crop_image in enumerate(crop_images):
+                crop_infos = {"height": crop_o_height, "width": crop_o_width, "image": crop_image}
+                if not crop_instances == None:
+                    crop_instance = crop_instances[crop_index]
+                    crop_instance.original_indices = torch.arange(0, len(crop_instance)).long()
+                    crop_infos["instances"] = crop_instance
+                batched_inputs_new.append(crop_infos)
+
+        images = [x["image"].to(self.device) for x in batched_inputs_new]
+        ## +1 means 
+        num_views = self.cfg.ENTITY.CROP_SAMPLE_NUM_TRAIN+1 if self.training else self.cfg.ENTITY.CROP_SAMPLE_NUM_TEST+1
+        for i in range(len(images)):
+            if i%num_views==0:
+                continue
+            _, c_h, c_w = images[i].shape
+            if "instances" in batched_inputs_new[i].keys():
+                batched_inputs_new[i]["instances"]._image_size = (c_h, c_w)
+        
+        images = [(x - self.pixel_mean) / self.pixel_std for x in images]
+        images = ImageList.from_tensors(images, self.size_divisibility)
+
+        features = self.backbone(images.tensor)
+        outputs_2d, outputs_3d = self.sem_seg_head(features)
+
+        if self.training:
+            if self.cfg.ENTITY.ENABLE:
+                for i in range(len(batched_inputs_new)):
+                    batched_inputs_new[i]["instances"].gt_classes[:] = 0
+            
+            if "instances" in batched_inputs[0]:
+                gt_instances = [x["instances"].to(self.device) for x in batched_inputs_new]
+                targets_2d = self.prepare_targets_2d(copy.deepcopy(gt_instances), copy.deepcopy(images))
+                targets_3d = self.prepare_targets_3d(copy.deepcopy(gt_instances), copy.deepcopy(images), num_views)
+            else:
+                targets = None
+
+            # bipartite matching-based loss
+            losses = {}
+            losses_2d = self.criterion_2d(outputs_2d, targets_2d)
+            losses_3d = self.criterion_3d(outputs_3d, targets_3d)
+
+            for k in list(losses_2d.keys()):
+                if k in self.criterion_2d.weight_dict:
+                    losses[k+"_2d"] = losses_2d[k] * self.criterion_2d.weight_dict[k] * 0.5
+                else:
+                    # remove this loss if not specified in `weight_dict`
+                    losses_2d.pop(k)
+            
+            for k in list(losses_3d.keys()):
+                if k in self.criterion_3d.weight_dict:
+                    losses[k+"_3d"] = losses_3d[k] * self.criterion_3d.weight_dict[k]
+                else:
+                    # remove this loss if not specified in `weight_dict`
+                    losses_3d.pop(k)
+            return losses
+        else:
+            mask_cls_results_3d  = outputs_3d["pred_logits"][0] ## 100,2
+            mask_pred_results_3d = outputs_3d["pred_masks"][0]  ## 100,5,200, 304
+
+            mask_cls_results_2d  = outputs_2d["pred_logits"]
+            mask_pred_results_2d = outputs_2d["pred_masks"]
+            # upsample masks
+            
+            mask_pred_results_3d = retry_if_cuda_oom(F.interpolate)(
+                mask_pred_results_3d,
+                size=(images.tensor.shape[-2], images.tensor.shape[-1]),
+                mode="bilinear",
+                align_corners=False,
+            )
+
+            mask_pred_results_2d = F.interpolate(
+                mask_pred_results_2d,
+                size=(images.tensor.shape[-2], images.tensor.shape[-1]),
+                mode="bilinear",
+                align_corners=False,
+            )
+
+            del outputs_2d, outputs_3d
+            
+            crop_regions = batched_input["crop_region"][:num_views-1]
+            processed_results = retry_if_cuda_oom(self.inference_whole_views)(
+                                mask_cls_results_3d,
+                                mask_pred_results_3d,
+                                mask_cls_results_2d,
+                                mask_pred_results_2d,
+                                batched_inputs_new,
+                                images.image_sizes, 
+                                crop_regions)
+
+            # processed_results = retry_if_cuda_oom(self.instance_inference_nonoverlap)(
+            #                             mask_cls_results_2d[0], 
+            #                             mask_pred_results_2d[0],
+            #                             batched_inputs_new[0], 
+            #                             images.image_sizes[0])
+
+            return [{"instances": processed_results}]
+
+    def prepare_targets_2d(self, targets, images):
+        h_pad, w_pad = images.tensor.shape[-2:]
+        new_targets = []
+        for targets_per_image in targets:
+            gt_masks = targets_per_image.gt_masks.tensor           
+            gt_valid = targets_per_image.gt_boxes_valid
+            padded_masks = torch.zeros((gt_masks.shape[0], h_pad, w_pad), dtype=gt_masks.dtype, device=gt_masks.device)
+            padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks
+            valid_index = torch.nonzero(gt_valid).flatten()
+            new_targets.append(
+                {
+                    "labels": targets_per_image.gt_classes[valid_index],
+                    "masks": padded_masks[valid_index],
+                }
+            )
+        return new_targets
+    
+    def prepare_targets_3d(self, targets_ori, images, num_views):
+        T = num_views
+        B = int(len(targets_ori) / T)
+        h_pad, w_pad = images.tensor.shape[-2:]
+        
+        ## reshape to new targets
+        new_targets = []
+        for count, target in enumerate(targets_ori):
+            b_index, t_index = int(count // T), int(count % T)
+            if t_index == 0:
+                new_targets.append([target])
+            else:
+                new_targets[b_index].append(target)
+
+        gt_instances = []
+        for count, targets in enumerate(new_targets):
+            _num_instance = len(targets[0])
+            mask_shape = [_num_instance, T, h_pad, w_pad]
+            gt_masks_per_view = torch.zeros(mask_shape, dtype=torch.bool, device=self.device)
+
+            for v_i, targets_per_view in enumerate(targets):
+                assert torch.all(targets[0].original_indices == targets_per_view.original_indices)
+            
+            gt_ids_per_view   = []
+            gt_ids_per_valid  = []
+            gt_ids_categories = []
+            ## view first, then entities
+            for v_i, targets_per_view in enumerate(targets):
+                targets_per_view = targets_per_view.to(self.device)
+                h, w = targets_per_view.image_size
+                for i_i, (instance_mask, instance_valid) in enumerate(zip(targets_per_view.gt_masks.tensor, targets_per_view.gt_boxes_valid)):
+                    if instance_valid == 1:
+                        gt_masks_per_view[i_i, v_i, :h, :w] = instance_mask
+                gt_ids_per_valid.append(targets_per_view.gt_boxes_valid[None,:])
+                gt_ids_per_view.append(targets_per_view.original_indices[None,:])
+                gt_ids_categories.append(targets_per_view.gt_classes[None, :])
+            ## (num_instances, num_views)
+            gt_ids_per_valid = torch.cat(gt_ids_per_valid, dim=0).permute((1,0))
+            gt_ids_per_view = torch.cat(gt_ids_per_view, dim=0).permute((1,0))
+            gt_ids_categories = torch.cat(gt_ids_categories, dim=0).permute((1,0))
+            
+            gt_ids_per_view[gt_ids_per_valid == 0] = -1
+            valid_idx = (gt_ids_per_view != 1).any(dim=-1)
+            ## categoreis 
+            gt_classes_per_group = gt_ids_categories[:,0]   ## N
+            gt_ids_per_group = gt_ids_per_view   ## N, num_views
+            gt_masks_per_group = gt_masks_per_view.float() ## N, num_views, H, W
+
+            ## 
+            gt_instances.append({"labels": gt_classes_per_group, 
+                                "ids": gt_ids_per_group,
+                                "masks": gt_masks_per_group})
+
+        return gt_instances
+
+    def semantic_inference(self, mask_cls, mask_pred):
+        mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1]
+        mask_pred = mask_pred.sigmoid()
+        semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred)
+        return semseg
+
+    def panoptic_inference(self, mask_cls, mask_pred):
+        scores, labels = F.softmax(mask_cls, dim=-1).max(-1)
+        mask_pred = mask_pred.sigmoid()
+
+        keep = labels.ne(self.sem_seg_head.num_classes) & (scores > self.object_mask_threshold)
+        cur_scores = scores[keep]
+        cur_classes = labels[keep]
+        cur_masks = mask_pred[keep]
+        cur_mask_cls = mask_cls[keep]
+        cur_mask_cls = cur_mask_cls[:, :-1]
+
+        cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks
+
+        h, w = cur_masks.shape[-2:]
+        panoptic_seg = torch.zeros((h, w), dtype=torch.int32, device=cur_masks.device)
+        segments_info = []
+
+        current_segment_id = 0
+
+        if cur_masks.shape[0] == 0:
+            # We didn't detect any mask :(
+            return panoptic_seg, segments_info
+        else:
+            # take argmax
+            cur_mask_ids = cur_prob_masks.argmax(0)
+            stuff_memory_list = {}
+            for k in range(cur_classes.shape[0]):
+                pred_class = cur_classes[k].item()
+                isthing = pred_class in self.metadata.thing_dataset_id_to_contiguous_id.values()
+                mask_area = (cur_mask_ids == k).sum().item()
+                original_area = (cur_masks[k] >= 0.5).sum().item()
+                mask = (cur_mask_ids == k) & (cur_masks[k] >= 0.5)
+
+                if mask_area > 0 and original_area > 0 and mask.sum().item() > 0:
+                    if mask_area / original_area < self.overlap_threshold:
+                        continue
+
+                    # merge stuff regions
+                    if not isthing:
+                        if int(pred_class) in stuff_memory_list.keys():
+                            panoptic_seg[mask] = stuff_memory_list[int(pred_class)]
+                            continue
+                        else:
+                            stuff_memory_list[int(pred_class)] = current_segment_id + 1
+
+                    current_segment_id += 1
+                    panoptic_seg[mask] = current_segment_id
+
+                    segments_info.append(
+                        {
+                            "id": current_segment_id,
+                            "isthing": bool(isthing),
+                            "category_id": int(pred_class),
+                        }
+                    )
+            return panoptic_seg, segments_info
+    
+    def instance_inference_nonoverlap(self, mask_cls, mask_pred):
+        # mask_pred is already processed to have the same shape as original input
+        image_size = mask_pred.shape[-2:]
+
+        # [Q, K]
+        scores = F.softmax(mask_cls, dim=-1)[:, :-1]
+        labels = torch.arange(self.sem_seg_head.num_classes, device=self.device).unsqueeze(0).repeat(self.num_queries, 1).flatten(0, 1)
+        # scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.num_queries, sorted=False)
+        scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.test_topk_per_image, sorted=False)
+        labels_per_image = labels[topk_indices]
+
+        topk_indices = topk_indices // self.sem_seg_head.num_classes
+        # mask_pred = mask_pred.unsqueeze(1).repeat(1, self.sem_seg_head.num_classes, 1).flatten(0, 1)
+        mask_pred = mask_pred[topk_indices]
+
+        ###### ranks
+        pred_masks = (mask_pred>0).float()
+        pred_masks_logits = mask_pred.sigmoid()
+        pred_scores = scores_per_image
+
+        _, m_H, m_W = pred_masks.shape
+        mask_id = torch.zeros((m_H, m_W), dtype=torch.int).to(pred_masks.device)
+        sorted_scores, ranks = torch.sort(pred_scores)
+        ranks = ranks + 1
+        for index in ranks:
+            mask_id[(pred_masks[index-1]==1)] = int(index)
+        # re-generate mask
+        new_scores = []
+        new_masks  = []
+        new_masks_logits = []
+        entity_nums = len(ranks)
+        for ii in range(entity_nums):
+            index = int(ranks[entity_nums-ii-1])
+            score = sorted_scores[entity_nums-ii-1]
+            new_scores.append(score)
+            new_masks.append((mask_id==index).float())
+            new_masks_logits.append(pred_masks_logits[index-1])
+        
+        new_scores = torch.stack(new_scores)
+        new_masks  = torch.stack(new_masks)
+        new_masks_logits = torch.stack(new_masks_logits)
+
+        result = Instances(image_size)
+        # mask (before sigmoid)
+        result.pred_masks = new_masks
+        result.pred_boxes = Boxes(torch.zeros(new_masks.size(0), 4))
+        # Uncomment the following to get boxes from masks (this is slow)
+
+        # calculate average mask prob
+        mask_scores_per_image = (new_masks_logits.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / (result.pred_masks.flatten(1).sum(1) + 1e-6)
+        result.scores = new_scores * mask_scores_per_image
+        result.pred_classes = labels_per_image
+        return result
+
+    def instance_inference(self, mask_cls, mask_pred):
+        # mask_pred is already processed to have the same shape as original input
+        image_size = mask_pred.shape[-2:]
+        
+        # [Q, K]
+        scores = F.softmax(mask_cls, dim=-1)[:, :-1]
+        labels = torch.arange(self.sem_seg_head.num_classes, device=self.device).unsqueeze(0).repeat(self.num_queries, 1).flatten(0, 1)
+        # scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.num_queries, sorted=False)
+        scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.test_topk_per_image, sorted=False)
+        labels_per_image = labels[topk_indices]
+        
+        topk_indices = topk_indices // self.sem_seg_head.num_classes
+        # mask_pred = mask_pred.unsqueeze(1).repeat(1, self.sem_seg_head.num_classes, 1).flatten(0, 1)
+        mask_pred = mask_pred[topk_indices]
+
+        # if this is panoptic segmentation, we only keep the "thing" classes
+        if self.panoptic_on:
+            keep = torch.zeros_like(scores_per_image).bool()
+            for i, lab in enumerate(labels_per_image):
+                keep[i] = lab in self.metadata.thing_dataset_id_to_contiguous_id.values()
+
+            scores_per_image = scores_per_image[keep]
+            labels_per_image = labels_per_image[keep]
+            mask_pred = mask_pred[keep]
+
+        result = Instances(image_size)
+        # mask (before sigmoid)
+        result.pred_masks = (mask_pred > 0).float()
+        result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4))
+        # Uncomment the following to get boxes from masks (this is slow)
+        # result.pred_boxes = BitMasks(mask_pred > 0).get_bounding_boxes()
+
+        # calculate average mask prob
+        mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / (result.pred_masks.flatten(1).sum(1) + 1e-6)
+        # pdb.set_trace()
+        result.scores = scores_per_image * mask_scores_per_image
+        result.pred_classes = labels_per_image
+        return result
+    
+    def inference_whole_views(self, pred_cls, pred_masks, pred_cls_2d, pred_masks_2d, batched_inputs, image_sizes, crop_regions):
+        ## pred_masks: [100, 5, 800, 1216]
+        ## pred_masks_2d: [5, 100, 800, 1216]
+        scores = F.softmax(pred_cls, dim=-1)[:,:-1]   # 100,1
+        scores_2d = F.softmax(pred_cls_2d, dim=-1)[:, :, :-1]  # 5, 100, 1
+        
+        # scores = (scores+scores_2d[0])/2
+        labels = torch.arange(self.sem_seg_head.num_classes, device=self.device).unsqueeze(0).repeat(self.num_queries, 1).flatten(0, 1)
+        ### keep all the indices
+        scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.num_queries, sorted=False)
+        labels_per_image = labels[topk_indices]
+        # topk_indices = topk_indices // self.sem_seg_head.num_classes
+        topk_indices = torch.div(topk_indices, self.sem_seg_head.num_classes, rounding_mode="trunc")
+        pred_masks = pred_masks[topk_indices]
+        pred_masks = pred_masks.permute((1,0,2,3))
+
+        new_pred_masks = []
+        for view_index, (pred_masks_per_view, batched_input_per_view, image_size_per_view) in enumerate(zip(pred_masks, batched_inputs, image_sizes)):
+            O_H = batched_input_per_view["height"]
+            O_W = batched_input_per_view["width"]
+
+            SO_H, SO_W = image_size_per_view
+
+            pred_masks_per_view = pred_masks_per_view[..., : SO_H, :SO_W]
+            pred_masks_per_view = F.interpolate(pred_masks_per_view[None], size=(O_H, O_W), mode="bilinear", align_corners=False)
+
+            new_pred_masks.append(pred_masks_per_view[0].sigmoid())
+        
+        ## fuse the masks
+        full_image_masks  = new_pred_masks[0]
+        
+        ## fuse crop image
+        fused_image_masks = torch.zeros_like(full_image_masks).float()
+        fused_image_masks_valid = torch.zeros_like(full_image_masks).float() + 1e-16
+        for crop_region_per_view, pred_masks_per_view in zip(crop_regions, new_pred_masks[1:]):
+            x0, y0, x1, y1 = crop_region_per_view
+            fused_image_masks[..., y0:y1, x0:x1] += pred_masks_per_view
+            fused_image_masks_valid[..., y0:y1, x0:x1] += 1
+        
+        # add original masks
+        fused_image_masks += full_image_masks
+        fused_image_masks_valid += 1
+
+        ## average
+        fuse_image_masks = fused_image_masks / fused_image_masks_valid
+
+        ###### change to the single image, begin to non_overlap_supression
+        ##  ranks
+        pred_masks_logits = fuse_image_masks
+        pred_masks = (fuse_image_masks>0.5).float()
+        pred_scores = scores_per_image
+
+        _, m_H, m_W = pred_masks.shape
+        ## for visualization
+        mask_id = torch.zeros((m_H, m_W), dtype=torch.int).to(pred_masks.device)
+        
+        # mask_id_colors = np.zeros((m_H, m_W, 3), dtype=np.uint8)
+        # pred_masks_np = pred_masks.cpu().numpy()
+
+        sorted_scores, ranks = torch.sort(pred_scores)
+        ranks = ranks + 1
+        for index in ranks:
+            mask_id[(pred_masks[index-1]==1)] = int(index)
+            # mask_id_colors[(pred_masks_np[index-1]==1)] = self.colors[index]
+        # base_path = "/group/20018/gavinqi/vis_entityv2_release_debug"
+        # pdb.set_trace()
+        # file_name = batched_inputs[0]["file_name"]
+        # split_index, img_name = file_name.split("/")[-2:]
+        # save_name = img_name.split(".")[0]+".png"
+        # if not os.path.exists(os.path.join(base_path, save_name)):
+        #     cv2.imwrite(os.path.join(base_path, save_name), mask_id_colors)
+        # re-generate mask
+        new_scores = []
+        new_masks  = []
+        new_masks_logits = []
+        entity_nums = len(ranks)
+        for ii in range(entity_nums):
+            index = int(ranks[entity_nums-ii-1])
+            score = sorted_scores[entity_nums-ii-1]
+            new_scores.append(score)
+            new_masks.append((mask_id==index).float())
+            new_masks_logits.append(pred_masks_logits[index-1])
+        
+        new_scores = torch.stack(new_scores)
+        new_masks  = torch.stack(new_masks)
+        new_masks_logits = torch.stack(new_masks_logits)
+        # make result
+        image_size = (batched_inputs[0]["height"], batched_inputs[0]["width"])
+        result = Instances(image_size)
+        # mask (before sigmoid)
+        result.pred_masks = new_masks
+        result.pred_boxes = Boxes(torch.zeros(new_masks.size(0), 4))
+        # Uncomment the following to get boxes from masks (this is slow)
+
+        # calculate average mask prob
+        mask_scores_per_image = (new_masks_logits.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / (result.pred_masks.flatten(1).sum(1) + 1e-6)
+        result.scores = new_scores * mask_scores_per_image
+        result.pred_classes = labels_per_image
+        return result
diff --git a/annotator/entityseg/mask2former/data/__init__.py b/annotator/entityseg/mask2former/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9020c2df23e2af280b7bb168b996ae9eaf312eb8
--- /dev/null
+++ b/annotator/entityseg/mask2former/data/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
diff --git a/annotator/entityseg/mask2former/data/dataset_mappers/__init__.py b/annotator/entityseg/mask2former/data/dataset_mappers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9020c2df23e2af280b7bb168b996ae9eaf312eb8
--- /dev/null
+++ b/annotator/entityseg/mask2former/data/dataset_mappers/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
diff --git a/annotator/entityseg/mask2former/data/dataset_mappers/crop_augmentations.py b/annotator/entityseg/mask2former/data/dataset_mappers/crop_augmentations.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c0b139b84b3ef27c95e522b7a919467cf7b3c81
--- /dev/null
+++ b/annotator/entityseg/mask2former/data/dataset_mappers/crop_augmentations.py
@@ -0,0 +1,421 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) Facebook, Inc. and its affiliates.
+"""
+Implement many useful :class:`Augmentation`.
+"""
+import numpy as np
+import sys
+from typing import Tuple
+from PIL import Image
+import random
+
+from fvcore.transforms.transform import NoOpTransform, Transform
+
+from detectron2.data.transforms.augmentation import Augmentation
+import pdb
+import math
+
+import logging
+import numpy as np
+import pycocotools.mask as mask_util
+import torch
+from PIL import Image
+from collections import defaultdict
+import copy
+from detectron2.data import transforms as T
+from detectron2.structures import (
+    BitMasks,
+    Boxes,
+    BoxMode,
+    Instances,
+    Keypoints,
+    PolygonMasks,
+    RotatedBoxes,
+    polygons_to_bitmask,
+)
+from detectron2.utils.file_io import PathManager
+
+__all__ = [
+    "BatchResizeShortestEdge",
+    "EntityCrop",
+]
+
+class BatchResizeTransform(Transform):
+    """
+    Resize the image to a target size.
+    """
+
+    def __init__(self, h, w, new_h, new_w, interp=None):
+        """
+        Args:
+            h, w (int): original image size
+            new_h, new_w (int): new image size
+            interp: PIL interpolation methods, defaults to bilinear.
+        """
+        # TODO decide on PIL vs opencv
+        super().__init__()
+        if interp is None:
+            interp = Image.BILINEAR
+        self._set_attributes(locals())
+
+    def apply_image(self, imgs, interp=None):
+        dim_num = len(imgs.shape)
+        assert dim_num == 4
+        interp_method = interp if interp is not None else self.interp
+        resized_imgs = []
+        for img in imgs:
+            if len(img.shape) > 2 and img.shape[2] == 1:
+                pil_image = Image.fromarray(img[:, :, 0], mode="L")
+            else:
+                pil_image = Image.fromarray(img)
+            pil_image = pil_image.resize((self.new_w, self.new_h), interp_method)
+            ret = np.asarray(pil_image)
+            if len(img.shape) > 2 and img.shape[2] == 1:
+                ret = np.expand_dims(ret, -1)
+            resized_imgs.append(ret)
+        resized_imgs = np.stack(resized_imgs)
+        return resized_imgs
+
+    def apply_coords(self, coords):
+        coords[:, 0] = coords[:, 0] * (self.new_w * 1.0 / self.w)
+        coords[:, 1] = coords[:, 1] * (self.new_h * 1.0 / self.h)
+        return coords
+    
+    def apply_box(self, boxes):
+        boxes = boxes[0]
+        new_boxes = super(BatchResizeTransform, self).apply_box(boxes[:,:4])
+        boxes[...,:4] = new_boxes
+        return boxes[None]
+
+    def apply_segmentation(self, segmentation):
+        if len(segmentation.shape)==3:
+            segmentation = segmentation[..., None]
+            segmentation = self.apply_image(segmentation, interp=Image.NEAREST)
+            segmentation = segmentation[..., 0]
+        else:
+            segmentation = self.apply_image(segmentation, interp=Image.NEAREST)
+        return segmentation
+
+class EntityCropTransform(Transform):
+    """
+    Consectively crop the images
+    """
+    def __init__(self, crop_axises, crop_indexes):
+        super().__init__()
+        self._set_attributes(locals())
+    
+    def apply_image(self, img):
+        """
+        Args:
+            img (ndarray): of shape NxHxWxC, or HxWxC or HxW. The array can be
+                of type uint8 in range [0, 255], or floating point in range
+                [0, 1] or [0, 255]
+        returns:
+            ndarray: cropped images
+        """
+        dim_num = len(img.shape)
+        imgs = []
+        
+        for crop_axis in self.crop_axises:
+            x0, y0, x1, y1 = crop_axis
+            if dim_num <= 3:
+                crop_img = img[y0:y1, x0:x1]
+            else:
+                crop_img = img[..., y0:y1, x0:x1, :]
+            imgs.append(crop_img)
+
+        if dim_num <= 3:
+            imgs = np.stack(imgs, axis=0)
+        else:
+            imgs = np.concatenate(imgs, axis=0)
+        return imgs
+    
+    def apply_coords(self, coords: np.ndarray, x0, y0):
+        coords[:, 0] -= x0
+        coords[:, 1] -= y0
+        return coords
+    
+    def apply_box(self, box: np.ndarray) -> np.ndarray:
+        """
+        box: Nx4, [x0, y0, x1, y1]
+        """
+        idxs = np.array([(0, 1), (2, 1), (0, 3), (2, 3)]).flatten()
+        coords = np.asarray(box).reshape(-1, 4)[:, idxs].reshape(-1, 2)
+        split_boxes = []
+        crop_ws, crop_hs = [], []
+        for crop_axis in self.crop_axises:
+            startw, starth, endw, endh = crop_axis
+            coords_new = self.apply_coords(copy.deepcopy(coords), startw, starth).reshape((-1, 4, 2))
+            minxy = coords_new.min(axis=1)
+            maxxy = coords_new.max(axis=1)
+            trans_boxes = np.concatenate((minxy, maxxy), axis=1)
+            
+            crop_ws.append(endw-startw)
+            crop_hs.append(endh-starth)
+            split_boxes.append(trans_boxes)
+        split_boxes = np.stack(split_boxes, axis=1)
+        ### clip to the image boundary
+        ## assert each crop size is equal
+        for crop_index, (crop_w, crop_h) in enumerate(zip(crop_ws, crop_hs)):
+            assert crop_w == crop_ws[0], "crop width is not equal, crop_{}: {}, crop_0: {}".format(crop_index, crop_w, crop_ws[0])
+            assert crop_h == crop_hs[0], "crop height is not equal, crop_{}: {}, crop_0: {}".format(crop_index, crop_h, crop_hs[0])
+        crop_w = crop_ws[0]
+        crop_h = crop_hs[0]
+        # pdb.set_trace()
+        split_boxes[...,0::2] = np.clip(split_boxes[...,0::2], 0, crop_w)
+        split_boxes[...,1::2] = np.clip(split_boxes[...,1::2], 0, crop_h)
+        valid_inds = (split_boxes[...,2]>split_boxes[...,0]) & (split_boxes[...,3]>split_boxes[...,1])
+        split_infos = np.concatenate((split_boxes, valid_inds[...,None]), axis=-1)
+        return split_infos
+
+class BatchResizeShortestEdge(Augmentation):
+    """
+    Scale the shorter edge to the given size, with a limit of `max_size` on the longer edge.
+    If `max_size` is reached, then downscale so that the longer edge does not exceed max_size.
+    """
+
+    def __init__(
+        self, short_edge_length, max_size=sys.maxsize, sample_style="range", interp=Image.BILINEAR
+    ):
+        """
+        Args:
+            short_edge_length (list[int]): If ``sample_style=="range"``,
+                a [min, max] interval from which to sample the shortest edge length.
+                If ``sample_style=="choice"``, a list of shortest edge lengths to sample from.
+            max_size (int): maximum allowed longest edge length.
+            sample_style (str): either "range" or "choice".
+        """
+        super().__init__()
+        assert sample_style in ["range", "choice"], sample_style
+
+        self.is_range = sample_style == "range"
+        if isinstance(short_edge_length, int):
+            short_edge_length = (short_edge_length, short_edge_length)
+        if self.is_range:
+            assert len(short_edge_length) == 2, (
+                "short_edge_length must be two values using 'range' sample style."
+                f" Got {short_edge_length}!"
+            )
+        self._init(locals())
+
+    def get_transform(self, image):
+        dim_num = len(image.shape)
+        assert dim_num == 4, "the tensor should be in [B, H, W, C]"
+        h, w = image.shape[1:3]
+        if self.is_range:
+            size = np.random.randint(self.short_edge_length[0], self.short_edge_length[1] + 1)
+        else:
+            size = np.random.choice(self.short_edge_length)
+        if size == 0:
+            return NoOpTransform()
+
+        scale = size * 1.0 / min(h, w)
+        if h < w:
+            newh, neww = size, scale * w
+        else:
+            newh, neww = scale * h, size
+        if max(newh, neww) > self.max_size:
+            scale = self.max_size * 1.0 / max(newh, neww)
+            newh = newh * scale
+            neww = neww * scale
+        neww = int(neww + 0.5)
+        newh = int(newh + 0.5)
+        return BatchResizeTransform(h, w, newh, neww, self.interp)
+
+class EntityCrop(Augmentation):
+    def __init__(self, crop_ratio, stride_ratio, sample_num, is_train):
+        super().__init__()
+        self._init(locals())
+
+    def get_transform(self, image):
+        h, w = image.shape[:2]
+        crop_axises, crop_indexes = self.get_crop_axises((h, w))
+        transform = EntityCropTransform(crop_axises, crop_indexes)
+        return transform
+    
+    def get_crop_axises(self, image_size):
+        h, w = image_size
+        crop_w = int(self.crop_ratio*w)
+        crop_h = int(self.crop_ratio*h)
+        # if self.is_train:
+        stride_w = int(self.stride_ratio*w)
+        stride_h = int(self.stride_ratio*h)
+        # pdb.set_trace()
+
+        crop_axises  = []
+        for starth in range(0, h, stride_h):
+            for startw in range(0, w, stride_w):
+                endh = min(starth+crop_h, h)
+                endw = min(startw+crop_w, w)
+                starth = int(endh-crop_h)
+                startw = int(endw-crop_w)
+                crop_axises.append([startw, starth, endw, endh])
+        if self.is_train:
+            crop_indexes = random.sample([i for i in range(len(crop_axises))], self.sample_num)
+            crop_axises = [crop_axises[i] for i in crop_indexes]
+        else:
+            crop_indexes = [i for i in range(self.sample_num)]
+        # left_upper   = [0, 0, crop_w, crop_h]
+        # right_upper  = [w-crop_w, 0, w, crop_h]
+        # left_bottom  = [0, h-crop_h, crop_w, h]
+        # right_bottom = [w-crop_w, h-crop_h, w, h]
+        
+        # crop_axises = [left_upper, right_upper, left_bottom, right_bottom]
+        # crop_indexes = [0,1,2,3]
+        assert len(crop_axises)==len(crop_indexes)
+        return crop_axises, crop_indexes
+
+def transform_instance_annotations_crop(
+    annotation, transforms, image_size, *, keypoint_hflip_indices=None
+):
+    """
+    Apply transforms to box, segmentation and keypoints annotations of a single instance.
+
+    It will use `transforms.apply_box` for the box, and
+    `transforms.apply_coords` for segmentation polygons & keypoints.
+    If you need anything more specially designed for each data structure,
+    you'll need to implement your own version of this function or the transforms.
+
+    Args:
+        annotation (dict): dict of instance annotations for a single instance.
+            It will be modified in-place.
+        transforms (TransformList or list[Transform]):
+        image_size (tuple): the height, width of the transformed image
+        keypoint_hflip_indices (ndarray[int]): see `create_keypoint_hflip_indices`.
+
+    Returns:
+        dict:
+            the same input dict with fields "bbox", "segmentation", "keypoints"
+            transformed according to `transforms`.
+            The "bbox_mode" field will be set to XYXY_ABS.
+    """
+    if isinstance(transforms, (tuple, list)):
+        transforms = T.TransformList(transforms)
+    # bbox is 1d (per-instance bounding box)
+    bbox = BoxMode.convert(annotation["bbox"], annotation["bbox_mode"], BoxMode.XYXY_ABS)
+    
+    # clip transformed bbox to image size
+    bboxes_info = transforms.apply_box(np.array([bbox]))[0].clip(min=0)
+    annotation["bbox"] = np.minimum(bbox, list(image_size + image_size)[::-1])
+    annotation["bbox"] = bboxes_info[...,:4]
+    annotation["bbox_mode"] = BoxMode.XYXY_ABS
+    annotation["bbox_valid"] = bboxes_info[...,4]
+    for transform_type in transforms:
+        if isinstance(transform_type, EntityCropTransform):
+            annotation["crop_axises"] = transform_type.crop_axises
+            annotation["crop_indexes"] = transform_type.crop_indexes
+
+    if "segmentation" in annotation:
+        segm = annotation["segmentation"]
+        assert isinstance(segm, dict), "requiring segmentation encoding -> RLE"
+        # RLE
+        mask = mask_util.decode(segm)
+        mask = transforms.apply_segmentation(mask)
+        annotation["segmentation"] = mask
+    return annotation
+
+def annotations_to_instances_crop(annos, image_size, mask_format="polygon", return_indexes=False):
+    """
+    Create an :class:`Instances` object used by the models,
+    from instance annotations in the dataset dict.
+
+    Args:
+        annos (list[dict]): a list of instance annotations in one image, each
+            element for one instance.
+        image_size (tuple): height, width
+
+    Returns:
+        Instances:
+            It will contain fields "gt_boxes", "gt_classes",
+            "gt_masks", "gt_keypoints", if they can be obtained from `annos`.
+            This is the format that builtin models expect.
+    """
+    ###
+    all_boxes = []
+    all_boxes_valid = []
+    all_classes = []
+    all_segmentations = []
+    all_iscrowds = []
+    # pdb.set_trace()
+    annos_num = len(annos)
+    patches_num = len(annos[0]["bbox"])
+    for ann_index, obj in enumerate(annos):
+        for split_index in range(len(obj["bbox"])):
+            all_boxes.append(BoxMode.convert(obj["bbox"][split_index], obj["bbox_mode"], BoxMode.XYXY_ABS))
+            all_boxes_valid.append(obj["bbox_valid"][split_index])
+            all_classes.append(obj["category_id"])
+            all_segmentations.append(obj["segmentation"][split_index])
+            all_iscrowds.append(obj["iscrowd"])
+            # print("ann_index:{}, split_index:{}".format(ann_index, split_index))
+    
+    new_targets = []
+    crop_axises = annos[0]["crop_axises"]
+    # pdb.set_trace()
+    crop_size = (crop_axises[0][3], crop_axises[0][2])
+    crop_axises = torch.tensor(crop_axises)
+    
+    for split_index in range(patches_num):
+        new_targets.append(Instances(crop_size))
+        # pdb.set_trace()
+        ## boxes
+        new_targets[-1].gt_boxes = Boxes(all_boxes[split_index::patches_num])
+        new_targets[-1].gt_boxes_valid = torch.tensor(all_boxes_valid[split_index::patches_num], dtype=torch.int64)
+        ## categories
+        new_targets[-1].gt_classes = torch.tensor(all_classes[split_index::patches_num], dtype=torch.int64)
+
+        ## masks
+        if "segmentation" in annos[0]:
+            new_targets[-1].gt_masks = BitMasks(torch.stack([torch.from_numpy(np.ascontiguousarray(x)) for x in all_segmentations[split_index::patches_num]]))
+        
+    # pdb.set_trace()
+    if return_indexes:
+        return new_targets, crop_axises, annos[0]["crop_indexes"]
+    else:
+        return new_targets, crop_axises
+
+class EntityCascadedCrop(Augmentation):
+    def __init__(self, crop_ratio, stride_ratio, sample_num, cascade_num, is_train):
+        super().__init__()
+        self._init(locals())
+
+    def get_transform(self, image):
+        h, w = image.shape[:2]
+        crop_axises, crop_indexes = self.get_crop_axises((h, w))
+        transform = EntityCropTransform(crop_axises, crop_indexes)
+        return transform
+    
+    def get_crop_axises(self, image_size):
+        h, w = image_size
+        # for i in range(self.cascade_num):
+        #     crop_w = int((self.crop_ratio**(i+1))*w)
+        #     crop_h = int((self.crop_ratio**(i+1))*h)
+        #     stride_w = int((self.stride_ratio**(i+1))*w)
+        #     stride_h = int((self.stride_ratio**(i+1))*h)
+        #     crop_axises = []
+        #     if i==0:
+
+        #     for starth in range(0, )
+
+
+        crop_axises  = []
+        for starth in range(0, h, stride_h):
+            for startw in range(0, w, stride_w):
+                endh = min(starth+crop_h, h)
+                endw = min(startw+crop_w, w)
+                starth = int(endh-crop_h)
+                startw = int(endw-crop_w)
+                crop_axises.append([startw, starth, endw, endh])
+        if self.is_train:
+            crop_indexes = random.sample([i for i in range(len(crop_axises))], self.sample_num)
+            crop_axises = [crop_axises[i] for i in crop_indexes]
+        else:
+            crop_indexes = [i for i in range(self.sample_num)]
+        # left_upper   = [0, 0, crop_w, crop_h]
+        # right_upper  = [w-crop_w, 0, w, crop_h]
+        # left_bottom  = [0, h-crop_h, crop_w, h]
+        # right_bottom = [w-crop_w, h-crop_h, w, h]
+        
+        # crop_axises = [left_upper, right_upper, left_bottom, right_bottom]
+        # crop_indexes = [0,1,2,3]
+        assert len(crop_axises)==len(crop_indexes)
+        return crop_axises, crop_indexes
\ No newline at end of file
diff --git a/annotator/entityseg/mask2former/maskformer_model.py b/annotator/entityseg/mask2former/maskformer_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ebfc932ab31e959e338b9f69b34a3bce027980f
--- /dev/null
+++ b/annotator/entityseg/mask2former/maskformer_model.py
@@ -0,0 +1,446 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+from typing import Tuple
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from detectron2.config import configurable
+from detectron2.data import MetadataCatalog
+from detectron2.modeling import META_ARCH_REGISTRY, build_backbone, build_sem_seg_head
+from detectron2.modeling.backbone import Backbone
+from detectron2.modeling.postprocessing import sem_seg_postprocess
+from detectron2.structures import Boxes, ImageList, Instances, BitMasks
+from detectron2.utils.memory import retry_if_cuda_oom
+
+from .modeling.criterion import SetCriterion
+from .modeling.matcher import HungarianMatcher
+
+
+@META_ARCH_REGISTRY.register()
+class MaskFormer(nn.Module):
+    """
+    Main class for mask classification semantic segmentation architectures.
+    """
+
+    @configurable
+    def __init__(
+        self,
+        *,
+        cfg,
+        backbone: Backbone,
+        sem_seg_head: nn.Module,
+        criterion: nn.Module,
+        num_queries: int,
+        object_mask_threshold: float,
+        overlap_threshold: float,
+        metadata,
+        size_divisibility: int,
+        sem_seg_postprocess_before_inference: bool,
+        pixel_mean: Tuple[float],
+        pixel_std: Tuple[float],
+        # inference
+        semantic_on: bool,
+        panoptic_on: bool,
+        instance_on: bool,
+        test_topk_per_image: int,
+    ):
+        """
+        Args:
+            backbone: a backbone module, must follow detectron2's backbone interface
+            sem_seg_head: a module that predicts semantic segmentation from backbone features
+            criterion: a module that defines the loss
+            num_queries: int, number of queries
+            object_mask_threshold: float, threshold to filter query based on classification score
+                for panoptic segmentation inference
+            overlap_threshold: overlap threshold used in general inference for panoptic segmentation
+            metadata: dataset meta, get `thing` and `stuff` category names for panoptic
+                segmentation inference
+            size_divisibility: Some backbones require the input height and width to be divisible by a
+                specific integer. We can use this to override such requirement.
+            sem_seg_postprocess_before_inference: whether to resize the prediction back
+                to original input size before semantic segmentation inference or after.
+                For high-resolution dataset like Mapillary, resizing predictions before
+                inference will cause OOM error.
+            pixel_mean, pixel_std: list or tuple with #channels element, representing
+                the per-channel mean and std to be used to normalize the input image
+            semantic_on: bool, whether to output semantic segmentation prediction
+            instance_on: bool, whether to output instance segmentation prediction
+            panoptic_on: bool, whether to output panoptic segmentation prediction
+            test_topk_per_image: int, instance segmentation parameter, keep topk instances per image
+        """
+        super().__init__()
+        self.cfg = cfg
+        self.backbone = backbone
+        self.sem_seg_head = sem_seg_head
+        self.criterion = criterion
+        self.num_queries = num_queries
+        self.overlap_threshold = overlap_threshold
+        self.entity_enable = self.cfg.ENTITY.ENABLE
+        self.object_mask_threshold = object_mask_threshold
+        self.metadata = metadata
+        if size_divisibility < 0:
+            # use backbone size_divisibility if not set
+            size_divisibility = self.backbone.size_divisibility
+        self.size_divisibility = size_divisibility
+        self.sem_seg_postprocess_before_inference = sem_seg_postprocess_before_inference
+        self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
+        self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
+
+        # additional args
+        self.semantic_on = semantic_on
+        self.instance_on = instance_on
+        self.panoptic_on = panoptic_on
+        self.test_topk_per_image = test_topk_per_image
+
+        if not self.semantic_on:
+            assert self.sem_seg_postprocess_before_inference
+
+    @classmethod
+    def from_config(cls, cfg):
+        backbone = build_backbone(cfg)
+        sem_seg_head = build_sem_seg_head(cfg, backbone.output_shape())
+
+        # Loss parameters:
+        deep_supervision = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION
+        no_object_weight = cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT
+
+        # loss weights
+        class_weight = cfg.MODEL.MASK_FORMER.CLASS_WEIGHT
+        dice_weight = cfg.MODEL.MASK_FORMER.DICE_WEIGHT
+        mask_weight = cfg.MODEL.MASK_FORMER.MASK_WEIGHT
+
+        # building criterion
+        matcher = HungarianMatcher(
+            cost_class=class_weight,
+            cost_mask=mask_weight,
+            cost_dice=dice_weight,
+            num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS,
+        )
+
+        weight_dict = {"loss_ce": class_weight, "loss_mask": mask_weight, "loss_dice": dice_weight}
+
+        if deep_supervision:
+            dec_layers = cfg.MODEL.MASK_FORMER.DEC_LAYERS
+            aux_weight_dict = {}
+            for i in range(dec_layers - 1):
+                aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
+            weight_dict.update(aux_weight_dict)
+
+        losses = ["labels", "masks"]
+
+        criterion = SetCriterion(
+            sem_seg_head.num_classes,
+            matcher=matcher,
+            weight_dict=weight_dict,
+            eos_coef=no_object_weight,
+            losses=losses,
+            num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS,
+            oversample_ratio=cfg.MODEL.MASK_FORMER.OVERSAMPLE_RATIO,
+            importance_sample_ratio=cfg.MODEL.MASK_FORMER.IMPORTANCE_SAMPLE_RATIO,
+        )
+
+        return {
+            "cfg": cfg,
+            "backbone": backbone,
+            "sem_seg_head": sem_seg_head,
+            "criterion": criterion,
+            "num_queries": cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES,
+            "object_mask_threshold": cfg.MODEL.MASK_FORMER.TEST.OBJECT_MASK_THRESHOLD,
+            "overlap_threshold": cfg.MODEL.MASK_FORMER.TEST.OVERLAP_THRESHOLD,
+            "metadata": MetadataCatalog.get(cfg.DATASETS.TRAIN[0]),
+            "size_divisibility": cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY,
+            "sem_seg_postprocess_before_inference": (
+                cfg.MODEL.MASK_FORMER.TEST.SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE
+                or cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON
+                or cfg.MODEL.MASK_FORMER.TEST.INSTANCE_ON
+            ),
+            "pixel_mean": cfg.MODEL.PIXEL_MEAN,
+            "pixel_std": cfg.MODEL.PIXEL_STD,
+            # inference
+            "semantic_on": cfg.MODEL.MASK_FORMER.TEST.SEMANTIC_ON,
+            "instance_on": cfg.MODEL.MASK_FORMER.TEST.INSTANCE_ON,
+            "panoptic_on": cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON,
+            "test_topk_per_image": cfg.TEST.DETECTIONS_PER_IMAGE,
+        }
+
+    @property
+    def device(self):
+        return self.pixel_mean.device
+
+    def forward(self, batched_inputs):
+        """
+        Args:
+            batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
+                Each item in the list contains the inputs for one image.
+                For now, each item in the list is a dict that contains:
+                   * "image": Tensor, image in (C, H, W) format.
+                   * "instances": per-region ground truth
+                   * Other information that's included in the original dicts, such as:
+                     "height", "width" (int): the output resolution of the model (may be different
+                     from input resolution), used in inference.
+        Returns:
+            list[dict]:
+                each dict has the results for one image. The dict contains the following keys:
+
+                * "sem_seg":
+                    A Tensor that represents the
+                    per-pixel segmentation prediced by the head.
+                    The prediction has shape KxHxW that represents the logits of
+                    each class for each pixel.
+                * "panoptic_seg":
+                    A tuple that represent panoptic output
+                    panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment.
+                    segments_info (list[dict]): Describe each segment in `panoptic_seg`.
+                        Each dict contains keys "id", "category_id", "isthing".
+        """
+        images = [x["image"].to(self.device) for x in batched_inputs]
+        images = [(x - self.pixel_mean) / self.pixel_std for x in images]
+        images = ImageList.from_tensors(images, self.size_divisibility)
+
+        features = self.backbone(images.tensor)
+        outputs = self.sem_seg_head(features)
+
+        if self.training:
+            # mask classification target
+            if "instances" in batched_inputs[0]:
+                if self.cfg.ENTITY.ENABLE:
+                    for i in range(len(batched_inputs)):
+                        batched_inputs[i]["instances"].gt_classes[:] = 0
+                gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
+                targets = self.prepare_targets(gt_instances, images)
+            else:
+                targets = None
+
+            # bipartite matching-based loss
+            losses = self.criterion(outputs, targets)
+
+            for k in list(losses.keys()):
+                if k in self.criterion.weight_dict:
+                    losses[k] *= self.criterion.weight_dict[k]
+                else:
+                    # remove this loss if not specified in `weight_dict`
+                    losses.pop(k)
+            return losses
+        else:
+            mask_cls_results = outputs["pred_logits"]
+            mask_pred_results = outputs["pred_masks"]
+            # upsample masks
+            mask_pred_results = F.interpolate(
+                mask_pred_results,
+                size=(images.tensor.shape[-2], images.tensor.shape[-1]),
+                mode="bilinear",
+                align_corners=False,
+            )
+
+            del outputs
+
+            processed_results = []
+            for mask_cls_result, mask_pred_result, input_per_image, image_size in zip(
+                mask_cls_results, mask_pred_results, batched_inputs, images.image_sizes
+            ):
+                height = input_per_image.get("height", image_size[0])
+                width = input_per_image.get("width", image_size[1])
+                processed_results.append({})
+
+                if self.sem_seg_postprocess_before_inference:
+                    mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
+                        mask_pred_result, image_size, height, width
+                    )
+                    mask_cls_result = mask_cls_result.to(mask_pred_result)
+
+                # semantic segmentation inference
+                if self.semantic_on:
+                    r = retry_if_cuda_oom(self.semantic_inference)(mask_cls_result, mask_pred_result)
+                    if not self.sem_seg_postprocess_before_inference:
+                        r = retry_if_cuda_oom(sem_seg_postprocess)(r, image_size, height, width)
+                    processed_results[-1]["sem_seg"] = r
+
+                # panoptic segmentation inference
+                if self.panoptic_on:
+                    panoptic_r = retry_if_cuda_oom(self.panoptic_inference)(mask_cls_result, mask_pred_result)
+                    processed_results[-1]["panoptic_seg"] = panoptic_r
+                
+                # instance segmentation and entity segmentation inference
+                if self.instance_on and self.cfg.ENTITY.ENABLE:
+                    instance_r = retry_if_cuda_oom(self.instance_inference_nonoverlap)(mask_cls_result, mask_pred_result)
+                    processed_results[-1]["instances"] = instance_r
+                else:
+                    instance_r = retry_if_cuda_oom(self.instance_inference)(mask_cls_result, mask_pred_result)
+                    processed_results[-1]["instances"] = instance_r
+
+            return processed_results
+
+    def prepare_targets(self, targets, images):
+        h_pad, w_pad = images.tensor.shape[-2:]
+        new_targets = []
+        for targets_per_image in targets:
+            # pad gt
+            gt_masks = targets_per_image.gt_masks
+            padded_masks = torch.zeros((gt_masks.shape[0], h_pad, w_pad), dtype=gt_masks.dtype, device=gt_masks.device)
+            padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks
+            new_targets.append(
+                {
+                    "labels": targets_per_image.gt_classes,
+                    "masks": padded_masks,
+                }
+            )
+        return new_targets
+
+    def semantic_inference(self, mask_cls, mask_pred):
+        mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1]
+        mask_pred = mask_pred.sigmoid()
+        semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred)
+        return semseg
+
+    def panoptic_inference(self, mask_cls, mask_pred):
+        scores, labels = F.softmax(mask_cls, dim=-1).max(-1)
+        mask_pred = mask_pred.sigmoid()
+
+        keep = labels.ne(self.sem_seg_head.num_classes) & (scores > self.object_mask_threshold)
+        cur_scores = scores[keep]
+        cur_classes = labels[keep]
+        cur_masks = mask_pred[keep]
+        cur_mask_cls = mask_cls[keep]
+        cur_mask_cls = cur_mask_cls[:, :-1]
+
+        cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks
+
+        h, w = cur_masks.shape[-2:]
+        panoptic_seg = torch.zeros((h, w), dtype=torch.int32, device=cur_masks.device)
+        segments_info = []
+
+        current_segment_id = 0
+
+        if cur_masks.shape[0] == 0:
+            # We didn't detect any mask :(
+            return panoptic_seg, segments_info
+        else:
+            # take argmax
+            cur_mask_ids = cur_prob_masks.argmax(0)
+            stuff_memory_list = {}
+            for k in range(cur_classes.shape[0]):
+                pred_class = cur_classes[k].item()
+                isthing = pred_class in self.metadata.thing_dataset_id_to_contiguous_id.values()
+                mask_area = (cur_mask_ids == k).sum().item()
+                original_area = (cur_masks[k] >= 0.5).sum().item()
+                mask = (cur_mask_ids == k) & (cur_masks[k] >= 0.5)
+
+                if mask_area > 0 and original_area > 0 and mask.sum().item() > 0:
+                    if mask_area / original_area < self.overlap_threshold:
+                        continue
+
+                    # merge stuff regions
+                    if not isthing:
+                        if int(pred_class) in stuff_memory_list.keys():
+                            panoptic_seg[mask] = stuff_memory_list[int(pred_class)]
+                            continue
+                        else:
+                            stuff_memory_list[int(pred_class)] = current_segment_id + 1
+
+                    current_segment_id += 1
+                    panoptic_seg[mask] = current_segment_id
+
+                    segments_info.append(
+                        {
+                            "id": current_segment_id,
+                            "isthing": bool(isthing),
+                            "category_id": int(pred_class),
+                        }
+                    )
+
+            return panoptic_seg, segments_info
+
+    def instance_inference(self, mask_cls, mask_pred):
+        # mask_pred is already processed to have the same shape as original input
+        image_size = mask_pred.shape[-2:]
+
+        # [Q, K]
+        scores = F.softmax(mask_cls, dim=-1)[:, :-1]
+        labels = torch.arange(self.sem_seg_head.num_classes, device=self.device).unsqueeze(0).repeat(self.num_queries, 1).flatten(0, 1)
+        # scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.num_queries, sorted=False)
+        scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.test_topk_per_image, sorted=False)
+        labels_per_image = labels[topk_indices]
+
+        # topk_indices = topk_indices // self.sem_seg_head.num_classes
+        topk_indices = torch.div(topk_indices, self.sem_seg_head.num_classes, rounding_mode='trunc')
+        # mask_pred = mask_pred.unsqueeze(1).repeat(1, self.sem_seg_head.num_classes, 1).flatten(0, 1)
+        mask_pred = mask_pred[topk_indices]
+
+        # if this is panoptic segmentation, we only keep the "thing" classes
+        if self.panoptic_on:
+            keep = torch.zeros_like(scores_per_image).bool()
+            for i, lab in enumerate(labels_per_image):
+                keep[i] = lab in self.metadata.thing_dataset_id_to_contiguous_id.values()
+
+            scores_per_image = scores_per_image[keep]
+            labels_per_image = labels_per_image[keep]
+            mask_pred = mask_pred[keep]
+
+        result = Instances(image_size)
+        # mask (before sigmoid)
+        result.pred_masks = (mask_pred > 0).float()
+        result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4))
+        # Uncomment the following to get boxes from masks (this is slow)
+        # result.pred_boxes = BitMasks(mask_pred > 0).get_bounding_boxes()
+
+        # calculate average mask prob
+        mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / (result.pred_masks.flatten(1).sum(1) + 1e-6)
+        result.scores = scores_per_image * mask_scores_per_image
+        result.pred_classes = labels_per_image
+        return result
+    
+    def instance_inference_nonoverlap(self, mask_cls, mask_pred):
+        # mask_pred is already processed to have the same shape as original input
+        image_size = mask_pred.shape[-2:]
+
+        # [Q, K]
+        scores = F.softmax(mask_cls, dim=-1)[:, :-1]
+        labels = torch.arange(self.sem_seg_head.num_classes, device=self.device).unsqueeze(0).repeat(self.num_queries, 1).flatten(0, 1)
+        # scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.num_queries, sorted=False)
+        scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.test_topk_per_image, sorted=False)
+        labels_per_image = labels[topk_indices]
+
+        # topk_indices = topk_indices // self.sem_seg_head.num_classes
+        topk_indices = torch.div(topk_indices, self.sem_seg_head.num_classes, rounding_mode='trunc')
+        # mask_pred = mask_pred.unsqueeze(1).repeat(1, self.sem_seg_head.num_classes, 1).flatten(0, 1)
+        mask_pred = mask_pred[topk_indices]
+
+        ###### ranks
+        pred_masks = (mask_pred>0).float()
+        pred_masks_logits = mask_pred.sigmoid()
+        pred_scores = scores_per_image
+
+        _, m_H, m_W = pred_masks.shape
+        mask_id = torch.zeros((m_H, m_W), dtype=torch.int).to(pred_masks.device)
+        sorted_scores, ranks = torch.sort(pred_scores)
+        ranks = ranks + 1
+        for index in ranks:
+            mask_id[(pred_masks[index-1]==1)] = int(index)
+        # re-generate mask
+        new_scores = []
+        new_masks  = []
+        new_masks_logits = []
+        entity_nums = len(ranks)
+        for ii in range(entity_nums):
+            index = int(ranks[entity_nums-ii-1])
+            score = sorted_scores[entity_nums-ii-1]
+            new_scores.append(score)
+            new_masks.append((mask_id==index).float())
+            new_masks_logits.append(pred_masks_logits[index-1])
+        
+        new_scores = torch.stack(new_scores)
+        new_masks  = torch.stack(new_masks)
+        new_masks_logits = torch.stack(new_masks_logits)
+
+        result = Instances(image_size)
+        # mask (before sigmoid)
+        result.pred_masks = new_masks
+        result.pred_boxes = Boxes(torch.zeros(new_masks.size(0), 4))
+        # Uncomment the following to get boxes from masks (this is slow)
+
+        # calculate average mask prob
+        mask_scores_per_image = (new_masks_logits.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / (result.pred_masks.flatten(1).sum(1) + 1e-6)
+        result.scores = new_scores * mask_scores_per_image
+        result.pred_classes = labels_per_image
+        return result
diff --git a/annotator/entityseg/mask2former/modeling/__init__.py b/annotator/entityseg/mask2former/modeling/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f41eb32467aa8a6cf0209396938547b306db8ef5
--- /dev/null
+++ b/annotator/entityseg/mask2former/modeling/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+from .backbone.swin import D2SwinTransformer
+from .backbone.hornet import D2HorNet
+from .pixel_decoder.fpn import BasePixelDecoder
+from .pixel_decoder.msdeformattn import MSDeformAttnPixelDecoder
+from .meta_arch.mask_former_head import MaskFormerHead
+from .meta_arch.per_pixel_baseline import PerPixelBaselineHead, PerPixelBaselinePlusHead
diff --git a/annotator/entityseg/mask2former/modeling/backbone/__init__.py b/annotator/entityseg/mask2former/modeling/backbone/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9020c2df23e2af280b7bb168b996ae9eaf312eb8
--- /dev/null
+++ b/annotator/entityseg/mask2former/modeling/backbone/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
diff --git a/annotator/entityseg/mask2former/modeling/backbone/hornet.py b/annotator/entityseg/mask2former/modeling/backbone/hornet.py
new file mode 100644
index 0000000000000000000000000000000000000000..7762de87c457795d83a2ac020bc862b8c86d87c7
--- /dev/null
+++ b/annotator/entityseg/mask2former/modeling/backbone/hornet.py
@@ -0,0 +1,363 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from functools import partial
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from timm.models.layers import trunc_normal_, DropPath
+from timm.models.registry import register_model
+import os
+import sys
+import torch.fft
+import math
+
+import traceback
+import torch.utils.checkpoint as checkpoint
+from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec
+
+
+if 'DWCONV_IMPL' in os.environ:
+    try:
+        sys.path.append(os.environ['DWCONV_IMPL'])
+        from depthwise_conv2d_implicit_gemm import DepthWiseConv2dImplicitGEMM
+        def get_dwconv(dim, kernel, bias):
+            return DepthWiseConv2dImplicitGEMM(dim, kernel, bias)
+        print('Using Megvii large kernel dw conv impl')
+    except:
+        print(traceback.format_exc())
+        def get_dwconv(dim, kernel, bias):
+            return nn.Conv2d(dim, dim, kernel_size=kernel, padding=(kernel-1)//2 ,bias=bias, groups=dim)
+
+        print('[fail to use Megvii Large kernel] Using PyTorch large kernel dw conv impl')
+else:
+    def get_dwconv(dim, kernel, bias):
+            return nn.Conv2d(dim, dim, kernel_size=kernel, padding=(kernel-1)//2 ,bias=bias, groups=dim)
+
+    print('Using PyTorch large kernel dw conv impl')
+
+class GlobalLocalFilter(nn.Module):
+    def __init__(self, dim, h=14, w=8):
+        super().__init__()
+        self.dw = nn.Conv2d(dim // 2, dim // 2, kernel_size=3, padding=1, bias=False, groups=dim // 2)
+        self.complex_weight = nn.Parameter(torch.randn(dim // 2, h, w, 2, dtype=torch.float32) * 0.02)
+        trunc_normal_(self.complex_weight, std=.02)
+        self.pre_norm = LayerNorm(dim, eps=1e-6, data_format='channels_first')
+        self.post_norm = LayerNorm(dim, eps=1e-6, data_format='channels_first')
+
+    def forward(self, x):
+        x = self.pre_norm(x)
+        x1, x2 = torch.chunk(x, 2, dim=1)
+        x1 = self.dw(x1)
+
+        x2 = x2.to(torch.float32)
+        B, C, a, b = x2.shape
+        x2 = torch.fft.rfft2(x2, dim=(2, 3), norm='ortho')
+
+        weight = self.complex_weight
+        if not weight.shape[1:3] == x2.shape[2:4]:
+            weight = F.interpolate(weight.permute(3,0,1,2), size=x2.shape[2:4], mode='bilinear', align_corners=True).permute(1,2,3,0)
+
+        weight = torch.view_as_complex(weight.contiguous())
+
+        x2 = x2 * weight
+        x2 = torch.fft.irfft2(x2, s=(a, b), dim=(2, 3), norm='ortho')
+
+        x = torch.cat([x1.unsqueeze(2), x2.unsqueeze(2)], dim=2).reshape(B, 2 * C, a, b)
+        x = self.post_norm(x)
+        return x
+
+
+class gnconv(nn.Module):
+    def __init__(self, dim, order=5, gflayer=None, h=14, w=8, s=1.0):
+        super().__init__()
+        self.order = order
+        self.dims = [dim // 2 ** i for i in range(order)]
+        self.dims.reverse()
+        self.proj_in = nn.Conv2d(dim, 2*dim, 1)
+
+        if gflayer is None:
+            self.dwconv = get_dwconv(sum(self.dims), 7, True)
+        else:
+            self.dwconv = gflayer(sum(self.dims), h=h, w=w)
+        
+        self.proj_out = nn.Conv2d(dim, dim, 1)
+
+        self.pws = nn.ModuleList(
+            [nn.Conv2d(self.dims[i], self.dims[i+1], 1) for i in range(order-1)]
+        )
+
+        self.scale = s
+
+        print('[gconv]', order, 'order with dims=', self.dims, 'scale=%.4f'%self.scale)
+
+
+    def forward(self, x, mask=None, dummy=False):
+        B, C, H, W = x.shape
+
+        fused_x = self.proj_in(x)
+        pwa, abc = torch.split(fused_x, (self.dims[0], sum(self.dims)), dim=1)
+
+        dw_abc = self.dwconv(abc) * self.scale
+
+        dw_list = torch.split(dw_abc, self.dims, dim=1)
+        x = pwa * dw_list[0]
+
+        for i in range(self.order -1):
+            x = self.pws[i](x) * dw_list[i+1]
+
+        x = self.proj_out(x)
+
+        return x
+
+class Block(nn.Module):
+    r""" HorNet block
+    """
+    def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6, gnconv=gnconv):
+        super().__init__()
+
+        self.norm1 = LayerNorm(dim, eps=1e-6, data_format='channels_first')
+        self.gnconv = gnconv(dim) # depthwise conv
+        self.norm2 = LayerNorm(dim, eps=1e-6)
+        self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
+        self.act = nn.GELU()
+        self.pwconv2 = nn.Linear(4 * dim, dim)
+
+        self.gamma1 = nn.Parameter(layer_scale_init_value * torch.ones(dim), 
+                                    requires_grad=True) if layer_scale_init_value > 0 else None
+
+        self.gamma2 = nn.Parameter(layer_scale_init_value * torch.ones((dim)), 
+                                    requires_grad=True) if layer_scale_init_value > 0 else None
+        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+    def forward(self, x):
+        B, C, H, W  = x.shape
+        if self.gamma1 is not None:
+            gamma1 = self.gamma1.view(C, 1, 1)
+        else:
+            gamma1 = 1
+        x = x + self.drop_path(gamma1 * self.gnconv(self.norm1(x)))
+
+        input = x
+        x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
+        x = self.norm2(x)
+        x = self.pwconv1(x)
+        x = self.act(x)
+        x = self.pwconv2(x)
+        if self.gamma2 is not None:
+            x = self.gamma2 * x
+        x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
+
+        x = input + self.drop_path(x)
+        return x
+
+
+class HorNet(nn.Module):
+    r""" HorNet
+        A PyTorch impl of : `HorNet: Efficient High-Order Spatial Interactions with Recursive Gated Convolutions`
+
+    Args:
+        in_chans (int): Number of input image channels. Default: 3
+        num_classes (int): Number of classes for classification head. Default: 1000
+        depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
+        dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
+        drop_path_rate (float): Stochastic depth rate. Default: 0.
+        layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
+        head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
+    """
+    def __init__(self, in_chans=3, num_classes=1000, 
+                 depths=[3, 3, 9, 3], base_dim=96, drop_path_rate=0.,
+                 layer_scale_init_value=1e-6, head_init_scale=1.,
+                 gnconv=gnconv, block=Block,
+                 pretrained=None,
+                 use_checkpoint=False,
+                 ):
+        super().__init__()
+
+        self.pretrained = pretrained
+        self.use_checkpoint = use_checkpoint
+
+        dims = [base_dim, base_dim*2, base_dim*4, base_dim*8]
+
+        self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
+        stem = nn.Sequential(
+            nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
+            LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
+        )
+        self.downsample_layers.append(stem)
+        for i in range(3):
+            downsample_layer = nn.Sequential(
+                    LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
+                    nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
+            )
+            self.downsample_layers.append(downsample_layer)
+
+        self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
+        dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 
+
+
+        if not isinstance(gnconv, list):
+            gnconv = [gnconv, gnconv, gnconv, gnconv]
+        else:
+            gnconv = gnconv
+            assert len(gnconv) == 4
+
+        if isinstance(gnconv[0], str):
+            print('[GConvNet]: convert str gconv to func')
+            gnconv = [eval(g) for g in gnconv]
+
+        if isinstance(block, str):
+            block = eval(block)
+
+        cur = 0
+        num_features = []
+        for i in range(4):
+            stage = nn.Sequential(
+                *[block(dim=dims[i], drop_path=dp_rates[cur + j], 
+                layer_scale_init_value=layer_scale_init_value, gnconv=gnconv[i]) for j in range(depths[i])]
+            )
+            self.stages.append(stage)
+            cur += depths[i]
+            num_features.append(dims[i])
+        self.num_features = num_features
+
+        norm_layer = partial(LayerNorm, eps=1e-6, data_format="channels_first")
+        for i_layer in range(4):
+            layer = norm_layer(dims[i_layer])
+            layer_name = f'norm{i_layer}'
+            self.add_module(layer_name, layer)
+
+    def init_weights(self):
+        """Initialize the weights in backbone.
+        Args:
+            pretrained (str, optional): Path to pre-trained weights.
+                Defaults to None.
+        """
+        #pretrained = self.pretrained
+
+        def _init_weights(m):
+            if isinstance(m, nn.Linear):
+                trunc_normal_(m.weight, std=.02)
+                if isinstance(m, nn.Linear) and m.bias is not None:
+                    nn.init.constant_(m.bias, 0)
+            elif isinstance(m, nn.LayerNorm):
+                nn.init.constant_(m.bias, 0)
+                nn.init.constant_(m.weight, 1.0)
+
+        #if isinstance(pretrained, str):
+        #    self.apply(_init_weights)
+        #    logger = get_root_logger()
+        #    load_checkpoint(self, pretrained, strict=False, logger=logger)
+        #elif pretrained is None:
+        #    raise NotImplementedError()
+        self.apply(_init_weights)
+        #else:
+        #    raise TypeError('pretrained must be a str or None')
+
+    def forward_features(self, x):
+        outs = dict()
+        for i in range(4):
+            x = self.downsample_layers[i](x)
+            if self.use_checkpoint:
+                x = checkpoint.checkpoint_sequential(self.stages[i], len(self.stages[i]), x)
+            else:
+                x = self.stages[i](x)
+            norm_layer = getattr(self, f'norm{i}')
+            x_out = norm_layer(x)
+            outs["res%i"% (i+2)] = x_out
+        return outs #tuple(outs)
+
+    def forward(self, x):
+        x = self.forward_features(x)
+        return x
+
+
+class LayerNorm(nn.Module):
+    r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. 
+    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 
+    shape (batch_size, height, width, channels) while channels_first corresponds to inputs 
+    with shape (batch_size, channels, height, width).
+    """
+    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
+        super().__init__()
+        self.weight = nn.Parameter(torch.ones(normalized_shape))
+        self.bias = nn.Parameter(torch.zeros(normalized_shape))
+        self.eps = eps
+        self.data_format = data_format
+        if self.data_format not in ["channels_last", "channels_first"]:
+            raise NotImplementedError 
+        self.normalized_shape = (normalized_shape, )
+    
+    def forward(self, x):
+        if self.data_format == "channels_last":
+            return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
+        elif self.data_format == "channels_first":
+            u = x.mean(1, keepdim=True)
+            s = (x - u).pow(2).mean(1, keepdim=True)
+            x = (x - u) / torch.sqrt(s + self.eps)
+            x = self.weight[:, None, None] * x + self.bias[:, None, None]
+            return x
+        
+@BACKBONE_REGISTRY.register()
+class D2HorNet(HorNet, Backbone):
+    def __init__(self, cfg, input_shape):
+
+        depths=cfg.MODEL.HORNET.DEPTHS
+        base_dim=cfg.MODEL.HORNET.BASE_DIM
+        gnconv=cfg.MODEL.HORNET.GCONV
+        drop_path_rate=cfg.MODEL.HORNET.DROP_PATH_RATE
+
+        super().__init__(
+            depths=depths,
+            base_dim=base_dim,
+            gnconv=gnconv,
+            drop_path_rate=drop_path_rate,
+        )
+
+        self._out_features = cfg.MODEL.HORNET.OUT_FEATURES
+
+        self._out_feature_strides = {
+            "res2": 4,
+            "res3": 8,
+            "res4": 16,
+            "res5": 32,
+        }
+        self._out_feature_channels = {
+            "res2": self.num_features[0],
+            "res3": self.num_features[1],
+            "res4": self.num_features[2],
+            "res5": self.num_features[3],
+        }
+
+    def forward(self, x):
+        """
+        Args:
+            x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
+        Returns:
+            dict[str->Tensor]: names and the corresponding features
+        """
+        assert (
+            x.dim() == 4
+        ), f"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!"
+        outputs = {}
+        y = super().forward(x)
+        for k in y.keys():
+            if k in self._out_features:
+                outputs[k] = y[k]
+        return outputs
+
+    def output_shape(self):
+        return {
+            name: ShapeSpec(
+                channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
+            )
+            for name in self._out_features
+        }
+
+    @property
+    def size_divisibility(self):
+        return 32
\ No newline at end of file
diff --git a/annotator/entityseg/mask2former/modeling/backbone/swin.py b/annotator/entityseg/mask2former/modeling/backbone/swin.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b099d84396ac31d22881e5b6c9e53d2d0abaef3
--- /dev/null
+++ b/annotator/entityseg/mask2former/modeling/backbone/swin.py
@@ -0,0 +1,770 @@
+# --------------------------------------------------------
+# Swin Transformer
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Ze Liu, Yutong Lin, Yixuan Wei
+# --------------------------------------------------------
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Modified by Bowen Cheng from https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation/blob/main/mmseg/models/backbones/swin_transformer.py
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+from timm.models.layers import DropPath, to_2tuple, trunc_normal_
+
+from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec
+
+
+class Mlp(nn.Module):
+    """Multilayer perceptron."""
+
+    def __init__(
+        self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0
+    ):
+        super().__init__()
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+        self.fc1 = nn.Linear(in_features, hidden_features)
+        self.act = act_layer()
+        self.fc2 = nn.Linear(hidden_features, out_features)
+        self.drop = nn.Dropout(drop)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.act(x)
+        x = self.drop(x)
+        x = self.fc2(x)
+        x = self.drop(x)
+        return x
+
+
+def window_partition(x, window_size):
+    """
+    Args:
+        x: (B, H, W, C)
+        window_size (int): window size
+    Returns:
+        windows: (num_windows*B, window_size, window_size, C)
+    """
+    B, H, W, C = x.shape
+    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+    return windows
+
+
+def window_reverse(windows, window_size, H, W):
+    """
+    Args:
+        windows: (num_windows*B, window_size, window_size, C)
+        window_size (int): Window size
+        H (int): Height of image
+        W (int): Width of image
+    Returns:
+        x: (B, H, W, C)
+    """
+    B = int(windows.shape[0] / (H * W / window_size / window_size))
+    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
+    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+    return x
+
+
+class WindowAttention(nn.Module):
+    """Window based multi-head self attention (W-MSA) module with relative position bias.
+    It supports both of shifted and non-shifted window.
+    Args:
+        dim (int): Number of input channels.
+        window_size (tuple[int]): The height and width of the window.
+        num_heads (int): Number of attention heads.
+        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
+        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
+        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+    """
+
+    def __init__(
+        self,
+        dim,
+        window_size,
+        num_heads,
+        qkv_bias=True,
+        qk_scale=None,
+        attn_drop=0.0,
+        proj_drop=0.0,
+    ):
+
+        super().__init__()
+        self.dim = dim
+        self.window_size = window_size  # Wh, Ww
+        self.num_heads = num_heads
+        head_dim = dim // num_heads
+        self.scale = qk_scale or head_dim ** -0.5
+
+        # define a parameter table of relative position bias
+        self.relative_position_bias_table = nn.Parameter(
+            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
+        )  # 2*Wh-1 * 2*Ww-1, nH
+
+        # get pair-wise relative position index for each token inside the window
+        coords_h = torch.arange(self.window_size[0])
+        coords_w = torch.arange(self.window_size[1])
+        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
+        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
+        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
+        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
+        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
+        relative_coords[:, :, 1] += self.window_size[1] - 1
+        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
+        self.register_buffer("relative_position_index", relative_position_index)
+
+        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+        self.attn_drop = nn.Dropout(attn_drop)
+        self.proj = nn.Linear(dim, dim)
+        self.proj_drop = nn.Dropout(proj_drop)
+
+        trunc_normal_(self.relative_position_bias_table, std=0.02)
+        self.softmax = nn.Softmax(dim=-1)
+
+    def forward(self, x, mask=None):
+        """Forward function.
+        Args:
+            x: input features with shape of (num_windows*B, N, C)
+            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
+        """
+        B_, N, C = x.shape
+        qkv = (
+            self.qkv(x)
+            .reshape(B_, N, 3, self.num_heads, C // self.num_heads)
+            .permute(2, 0, 3, 1, 4)
+        )
+        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)
+
+        q = q * self.scale
+        attn = q @ k.transpose(-2, -1)
+
+        relative_position_bias = self.relative_position_bias_table[
+            self.relative_position_index.view(-1)
+        ].view(
+            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
+        )  # Wh*Ww,Wh*Ww,nH
+        relative_position_bias = relative_position_bias.permute(
+            2, 0, 1
+        ).contiguous()  # nH, Wh*Ww, Wh*Ww
+        attn = attn + relative_position_bias.unsqueeze(0)
+
+        if mask is not None:
+            nW = mask.shape[0]
+            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
+            attn = attn.view(-1, self.num_heads, N, N)
+            attn = self.softmax(attn)
+        else:
+            attn = self.softmax(attn)
+
+        attn = self.attn_drop(attn)
+
+        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
+        x = self.proj(x)
+        x = self.proj_drop(x)
+        return x
+
+
+class SwinTransformerBlock(nn.Module):
+    """Swin Transformer Block.
+    Args:
+        dim (int): Number of input channels.
+        num_heads (int): Number of attention heads.
+        window_size (int): Window size.
+        shift_size (int): Shift size for SW-MSA.
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+        drop (float, optional): Dropout rate. Default: 0.0
+        attn_drop (float, optional): Attention dropout rate. Default: 0.0
+        drop_path (float, optional): Stochastic depth rate. Default: 0.0
+        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
+    """
+
+    def __init__(
+        self,
+        dim,
+        num_heads,
+        window_size=7,
+        shift_size=0,
+        mlp_ratio=4.0,
+        qkv_bias=True,
+        qk_scale=None,
+        drop=0.0,
+        attn_drop=0.0,
+        drop_path=0.0,
+        act_layer=nn.GELU,
+        norm_layer=nn.LayerNorm,
+    ):
+        super().__init__()
+        self.dim = dim
+        self.num_heads = num_heads
+        self.window_size = window_size
+        self.shift_size = shift_size
+        self.mlp_ratio = mlp_ratio
+        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
+
+        self.norm1 = norm_layer(dim)
+        self.attn = WindowAttention(
+            dim,
+            window_size=to_2tuple(self.window_size),
+            num_heads=num_heads,
+            qkv_bias=qkv_bias,
+            qk_scale=qk_scale,
+            attn_drop=attn_drop,
+            proj_drop=drop,
+        )
+
+        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+        self.norm2 = norm_layer(dim)
+        mlp_hidden_dim = int(dim * mlp_ratio)
+        self.mlp = Mlp(
+            in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop
+        )
+
+        self.H = None
+        self.W = None
+
+    def forward(self, x, mask_matrix):
+        """Forward function.
+        Args:
+            x: Input feature, tensor size (B, H*W, C).
+            H, W: Spatial resolution of the input feature.
+            mask_matrix: Attention mask for cyclic shift.
+        """
+        B, L, C = x.shape
+        H, W = self.H, self.W
+        assert L == H * W, "input feature has wrong size"
+
+        shortcut = x
+        x = self.norm1(x)
+        x = x.view(B, H, W, C)
+
+        # pad feature maps to multiples of window size
+        pad_l = pad_t = 0
+        pad_r = (self.window_size - W % self.window_size) % self.window_size
+        pad_b = (self.window_size - H % self.window_size) % self.window_size
+        x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
+        _, Hp, Wp, _ = x.shape
+
+        # cyclic shift
+        if self.shift_size > 0:
+            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+            attn_mask = mask_matrix
+        else:
+            shifted_x = x
+            attn_mask = None
+
+        # partition windows
+        x_windows = window_partition(
+            shifted_x, self.window_size
+        )  # nW*B, window_size, window_size, C
+        x_windows = x_windows.view(
+            -1, self.window_size * self.window_size, C
+        )  # nW*B, window_size*window_size, C
+
+        # W-MSA/SW-MSA
+        attn_windows = self.attn(x_windows, mask=attn_mask)  # nW*B, window_size*window_size, C
+
+        # merge windows
+        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
+        shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp)  # B H' W' C
+
+        # reverse cyclic shift
+        if self.shift_size > 0:
+            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+        else:
+            x = shifted_x
+
+        if pad_r > 0 or pad_b > 0:
+            x = x[:, :H, :W, :].contiguous()
+
+        x = x.view(B, H * W, C)
+
+        # FFN
+        x = shortcut + self.drop_path(x)
+        x = x + self.drop_path(self.mlp(self.norm2(x)))
+
+        return x
+
+
+class PatchMerging(nn.Module):
+    """Patch Merging Layer
+    Args:
+        dim (int): Number of input channels.
+        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
+    """
+
+    def __init__(self, dim, norm_layer=nn.LayerNorm):
+        super().__init__()
+        self.dim = dim
+        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
+        self.norm = norm_layer(4 * dim)
+
+    def forward(self, x, H, W):
+        """Forward function.
+        Args:
+            x: Input feature, tensor size (B, H*W, C).
+            H, W: Spatial resolution of the input feature.
+        """
+        B, L, C = x.shape
+        assert L == H * W, "input feature has wrong size"
+
+        x = x.view(B, H, W, C)
+
+        # padding
+        pad_input = (H % 2 == 1) or (W % 2 == 1)
+        if pad_input:
+            x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
+
+        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
+        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
+        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
+        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
+        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
+        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C
+
+        x = self.norm(x)
+        x = self.reduction(x)
+
+        return x
+
+
+class BasicLayer(nn.Module):
+    """A basic Swin Transformer layer for one stage.
+    Args:
+        dim (int): Number of feature channels
+        depth (int): Depths of this stage.
+        num_heads (int): Number of attention head.
+        window_size (int): Local window size. Default: 7.
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
+        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+        drop (float, optional): Dropout rate. Default: 0.0
+        attn_drop (float, optional): Attention dropout rate. Default: 0.0
+        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+    """
+
+    def __init__(
+        self,
+        dim,
+        depth,
+        num_heads,
+        window_size=7,
+        mlp_ratio=4.0,
+        qkv_bias=True,
+        qk_scale=None,
+        drop=0.0,
+        attn_drop=0.0,
+        drop_path=0.0,
+        norm_layer=nn.LayerNorm,
+        downsample=None,
+        use_checkpoint=False,
+    ):
+        super().__init__()
+        self.window_size = window_size
+        self.shift_size = window_size // 2
+        self.depth = depth
+        self.use_checkpoint = use_checkpoint
+
+        # build blocks
+        self.blocks = nn.ModuleList(
+            [
+                SwinTransformerBlock(
+                    dim=dim,
+                    num_heads=num_heads,
+                    window_size=window_size,
+                    shift_size=0 if (i % 2 == 0) else window_size // 2,
+                    mlp_ratio=mlp_ratio,
+                    qkv_bias=qkv_bias,
+                    qk_scale=qk_scale,
+                    drop=drop,
+                    attn_drop=attn_drop,
+                    drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
+                    norm_layer=norm_layer,
+                )
+                for i in range(depth)
+            ]
+        )
+
+        # patch merging layer
+        if downsample is not None:
+            self.downsample = downsample(dim=dim, norm_layer=norm_layer)
+        else:
+            self.downsample = None
+
+    def forward(self, x, H, W):
+        """Forward function.
+        Args:
+            x: Input feature, tensor size (B, H*W, C).
+            H, W: Spatial resolution of the input feature.
+        """
+
+        # calculate attention mask for SW-MSA
+        Hp = int(np.ceil(H / self.window_size)) * self.window_size
+        Wp = int(np.ceil(W / self.window_size)) * self.window_size
+        img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)  # 1 Hp Wp 1
+        h_slices = (
+            slice(0, -self.window_size),
+            slice(-self.window_size, -self.shift_size),
+            slice(-self.shift_size, None),
+        )
+        w_slices = (
+            slice(0, -self.window_size),
+            slice(-self.window_size, -self.shift_size),
+            slice(-self.shift_size, None),
+        )
+        cnt = 0
+        for h in h_slices:
+            for w in w_slices:
+                img_mask[:, h, w, :] = cnt
+                cnt += 1
+
+        mask_windows = window_partition(
+            img_mask, self.window_size
+        )  # nW, window_size, window_size, 1
+        mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
+        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
+            attn_mask == 0, float(0.0)
+        )
+
+        for blk in self.blocks:
+            blk.H, blk.W = H, W
+            if self.use_checkpoint:
+                x = checkpoint.checkpoint(blk, x, attn_mask)
+            else:
+                x = blk(x, attn_mask)
+        if self.downsample is not None:
+            x_down = self.downsample(x, H, W)
+            Wh, Ww = (H + 1) // 2, (W + 1) // 2
+            return x, H, W, x_down, Wh, Ww
+        else:
+            return x, H, W, x, H, W
+
+
+class PatchEmbed(nn.Module):
+    """Image to Patch Embedding
+    Args:
+        patch_size (int): Patch token size. Default: 4.
+        in_chans (int): Number of input image channels. Default: 3.
+        embed_dim (int): Number of linear projection output channels. Default: 96.
+        norm_layer (nn.Module, optional): Normalization layer. Default: None
+    """
+
+    def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
+        super().__init__()
+        patch_size = to_2tuple(patch_size)
+        self.patch_size = patch_size
+
+        self.in_chans = in_chans
+        self.embed_dim = embed_dim
+
+        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+        if norm_layer is not None:
+            self.norm = norm_layer(embed_dim)
+        else:
+            self.norm = None
+
+    def forward(self, x):
+        """Forward function."""
+        # padding
+        _, _, H, W = x.size()
+        if W % self.patch_size[1] != 0:
+            x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
+        if H % self.patch_size[0] != 0:
+            x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
+
+        x = self.proj(x)  # B C Wh Ww
+        if self.norm is not None:
+            Wh, Ww = x.size(2), x.size(3)
+            x = x.flatten(2).transpose(1, 2)
+            x = self.norm(x)
+            x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
+
+        return x
+
+
+class SwinTransformer(nn.Module):
+    """Swin Transformer backbone.
+        A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -
+          https://arxiv.org/pdf/2103.14030
+    Args:
+        pretrain_img_size (int): Input image size for training the pretrained model,
+            used in absolute postion embedding. Default 224.
+        patch_size (int | tuple(int)): Patch size. Default: 4.
+        in_chans (int): Number of input image channels. Default: 3.
+        embed_dim (int): Number of linear projection output channels. Default: 96.
+        depths (tuple[int]): Depths of each Swin Transformer stage.
+        num_heads (tuple[int]): Number of attention head of each stage.
+        window_size (int): Window size. Default: 7.
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
+        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
+        drop_rate (float): Dropout rate.
+        attn_drop_rate (float): Attention dropout rate. Default: 0.
+        drop_path_rate (float): Stochastic depth rate. Default: 0.2.
+        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
+        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
+        patch_norm (bool): If True, add normalization after patch embedding. Default: True.
+        out_indices (Sequence[int]): Output from which stages.
+        frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+            -1 means not freezing any parameters.
+        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+    """
+
+    def __init__(
+        self,
+        pretrain_img_size=224,
+        patch_size=4,
+        in_chans=3,
+        embed_dim=96,
+        depths=[2, 2, 6, 2],
+        num_heads=[3, 6, 12, 24],
+        window_size=7,
+        mlp_ratio=4.0,
+        qkv_bias=True,
+        qk_scale=None,
+        drop_rate=0.0,
+        attn_drop_rate=0.0,
+        drop_path_rate=0.2,
+        norm_layer=nn.LayerNorm,
+        ape=False,
+        patch_norm=True,
+        out_indices=(0, 1, 2, 3),
+        frozen_stages=-1,
+        use_checkpoint=False,
+    ):
+        super().__init__()
+
+        self.pretrain_img_size = pretrain_img_size
+        self.num_layers = len(depths)
+        self.embed_dim = embed_dim
+        self.ape = ape
+        self.patch_norm = patch_norm
+        self.out_indices = out_indices
+        self.frozen_stages = frozen_stages
+
+        # split image into non-overlapping patches
+        self.patch_embed = PatchEmbed(
+            patch_size=patch_size,
+            in_chans=in_chans,
+            embed_dim=embed_dim,
+            norm_layer=norm_layer if self.patch_norm else None,
+        )
+
+        # absolute position embedding
+        if self.ape:
+            pretrain_img_size = to_2tuple(pretrain_img_size)
+            patch_size = to_2tuple(patch_size)
+            patches_resolution = [
+                pretrain_img_size[0] // patch_size[0],
+                pretrain_img_size[1] // patch_size[1],
+            ]
+
+            self.absolute_pos_embed = nn.Parameter(
+                torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])
+            )
+            trunc_normal_(self.absolute_pos_embed, std=0.02)
+
+        self.pos_drop = nn.Dropout(p=drop_rate)
+
+        # stochastic depth
+        dpr = [
+            x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
+        ]  # stochastic depth decay rule
+
+        # build layers
+        self.layers = nn.ModuleList()
+        for i_layer in range(self.num_layers):
+            layer = BasicLayer(
+                dim=int(embed_dim * 2 ** i_layer),
+                depth=depths[i_layer],
+                num_heads=num_heads[i_layer],
+                window_size=window_size,
+                mlp_ratio=mlp_ratio,
+                qkv_bias=qkv_bias,
+                qk_scale=qk_scale,
+                drop=drop_rate,
+                attn_drop=attn_drop_rate,
+                drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
+                norm_layer=norm_layer,
+                downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
+                use_checkpoint=use_checkpoint,
+            )
+            self.layers.append(layer)
+
+        num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
+        self.num_features = num_features
+
+        # add a norm layer for each output
+        for i_layer in out_indices:
+            layer = norm_layer(num_features[i_layer])
+            layer_name = f"norm{i_layer}"
+            self.add_module(layer_name, layer)
+
+        self._freeze_stages()
+
+    def _freeze_stages(self):
+        if self.frozen_stages >= 0:
+            self.patch_embed.eval()
+            for param in self.patch_embed.parameters():
+                param.requires_grad = False
+
+        if self.frozen_stages >= 1 and self.ape:
+            self.absolute_pos_embed.requires_grad = False
+
+        if self.frozen_stages >= 2:
+            self.pos_drop.eval()
+            for i in range(0, self.frozen_stages - 1):
+                m = self.layers[i]
+                m.eval()
+                for param in m.parameters():
+                    param.requires_grad = False
+
+    def init_weights(self, pretrained=None):
+        """Initialize the weights in backbone.
+        Args:
+            pretrained (str, optional): Path to pre-trained weights.
+                Defaults to None.
+        """
+
+        def _init_weights(m):
+            if isinstance(m, nn.Linear):
+                trunc_normal_(m.weight, std=0.02)
+                if isinstance(m, nn.Linear) and m.bias is not None:
+                    nn.init.constant_(m.bias, 0)
+            elif isinstance(m, nn.LayerNorm):
+                nn.init.constant_(m.bias, 0)
+                nn.init.constant_(m.weight, 1.0)
+
+    def forward(self, x):
+        """Forward function."""
+        x = self.patch_embed(x)
+
+        Wh, Ww = x.size(2), x.size(3)
+        if self.ape:
+            # interpolate the position embedding to the corresponding size
+            absolute_pos_embed = F.interpolate(
+                self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
+            )
+            x = (x + absolute_pos_embed).flatten(2).transpose(1, 2)  # B Wh*Ww C
+        else:
+            x = x.flatten(2).transpose(1, 2)
+        x = self.pos_drop(x)
+
+        outs = {}
+        for i in range(self.num_layers):
+            layer = self.layers[i]
+            x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
+
+            if i in self.out_indices:
+                norm_layer = getattr(self, f"norm{i}")
+                x_out = norm_layer(x_out)
+
+                out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
+                outs["res{}".format(i + 2)] = out
+
+        return outs
+
+    def train(self, mode=True):
+        """Convert the model into training mode while keep layers freezed."""
+        super(SwinTransformer, self).train(mode)
+        self._freeze_stages()
+
+
+@BACKBONE_REGISTRY.register()
+class D2SwinTransformer(SwinTransformer, Backbone):
+    def __init__(self, cfg, input_shape):
+
+        pretrain_img_size = cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE
+        patch_size = cfg.MODEL.SWIN.PATCH_SIZE
+        in_chans = 3
+        embed_dim = cfg.MODEL.SWIN.EMBED_DIM
+        depths = cfg.MODEL.SWIN.DEPTHS
+        num_heads = cfg.MODEL.SWIN.NUM_HEADS
+        window_size = cfg.MODEL.SWIN.WINDOW_SIZE
+        mlp_ratio = cfg.MODEL.SWIN.MLP_RATIO
+        qkv_bias = cfg.MODEL.SWIN.QKV_BIAS
+        qk_scale = cfg.MODEL.SWIN.QK_SCALE
+        drop_rate = cfg.MODEL.SWIN.DROP_RATE
+        attn_drop_rate = cfg.MODEL.SWIN.ATTN_DROP_RATE
+        drop_path_rate = cfg.MODEL.SWIN.DROP_PATH_RATE
+        norm_layer = nn.LayerNorm
+        ape = cfg.MODEL.SWIN.APE
+        patch_norm = cfg.MODEL.SWIN.PATCH_NORM
+        use_checkpoint = cfg.MODEL.SWIN.USE_CHECKPOINT
+
+        super().__init__(
+            pretrain_img_size,
+            patch_size,
+            in_chans,
+            embed_dim,
+            depths,
+            num_heads,
+            window_size,
+            mlp_ratio,
+            qkv_bias,
+            qk_scale,
+            drop_rate,
+            attn_drop_rate,
+            drop_path_rate,
+            norm_layer,
+            ape,
+            patch_norm,
+            use_checkpoint=use_checkpoint,
+        )
+
+        self._out_features = cfg.MODEL.SWIN.OUT_FEATURES
+
+        self._out_feature_strides = {
+            "res2": 4,
+            "res3": 8,
+            "res4": 16,
+            "res5": 32,
+        }
+        self._out_feature_channels = {
+            "res2": self.num_features[0],
+            "res3": self.num_features[1],
+            "res4": self.num_features[2],
+            "res5": self.num_features[3],
+        }
+
+    def forward(self, x):
+        """
+        Args:
+            x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
+        Returns:
+            dict[str->Tensor]: names and the corresponding features
+        """
+        assert (
+            x.dim() == 4
+        ), f"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!"
+        outputs = {}
+        y = super().forward(x)
+        for k in y.keys():
+            if k in self._out_features:
+                outputs[k] = y[k]
+        return outputs
+
+    def output_shape(self):
+        return {
+            name: ShapeSpec(
+                channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
+            )
+            for name in self._out_features
+        }
+
+    @property
+    def size_divisibility(self):
+        return 32
diff --git a/annotator/entityseg/mask2former/modeling/criterion.py b/annotator/entityseg/mask2former/modeling/criterion.py
new file mode 100644
index 0000000000000000000000000000000000000000..878ae754d1a108084644bfaebb3409fa6849cf13
--- /dev/null
+++ b/annotator/entityseg/mask2former/modeling/criterion.py
@@ -0,0 +1,263 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/models/detr.py
+"""
+MaskFormer criterion.
+"""
+import logging
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from detectron2.utils.comm import get_world_size
+from detectron2.projects.point_rend.point_features import (
+    get_uncertain_point_coords_with_randomness,
+    point_sample,
+)
+
+from ..utils.misc import is_dist_avail_and_initialized, nested_tensor_from_tensor_list
+
+
+def dice_loss(
+        inputs: torch.Tensor,
+        targets: torch.Tensor,
+        num_masks: float,
+    ):
+    """
+    Compute the DICE loss, similar to generalized IOU for masks
+    Args:
+        inputs: A float tensor of arbitrary shape.
+                The predictions for each example.
+        targets: A float tensor with the same shape as inputs. Stores the binary
+                 classification label for each element in inputs
+                (0 for the negative class and 1 for the positive class).
+    """
+    inputs = inputs.sigmoid()
+    inputs = inputs.flatten(1)
+    numerator = 2 * (inputs * targets).sum(-1)
+    denominator = inputs.sum(-1) + targets.sum(-1)
+    loss = 1 - (numerator + 1) / (denominator + 1)
+    return loss.sum() / num_masks
+
+
+dice_loss_jit = torch.jit.script(
+    dice_loss
+)  # type: torch.jit.ScriptModule
+
+
+def sigmoid_ce_loss(
+        inputs: torch.Tensor,
+        targets: torch.Tensor,
+        num_masks: float,
+    ):
+    """
+    Args:
+        inputs: A float tensor of arbitrary shape.
+                The predictions for each example.
+        targets: A float tensor with the same shape as inputs. Stores the binary
+                 classification label for each element in inputs
+                (0 for the negative class and 1 for the positive class).
+    Returns:
+        Loss tensor
+    """
+    loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
+
+    return loss.mean(1).sum() / num_masks
+
+
+sigmoid_ce_loss_jit = torch.jit.script(
+    sigmoid_ce_loss
+)  # type: torch.jit.ScriptModule
+
+
+def calculate_uncertainty(logits):
+    """
+    We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the
+        foreground class in `classes`.
+    Args:
+        logits (Tensor): A tensor of shape (R, 1, ...) for class-specific or
+            class-agnostic, where R is the total number of predicted masks in all images and C is
+            the number of foreground classes. The values are logits.
+    Returns:
+        scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with
+            the most uncertain locations having the highest uncertainty score.
+    """
+    assert logits.shape[1] == 1
+    gt_class_logits = logits.clone()
+    return -(torch.abs(gt_class_logits))
+
+
+class SetCriterion(nn.Module):
+    """This class computes the loss for DETR.
+    The process happens in two steps:
+        1) we compute hungarian assignment between ground truth boxes and the outputs of the model
+        2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
+    """
+
+    def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses,
+                 num_points, oversample_ratio, importance_sample_ratio):
+        """Create the criterion.
+        Parameters:
+            num_classes: number of object categories, omitting the special no-object category
+            matcher: module able to compute a matching between targets and proposals
+            weight_dict: dict containing as key the names of the losses and as values their relative weight.
+            eos_coef: relative classification weight applied to the no-object category
+            losses: list of all the losses to be applied. See get_loss for list of available losses.
+        """
+        super().__init__()
+        self.num_classes = num_classes
+        self.matcher = matcher
+        self.weight_dict = weight_dict
+        self.eos_coef = eos_coef
+        self.losses = losses
+        empty_weight = torch.ones(self.num_classes + 1)
+        empty_weight[-1] = self.eos_coef
+        self.register_buffer("empty_weight", empty_weight)
+
+        # pointwise mask loss parameters
+        self.num_points = num_points
+        self.oversample_ratio = oversample_ratio
+        self.importance_sample_ratio = importance_sample_ratio
+
+    def loss_labels(self, outputs, targets, indices, num_masks):
+        """Classification loss (NLL)
+        targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
+        """
+        assert "pred_logits" in outputs
+        src_logits = outputs["pred_logits"].float()
+
+        idx = self._get_src_permutation_idx(indices)
+        target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
+        target_classes = torch.full(
+            src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
+        )
+        target_classes[idx] = target_classes_o
+
+        loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
+        losses = {"loss_ce": loss_ce}
+        return losses
+    
+    def loss_masks(self, outputs, targets, indices, num_masks):
+        """Compute the losses related to the masks: the focal loss and the dice loss.
+        targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
+        """
+        assert "pred_masks" in outputs
+
+        src_idx = self._get_src_permutation_idx(indices)
+        tgt_idx = self._get_tgt_permutation_idx(indices)
+        src_masks = outputs["pred_masks"]
+        src_masks = src_masks[src_idx]
+        masks = [t["masks"] for t in targets]
+        # TODO use valid to mask invalid areas due to padding in loss
+        target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
+        target_masks = target_masks.to(src_masks)
+        target_masks = target_masks[tgt_idx]
+
+        # No need to upsample predictions as we are using normalized coordinates :)
+        # N x 1 x H x W
+        src_masks = src_masks[:, None]
+        target_masks = target_masks[:, None]
+
+        with torch.no_grad():
+            # sample point_coords
+            point_coords = get_uncertain_point_coords_with_randomness(
+                src_masks,
+                lambda logits: calculate_uncertainty(logits),
+                self.num_points,
+                self.oversample_ratio,
+                self.importance_sample_ratio,
+            )
+            # get gt labels
+            point_labels = point_sample(
+                target_masks,
+                point_coords,
+                align_corners=False,
+            ).squeeze(1)
+
+        point_logits = point_sample(
+            src_masks,
+            point_coords,
+            align_corners=False,
+        ).squeeze(1)
+
+        losses = {
+            "loss_mask": sigmoid_ce_loss_jit(point_logits, point_labels, num_masks),
+            "loss_dice": dice_loss_jit(point_logits, point_labels, num_masks),
+        }
+
+        del src_masks
+        del target_masks
+        return losses
+
+    def _get_src_permutation_idx(self, indices):
+        # permute predictions following indices
+        batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
+        src_idx = torch.cat([src for (src, _) in indices])
+        return batch_idx, src_idx
+
+    def _get_tgt_permutation_idx(self, indices):
+        # permute targets following indices
+        batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
+        tgt_idx = torch.cat([tgt for (_, tgt) in indices])
+        return batch_idx, tgt_idx
+
+    def get_loss(self, loss, outputs, targets, indices, num_masks):
+        loss_map = {
+            'labels': self.loss_labels,
+            'masks': self.loss_masks,
+        }
+        assert loss in loss_map, f"do you really want to compute {loss} loss?"
+        return loss_map[loss](outputs, targets, indices, num_masks)
+
+    def forward(self, outputs, targets):
+        """This performs the loss computation.
+        Parameters:
+             outputs: dict of tensors, see the output specification of the model for the format
+             targets: list of dicts, such that len(targets) == batch_size.
+                      The expected keys in each dict depends on the losses applied, see each loss' doc
+        """
+        outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"}
+
+        # Retrieve the matching between the outputs of the last layer and the targets
+        indices = self.matcher(outputs_without_aux, targets)
+
+        # Compute the average number of target boxes accross all nodes, for normalization purposes
+        num_masks = sum(len(t["labels"]) for t in targets)
+        num_masks = torch.as_tensor(
+            [num_masks], dtype=torch.float, device=next(iter(outputs.values())).device
+        )
+        if is_dist_avail_and_initialized():
+            torch.distributed.all_reduce(num_masks)
+        num_masks = torch.clamp(num_masks / get_world_size(), min=1).item()
+
+        # Compute all the requested losses
+        losses = {}
+        for loss in self.losses:
+            losses.update(self.get_loss(loss, outputs, targets, indices, num_masks))
+
+        # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
+        if "aux_outputs" in outputs:
+            for i, aux_outputs in enumerate(outputs["aux_outputs"]):
+                indices = self.matcher(aux_outputs, targets)
+                for loss in self.losses:
+                    l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_masks)
+                    l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
+                    losses.update(l_dict)
+
+        return losses
+
+    def __repr__(self):
+        head = "Criterion " + self.__class__.__name__
+        body = [
+            "matcher: {}".format(self.matcher.__repr__(_repr_indent=8)),
+            "losses: {}".format(self.losses),
+            "weight_dict: {}".format(self.weight_dict),
+            "num_classes: {}".format(self.num_classes),
+            "eos_coef: {}".format(self.eos_coef),
+            "num_points: {}".format(self.num_points),
+            "oversample_ratio: {}".format(self.oversample_ratio),
+            "importance_sample_ratio: {}".format(self.importance_sample_ratio),
+        ]
+        _repr_indent = 4
+        lines = [head] + [" " * _repr_indent + line for line in body]
+        return "\n".join(lines)
diff --git a/annotator/entityseg/mask2former/modeling/criterion_view.py b/annotator/entityseg/mask2former/modeling/criterion_view.py
new file mode 100644
index 0000000000000000000000000000000000000000..19c64908398e59ee069f802c426bc2d9c236d2c5
--- /dev/null
+++ b/annotator/entityseg/mask2former/modeling/criterion_view.py
@@ -0,0 +1,288 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/models/detr.py
+"""
+MaskFormer criterion.
+"""
+import logging
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from detectron2.utils.comm import get_world_size
+from detectron2.projects.point_rend.point_features import (
+    get_uncertain_point_coords_with_randomness,
+    point_sample,
+)
+
+from mask2former.utils.misc import is_dist_avail_and_initialized
+
+import pdb
+
+
+def dice_loss(
+        inputs: torch.Tensor,
+        targets: torch.Tensor,
+        num_masks: float,
+    ):
+    """
+    Compute the DICE loss, similar to generalized IOU for masks
+    Args:
+        inputs: A float tensor of arbitrary shape.
+                The predictions for each example.
+        targets: A float tensor with the same shape as inputs. Stores the binary
+                 classification label for each element in inputs
+                (0 for the negative class and 1 for the positive class).
+    """
+    inputs = inputs.sigmoid()
+    inputs = inputs.flatten(1)
+    numerator = 2 * (inputs * targets).sum(-1)
+    denominator = inputs.sum(-1) + targets.sum(-1)
+    loss = 1 - (numerator + 1) / (denominator + 1)
+    return loss.sum() / num_masks
+
+
+dice_loss_jit = torch.jit.script(
+    dice_loss
+)  # type: torch.jit.ScriptModule
+
+
+def sigmoid_ce_loss(
+        inputs: torch.Tensor,
+        targets: torch.Tensor,
+        num_masks: float,
+    ):
+    """
+    Args:
+        inputs: A float tensor of arbitrary shape.
+                The predictions for each example.
+        targets: A float tensor with the same shape as inputs. Stores the binary
+                 classification label for each element in inputs
+                (0 for the negative class and 1 for the positive class).
+    Returns:
+        Loss tensor
+    """
+    loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
+
+    return loss.mean(1).sum() / num_masks
+
+
+sigmoid_ce_loss_jit = torch.jit.script(
+    sigmoid_ce_loss
+)  # type: torch.jit.ScriptModule
+
+
+def calculate_uncertainty(logits):
+    """
+    We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the
+        foreground class in `classes`.
+    Args:
+        logits (Tensor): A tensor of shape (R, 1, ...) for class-specific or
+            class-agnostic, where R is the total number of predicted masks in all images and C is
+            the number of foreground classes. The values are logits.
+    Returns:
+        scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with
+            the most uncertain locations having the highest uncertainty score.
+    """
+    assert logits.shape[1] == 1
+    gt_class_logits = logits.clone()
+    return -(torch.abs(gt_class_logits))
+
+
+class ViewSetCriterion(nn.Module):
+    """This class computes the loss for DETR.
+    The process happens in two steps:
+        1) we compute hungarian assignment between ground truth boxes and the outputs of the model
+        2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
+    """
+
+    def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses,
+                 num_points, oversample_ratio, importance_sample_ratio):
+        """Create the criterion.
+        Parameters:
+            num_classes: number of object categories, omitting the special no-object category
+            matcher: module able to compute a matching between targets and proposals
+            weight_dict: dict containing as key the names of the losses and as values their relative weight.
+            eos_coef: relative classification weight applied to the no-object category
+            losses: list of all the losses to be applied. See get_loss for list of available losses.
+        """
+        super().__init__()
+        self.num_classes = num_classes
+        self.matcher = matcher
+        self.weight_dict = weight_dict
+        self.eos_coef = eos_coef
+        self.losses = losses
+        empty_weight = torch.ones(self.num_classes + 1)
+        empty_weight[-1] = self.eos_coef
+        self.register_buffer("empty_weight", empty_weight)
+
+        # pointwise mask loss parameters
+        self.num_points = num_points
+        self.oversample_ratio = oversample_ratio
+        self.importance_sample_ratio = importance_sample_ratio
+
+    def loss_labels(self, outputs, targets, indices, num_masks):
+        """Classification loss (NLL)
+        targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
+        """
+        assert "pred_logits" in outputs
+        src_logits = outputs["pred_logits"].float()
+        ## src_logits: torch.Size([2, 100, 41])
+
+        idx = self._get_src_permutation_idx(indices)
+        ## idx: (tensor([0, 0, 1, 1]), tensor([17, 84, 17, 76]))
+        target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
+        ### target_class_o: tensor([ 0, 26,  0, 11], device='cuda:0')
+        target_classes = torch.full(
+            src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
+        )
+        ## target_class: torch.Size([2, 100]), 全是40, background类
+        target_classes[idx] = target_classes_o
+        ## 
+        ## src_logits: torch.Size([2, 41, 100])
+        ## target_classes: torch.Size([2, 100])
+        ## self.empty_weight: tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
+        ##1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
+        ##1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
+        ##1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
+        ##1.0000, 1.0000, 1.0000, 1.0000, 0.1000], device='cuda:0')
+        loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
+        losses = {"loss_ce": loss_ce}
+        return losses
+    
+    def loss_masks(self, outputs, targets, indices, num_masks):
+        """Compute the losses related to the masks: the focal loss and the dice loss.
+        targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
+        """
+        assert "pred_masks" in outputs
+
+        src_idx = self._get_src_permutation_idx(indices)
+        ### src_idx: (tensor([0, 0, 1, 1]), tensor([34, 95, 32, 65]))
+        src_masks = outputs["pred_masks"]
+        ## src_masks: torch.Size([2, 100, 2, 120, 216])
+        src_masks = src_masks[src_idx]
+        ## src_masks: torch.Size([4, 2, 120, 216])
+
+        # Modified to handle video
+        target_masks = torch.cat([t['masks'][i] for t, (_, i) in zip(targets, indices)]).to(src_masks)
+        ### target_masks: torch.Size([4, 2, 480, 864])
+
+        # No need to upsample predictions as we are using normalized coordinates :)
+        # NT x 1 x H x W
+        src_masks = src_masks.flatten(0, 1)[:, None]
+        ## src_masks: torch.Size([8, 1, 120, 216])
+        target_masks = target_masks.flatten(0, 1)[:, None]
+        ## target_masks: torch.Size([8, 1, 480, 864])
+
+        with torch.no_grad():
+            # sample point_coords
+            point_coords = get_uncertain_point_coords_with_randomness(
+                src_masks,
+                lambda logits: calculate_uncertainty(logits),
+                self.num_points,
+                self.oversample_ratio,
+                self.importance_sample_ratio,
+            )
+            # get gt labels
+            point_labels = point_sample(
+                target_masks,
+                point_coords,
+                align_corners=False,
+            ).squeeze(1)
+
+        point_logits = point_sample(
+            src_masks,
+            point_coords,
+            align_corners=False,
+        ).squeeze(1)
+
+        losses = {
+            "loss_mask": sigmoid_ce_loss_jit(point_logits, point_labels, num_masks),
+            "loss_dice": dice_loss_jit(point_logits, point_labels, num_masks),
+        }
+
+        del src_masks
+        del target_masks
+        return losses
+
+    def _get_src_permutation_idx(self, indices):
+        # permute predictions following indices
+        batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
+        src_idx = torch.cat([src for (src, _) in indices])
+        return batch_idx, src_idx
+
+    def _get_tgt_permutation_idx(self, indices):
+        # permute targets following indices
+        batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
+        tgt_idx = torch.cat([tgt for (_, tgt) in indices])
+        return batch_idx, tgt_idx
+
+    def get_loss(self, loss, outputs, targets, indices, num_masks):
+        loss_map = {
+            'labels': self.loss_labels,
+            'masks': self.loss_masks,
+        }
+        assert loss in loss_map, f"do you really want to compute {loss} loss?"
+        return loss_map[loss](outputs, targets, indices, num_masks)
+
+    def forward(self, outputs, targets, return_indices=False):
+        """This performs the loss computation.
+        Parameters:
+             outputs: dict of tensors, see the output specification of the model for the format
+             targets: list of dicts, such that len(targets) == batch_size.
+                      The expected keys in each dict depends on the losses applied, see each loss' doc
+        """
+        outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"}
+
+        # Retrieve the matching between the outputs of the last layer and the targets
+        indices = self.matcher(outputs_without_aux, targets)
+        indices_l = []
+        indices_l.append(indices)
+        # pdb.set_trace()
+
+        # Compute the average number of target boxes accross all nodes, for normalization purposes
+        num_masks = sum(len(t["labels"]) for t in targets)
+        num_masks = torch.as_tensor(
+            [num_masks], dtype=torch.float, device=next(iter(outputs.values())).device
+        )
+        if is_dist_avail_and_initialized():
+            torch.distributed.all_reduce(num_masks)
+        num_masks = torch.clamp(num_masks / get_world_size(), min=1).item()
+
+        # Compute all the requested losses
+        losses = {}
+        for loss in self.losses:
+            losses.update(self.get_loss(loss, outputs, targets, indices, num_masks))
+
+        # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
+        if "aux_outputs" in outputs:
+            for i, aux_outputs in enumerate(outputs["aux_outputs"]):
+                indices = self.matcher(aux_outputs, targets)
+                indices_l.append(indices)
+                for loss in self.losses:
+                    l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_masks)
+                    l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
+                    losses.update(l_dict)
+        indices_l.append(indices_l[0])
+        indices_l = indices_l[1:]
+
+        if return_indices:
+            return losses, indices_l
+        else:
+            return losses
+
+    def __repr__(self):
+        head = "Criterion " + self.__class__.__name__
+        body = [
+            "matcher: {}".format(self.matcher.__repr__(_repr_indent=8)),
+            "losses: {}".format(self.losses),
+            "weight_dict: {}".format(self.weight_dict),
+            "num_classes: {}".format(self.num_classes),
+            "eos_coef: {}".format(self.eos_coef),
+            "num_points: {}".format(self.num_points),
+            "oversample_ratio: {}".format(self.oversample_ratio),
+            "importance_sample_ratio: {}".format(self.importance_sample_ratio),
+        ]
+        _repr_indent = 4
+        lines = [head] + [" " * _repr_indent + line for line in body]
+        return "\n".join(lines)
diff --git a/annotator/entityseg/mask2former/modeling/matcher.py b/annotator/entityseg/mask2former/modeling/matcher.py
new file mode 100644
index 0000000000000000000000000000000000000000..dafdbbbaa7eb18dc1dc7c5a2b97e0dcb248b7b97
--- /dev/null
+++ b/annotator/entityseg/mask2former/modeling/matcher.py
@@ -0,0 +1,189 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/models/matcher.py
+"""
+Modules to compute the matching cost and solve the corresponding LSAP.
+"""
+import torch
+import torch.nn.functional as F
+from scipy.optimize import linear_sum_assignment
+from torch import nn
+from torch.cuda.amp import autocast
+
+from detectron2.projects.point_rend.point_features import point_sample
+
+
+def batch_dice_loss(inputs: torch.Tensor, targets: torch.Tensor):
+    """
+    Compute the DICE loss, similar to generalized IOU for masks
+    Args:
+        inputs: A float tensor of arbitrary shape.
+                The predictions for each example.
+        targets: A float tensor with the same shape as inputs. Stores the binary
+                 classification label for each element in inputs
+                (0 for the negative class and 1 for the positive class).
+    """
+    inputs = inputs.sigmoid()
+    inputs = inputs.flatten(1)
+    numerator = 2 * torch.einsum("nc,mc->nm", inputs, targets)
+    denominator = inputs.sum(-1)[:, None] + targets.sum(-1)[None, :]
+    loss = 1 - (numerator + 1) / (denominator + 1)
+    return loss
+
+
+batch_dice_loss_jit = torch.jit.script(
+    batch_dice_loss
+)  # type: torch.jit.ScriptModule
+
+
+def batch_sigmoid_ce_loss(inputs: torch.Tensor, targets: torch.Tensor):
+    """
+    Args:
+        inputs: A float tensor of arbitrary shape.
+                The predictions for each example.
+        targets: A float tensor with the same shape as inputs. Stores the binary
+                 classification label for each element in inputs
+                (0 for the negative class and 1 for the positive class).
+    Returns:
+        Loss tensor
+    """
+    hw = inputs.shape[1]
+
+    pos = F.binary_cross_entropy_with_logits(
+        inputs, torch.ones_like(inputs), reduction="none"
+    )
+    neg = F.binary_cross_entropy_with_logits(
+        inputs, torch.zeros_like(inputs), reduction="none"
+    )
+
+    loss = torch.einsum("nc,mc->nm", pos, targets) + torch.einsum(
+        "nc,mc->nm", neg, (1 - targets)
+    )
+
+    return loss / hw
+
+
+batch_sigmoid_ce_loss_jit = torch.jit.script(
+    batch_sigmoid_ce_loss
+)  # type: torch.jit.ScriptModule
+
+
+class HungarianMatcher(nn.Module):
+    """This class computes an assignment between the targets and the predictions of the network
+
+    For efficiency reasons, the targets don't include the no_object. Because of this, in general,
+    there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
+    while the others are un-matched (and thus treated as non-objects).
+    """
+
+    def __init__(self, cost_class: float = 1, cost_mask: float = 1, cost_dice: float = 1, num_points: int = 0):
+        """Creates the matcher
+
+        Params:
+            cost_class: This is the relative weight of the classification error in the matching cost
+            cost_mask: This is the relative weight of the focal loss of the binary mask in the matching cost
+            cost_dice: This is the relative weight of the dice loss of the binary mask in the matching cost
+        """
+        super().__init__()
+        self.cost_class = cost_class
+        self.cost_mask = cost_mask
+        self.cost_dice = cost_dice
+
+        assert cost_class != 0 or cost_mask != 0 or cost_dice != 0, "all costs cant be 0"
+
+        self.num_points = num_points
+
+    @torch.no_grad()
+    def memory_efficient_forward(self, outputs, targets):
+        """More memory-friendly matching"""
+        bs, num_queries = outputs["pred_logits"].shape[:2]
+
+        indices = []
+
+        # Iterate through batch size
+        for b in range(bs):
+
+            out_prob = outputs["pred_logits"][b].softmax(-1)  # [num_queries, num_classes]
+            tgt_ids = targets[b]["labels"]
+
+            # Compute the classification cost. Contrary to the loss, we don't use the NLL,
+            # but approximate it in 1 - proba[target class].
+            # The 1 is a constant that doesn't change the matching, it can be ommitted.
+            cost_class = -out_prob[:, tgt_ids]
+
+            out_mask = outputs["pred_masks"][b]  # [num_queries, H_pred, W_pred]
+            # gt masks are already padded when preparing target
+            tgt_mask = targets[b]["masks"].to(out_mask)
+
+            out_mask = out_mask[:, None]
+            tgt_mask = tgt_mask[:, None]
+            # all masks share the same set of points for efficient matching!
+            point_coords = torch.rand(1, self.num_points, 2, device=out_mask.device)
+            # get gt labels
+            tgt_mask = point_sample(
+                tgt_mask,
+                point_coords.repeat(tgt_mask.shape[0], 1, 1),
+                align_corners=False,
+            ).squeeze(1)
+
+            out_mask = point_sample(
+                out_mask,
+                point_coords.repeat(out_mask.shape[0], 1, 1),
+                align_corners=False,
+            ).squeeze(1)
+
+            with autocast(enabled=False):
+                out_mask = out_mask.float()
+                tgt_mask = tgt_mask.float()
+                # Compute the focal loss between masks
+                cost_mask = batch_sigmoid_ce_loss(out_mask, tgt_mask)
+
+                # Compute the dice loss betwen masks
+                cost_dice = batch_dice_loss(out_mask, tgt_mask)
+            
+            # Final cost matrix
+            C = (
+                self.cost_mask * cost_mask
+                + self.cost_class * cost_class
+                + self.cost_dice * cost_dice
+            )
+            C = C.reshape(num_queries, -1).cpu()
+
+            indices.append(linear_sum_assignment(C))
+
+        return [
+            (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))
+            for i, j in indices
+        ]
+
+    @torch.no_grad()
+    def forward(self, outputs, targets):
+        """Performs the matching
+
+        Params:
+            outputs: This is a dict that contains at least these entries:
+                 "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
+                 "pred_masks": Tensor of dim [batch_size, num_queries, H_pred, W_pred] with the predicted masks
+
+            targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
+                 "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
+                           objects in the target) containing the class labels
+                 "masks": Tensor of dim [num_target_boxes, H_gt, W_gt] containing the target masks
+
+        Returns:
+            A list of size batch_size, containing tuples of (index_i, index_j) where:
+                - index_i is the indices of the selected predictions (in order)
+                - index_j is the indices of the corresponding selected targets (in order)
+            For each batch element, it holds:
+                len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
+        """
+        return self.memory_efficient_forward(outputs, targets)
+
+    def __repr__(self, _repr_indent=4):
+        head = "Matcher " + self.__class__.__name__
+        body = [
+            "cost_class: {}".format(self.cost_class),
+            "cost_mask: {}".format(self.cost_mask),
+            "cost_dice: {}".format(self.cost_dice),
+        ]
+        lines = [head] + [" " * _repr_indent + line for line in body]
+        return "\n".join(lines)
diff --git a/annotator/entityseg/mask2former/modeling/matcher_view.py b/annotator/entityseg/mask2former/modeling/matcher_view.py
new file mode 100644
index 0000000000000000000000000000000000000000..75d6c25c87f53e53a0d21a40fb5a9dad943a3f80
--- /dev/null
+++ b/annotator/entityseg/mask2former/modeling/matcher_view.py
@@ -0,0 +1,194 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/models/matcher.py
+"""
+Modules to compute the matching cost and solve the corresponding LSAP.
+"""
+import torch
+import torch.nn.functional as F
+from scipy.optimize import linear_sum_assignment
+from torch import nn
+from torch.cuda.amp import autocast
+
+from detectron2.projects.point_rend.point_features import point_sample
+
+def batch_dice_loss(inputs: torch.Tensor, targets: torch.Tensor):
+    """
+    Compute the DICE loss, similar to generalized IOU for masks
+    Args:
+        inputs: A float tensor of arbitrary shape.
+                The predictions for each example.
+        targets: A float tensor with the same shape as inputs. Stores the binary
+                 classification label for each element in inputs
+                (0 for the negative class and 1 for the positive class).
+    """
+    inputs = inputs.sigmoid()
+    inputs = inputs.flatten(1)
+    numerator = 2 * torch.einsum("nc,mc->nm", inputs, targets)
+    denominator = inputs.sum(-1)[:, None] + targets.sum(-1)[None, :]
+    loss = 1 - (numerator + 1) / (denominator + 1)
+    return loss
+
+
+batch_dice_loss_jit = torch.jit.script(
+    batch_dice_loss
+)  # type: torch.jit.ScriptModule
+
+
+def batch_sigmoid_ce_loss(inputs: torch.Tensor, targets: torch.Tensor):
+    """
+    Args:
+        inputs: A float tensor of arbitrary shape.
+                The predictions for each example.
+        targets: A float tensor with the same shape as inputs. Stores the binary
+                 classification label for each element in inputs
+                (0 for the negative class and 1 for the positive class).
+    Returns:
+        Loss tensor
+    """
+    hw = inputs.shape[1]
+
+    pos = F.binary_cross_entropy_with_logits(
+        inputs, torch.ones_like(inputs), reduction="none"
+    )
+    neg = F.binary_cross_entropy_with_logits(
+        inputs, torch.zeros_like(inputs), reduction="none"
+    )
+
+    loss = torch.einsum("nc,mc->nm", pos, targets) + torch.einsum(
+        "nc,mc->nm", neg, (1 - targets)
+    )
+
+    return loss / hw
+
+
+batch_sigmoid_ce_loss_jit = torch.jit.script(
+    batch_sigmoid_ce_loss
+)  # type: torch.jit.ScriptModule
+
+
+class ViewHungarianMatcher(nn.Module):
+    """This class computes an assignment between the targets and the predictions of the network
+
+    For efficiency reasons, the targets don't include the no_object. Because of this, in general,
+    there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
+    while the others are un-matched (and thus treated as non-objects).
+    """
+
+    def __init__(self, cost_class: float = 1, cost_mask: float = 1, cost_dice: float = 1, num_points: int = 0):
+        """Creates the matcher
+
+        Params:
+            cost_class: This is the relative weight of the classification error in the matching cost
+            cost_mask: This is the relative weight of the focal loss of the binary mask in the matching cost
+            cost_dice: This is the relative weight of the dice loss of the binary mask in the matching cost
+        """
+        super().__init__()
+        self.cost_class = cost_class
+        self.cost_mask = cost_mask
+        self.cost_dice = cost_dice
+
+        assert cost_class != 0 or cost_mask != 0 or cost_dice != 0, "all costs cant be 0"
+
+        self.num_points = num_points
+
+    @torch.no_grad()
+    def memory_efficient_forward(self, outputs, targets):
+        """More memory-friendly matching"""
+        ### outputs["pred_logits"]: torch.Size([2, 100, 41]), query是对两帧负责,所以没有frame的概念
+        ### outputs["pred_masks"]: torch.Size([2, 100, 2, 120, 160]), 第三维的2是两帧frame
+        bs, num_queries = outputs["pred_logits"].shape[:2]
+
+        indices = []
+
+        # Iterate through batch size
+        for b in range(bs):
+            out_prob = outputs["pred_logits"][b].softmax(-1)  # [num_queries, num_classes]
+            ## out_prob: [100, 41], 100个query, 40类+background类
+            tgt_ids = targets[b]["labels"]
+            ## tgt_ids: tensor([ 3, 10]), 说明只有两个ground truth
+
+            # Compute the classification cost. Contrary to the loss, we don't use the NLL,
+            # but approximate it in 1 - proba[target class].
+            # The 1 is a constant that doesn't change the matching, it can be ommitted.
+            cost_class = -out_prob[:, tgt_ids]
+
+            out_mask = outputs["pred_masks"][b]  # [num_queries, T, H_pred, W_pred]
+            ### out_mask: torch.Size([100, 2, 120, 160])
+            # gt masks are already padded when preparing target
+            tgt_mask = targets[b]["masks"].to(out_mask)  # [num_gts, T, H_pred, W_pred]
+            ## tgt_mask: torch.Size([2, 2, 480, 640])
+
+            # out_mask = out_mask[:, None]
+            # tgt_mask = tgt_mask[:, None]
+            # all masks share the same set of points for efficient matching!
+            point_coords = torch.rand(1, self.num_points, 2, device=out_mask.device)
+            # get gt labels
+            tgt_mask = point_sample(
+                tgt_mask,
+                point_coords.repeat(tgt_mask.shape[0], 1, 1), ## repeat了一份, torch.Size([2, 12544, 2]), 每一帧采样的位置都是一样的
+                align_corners=False,
+            ).flatten(1)
+
+            out_mask = point_sample(
+                out_mask,
+                point_coords.repeat(out_mask.shape[0], 1, 1),
+                align_corners=False,
+            ).flatten(1)
+
+            with autocast(enabled=False):
+                out_mask = out_mask.float()  ## out_mask: torch.Size([100, 25088])
+                tgt_mask = tgt_mask.float()  ## tgt_mask: torch.Size([2, 25088])
+                # Compute the focal loss between masks
+                cost_mask = batch_sigmoid_ce_loss_jit(out_mask, tgt_mask) ## cost_mask: torch.Size([100, 2])
+
+                # Compute the dice loss betwen masks
+                cost_dice = batch_dice_loss_jit(out_mask, tgt_mask) ## cost_dice: torch.Size([100, 2])
+            
+            # Final cost matrix
+            C = (
+                self.cost_mask * cost_mask
+                + self.cost_class * cost_class
+                + self.cost_dice * cost_dice
+            )
+            C = C.reshape(num_queries, -1).cpu()
+
+            indices.append(linear_sum_assignment(C))
+            ## [(array([17, 33]), array([1, 0]), ...]
+
+        return [
+            (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))
+            for i, j in indices
+        ]
+
+    @torch.no_grad()
+    def forward(self, outputs, targets):
+        """Performs the matching
+
+        Params:
+            outputs: This is a dict that contains at least these entries:
+                 "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
+                 "pred_masks": Tensor of dim [batch_size, num_queries, H_pred, W_pred] with the predicted masks
+
+            targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
+                 "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
+                           objects in the target) containing the class labels
+                 "masks": Tensor of dim [num_target_boxes, H_gt, W_gt] containing the target masks
+
+        Returns:
+            A list of size batch_size, containing tuples of (index_i, index_j) where:
+                - index_i is the indices of the selected predictions (in order)
+                - index_j is the indices of the corresponding selected targets (in order)
+            For each batch element, it holds:
+                len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
+        """
+        return self.memory_efficient_forward(outputs, targets)
+
+    def __repr__(self, _repr_indent=4):
+        head = "Matcher " + self.__class__.__name__
+        body = [
+            "cost_class: {}".format(self.cost_class),
+            "cost_mask: {}".format(self.cost_mask),
+            "cost_dice: {}".format(self.cost_dice),
+        ]
+        lines = [head] + [" " * _repr_indent + line for line in body]
+        return "\n".join(lines)
\ No newline at end of file
diff --git a/annotator/entityseg/mask2former/modeling/meta_arch/__init__.py b/annotator/entityseg/mask2former/modeling/meta_arch/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9020c2df23e2af280b7bb168b996ae9eaf312eb8
--- /dev/null
+++ b/annotator/entityseg/mask2former/modeling/meta_arch/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
diff --git a/annotator/entityseg/mask2former/modeling/meta_arch/mask_former_head.py b/annotator/entityseg/mask2former/modeling/meta_arch/mask_former_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..2994f1ecd23445a3d2b913ecbc36bc132ffebb73
--- /dev/null
+++ b/annotator/entityseg/mask2former/modeling/meta_arch/mask_former_head.py
@@ -0,0 +1,133 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import logging
+from copy import deepcopy
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+import fvcore.nn.weight_init as weight_init
+from torch import nn
+from torch.nn import functional as F
+
+from detectron2.config import configurable
+from detectron2.layers import Conv2d, ShapeSpec, get_norm
+from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
+
+from ..transformer_decoder.maskformer_transformer_decoder import build_transformer_decoder
+from ..pixel_decoder.fpn import build_pixel_decoder
+
+
+@SEM_SEG_HEADS_REGISTRY.register()
+class MaskFormerHead(nn.Module):
+
+    _version = 2
+
+    def _load_from_state_dict(
+        self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
+    ):
+        version = local_metadata.get("version", None)
+        if version is None or version < 2:
+            # Do not warn if train from scratch
+            scratch = True
+            logger = logging.getLogger(__name__)
+            for k in list(state_dict.keys()):
+                newk = k
+                if "sem_seg_head" in k and not k.startswith(prefix + "predictor"):
+                    # newk = k.replace(prefix, prefix + "pixel_decoder.")
+                    newk = k.replace(prefix, prefix)
+                    # logger.debug(f"{k} ==> {newk}")
+                if newk != k:
+                    state_dict[newk] = state_dict[k]
+                    del state_dict[k]
+                    scratch = False
+
+            if not scratch:
+                logger.warning(
+                    f"Weight format of {self.__class__.__name__} have changed! "
+                    "Please upgrade your models. Applying automatic conversion now ..."
+                )
+
+    @configurable
+    def __init__(
+        self,
+        input_shape: Dict[str, ShapeSpec],
+        *,
+        num_classes: int,
+        pixel_decoder: nn.Module,
+        loss_weight: float = 1.0,
+        ignore_value: int = -1,
+        # extra parameters
+        transformer_predictor: nn.Module,
+        transformer_in_feature: str,
+    ):
+        """
+        NOTE: this interface is experimental.
+        Args:
+            input_shape: shapes (channels and stride) of the input features
+            num_classes: number of classes to predict
+            pixel_decoder: the pixel decoder module
+            loss_weight: loss weight
+            ignore_value: category id to be ignored during training.
+            transformer_predictor: the transformer decoder that makes prediction
+            transformer_in_feature: input feature name to the transformer_predictor
+        """
+        super().__init__()
+        input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
+        self.in_features = [k for k, v in input_shape]
+        feature_strides = [v.stride for k, v in input_shape]
+        feature_channels = [v.channels for k, v in input_shape]
+
+        self.ignore_value = ignore_value
+        self.common_stride = 4
+        self.loss_weight = loss_weight
+
+        self.pixel_decoder = pixel_decoder
+        self.predictor = transformer_predictor
+        self.transformer_in_feature = transformer_in_feature
+
+        self.num_classes = num_classes
+
+    @classmethod
+    def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
+        # figure out in_channels to transformer predictor
+        if cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE == "transformer_encoder":
+            transformer_predictor_in_channels = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
+        elif cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE == "pixel_embedding":
+            transformer_predictor_in_channels = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM
+        elif cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE == "multi_scale_pixel_decoder":  # for maskformer2
+            transformer_predictor_in_channels = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
+        else:
+            transformer_predictor_in_channels = input_shape[cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE].channels
+
+        return {
+            "input_shape": {
+                k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES
+            },
+            "ignore_value": cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
+            "num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES,
+            "pixel_decoder": build_pixel_decoder(cfg, input_shape),
+            "loss_weight": cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT,
+            "transformer_in_feature": cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE,
+            "transformer_predictor": build_transformer_decoder(
+                cfg,
+                transformer_predictor_in_channels,
+                mask_classification=True,
+            ),
+        }
+
+    def forward(self, features, mask=None):
+        return self.layers(features, mask)
+
+    def layers(self, features, mask=None):
+        mask_features, transformer_encoder_features, multi_scale_features = self.pixel_decoder.forward_features(features)
+        if self.transformer_in_feature == "multi_scale_pixel_decoder":
+            predictions = self.predictor(multi_scale_features, mask_features, mask)
+        else:
+            if self.transformer_in_feature == "transformer_encoder":
+                assert (
+                    transformer_encoder_features is not None
+                ), "Please use the TransformerEncoderPixelDecoder."
+                predictions = self.predictor(transformer_encoder_features, mask_features, mask)
+            elif self.transformer_in_feature == "pixel_embedding":
+                predictions = self.predictor(mask_features, mask_features, mask)
+            else:
+                predictions = self.predictor(features[self.transformer_in_feature], mask_features, mask)
+        return predictions
diff --git a/annotator/entityseg/mask2former/modeling/meta_arch/per_pixel_baseline.py b/annotator/entityseg/mask2former/modeling/meta_arch/per_pixel_baseline.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ce7573e0ff97e7fdeef0ea94928def6e263ab1d
--- /dev/null
+++ b/annotator/entityseg/mask2former/modeling/meta_arch/per_pixel_baseline.py
@@ -0,0 +1,243 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import logging
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+import fvcore.nn.weight_init as weight_init
+from torch import nn
+from torch.nn import functional as F
+
+from detectron2.config import configurable
+from detectron2.layers import Conv2d, ShapeSpec, get_norm
+from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
+
+from ..transformer_decoder.maskformer_transformer_decoder import StandardTransformerDecoder
+from ..pixel_decoder.fpn import build_pixel_decoder
+
+
+@SEM_SEG_HEADS_REGISTRY.register()
+class PerPixelBaselineHead(nn.Module):
+
+    _version = 2
+
+    def _load_from_state_dict(
+        self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
+    ):
+        version = local_metadata.get("version", None)
+        if version is None or version < 2:
+            logger = logging.getLogger(__name__)
+            # Do not warn if train from scratch
+            scratch = True
+            logger = logging.getLogger(__name__)
+            for k in list(state_dict.keys()):
+                newk = k
+                if "sem_seg_head" in k and not k.startswith(prefix + "predictor"):
+                    newk = k.replace(prefix, prefix + "pixel_decoder.")
+                    # logger.warning(f"{k} ==> {newk}")
+                if newk != k:
+                    state_dict[newk] = state_dict[k]
+                    del state_dict[k]
+                    scratch = False
+
+            if not scratch:
+                logger.warning(
+                    f"Weight format of {self.__class__.__name__} have changed! "
+                    "Please upgrade your models. Applying automatic conversion now ..."
+                )
+
+    @configurable
+    def __init__(
+        self,
+        input_shape: Dict[str, ShapeSpec],
+        *,
+        num_classes: int,
+        pixel_decoder: nn.Module,
+        loss_weight: float = 1.0,
+        ignore_value: int = -1,
+    ):
+        """
+        NOTE: this interface is experimental.
+        Args:
+            input_shape: shapes (channels and stride) of the input features
+            num_classes: number of classes to predict
+            pixel_decoder: the pixel decoder module
+            loss_weight: loss weight
+            ignore_value: category id to be ignored during training.
+        """
+        super().__init__()
+        input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
+        self.in_features = [k for k, v in input_shape]
+        feature_strides = [v.stride for k, v in input_shape]
+        feature_channels = [v.channels for k, v in input_shape]
+
+        self.ignore_value = ignore_value
+        self.common_stride = 4
+        self.loss_weight = loss_weight
+
+        self.pixel_decoder = pixel_decoder
+        self.predictor = Conv2d(
+            self.pixel_decoder.mask_dim, num_classes, kernel_size=1, stride=1, padding=0
+        )
+        weight_init.c2_msra_fill(self.predictor)
+
+    @classmethod
+    def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
+        return {
+            "input_shape": {
+                k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES
+            },
+            "ignore_value": cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
+            "num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES,
+            "pixel_decoder": build_pixel_decoder(cfg, input_shape),
+            "loss_weight": cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT,
+        }
+
+    def forward(self, features, targets=None):
+        """
+        Returns:
+            In training, returns (None, dict of losses)
+            In inference, returns (CxHxW logits, {})
+        """
+        x = self.layers(features)
+        if self.training:
+            return None, self.losses(x, targets)
+        else:
+            x = F.interpolate(
+                x, scale_factor=self.common_stride, mode="bilinear", align_corners=False
+            )
+            return x, {}
+
+    def layers(self, features):
+        x, _, _ = self.pixel_decoder.forward_features(features)
+        x = self.predictor(x)
+        return x
+
+    def losses(self, predictions, targets):
+        predictions = predictions.float()  # https://github.com/pytorch/pytorch/issues/48163
+        predictions = F.interpolate(
+            predictions, scale_factor=self.common_stride, mode="bilinear", align_corners=False
+        )
+        loss = F.cross_entropy(
+            predictions, targets, reduction="mean", ignore_index=self.ignore_value
+        )
+        losses = {"loss_sem_seg": loss * self.loss_weight}
+        return losses
+
+
+@SEM_SEG_HEADS_REGISTRY.register()
+class PerPixelBaselinePlusHead(PerPixelBaselineHead):
+    def _load_from_state_dict(
+        self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
+    ):
+        version = local_metadata.get("version", None)
+        if version is None or version < 2:
+            # Do not warn if train from scratch
+            scratch = True
+            logger = logging.getLogger(__name__)
+            for k in list(state_dict.keys()):
+                newk = k
+                if "sem_seg_head" in k and not k.startswith(prefix + "predictor"):
+                    newk = k.replace(prefix, prefix + "pixel_decoder.")
+                    logger.debug(f"{k} ==> {newk}")
+                if newk != k:
+                    state_dict[newk] = state_dict[k]
+                    del state_dict[k]
+                    scratch = False
+
+            if not scratch:
+                logger.warning(
+                    f"Weight format of {self.__class__.__name__} have changed! "
+                    "Please upgrade your models. Applying automatic conversion now ..."
+                )
+
+    @configurable
+    def __init__(
+        self,
+        input_shape: Dict[str, ShapeSpec],
+        *,
+        # extra parameters
+        transformer_predictor: nn.Module,
+        transformer_in_feature: str,
+        deep_supervision: bool,
+        # inherit parameters
+        num_classes: int,
+        pixel_decoder: nn.Module,
+        loss_weight: float = 1.0,
+        ignore_value: int = -1,
+    ):
+        """
+        NOTE: this interface is experimental.
+        Args:
+            input_shape: shapes (channels and stride) of the input features
+            transformer_predictor: the transformer decoder that makes prediction
+            transformer_in_feature: input feature name to the transformer_predictor
+            deep_supervision: whether or not to add supervision to the output of
+                every transformer decoder layer
+            num_classes: number of classes to predict
+            pixel_decoder: the pixel decoder module
+            loss_weight: loss weight
+            ignore_value: category id to be ignored during training.
+        """
+        super().__init__(
+            input_shape,
+            num_classes=num_classes,
+            pixel_decoder=pixel_decoder,
+            loss_weight=loss_weight,
+            ignore_value=ignore_value,
+        )
+
+        del self.predictor
+
+        self.predictor = transformer_predictor
+        self.transformer_in_feature = transformer_in_feature
+        self.deep_supervision = deep_supervision
+
+    @classmethod
+    def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
+        ret = super().from_config(cfg, input_shape)
+        ret["transformer_in_feature"] = cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE
+        if cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE == "transformer_encoder":
+            in_channels = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
+        else:
+            in_channels = input_shape[ret["transformer_in_feature"]].channels
+        ret["transformer_predictor"] = StandardTransformerDecoder(
+            cfg, in_channels, mask_classification=False
+        )
+        ret["deep_supervision"] = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION
+        return ret
+
+    def forward(self, features, targets=None):
+        """
+        Returns:
+            In training, returns (None, dict of losses)
+            In inference, returns (CxHxW logits, {})
+        """
+        x, aux_outputs = self.layers(features)
+        if self.training:
+            if self.deep_supervision:
+                losses = self.losses(x, targets)
+                for i, aux_output in enumerate(aux_outputs):
+                    losses["loss_sem_seg" + f"_{i}"] = self.losses(
+                        aux_output["pred_masks"], targets
+                    )["loss_sem_seg"]
+                return None, losses
+            else:
+                return None, self.losses(x, targets)
+        else:
+            x = F.interpolate(
+                x, scale_factor=self.common_stride, mode="bilinear", align_corners=False
+            )
+            return x, {}
+
+    def layers(self, features):
+        mask_features, transformer_encoder_features, _ = self.pixel_decoder.forward_features(features)
+        if self.transformer_in_feature == "transformer_encoder":
+            assert (
+                transformer_encoder_features is not None
+            ), "Please use the TransformerEncoderPixelDecoder."
+            predictions = self.predictor(transformer_encoder_features, mask_features)
+        else:
+            predictions = self.predictor(features[self.transformer_in_feature], mask_features)
+        if self.deep_supervision:
+            return predictions["pred_masks"], predictions["aux_outputs"]
+        else:
+            return predictions["pred_masks"], None
diff --git a/annotator/entityseg/mask2former/modeling/pixel_decoder/__init__.py b/annotator/entityseg/mask2former/modeling/pixel_decoder/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9020c2df23e2af280b7bb168b996ae9eaf312eb8
--- /dev/null
+++ b/annotator/entityseg/mask2former/modeling/pixel_decoder/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
diff --git a/annotator/entityseg/mask2former/modeling/pixel_decoder/fpn.py b/annotator/entityseg/mask2former/modeling/pixel_decoder/fpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..7df65a178ce4a105d5c803ff5aa18aa56c44d374
--- /dev/null
+++ b/annotator/entityseg/mask2former/modeling/pixel_decoder/fpn.py
@@ -0,0 +1,312 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import logging
+import numpy as np
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+import fvcore.nn.weight_init as weight_init
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_
+from torch.cuda.amp import autocast
+
+from detectron2.config import configurable
+from detectron2.layers import Conv2d, DeformConv, ShapeSpec, get_norm
+from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
+
+from ..transformer_decoder.position_encoding import PositionEmbeddingSine
+from ..transformer_decoder.transformer import TransformerEncoder, TransformerEncoderLayer, _get_clones, _get_activation_fn
+
+
+def build_pixel_decoder(cfg, input_shape):
+    """
+    Build a pixel decoder from `cfg.MODEL.MASK_FORMER.PIXEL_DECODER_NAME`.
+    """
+    name = cfg.MODEL.SEM_SEG_HEAD.PIXEL_DECODER_NAME
+    model = SEM_SEG_HEADS_REGISTRY.get(name)(cfg, input_shape)
+    forward_features = getattr(model, "forward_features", None)
+    if not callable(forward_features):
+        raise ValueError(
+            "Only SEM_SEG_HEADS with forward_features method can be used as pixel decoder. "
+            f"Please implement forward_features for {name} to only return mask features."
+        )
+    return model
+
+
+# This is a modified FPN decoder.
+@SEM_SEG_HEADS_REGISTRY.register()
+class BasePixelDecoder(nn.Module):
+    @configurable
+    def __init__(
+        self,
+        input_shape: Dict[str, ShapeSpec],
+        *,
+        conv_dim: int,
+        mask_dim: int,
+        norm: Optional[Union[str, Callable]] = None,
+    ):
+        """
+        NOTE: this interface is experimental.
+        Args:
+            input_shape: shapes (channels and stride) of the input features
+            conv_dims: number of output channels for the intermediate conv layers.
+            mask_dim: number of output channels for the final conv layer.
+            norm (str or callable): normalization for all conv layers
+        """
+        super().__init__()
+
+        input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
+        self.in_features = [k for k, v in input_shape]  # starting from "res2" to "res5"
+        feature_channels = [v.channels for k, v in input_shape]
+
+        lateral_convs = []
+        output_convs = []
+
+        use_bias = norm == ""
+        for idx, in_channels in enumerate(feature_channels):
+            if idx == len(self.in_features) - 1:
+                output_norm = get_norm(norm, conv_dim)
+                output_conv = Conv2d(
+                    in_channels,
+                    conv_dim,
+                    kernel_size=3,
+                    stride=1,
+                    padding=1,
+                    bias=use_bias,
+                    norm=output_norm,
+                    activation=F.relu,
+                )
+                weight_init.c2_xavier_fill(output_conv)
+                self.add_module("layer_{}".format(idx + 1), output_conv)
+
+                lateral_convs.append(None)
+                output_convs.append(output_conv)
+            else:
+                lateral_norm = get_norm(norm, conv_dim)
+                output_norm = get_norm(norm, conv_dim)
+
+                lateral_conv = Conv2d(
+                    in_channels, conv_dim, kernel_size=1, bias=use_bias, norm=lateral_norm
+                )
+                output_conv = Conv2d(
+                    conv_dim,
+                    conv_dim,
+                    kernel_size=3,
+                    stride=1,
+                    padding=1,
+                    bias=use_bias,
+                    norm=output_norm,
+                    activation=F.relu,
+                )
+                weight_init.c2_xavier_fill(lateral_conv)
+                weight_init.c2_xavier_fill(output_conv)
+                self.add_module("adapter_{}".format(idx + 1), lateral_conv)
+                self.add_module("layer_{}".format(idx + 1), output_conv)
+
+                lateral_convs.append(lateral_conv)
+                output_convs.append(output_conv)
+        # Place convs into top-down order (from low to high resolution)
+        # to make the top-down computation in forward clearer.
+        self.lateral_convs = lateral_convs[::-1]
+        self.output_convs = output_convs[::-1]
+
+        self.mask_dim = mask_dim
+        self.mask_features = Conv2d(
+            conv_dim,
+            mask_dim,
+            kernel_size=3,
+            stride=1,
+            padding=1,
+        )
+        weight_init.c2_xavier_fill(self.mask_features)
+
+        self.maskformer_num_feature_levels = 3  # always use 3 scales
+
+    @classmethod
+    def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
+        ret = {}
+        ret["input_shape"] = {
+            k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES
+        }
+        ret["conv_dim"] = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
+        ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM
+        ret["norm"] = cfg.MODEL.SEM_SEG_HEAD.NORM
+        return ret
+
+    def forward_features(self, features):
+        multi_scale_features = []
+        num_cur_levels = 0
+        # Reverse feature maps into top-down order (from low to high resolution)
+        for idx, f in enumerate(self.in_features[::-1]):
+            x = features[f]
+            lateral_conv = self.lateral_convs[idx]
+            output_conv = self.output_convs[idx]
+            if lateral_conv is None:
+                y = output_conv(x)
+            else:
+                cur_fpn = lateral_conv(x)
+                # Following FPN implementation, we use nearest upsampling here
+                y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest")
+                y = output_conv(y)
+            if num_cur_levels < self.maskformer_num_feature_levels:
+                multi_scale_features.append(y)
+                num_cur_levels += 1
+        return self.mask_features(y), None, multi_scale_features
+
+    def forward(self, features, targets=None):
+        logger = logging.getLogger(__name__)
+        logger.warning("Calling forward() may cause unpredicted behavior of PixelDecoder module.")
+        return self.forward_features(features)
+
+
+class TransformerEncoderOnly(nn.Module):
+    def __init__(
+        self,
+        d_model=512,
+        nhead=8,
+        num_encoder_layers=6,
+        dim_feedforward=2048,
+        dropout=0.1,
+        activation="relu",
+        normalize_before=False,
+    ):
+        super().__init__()
+
+        encoder_layer = TransformerEncoderLayer(
+            d_model, nhead, dim_feedforward, dropout, activation, normalize_before
+        )
+        encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
+        self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
+
+        self._reset_parameters()
+
+        self.d_model = d_model
+        self.nhead = nhead
+
+    def _reset_parameters(self):
+        for p in self.parameters():
+            if p.dim() > 1:
+                nn.init.xavier_uniform_(p)
+
+    def forward(self, src, mask, pos_embed):
+        # flatten NxCxHxW to HWxNxC
+        bs, c, h, w = src.shape
+        src = src.flatten(2).permute(2, 0, 1)
+        pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
+        if mask is not None:
+            mask = mask.flatten(1)
+
+        memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
+        return memory.permute(1, 2, 0).view(bs, c, h, w)
+
+
+# This is a modified FPN decoder with extra Transformer encoder that processes the lowest-resolution feature map.
+@SEM_SEG_HEADS_REGISTRY.register()
+class TransformerEncoderPixelDecoder(BasePixelDecoder):
+    @configurable
+    def __init__(
+        self,
+        input_shape: Dict[str, ShapeSpec],
+        *,
+        transformer_dropout: float,
+        transformer_nheads: int,
+        transformer_dim_feedforward: int,
+        transformer_enc_layers: int,
+        transformer_pre_norm: bool,
+        conv_dim: int,
+        mask_dim: int,
+        norm: Optional[Union[str, Callable]] = None,
+    ):
+        """
+        NOTE: this interface is experimental.
+        Args:
+            input_shape: shapes (channels and stride) of the input features
+            transformer_dropout: dropout probability in transformer
+            transformer_nheads: number of heads in transformer
+            transformer_dim_feedforward: dimension of feedforward network
+            transformer_enc_layers: number of transformer encoder layers
+            transformer_pre_norm: whether to use pre-layernorm or not
+            conv_dims: number of output channels for the intermediate conv layers.
+            mask_dim: number of output channels for the final conv layer.
+            norm (str or callable): normalization for all conv layers
+        """
+        super().__init__(input_shape, conv_dim=conv_dim, mask_dim=mask_dim, norm=norm)
+
+        input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
+        self.in_features = [k for k, v in input_shape]  # starting from "res2" to "res5"
+        feature_strides = [v.stride for k, v in input_shape]
+        feature_channels = [v.channels for k, v in input_shape]
+
+        in_channels = feature_channels[len(self.in_features) - 1]
+        self.input_proj = Conv2d(in_channels, conv_dim, kernel_size=1)
+        weight_init.c2_xavier_fill(self.input_proj)
+        self.transformer = TransformerEncoderOnly(
+            d_model=conv_dim,
+            dropout=transformer_dropout,
+            nhead=transformer_nheads,
+            dim_feedforward=transformer_dim_feedforward,
+            num_encoder_layers=transformer_enc_layers,
+            normalize_before=transformer_pre_norm,
+        )
+        N_steps = conv_dim // 2
+        self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
+
+        # update layer
+        use_bias = norm == ""
+        output_norm = get_norm(norm, conv_dim)
+        output_conv = Conv2d(
+            conv_dim,
+            conv_dim,
+            kernel_size=3,
+            stride=1,
+            padding=1,
+            bias=use_bias,
+            norm=output_norm,
+            activation=F.relu,
+        )
+        weight_init.c2_xavier_fill(output_conv)
+        delattr(self, "layer_{}".format(len(self.in_features)))
+        self.add_module("layer_{}".format(len(self.in_features)), output_conv)
+        self.output_convs[0] = output_conv
+
+    @classmethod
+    def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
+        ret = super().from_config(cfg, input_shape)
+        ret["transformer_dropout"] = cfg.MODEL.MASK_FORMER.DROPOUT
+        ret["transformer_nheads"] = cfg.MODEL.MASK_FORMER.NHEADS
+        ret["transformer_dim_feedforward"] = cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD
+        ret[
+            "transformer_enc_layers"
+        ] = cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS  # a separate config
+        ret["transformer_pre_norm"] = cfg.MODEL.MASK_FORMER.PRE_NORM
+        return ret
+
+    def forward_features(self, features):
+        multi_scale_features = []
+        num_cur_levels = 0
+        # Reverse feature maps into top-down order (from low to high resolution)
+        for idx, f in enumerate(self.in_features[::-1]):
+            x = features[f]
+            lateral_conv = self.lateral_convs[idx]
+            output_conv = self.output_convs[idx]
+            if lateral_conv is None:
+                transformer = self.input_proj(x)
+                pos = self.pe_layer(x)
+                transformer = self.transformer(transformer, None, pos)
+                y = output_conv(transformer)
+                # save intermediate feature as input to Transformer decoder
+                transformer_encoder_features = transformer
+            else:
+                cur_fpn = lateral_conv(x)
+                # Following FPN implementation, we use nearest upsampling here
+                y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest")
+                y = output_conv(y)
+            if num_cur_levels < self.maskformer_num_feature_levels:
+                multi_scale_features.append(y)
+                num_cur_levels += 1
+        return self.mask_features(y), transformer_encoder_features, multi_scale_features
+
+    def forward(self, features, targets=None):
+        logger = logging.getLogger(__name__)
+        logger.warning("Calling forward() may cause unpredicted behavior of PixelDecoder module.")
+        return self.forward_features(features)
diff --git a/annotator/entityseg/mask2former/modeling/pixel_decoder/msdeformattn.py b/annotator/entityseg/mask2former/modeling/pixel_decoder/msdeformattn.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ff1a81a3ed0c05464dad2143830bacac5951dfe
--- /dev/null
+++ b/annotator/entityseg/mask2former/modeling/pixel_decoder/msdeformattn.py
@@ -0,0 +1,358 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import logging
+import numpy as np
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+import fvcore.nn.weight_init as weight_init
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_
+from torch.cuda.amp import autocast
+
+from detectron2.config import configurable
+from detectron2.layers import Conv2d, ShapeSpec, get_norm
+from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
+
+from ..transformer_decoder.position_encoding import PositionEmbeddingSine
+from ..transformer_decoder.transformer import _get_clones, _get_activation_fn
+from .ops.modules import MSDeformAttn
+
+
+# MSDeformAttn Transformer encoder in deformable detr
+class MSDeformAttnTransformerEncoderOnly(nn.Module):
+    def __init__(self, d_model=256, nhead=8,
+                 num_encoder_layers=6, dim_feedforward=1024, dropout=0.1,
+                 activation="relu",
+                 num_feature_levels=4, enc_n_points=4,
+        ):
+        super().__init__()
+
+        self.d_model = d_model
+        self.nhead = nhead
+
+        encoder_layer = MSDeformAttnTransformerEncoderLayer(d_model, dim_feedforward,
+                                                            dropout, activation,
+                                                            num_feature_levels, nhead, enc_n_points)
+        self.encoder = MSDeformAttnTransformerEncoder(encoder_layer, num_encoder_layers)
+
+        self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
+
+        self._reset_parameters()
+
+    def _reset_parameters(self):
+        for p in self.parameters():
+            if p.dim() > 1:
+                nn.init.xavier_uniform_(p)
+        for m in self.modules():
+            if isinstance(m, MSDeformAttn):
+                m._reset_parameters()
+        normal_(self.level_embed)
+
+    def get_valid_ratio(self, mask):
+        _, H, W = mask.shape
+        valid_H = torch.sum(~mask[:, :, 0], 1)
+        valid_W = torch.sum(~mask[:, 0, :], 1)
+        valid_ratio_h = valid_H.float() / H
+        valid_ratio_w = valid_W.float() / W
+        valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
+        return valid_ratio
+
+    def forward(self, srcs, pos_embeds):
+        masks = [torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) for x in srcs]
+        # prepare input for encoder
+        src_flatten = []
+        mask_flatten = []
+        lvl_pos_embed_flatten = []
+        spatial_shapes = []
+        for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
+            bs, c, h, w = src.shape
+            spatial_shape = (h, w)
+            spatial_shapes.append(spatial_shape)
+            src = src.flatten(2).transpose(1, 2)
+            mask = mask.flatten(1)
+            pos_embed = pos_embed.flatten(2).transpose(1, 2)
+            lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
+            lvl_pos_embed_flatten.append(lvl_pos_embed)
+            src_flatten.append(src)
+            mask_flatten.append(mask)
+        src_flatten = torch.cat(src_flatten, 1)
+        mask_flatten = torch.cat(mask_flatten, 1)
+        lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
+        spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)
+        level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
+        valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
+
+        # encoder
+        memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten)
+
+        return memory, spatial_shapes, level_start_index
+
+
+class MSDeformAttnTransformerEncoderLayer(nn.Module):
+    def __init__(self,
+                 d_model=256, d_ffn=1024,
+                 dropout=0.1, activation="relu",
+                 n_levels=4, n_heads=8, n_points=4):
+        super().__init__()
+
+        # self attention
+        self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
+        self.dropout1 = nn.Dropout(dropout)
+        self.norm1 = nn.LayerNorm(d_model)
+
+        # ffn
+        self.linear1 = nn.Linear(d_model, d_ffn)
+        self.activation = _get_activation_fn(activation)
+        self.dropout2 = nn.Dropout(dropout)
+        self.linear2 = nn.Linear(d_ffn, d_model)
+        self.dropout3 = nn.Dropout(dropout)
+        self.norm2 = nn.LayerNorm(d_model)
+
+    @staticmethod
+    def with_pos_embed(tensor, pos):
+        return tensor if pos is None else tensor + pos
+
+    def forward_ffn(self, src):
+        src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
+        src = src + self.dropout3(src2)
+        src = self.norm2(src)
+        return src
+
+    def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, padding_mask=None):
+        # self attention
+        src2 = self.self_attn(self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index, padding_mask)
+        src = src + self.dropout1(src2)
+        src = self.norm1(src)
+
+        # ffn
+        src = self.forward_ffn(src)
+
+        return src
+
+
+class MSDeformAttnTransformerEncoder(nn.Module):
+    def __init__(self, encoder_layer, num_layers):
+        super().__init__()
+        self.layers = _get_clones(encoder_layer, num_layers)
+        self.num_layers = num_layers
+
+    @staticmethod
+    def get_reference_points(spatial_shapes, valid_ratios, device):
+        reference_points_list = []
+        for lvl, (H_, W_) in enumerate(spatial_shapes):
+
+            ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
+                                          torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device))
+            ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
+            ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
+            ref = torch.stack((ref_x, ref_y), -1)
+            reference_points_list.append(ref)
+        reference_points = torch.cat(reference_points_list, 1)
+        reference_points = reference_points[:, :, None] * valid_ratios[:, None]
+        return reference_points
+
+    def forward(self, src, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None):
+        output = src
+        reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device)
+        for _, layer in enumerate(self.layers):
+            output = layer(output, pos, reference_points, spatial_shapes, level_start_index, padding_mask)
+
+        return output
+
+
+@SEM_SEG_HEADS_REGISTRY.register()
+class MSDeformAttnPixelDecoder(nn.Module):
+    @configurable
+    def __init__(
+        self,
+        input_shape: Dict[str, ShapeSpec],
+        *,
+        transformer_dropout: float,
+        transformer_nheads: int,
+        transformer_dim_feedforward: int,
+        transformer_enc_layers: int,
+        conv_dim: int,
+        mask_dim: int,
+        norm: Optional[Union[str, Callable]] = None,
+        # deformable transformer encoder args
+        transformer_in_features: List[str],
+        common_stride: int,
+    ):
+        """
+        NOTE: this interface is experimental.
+        Args:
+            input_shape: shapes (channels and stride) of the input features
+            transformer_dropout: dropout probability in transformer
+            transformer_nheads: number of heads in transformer
+            transformer_dim_feedforward: dimension of feedforward network
+            transformer_enc_layers: number of transformer encoder layers
+            conv_dims: number of output channels for the intermediate conv layers.
+            mask_dim: number of output channels for the final conv layer.
+            norm (str or callable): normalization for all conv layers
+        """
+        super().__init__()
+        transformer_input_shape = {
+            k: v for k, v in input_shape.items() if k in transformer_in_features
+        }
+
+        # this is the input shape of pixel decoder
+        input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
+        self.in_features = [k for k, v in input_shape]  # starting from "res2" to "res5"
+        self.feature_strides = [v.stride for k, v in input_shape]
+        self.feature_channels = [v.channels for k, v in input_shape]
+        
+        # this is the input shape of transformer encoder (could use less features than pixel decoder
+        transformer_input_shape = sorted(transformer_input_shape.items(), key=lambda x: x[1].stride)
+        self.transformer_in_features = [k for k, v in transformer_input_shape]  # starting from "res2" to "res5"
+        transformer_in_channels = [v.channels for k, v in transformer_input_shape]
+        self.transformer_feature_strides = [v.stride for k, v in transformer_input_shape]  # to decide extra FPN layers
+
+        self.transformer_num_feature_levels = len(self.transformer_in_features)
+        if self.transformer_num_feature_levels > 1:
+            input_proj_list = []
+            # from low resolution to high resolution (res5 -> res2)
+            for in_channels in transformer_in_channels[::-1]:
+                input_proj_list.append(nn.Sequential(
+                    nn.Conv2d(in_channels, conv_dim, kernel_size=1),
+                    nn.GroupNorm(32, conv_dim),
+                ))
+            self.input_proj = nn.ModuleList(input_proj_list)
+        else:
+            self.input_proj = nn.ModuleList([
+                nn.Sequential(
+                    nn.Conv2d(transformer_in_channels[-1], conv_dim, kernel_size=1),
+                    nn.GroupNorm(32, conv_dim),
+                )])
+
+        for proj in self.input_proj:
+            nn.init.xavier_uniform_(proj[0].weight, gain=1)
+            nn.init.constant_(proj[0].bias, 0)
+
+        self.transformer = MSDeformAttnTransformerEncoderOnly(
+            d_model=conv_dim,
+            dropout=transformer_dropout,
+            nhead=transformer_nheads,
+            dim_feedforward=transformer_dim_feedforward,
+            num_encoder_layers=transformer_enc_layers,
+            num_feature_levels=self.transformer_num_feature_levels,
+        )
+        N_steps = conv_dim // 2
+        self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
+
+        self.mask_dim = mask_dim
+        # use 1x1 conv instead
+        self.mask_features = Conv2d(
+            conv_dim,
+            mask_dim,
+            kernel_size=1,
+            stride=1,
+            padding=0,
+        )
+        weight_init.c2_xavier_fill(self.mask_features)
+        
+        self.maskformer_num_feature_levels = 3  # always use 3 scales
+        self.common_stride = common_stride
+
+        # extra fpn levels
+        stride = min(self.transformer_feature_strides)
+        self.num_fpn_levels = int(np.log2(stride) - np.log2(self.common_stride))
+
+        lateral_convs = []
+        output_convs = []
+
+        use_bias = norm == ""
+        for idx, in_channels in enumerate(self.feature_channels[:self.num_fpn_levels]):
+            lateral_norm = get_norm(norm, conv_dim)
+            output_norm = get_norm(norm, conv_dim)
+
+            lateral_conv = Conv2d(
+                in_channels, conv_dim, kernel_size=1, bias=use_bias, norm=lateral_norm
+            )
+            output_conv = Conv2d(
+                conv_dim,
+                conv_dim,
+                kernel_size=3,
+                stride=1,
+                padding=1,
+                bias=use_bias,
+                norm=output_norm,
+                activation=F.relu,
+            )
+            weight_init.c2_xavier_fill(lateral_conv)
+            weight_init.c2_xavier_fill(output_conv)
+            self.add_module("adapter_{}".format(idx + 1), lateral_conv)
+            self.add_module("layer_{}".format(idx + 1), output_conv)
+
+            lateral_convs.append(lateral_conv)
+            output_convs.append(output_conv)
+        # Place convs into top-down order (from low to high resolution)
+        # to make the top-down computation in forward clearer.
+        self.lateral_convs = lateral_convs[::-1]
+        self.output_convs = output_convs[::-1]
+
+    @classmethod
+    def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
+        ret = {}
+        ret["input_shape"] = {
+            k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES
+        }
+        ret["conv_dim"] = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
+        ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM
+        ret["norm"] = cfg.MODEL.SEM_SEG_HEAD.NORM
+        ret["transformer_dropout"] = cfg.MODEL.MASK_FORMER.DROPOUT
+        ret["transformer_nheads"] = cfg.MODEL.MASK_FORMER.NHEADS
+        # ret["transformer_dim_feedforward"] = cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD
+        ret["transformer_dim_feedforward"] = 1024  # use 1024 for deformable transformer encoder
+        ret[
+            "transformer_enc_layers"
+        ] = cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS  # a separate config
+        ret["transformer_in_features"] = cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES
+        ret["common_stride"] = cfg.MODEL.SEM_SEG_HEAD.COMMON_STRIDE
+        return ret
+
+    @autocast(enabled=False)
+    def forward_features(self, features):
+        srcs = []
+        pos = []
+        # Reverse feature maps into top-down order (from low to high resolution)
+        for idx, f in enumerate(self.transformer_in_features[::-1]):
+            x = features[f].float()  # deformable detr does not support half precision
+            srcs.append(self.input_proj[idx](x))
+            pos.append(self.pe_layer(x))
+
+        y, spatial_shapes, level_start_index = self.transformer(srcs, pos)
+        bs = y.shape[0]
+
+        split_size_or_sections = [None] * self.transformer_num_feature_levels
+        for i in range(self.transformer_num_feature_levels):
+            if i < self.transformer_num_feature_levels - 1:
+                split_size_or_sections[i] = level_start_index[i + 1] - level_start_index[i]
+            else:
+                split_size_or_sections[i] = y.shape[1] - level_start_index[i]
+        y = torch.split(y, split_size_or_sections, dim=1)
+
+        out = []
+        multi_scale_features = []
+        num_cur_levels = 0
+        for i, z in enumerate(y):
+            out.append(z.transpose(1, 2).view(bs, -1, spatial_shapes[i][0], spatial_shapes[i][1]))
+
+        # append `out` with extra FPN levels
+        # Reverse feature maps into top-down order (from low to high resolution)
+        for idx, f in enumerate(self.in_features[:self.num_fpn_levels][::-1]):
+            x = features[f].float()
+            lateral_conv = self.lateral_convs[idx]
+            output_conv = self.output_convs[idx]
+            cur_fpn = lateral_conv(x)
+            # Following FPN implementation, we use nearest upsampling here
+            y = cur_fpn + F.interpolate(out[-1], size=cur_fpn.shape[-2:], mode="bilinear", align_corners=False)
+            y = output_conv(y)
+            out.append(y)
+
+        for o in out:
+            if num_cur_levels < self.maskformer_num_feature_levels:
+                multi_scale_features.append(o)
+                num_cur_levels += 1
+
+        return self.mask_features(out[-1]), out[0], multi_scale_features
diff --git a/annotator/entityseg/mask2former/modeling/pixel_decoder/ops/functions/__init__.py b/annotator/entityseg/mask2former/modeling/pixel_decoder/ops/functions/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b06b5ac538b63bdb9a6c82e4635b95bb5491d5b
--- /dev/null
+++ b/annotator/entityseg/mask2former/modeling/pixel_decoder/ops/functions/__init__.py
@@ -0,0 +1,13 @@
+# ------------------------------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------------------------------
+# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+# ------------------------------------------------------------------------------------------------
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
+
+from .ms_deform_attn_func import MSDeformAttnFunction
+
diff --git a/annotator/entityseg/mask2former/modeling/pixel_decoder/ops/functions/ms_deform_attn_func.py b/annotator/entityseg/mask2former/modeling/pixel_decoder/ops/functions/ms_deform_attn_func.py
new file mode 100644
index 0000000000000000000000000000000000000000..94a36ab85b7c5f9ecee342db91a5d5731740740f
--- /dev/null
+++ b/annotator/entityseg/mask2former/modeling/pixel_decoder/ops/functions/ms_deform_attn_func.py
@@ -0,0 +1,72 @@
+# ------------------------------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------------------------------
+# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+# ------------------------------------------------------------------------------------------------
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
+
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
+
+import torch
+import torch.nn.functional as F
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+
+try:
+    import MultiScaleDeformableAttention as MSDA
+except ModuleNotFoundError as e:
+    info_string = (
+        "\n\nPlease compile MultiScaleDeformableAttention CUDA op with the following commands:\n"
+        "\t`cd mask2former/modeling/pixel_decoder/ops`\n"
+        "\t`sh make.sh`\n"
+    )
+    raise ModuleNotFoundError(info_string)
+
+
+class MSDeformAttnFunction(Function):
+    @staticmethod
+    def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step):
+        ctx.im2col_step = im2col_step
+        output = MSDA.ms_deform_attn_forward(
+            value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step)
+        ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights)
+        return output
+
+    @staticmethod
+    @once_differentiable
+    def backward(ctx, grad_output):
+        value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors
+        grad_value, grad_sampling_loc, grad_attn_weight = \
+            MSDA.ms_deform_attn_backward(
+                value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step)
+
+        return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
+
+
+def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights):
+    # for debug and test only,
+    # need to use cuda version instead
+    N_, S_, M_, D_ = value.shape
+    _, Lq_, M_, L_, P_, _ = sampling_locations.shape
+    value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
+    sampling_grids = 2 * sampling_locations - 1
+    sampling_value_list = []
+    for lid_, (H_, W_) in enumerate(value_spatial_shapes):
+        # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
+        value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_)
+        # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
+        sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
+        # N_*M_, D_, Lq_, P_
+        sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_,
+                                          mode='bilinear', padding_mode='zeros', align_corners=False)
+        sampling_value_list.append(sampling_value_l_)
+    # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
+    attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_)
+    output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_)
+    return output.transpose(1, 2).contiguous()
diff --git a/annotator/entityseg/mask2former/modeling/pixel_decoder/ops/make.sh b/annotator/entityseg/mask2former/modeling/pixel_decoder/ops/make.sh
new file mode 100644
index 0000000000000000000000000000000000000000..7b38cdbf48f3571d986a33e7563b517952b51bb2
--- /dev/null
+++ b/annotator/entityseg/mask2former/modeling/pixel_decoder/ops/make.sh
@@ -0,0 +1,13 @@
+#!/usr/bin/env bash
+# ------------------------------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------------------------------
+# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+# ------------------------------------------------------------------------------------------------
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
+
+python setup.py build install
diff --git a/annotator/entityseg/mask2former/modeling/pixel_decoder/ops/modules/__init__.py b/annotator/entityseg/mask2former/modeling/pixel_decoder/ops/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6fdbf03359958f3d67ab00f879bf6b61a6c8f06a
--- /dev/null
+++ b/annotator/entityseg/mask2former/modeling/pixel_decoder/ops/modules/__init__.py
@@ -0,0 +1,12 @@
+# ------------------------------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------------------------------
+# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+# ------------------------------------------------------------------------------------------------
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
+
+from .ms_deform_attn import MSDeformAttn
diff --git a/annotator/entityseg/mask2former/modeling/pixel_decoder/ops/modules/ms_deform_attn.py b/annotator/entityseg/mask2former/modeling/pixel_decoder/ops/modules/ms_deform_attn.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7b4c42ea504a0859ccadd72646919c941e72f73
--- /dev/null
+++ b/annotator/entityseg/mask2former/modeling/pixel_decoder/ops/modules/ms_deform_attn.py
@@ -0,0 +1,125 @@
+# ------------------------------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------------------------------
+# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+# ------------------------------------------------------------------------------------------------
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
+
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
+
+import warnings
+import math
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+from torch.nn.init import xavier_uniform_, constant_
+
+from ..functions import MSDeformAttnFunction
+from ..functions.ms_deform_attn_func import ms_deform_attn_core_pytorch
+
+
+def _is_power_of_2(n):
+    if (not isinstance(n, int)) or (n < 0):
+        raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
+    return (n & (n-1) == 0) and n != 0
+
+
+class MSDeformAttn(nn.Module):
+    def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
+        """
+        Multi-Scale Deformable Attention Module
+        :param d_model      hidden dimension
+        :param n_levels     number of feature levels
+        :param n_heads      number of attention heads
+        :param n_points     number of sampling points per attention head per feature level
+        """
+        super().__init__()
+        if d_model % n_heads != 0:
+            raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))
+        _d_per_head = d_model // n_heads
+        # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
+        if not _is_power_of_2(_d_per_head):
+            warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
+                          "which is more efficient in our CUDA implementation.")
+
+        self.im2col_step = 128
+
+        self.d_model = d_model
+        self.n_levels = n_levels
+        self.n_heads = n_heads
+        self.n_points = n_points
+
+        self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
+        self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
+        self.value_proj = nn.Linear(d_model, d_model)
+        self.output_proj = nn.Linear(d_model, d_model)
+
+        self._reset_parameters()
+
+    def _reset_parameters(self):
+        constant_(self.sampling_offsets.weight.data, 0.)
+        thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
+        grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
+        grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1)
+        for i in range(self.n_points):
+            grid_init[:, :, i, :] *= i + 1
+        with torch.no_grad():
+            self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
+        constant_(self.attention_weights.weight.data, 0.)
+        constant_(self.attention_weights.bias.data, 0.)
+        xavier_uniform_(self.value_proj.weight.data)
+        constant_(self.value_proj.bias.data, 0.)
+        xavier_uniform_(self.output_proj.weight.data)
+        constant_(self.output_proj.bias.data, 0.)
+
+    def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None):
+        """
+        :param query                       (N, Length_{query}, C)
+        :param reference_points            (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
+                                        or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
+        :param input_flatten               (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
+        :param input_spatial_shapes        (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
+        :param input_level_start_index     (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
+        :param input_padding_mask          (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements
+
+        :return output                     (N, Length_{query}, C)
+        """
+        N, Len_q, _ = query.shape
+        N, Len_in, _ = input_flatten.shape
+        assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in
+
+        value = self.value_proj(input_flatten)
+        if input_padding_mask is not None:
+            value = value.masked_fill(input_padding_mask[..., None], float(0))
+        value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
+        sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2)
+        attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points)
+        attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points)
+        # N, Len_q, n_heads, n_levels, n_points, 2
+        if reference_points.shape[-1] == 2:
+            offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
+            sampling_locations = reference_points[:, :, None, :, None, :] \
+                                 + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
+        elif reference_points.shape[-1] == 4:
+            sampling_locations = reference_points[:, :, None, :, None, :2] \
+                                 + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
+        else:
+            raise ValueError(
+                'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1]))
+        try:
+            output = MSDeformAttnFunction.apply(
+                value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step)
+        except:
+            # CPU
+            output = ms_deform_attn_core_pytorch(value, input_spatial_shapes, sampling_locations, attention_weights)
+        # # For FLOPs calculation only
+        # output = ms_deform_attn_core_pytorch(value, input_spatial_shapes, sampling_locations, attention_weights)
+        output = self.output_proj(output)
+        return output
diff --git a/annotator/entityseg/mask2former/modeling/pixel_decoder/ops/setup.py b/annotator/entityseg/mask2former/modeling/pixel_decoder/ops/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b57ad313ac8f9b6586892142da8ba943e516cec
--- /dev/null
+++ b/annotator/entityseg/mask2former/modeling/pixel_decoder/ops/setup.py
@@ -0,0 +1,78 @@
+# ------------------------------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------------------------------
+# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+# ------------------------------------------------------------------------------------------------
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
+
+import os
+import glob
+
+import torch
+
+from torch.utils.cpp_extension import CUDA_HOME
+from torch.utils.cpp_extension import CppExtension
+from torch.utils.cpp_extension import CUDAExtension
+
+from setuptools import find_packages
+from setuptools import setup
+
+requirements = ["torch", "torchvision"]
+
+def get_extensions():
+    this_dir = os.path.dirname(os.path.abspath(__file__))
+    extensions_dir = os.path.join(this_dir, "src")
+
+    main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
+    source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp"))
+    source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu"))
+
+    sources = main_file + source_cpu
+    extension = CppExtension
+    extra_compile_args = {"cxx": []}
+    define_macros = []
+
+    # Force cuda since torch ask for a device, not if cuda is in fact available.
+    if (os.environ.get('FORCE_CUDA') or torch.cuda.is_available()) and CUDA_HOME is not None:
+        extension = CUDAExtension
+        sources += source_cuda
+        define_macros += [("WITH_CUDA", None)]
+        extra_compile_args["nvcc"] = [
+            "-DCUDA_HAS_FP16=1",
+            "-D__CUDA_NO_HALF_OPERATORS__",
+            "-D__CUDA_NO_HALF_CONVERSIONS__",
+            "-D__CUDA_NO_HALF2_OPERATORS__",
+        ]
+    else:
+        if CUDA_HOME is None:
+            raise NotImplementedError('CUDA_HOME is None. Please set environment variable CUDA_HOME.')
+        else:
+            raise NotImplementedError('No CUDA runtime is found. Please set FORCE_CUDA=1 or test it by running torch.cuda.is_available().')
+
+    sources = [os.path.join(extensions_dir, s) for s in sources]
+    include_dirs = [extensions_dir]
+    ext_modules = [
+        extension(
+            "MultiScaleDeformableAttention",
+            sources,
+            include_dirs=include_dirs,
+            define_macros=define_macros,
+            extra_compile_args=extra_compile_args,
+        )
+    ]
+    return ext_modules
+
+setup(
+    name="MultiScaleDeformableAttention",
+    version="1.0",
+    author="Weijie Su",
+    url="https://github.com/fundamentalvision/Deformable-DETR",
+    description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention",
+    packages=find_packages(exclude=("configs", "tests",)),
+    ext_modules=get_extensions(),
+    cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
+)
diff --git a/annotator/entityseg/mask2former/modeling/pixel_decoder/ops/src/cpu/ms_deform_attn_cpu.cpp b/annotator/entityseg/mask2former/modeling/pixel_decoder/ops/src/cpu/ms_deform_attn_cpu.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..48757e2b0156b2c1513b615d2a17e5aee5172ae7
--- /dev/null
+++ b/annotator/entityseg/mask2former/modeling/pixel_decoder/ops/src/cpu/ms_deform_attn_cpu.cpp
@@ -0,0 +1,46 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+/*!
+* Copyright (c) Facebook, Inc. and its affiliates.
+* Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
+*/
+
+#include <vector>
+
+#include <ATen/ATen.h>
+#include <ATen/cuda/CUDAContext.h>
+
+
+at::Tensor
+ms_deform_attn_cpu_forward(
+    const at::Tensor &value, 
+    const at::Tensor &spatial_shapes,
+    const at::Tensor &level_start_index,
+    const at::Tensor &sampling_loc,
+    const at::Tensor &attn_weight,
+    const int im2col_step)
+{
+    AT_ERROR("Not implement on cpu");
+}
+
+std::vector<at::Tensor>
+ms_deform_attn_cpu_backward(
+    const at::Tensor &value, 
+    const at::Tensor &spatial_shapes,
+    const at::Tensor &level_start_index,
+    const at::Tensor &sampling_loc,
+    const at::Tensor &attn_weight,
+    const at::Tensor &grad_output,
+    const int im2col_step)
+{
+    AT_ERROR("Not implement on cpu");
+}
+
diff --git a/annotator/entityseg/mask2former/modeling/pixel_decoder/ops/src/cpu/ms_deform_attn_cpu.h b/annotator/entityseg/mask2former/modeling/pixel_decoder/ops/src/cpu/ms_deform_attn_cpu.h
new file mode 100644
index 0000000000000000000000000000000000000000..51bb27e9ee828f967e8aa854c2d55574040c6d7e
--- /dev/null
+++ b/annotator/entityseg/mask2former/modeling/pixel_decoder/ops/src/cpu/ms_deform_attn_cpu.h
@@ -0,0 +1,38 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+/*!
+* Copyright (c) Facebook, Inc. and its affiliates.
+* Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
+*/
+
+#pragma once
+#include <torch/extension.h>
+
+at::Tensor
+ms_deform_attn_cpu_forward(
+    const at::Tensor &value, 
+    const at::Tensor &spatial_shapes,
+    const at::Tensor &level_start_index,
+    const at::Tensor &sampling_loc,
+    const at::Tensor &attn_weight,
+    const int im2col_step);
+
+std::vector<at::Tensor>
+ms_deform_attn_cpu_backward(
+    const at::Tensor &value, 
+    const at::Tensor &spatial_shapes,
+    const at::Tensor &level_start_index,
+    const at::Tensor &sampling_loc,
+    const at::Tensor &attn_weight,
+    const at::Tensor &grad_output,
+    const int im2col_step);
+
+
diff --git a/annotator/entityseg/mask2former/modeling/pixel_decoder/ops/src/cuda/ms_deform_attn_cuda.cu b/annotator/entityseg/mask2former/modeling/pixel_decoder/ops/src/cuda/ms_deform_attn_cuda.cu
new file mode 100644
index 0000000000000000000000000000000000000000..0c465dab3d636dfd6a44523c63f148b6e15084d9
--- /dev/null
+++ b/annotator/entityseg/mask2former/modeling/pixel_decoder/ops/src/cuda/ms_deform_attn_cuda.cu
@@ -0,0 +1,158 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+/*!
+* Copyright (c) Facebook, Inc. and its affiliates.
+* Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
+*/
+
+#include <vector>
+#include "cuda/ms_deform_im2col_cuda.cuh"
+
+#include <ATen/ATen.h>
+#include <ATen/cuda/CUDAContext.h>
+#include <cuda.h>
+#include <cuda_runtime.h>
+
+
+at::Tensor ms_deform_attn_cuda_forward(
+    const at::Tensor &value, 
+    const at::Tensor &spatial_shapes,
+    const at::Tensor &level_start_index,
+    const at::Tensor &sampling_loc,
+    const at::Tensor &attn_weight,
+    const int im2col_step)
+{
+    AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
+    AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
+    AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
+    AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
+    AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
+
+    AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
+    AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
+    AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
+    AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
+    AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
+
+    const int batch = value.size(0);
+    const int spatial_size = value.size(1);
+    const int num_heads = value.size(2);
+    const int channels = value.size(3);
+
+    const int num_levels = spatial_shapes.size(0);
+
+    const int num_query = sampling_loc.size(1);
+    const int num_point = sampling_loc.size(4);
+
+    const int im2col_step_ = std::min(batch, im2col_step);
+
+    AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
+    
+    auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
+
+    const int batch_n = im2col_step_;
+    auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
+    auto per_value_size = spatial_size * num_heads * channels;
+    auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
+    auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
+    for (int n = 0; n < batch/im2col_step_; ++n)
+    {
+        auto columns = output_n.select(0, n);
+        AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
+            ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
+                value.data<scalar_t>() + n * im2col_step_ * per_value_size,
+                spatial_shapes.data<int64_t>(),
+                level_start_index.data<int64_t>(),
+                sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
+                attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
+                batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
+                columns.data<scalar_t>());
+
+        }));
+    }
+
+    output = output.view({batch, num_query, num_heads*channels});
+
+    return output;
+}
+
+
+std::vector<at::Tensor> ms_deform_attn_cuda_backward(
+    const at::Tensor &value, 
+    const at::Tensor &spatial_shapes,
+    const at::Tensor &level_start_index,
+    const at::Tensor &sampling_loc,
+    const at::Tensor &attn_weight,
+    const at::Tensor &grad_output,
+    const int im2col_step)
+{
+
+    AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
+    AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
+    AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
+    AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
+    AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
+    AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
+
+    AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
+    AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
+    AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
+    AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
+    AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
+    AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
+
+    const int batch = value.size(0);
+    const int spatial_size = value.size(1);
+    const int num_heads = value.size(2);
+    const int channels = value.size(3);
+
+    const int num_levels = spatial_shapes.size(0);
+
+    const int num_query = sampling_loc.size(1);
+    const int num_point = sampling_loc.size(4);
+
+    const int im2col_step_ = std::min(batch, im2col_step);
+
+    AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
+
+    auto grad_value = at::zeros_like(value);
+    auto grad_sampling_loc = at::zeros_like(sampling_loc);
+    auto grad_attn_weight = at::zeros_like(attn_weight);
+
+    const int batch_n = im2col_step_;
+    auto per_value_size = spatial_size * num_heads * channels;
+    auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
+    auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
+    auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
+    
+    for (int n = 0; n < batch/im2col_step_; ++n)
+    {
+        auto grad_output_g = grad_output_n.select(0, n);
+        AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
+            ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
+                                    grad_output_g.data<scalar_t>(),
+                                    value.data<scalar_t>() + n * im2col_step_ * per_value_size,
+                                    spatial_shapes.data<int64_t>(),
+                                    level_start_index.data<int64_t>(),
+                                    sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
+                                    attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
+                                    batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
+                                    grad_value.data<scalar_t>() +  n * im2col_step_ * per_value_size,
+                                    grad_sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
+                                    grad_attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size);
+
+        }));
+    }
+
+    return {
+        grad_value, grad_sampling_loc, grad_attn_weight
+    };
+}
\ No newline at end of file
diff --git a/annotator/entityseg/mask2former/modeling/pixel_decoder/ops/src/cuda/ms_deform_attn_cuda.h b/annotator/entityseg/mask2former/modeling/pixel_decoder/ops/src/cuda/ms_deform_attn_cuda.h
new file mode 100644
index 0000000000000000000000000000000000000000..4f0658e8668a11f0e7d71deff9adac71884f2e87
--- /dev/null
+++ b/annotator/entityseg/mask2former/modeling/pixel_decoder/ops/src/cuda/ms_deform_attn_cuda.h
@@ -0,0 +1,35 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+/*!
+* Copyright (c) Facebook, Inc. and its affiliates.
+* Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
+*/
+
+#pragma once
+#include <torch/extension.h>
+
+at::Tensor ms_deform_attn_cuda_forward(
+    const at::Tensor &value, 
+    const at::Tensor &spatial_shapes,
+    const at::Tensor &level_start_index,
+    const at::Tensor &sampling_loc,
+    const at::Tensor &attn_weight,
+    const int im2col_step);
+
+std::vector<at::Tensor> ms_deform_attn_cuda_backward(
+    const at::Tensor &value, 
+    const at::Tensor &spatial_shapes,
+    const at::Tensor &level_start_index,
+    const at::Tensor &sampling_loc,
+    const at::Tensor &attn_weight,
+    const at::Tensor &grad_output,
+    const int im2col_step);
+
diff --git a/annotator/entityseg/mask2former/modeling/pixel_decoder/ops/src/cuda/ms_deform_im2col_cuda.cuh b/annotator/entityseg/mask2former/modeling/pixel_decoder/ops/src/cuda/ms_deform_im2col_cuda.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..c04e0d4ab97d25c1756fcd8d08dd1e5a6d280b7c
--- /dev/null
+++ b/annotator/entityseg/mask2former/modeling/pixel_decoder/ops/src/cuda/ms_deform_im2col_cuda.cuh
@@ -0,0 +1,1332 @@
+/*!
+**************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************
+* Modified from DCN (https://github.com/msracver/Deformable-ConvNets)
+* Copyright (c) 2018 Microsoft
+**************************************************************************
+*/
+
+/*!
+* Copyright (c) Facebook, Inc. and its affiliates.
+* Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
+*/
+
+#include <cstdio>
+#include <algorithm>
+#include <cstring>
+
+#include <ATen/ATen.h>
+#include <ATen/cuda/CUDAContext.h>
+
+#include <THC/THCAtomics.cuh>
+
+#define CUDA_KERNEL_LOOP(i, n)                          \
+  for (int i = blockIdx.x * blockDim.x + threadIdx.x;   \
+      i < (n);                                          \
+      i += blockDim.x * gridDim.x)
+
+const int CUDA_NUM_THREADS = 1024;
+inline int GET_BLOCKS(const int N, const int num_threads)
+{
+  return (N + num_threads - 1) / num_threads;
+}
+
+
+template <typename scalar_t>
+__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data, 
+                                                   const int &height, const int &width, const int &nheads, const int &channels,
+                                                   const scalar_t &h, const scalar_t &w, const int &m, const int &c)
+{
+  const int h_low = floor(h);
+  const int w_low = floor(w);
+  const int h_high = h_low + 1;
+  const int w_high = w_low + 1;
+
+  const scalar_t lh = h - h_low;
+  const scalar_t lw = w - w_low;
+  const scalar_t hh = 1 - lh, hw = 1 - lw;
+
+  const int w_stride = nheads * channels;
+  const int h_stride = width * w_stride;
+  const int h_low_ptr_offset = h_low * h_stride;
+  const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
+  const int w_low_ptr_offset = w_low * w_stride;
+  const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
+  const int base_ptr = m * channels + c;
+
+  scalar_t v1 = 0;
+  if (h_low >= 0 && w_low >= 0)
+  {
+    const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
+    v1 = bottom_data[ptr1];
+  }
+  scalar_t v2 = 0;
+  if (h_low >= 0 && w_high <= width - 1)
+  {
+    const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
+    v2 = bottom_data[ptr2];
+  }
+  scalar_t v3 = 0;
+  if (h_high <= height - 1 && w_low >= 0)
+  {
+    const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
+    v3 = bottom_data[ptr3];
+  }
+  scalar_t v4 = 0;
+  if (h_high <= height - 1 && w_high <= width - 1)
+  {
+    const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
+    v4 = bottom_data[ptr4];
+  }
+
+  const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+
+  const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+  return val;
+}
+
+
+template <typename scalar_t>
+__device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data, 
+                                                   const int &height, const int &width, const int &nheads, const int &channels,
+                                                   const scalar_t &h, const scalar_t &w, const int &m, const int &c,
+                                                   const scalar_t &top_grad,
+                                                   const scalar_t &attn_weight,
+                                                   scalar_t* &grad_value, 
+                                                   scalar_t* grad_sampling_loc,
+                                                   scalar_t* grad_attn_weight)
+{
+  const int h_low = floor(h);
+  const int w_low = floor(w);
+  const int h_high = h_low + 1;
+  const int w_high = w_low + 1;
+
+  const scalar_t lh = h - h_low;
+  const scalar_t lw = w - w_low;
+  const scalar_t hh = 1 - lh, hw = 1 - lw;
+
+  const int w_stride = nheads * channels;
+  const int h_stride = width * w_stride;
+  const int h_low_ptr_offset = h_low * h_stride;
+  const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
+  const int w_low_ptr_offset = w_low * w_stride;
+  const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
+  const int base_ptr = m * channels + c;
+
+  const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+  const scalar_t top_grad_value = top_grad * attn_weight;
+  scalar_t grad_h_weight = 0, grad_w_weight = 0;
+
+  scalar_t v1 = 0;
+  if (h_low >= 0 && w_low >= 0)
+  {
+    const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
+    v1 = bottom_data[ptr1];
+    grad_h_weight -= hw * v1;
+    grad_w_weight -= hh * v1;
+    atomicAdd(grad_value+ptr1, w1*top_grad_value);
+  }
+  scalar_t v2 = 0;
+  if (h_low >= 0 && w_high <= width - 1)
+  {
+    const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
+    v2 = bottom_data[ptr2];
+    grad_h_weight -= lw * v2;
+    grad_w_weight += hh * v2;
+    atomicAdd(grad_value+ptr2, w2*top_grad_value);
+  }
+  scalar_t v3 = 0;
+  if (h_high <= height - 1 && w_low >= 0)
+  {
+    const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
+    v3 = bottom_data[ptr3];
+    grad_h_weight += hw * v3;
+    grad_w_weight -= lh * v3;
+    atomicAdd(grad_value+ptr3, w3*top_grad_value); 
+  }
+  scalar_t v4 = 0;
+  if (h_high <= height - 1 && w_high <= width - 1)
+  {
+    const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
+    v4 = bottom_data[ptr4];
+    grad_h_weight += lw * v4;
+    grad_w_weight += lh * v4;
+    atomicAdd(grad_value+ptr4, w4*top_grad_value);
+  }
+
+  const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+  *grad_attn_weight = top_grad * val;
+  *grad_sampling_loc = width * grad_w_weight * top_grad_value;
+  *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
+}
+
+
+template <typename scalar_t>
+__device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data, 
+                                                   const int &height, const int &width, const int &nheads, const int &channels,
+                                                   const scalar_t &h, const scalar_t &w, const int &m, const int &c,
+                                                   const scalar_t &top_grad,
+                                                   const scalar_t &attn_weight,
+                                                   scalar_t* &grad_value, 
+                                                   scalar_t* grad_sampling_loc,
+                                                   scalar_t* grad_attn_weight)
+{
+  const int h_low = floor(h);
+  const int w_low = floor(w);
+  const int h_high = h_low + 1;
+  const int w_high = w_low + 1;
+
+  const scalar_t lh = h - h_low;
+  const scalar_t lw = w - w_low;
+  const scalar_t hh = 1 - lh, hw = 1 - lw;
+
+  const int w_stride = nheads * channels;
+  const int h_stride = width * w_stride;
+  const int h_low_ptr_offset = h_low * h_stride;
+  const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
+  const int w_low_ptr_offset = w_low * w_stride;
+  const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
+  const int base_ptr = m * channels + c;
+
+  const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+  const scalar_t top_grad_value = top_grad * attn_weight;
+  scalar_t grad_h_weight = 0, grad_w_weight = 0;
+
+  scalar_t v1 = 0;
+  if (h_low >= 0 && w_low >= 0)
+  {
+    const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
+    v1 = bottom_data[ptr1];
+    grad_h_weight -= hw * v1;
+    grad_w_weight -= hh * v1;
+    atomicAdd(grad_value+ptr1, w1*top_grad_value);
+  }
+  scalar_t v2 = 0;
+  if (h_low >= 0 && w_high <= width - 1)
+  {
+    const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
+    v2 = bottom_data[ptr2];
+    grad_h_weight -= lw * v2;
+    grad_w_weight += hh * v2;
+    atomicAdd(grad_value+ptr2, w2*top_grad_value);
+  }
+  scalar_t v3 = 0;
+  if (h_high <= height - 1 && w_low >= 0)
+  {
+    const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
+    v3 = bottom_data[ptr3];
+    grad_h_weight += hw * v3;
+    grad_w_weight -= lh * v3;
+    atomicAdd(grad_value+ptr3, w3*top_grad_value); 
+  }
+  scalar_t v4 = 0;
+  if (h_high <= height - 1 && w_high <= width - 1)
+  {
+    const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
+    v4 = bottom_data[ptr4];
+    grad_h_weight += lw * v4;
+    grad_w_weight += lh * v4;
+    atomicAdd(grad_value+ptr4, w4*top_grad_value);
+  }
+
+  const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+  atomicAdd(grad_attn_weight, top_grad * val); 
+  atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
+  atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
+}
+
+
+template <typename scalar_t>
+__global__ void ms_deformable_im2col_gpu_kernel(const int n,
+                                                const scalar_t *data_value, 
+                                                const int64_t *data_spatial_shapes,
+                                                const int64_t *data_level_start_index, 
+                                                const scalar_t *data_sampling_loc,
+                                                const scalar_t *data_attn_weight,
+                                                const int batch_size, 
+                                                const int spatial_size, 
+                                                const int num_heads,
+                                                const int channels, 
+                                                const int num_levels,
+                                                const int num_query,
+                                                const int num_point,
+                                                scalar_t *data_col)
+{
+  CUDA_KERNEL_LOOP(index, n)
+  {
+    int _temp = index;
+    const int c_col = _temp % channels;
+    _temp /= channels;
+    const int sampling_index = _temp; 
+    const int m_col = _temp % num_heads;
+    _temp /= num_heads;
+    const int q_col = _temp % num_query;
+    _temp /= num_query;
+    const int b_col = _temp;
+
+    scalar_t *data_col_ptr = data_col + index;
+    int data_weight_ptr = sampling_index * num_levels * num_point;
+    int data_loc_w_ptr = data_weight_ptr << 1;
+    const int qid_stride = num_heads * channels;
+    const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+    scalar_t col = 0;
+    
+    for (int l_col=0; l_col < num_levels; ++l_col)
+    {
+      const int level_start_id = data_level_start_index[l_col];
+      const int spatial_h_ptr = l_col << 1;
+      const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+      const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+      const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride);
+      for (int p_col=0; p_col < num_point; ++p_col)
+      {
+        const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+        const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+        const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+        const scalar_t h_im = loc_h * spatial_h - 0.5;
+        const scalar_t w_im = loc_w * spatial_w - 0.5;
+
+        if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+        {
+          col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight;
+        }
+
+        data_weight_ptr += 1;
+        data_loc_w_ptr += 2;
+      }
+    }
+    *data_col_ptr = col;
+  }
+}
+
+template <typename scalar_t, unsigned int blockSize>
+__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n,
+                                                const scalar_t *grad_col,
+                                                const scalar_t *data_value,
+                                                const int64_t *data_spatial_shapes,
+                                                const int64_t *data_level_start_index, 
+                                                const scalar_t *data_sampling_loc,
+                                                const scalar_t *data_attn_weight,
+                                                const int batch_size, 
+                                                const int spatial_size, 
+                                                const int num_heads,
+                                                const int channels, 
+                                                const int num_levels,
+                                                const int num_query,
+                                                const int num_point,
+                                                scalar_t *grad_value,
+                                                scalar_t *grad_sampling_loc,
+                                                scalar_t *grad_attn_weight)
+{
+  CUDA_KERNEL_LOOP(index, n)
+  {
+    __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
+    __shared__ scalar_t cache_grad_attn_weight[blockSize];
+    unsigned int tid = threadIdx.x;
+    int _temp = index;
+    const int c_col = _temp % channels;
+    _temp /= channels;
+    const int sampling_index = _temp; 
+    const int m_col = _temp % num_heads;
+    _temp /= num_heads;
+    const int q_col = _temp % num_query;
+    _temp /= num_query;
+    const int b_col = _temp;
+
+    const scalar_t top_grad = grad_col[index];
+
+    int data_weight_ptr = sampling_index * num_levels * num_point;
+    int data_loc_w_ptr = data_weight_ptr << 1;
+    const int grad_sampling_ptr = data_weight_ptr;
+    grad_sampling_loc += grad_sampling_ptr << 1;
+    grad_attn_weight += grad_sampling_ptr;
+    const int grad_weight_stride = 1;
+    const int grad_loc_stride = 2;
+    const int qid_stride = num_heads * channels;
+    const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+    for (int l_col=0; l_col < num_levels; ++l_col)
+    {
+      const int level_start_id = data_level_start_index[l_col];
+      const int spatial_h_ptr = l_col << 1;
+      const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+      const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+      const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+      const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+      scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+      for (int p_col=0; p_col < num_point; ++p_col)
+      {
+        const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+        const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+        const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+        const scalar_t h_im = loc_h * spatial_h - 0.5;
+        const scalar_t w_im = loc_w * spatial_w - 0.5;
+        *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+        *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+        *(cache_grad_attn_weight+threadIdx.x)=0;
+        if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+        {
+          ms_deform_attn_col2im_bilinear(
+            data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+            top_grad, weight, grad_value_ptr, 
+            cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+        }
+        
+        __syncthreads();
+        if (tid == 0)
+        {
+          scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
+          int sid=2;
+          for (unsigned int tid = 1; tid < blockSize; ++tid)
+          {
+            _grad_w += cache_grad_sampling_loc[sid];
+            _grad_h += cache_grad_sampling_loc[sid + 1];
+            _grad_a += cache_grad_attn_weight[tid];
+            sid += 2;
+          }
+          
+          
+          *grad_sampling_loc = _grad_w;
+          *(grad_sampling_loc + 1) = _grad_h;
+          *grad_attn_weight = _grad_a;
+        }
+        __syncthreads();
+
+        data_weight_ptr += 1;
+        data_loc_w_ptr += 2;
+        grad_attn_weight += grad_weight_stride;
+        grad_sampling_loc += grad_loc_stride;
+      }
+    }
+  }
+}
+
+
+template <typename scalar_t, unsigned int blockSize>
+__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n,
+                                                const scalar_t *grad_col,
+                                                const scalar_t *data_value,
+                                                const int64_t *data_spatial_shapes,
+                                                const int64_t *data_level_start_index, 
+                                                const scalar_t *data_sampling_loc,
+                                                const scalar_t *data_attn_weight,
+                                                const int batch_size, 
+                                                const int spatial_size, 
+                                                const int num_heads,
+                                                const int channels, 
+                                                const int num_levels,
+                                                const int num_query,
+                                                const int num_point,
+                                                scalar_t *grad_value,
+                                                scalar_t *grad_sampling_loc,
+                                                scalar_t *grad_attn_weight)
+{
+  CUDA_KERNEL_LOOP(index, n)
+  {
+    __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
+    __shared__ scalar_t cache_grad_attn_weight[blockSize];
+    unsigned int tid = threadIdx.x;
+    int _temp = index;
+    const int c_col = _temp % channels;
+    _temp /= channels;
+    const int sampling_index = _temp; 
+    const int m_col = _temp % num_heads;
+    _temp /= num_heads;
+    const int q_col = _temp % num_query;
+    _temp /= num_query;
+    const int b_col = _temp;
+
+    const scalar_t top_grad = grad_col[index];
+
+    int data_weight_ptr = sampling_index * num_levels * num_point;
+    int data_loc_w_ptr = data_weight_ptr << 1;
+    const int grad_sampling_ptr = data_weight_ptr;
+    grad_sampling_loc += grad_sampling_ptr << 1;
+    grad_attn_weight += grad_sampling_ptr;
+    const int grad_weight_stride = 1;
+    const int grad_loc_stride = 2;
+    const int qid_stride = num_heads * channels;
+    const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+    for (int l_col=0; l_col < num_levels; ++l_col)
+    {
+      const int level_start_id = data_level_start_index[l_col];
+      const int spatial_h_ptr = l_col << 1;
+      const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+      const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+      const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+      const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+      scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+      for (int p_col=0; p_col < num_point; ++p_col)
+      {
+        const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+        const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+        const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+        const scalar_t h_im = loc_h * spatial_h - 0.5;
+        const scalar_t w_im = loc_w * spatial_w - 0.5;
+        *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+        *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+        *(cache_grad_attn_weight+threadIdx.x)=0;
+        if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+        {
+          ms_deform_attn_col2im_bilinear(
+            data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+            top_grad, weight, grad_value_ptr, 
+            cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+        }
+        
+        __syncthreads();
+
+        for (unsigned int s=blockSize/2; s>0; s>>=1)
+        {
+          if (tid < s) {
+            const unsigned int xid1 = tid << 1;
+            const unsigned int xid2 = (tid + s) << 1;
+            cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
+            cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
+            cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
+          }
+          __syncthreads();
+        }
+
+        if (tid == 0)
+        { 
+          *grad_sampling_loc = cache_grad_sampling_loc[0];
+          *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
+          *grad_attn_weight = cache_grad_attn_weight[0];
+        }
+        __syncthreads();
+
+        data_weight_ptr += 1;
+        data_loc_w_ptr += 2;
+        grad_attn_weight += grad_weight_stride;
+        grad_sampling_loc += grad_loc_stride;
+      }
+    }
+  }
+}
+
+
+template <typename scalar_t>
+__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n,
+                                                const scalar_t *grad_col,
+                                                const scalar_t *data_value,
+                                                const int64_t *data_spatial_shapes,
+                                                const int64_t *data_level_start_index, 
+                                                const scalar_t *data_sampling_loc,
+                                                const scalar_t *data_attn_weight,
+                                                const int batch_size, 
+                                                const int spatial_size, 
+                                                const int num_heads,
+                                                const int channels, 
+                                                const int num_levels,
+                                                const int num_query,
+                                                const int num_point,
+                                                scalar_t *grad_value,
+                                                scalar_t *grad_sampling_loc,
+                                                scalar_t *grad_attn_weight)
+{
+  CUDA_KERNEL_LOOP(index, n)
+  {
+    extern __shared__ int _s[];
+    scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
+    scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
+    unsigned int tid = threadIdx.x;
+    int _temp = index;
+    const int c_col = _temp % channels;
+    _temp /= channels;
+    const int sampling_index = _temp; 
+    const int m_col = _temp % num_heads;
+    _temp /= num_heads;
+    const int q_col = _temp % num_query;
+    _temp /= num_query;
+    const int b_col = _temp;
+
+    const scalar_t top_grad = grad_col[index];
+
+    int data_weight_ptr = sampling_index * num_levels * num_point;
+    int data_loc_w_ptr = data_weight_ptr << 1;
+    const int grad_sampling_ptr = data_weight_ptr;
+    grad_sampling_loc += grad_sampling_ptr << 1;
+    grad_attn_weight += grad_sampling_ptr;
+    const int grad_weight_stride = 1;
+    const int grad_loc_stride = 2;
+    const int qid_stride = num_heads * channels;
+    const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+    for (int l_col=0; l_col < num_levels; ++l_col)
+    {
+      const int level_start_id = data_level_start_index[l_col];
+      const int spatial_h_ptr = l_col << 1;
+      const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+      const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+      const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+      const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+      scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+      for (int p_col=0; p_col < num_point; ++p_col)
+      {
+        const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+        const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+        const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+        const scalar_t h_im = loc_h * spatial_h - 0.5;
+        const scalar_t w_im = loc_w * spatial_w - 0.5;
+        *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+        *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+        *(cache_grad_attn_weight+threadIdx.x)=0;
+        if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+        {
+          ms_deform_attn_col2im_bilinear(
+            data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+            top_grad, weight, grad_value_ptr, 
+            cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+        }
+        
+        __syncthreads();
+        if (tid == 0)
+        {
+          scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
+          int sid=2;
+          for (unsigned int tid = 1; tid < blockDim.x; ++tid)
+          {
+            _grad_w += cache_grad_sampling_loc[sid];
+            _grad_h += cache_grad_sampling_loc[sid + 1];
+            _grad_a += cache_grad_attn_weight[tid];
+            sid += 2;
+          }
+          
+          
+          *grad_sampling_loc = _grad_w;
+          *(grad_sampling_loc + 1) = _grad_h;
+          *grad_attn_weight = _grad_a;
+        }
+        __syncthreads();
+
+        data_weight_ptr += 1;
+        data_loc_w_ptr += 2;
+        grad_attn_weight += grad_weight_stride;
+        grad_sampling_loc += grad_loc_stride;
+      }
+    }
+  }
+}
+
+template <typename scalar_t>
+__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n,
+                                                const scalar_t *grad_col,
+                                                const scalar_t *data_value,
+                                                const int64_t *data_spatial_shapes,
+                                                const int64_t *data_level_start_index, 
+                                                const scalar_t *data_sampling_loc,
+                                                const scalar_t *data_attn_weight,
+                                                const int batch_size, 
+                                                const int spatial_size, 
+                                                const int num_heads,
+                                                const int channels, 
+                                                const int num_levels,
+                                                const int num_query,
+                                                const int num_point,
+                                                scalar_t *grad_value,
+                                                scalar_t *grad_sampling_loc,
+                                                scalar_t *grad_attn_weight)
+{
+  CUDA_KERNEL_LOOP(index, n)
+  {
+    extern __shared__ int _s[];
+    scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
+    scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
+    unsigned int tid = threadIdx.x;
+    int _temp = index;
+    const int c_col = _temp % channels;
+    _temp /= channels;
+    const int sampling_index = _temp; 
+    const int m_col = _temp % num_heads;
+    _temp /= num_heads;
+    const int q_col = _temp % num_query;
+    _temp /= num_query;
+    const int b_col = _temp;
+
+    const scalar_t top_grad = grad_col[index];
+
+    int data_weight_ptr = sampling_index * num_levels * num_point;
+    int data_loc_w_ptr = data_weight_ptr << 1;
+    const int grad_sampling_ptr = data_weight_ptr;
+    grad_sampling_loc += grad_sampling_ptr << 1;
+    grad_attn_weight += grad_sampling_ptr;
+    const int grad_weight_stride = 1;
+    const int grad_loc_stride = 2;
+    const int qid_stride = num_heads * channels;
+    const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+    for (int l_col=0; l_col < num_levels; ++l_col)
+    {
+      const int level_start_id = data_level_start_index[l_col];
+      const int spatial_h_ptr = l_col << 1;
+      const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+      const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+      const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+      const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+      scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+      for (int p_col=0; p_col < num_point; ++p_col)
+      {
+        const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+        const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+        const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+        const scalar_t h_im = loc_h * spatial_h - 0.5;
+        const scalar_t w_im = loc_w * spatial_w - 0.5;
+        *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+        *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+        *(cache_grad_attn_weight+threadIdx.x)=0;
+        if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+        {
+          ms_deform_attn_col2im_bilinear(
+            data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+            top_grad, weight, grad_value_ptr, 
+            cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+        }
+        
+        __syncthreads();
+
+        for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
+        {
+          if (tid < s) {
+            const unsigned int xid1 = tid << 1;
+            const unsigned int xid2 = (tid + s) << 1;
+            cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
+            cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
+            cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
+            if (tid + (s << 1) < spre)
+            {
+              cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
+              cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
+              cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
+            } 
+          }
+          __syncthreads();
+        }
+
+        if (tid == 0)
+        {
+          *grad_sampling_loc = cache_grad_sampling_loc[0];
+          *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
+          *grad_attn_weight = cache_grad_attn_weight[0];
+        }
+        __syncthreads();
+
+        data_weight_ptr += 1;
+        data_loc_w_ptr += 2;
+        grad_attn_weight += grad_weight_stride;
+        grad_sampling_loc += grad_loc_stride;
+      }
+    }
+  }
+}
+
+template <typename scalar_t>
+__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n,
+                                                const scalar_t *grad_col,
+                                                const scalar_t *data_value,
+                                                const int64_t *data_spatial_shapes,
+                                                const int64_t *data_level_start_index, 
+                                                const scalar_t *data_sampling_loc,
+                                                const scalar_t *data_attn_weight,
+                                                const int batch_size, 
+                                                const int spatial_size, 
+                                                const int num_heads,
+                                                const int channels, 
+                                                const int num_levels,
+                                                const int num_query,
+                                                const int num_point,
+                                                scalar_t *grad_value,
+                                                scalar_t *grad_sampling_loc,
+                                                scalar_t *grad_attn_weight)
+{
+  CUDA_KERNEL_LOOP(index, n)
+  {
+    extern __shared__ int _s[];
+    scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
+    scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
+    unsigned int tid = threadIdx.x;
+    int _temp = index;
+    const int c_col = _temp % channels;
+    _temp /= channels;
+    const int sampling_index = _temp; 
+    const int m_col = _temp % num_heads;
+    _temp /= num_heads;
+    const int q_col = _temp % num_query;
+    _temp /= num_query;
+    const int b_col = _temp;
+
+    const scalar_t top_grad = grad_col[index];
+
+    int data_weight_ptr = sampling_index * num_levels * num_point;
+    int data_loc_w_ptr = data_weight_ptr << 1;
+    const int grad_sampling_ptr = data_weight_ptr;
+    grad_sampling_loc += grad_sampling_ptr << 1;
+    grad_attn_weight += grad_sampling_ptr;
+    const int grad_weight_stride = 1;
+    const int grad_loc_stride = 2;
+    const int qid_stride = num_heads * channels;
+    const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+    for (int l_col=0; l_col < num_levels; ++l_col)
+    {
+      const int level_start_id = data_level_start_index[l_col];
+      const int spatial_h_ptr = l_col << 1;
+      const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+      const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+      const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+      const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+      scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+      for (int p_col=0; p_col < num_point; ++p_col)
+      {
+        const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+        const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+        const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+        const scalar_t h_im = loc_h * spatial_h - 0.5;
+        const scalar_t w_im = loc_w * spatial_w - 0.5;
+        *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+        *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+        *(cache_grad_attn_weight+threadIdx.x)=0;
+        if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+        {
+          ms_deform_attn_col2im_bilinear(
+            data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+            top_grad, weight, grad_value_ptr, 
+            cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+        }
+        
+        __syncthreads();
+
+        for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
+        {
+          if (tid < s) {
+            const unsigned int xid1 = tid << 1;
+            const unsigned int xid2 = (tid + s) << 1;
+            cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
+            cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
+            cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
+            if (tid + (s << 1) < spre)
+            {
+              cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
+              cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
+              cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
+            }
+          }
+          __syncthreads();
+        }
+
+        if (tid == 0)
+        {
+          atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
+          atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
+          atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
+        }
+        __syncthreads();
+
+        data_weight_ptr += 1;
+        data_loc_w_ptr += 2;
+        grad_attn_weight += grad_weight_stride;
+        grad_sampling_loc += grad_loc_stride;
+      }
+    }
+  }
+}
+
+
+template <typename scalar_t>
+__global__ void ms_deformable_col2im_gpu_kernel_gm(const int n,
+                                                const scalar_t *grad_col,
+                                                const scalar_t *data_value,
+                                                const int64_t *data_spatial_shapes,
+                                                const int64_t *data_level_start_index, 
+                                                const scalar_t *data_sampling_loc,
+                                                const scalar_t *data_attn_weight,
+                                                const int batch_size, 
+                                                const int spatial_size, 
+                                                const int num_heads,
+                                                const int channels, 
+                                                const int num_levels,
+                                                const int num_query,
+                                                const int num_point,
+                                                scalar_t *grad_value,
+                                                scalar_t *grad_sampling_loc,
+                                                scalar_t *grad_attn_weight)
+{
+  CUDA_KERNEL_LOOP(index, n)
+  {
+    int _temp = index;
+    const int c_col = _temp % channels;
+    _temp /= channels;
+    const int sampling_index = _temp; 
+    const int m_col = _temp % num_heads;
+    _temp /= num_heads;
+    const int q_col = _temp % num_query;
+    _temp /= num_query;
+    const int b_col = _temp;
+
+    const scalar_t top_grad = grad_col[index];
+
+    int data_weight_ptr = sampling_index * num_levels * num_point;
+    int data_loc_w_ptr = data_weight_ptr << 1;
+    const int grad_sampling_ptr = data_weight_ptr;
+    grad_sampling_loc += grad_sampling_ptr << 1;
+    grad_attn_weight += grad_sampling_ptr;
+    const int grad_weight_stride = 1;
+    const int grad_loc_stride = 2;
+    const int qid_stride = num_heads * channels;
+    const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+    for (int l_col=0; l_col < num_levels; ++l_col)
+    {
+      const int level_start_id = data_level_start_index[l_col];
+      const int spatial_h_ptr = l_col << 1;
+      const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+      const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+      const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+      const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+      scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+      for (int p_col=0; p_col < num_point; ++p_col)
+      {
+        const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+        const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+        const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+        const scalar_t h_im = loc_h * spatial_h - 0.5;
+        const scalar_t w_im = loc_w * spatial_w - 0.5;
+        if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+        {
+          ms_deform_attn_col2im_bilinear_gm(
+            data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+            top_grad, weight, grad_value_ptr, 
+            grad_sampling_loc, grad_attn_weight);
+        }
+        data_weight_ptr += 1;
+        data_loc_w_ptr += 2;
+        grad_attn_weight += grad_weight_stride;
+        grad_sampling_loc += grad_loc_stride;
+      }
+    }
+  }
+}
+
+
+template <typename scalar_t>
+void ms_deformable_im2col_cuda(cudaStream_t stream,
+                              const scalar_t* data_value,
+                              const int64_t* data_spatial_shapes, 
+                              const int64_t* data_level_start_index, 
+                              const scalar_t* data_sampling_loc,
+                              const scalar_t* data_attn_weight,
+                              const int batch_size,
+                              const int spatial_size, 
+                              const int num_heads, 
+                              const int channels, 
+                              const int num_levels, 
+                              const int num_query,
+                              const int num_point,
+                              scalar_t* data_col)
+{
+  const int num_kernels = batch_size * num_query * num_heads * channels;
+  const int num_actual_kernels = batch_size * num_query * num_heads * channels;
+  const int num_threads = CUDA_NUM_THREADS;
+  ms_deformable_im2col_gpu_kernel<scalar_t>
+      <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
+          0, stream>>>(
+      num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, 
+      batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col);
+  
+  cudaError_t err = cudaGetLastError();
+  if (err != cudaSuccess)
+  {
+    printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
+  }
+
+}
+
+template <typename scalar_t>
+void ms_deformable_col2im_cuda(cudaStream_t stream,
+                              const scalar_t* grad_col,
+                              const scalar_t* data_value,
+                              const int64_t * data_spatial_shapes,
+                              const int64_t * data_level_start_index,
+                              const scalar_t * data_sampling_loc,
+                              const scalar_t * data_attn_weight,
+                              const int batch_size, 
+                              const int spatial_size, 
+                              const int num_heads,
+                              const int channels, 
+                              const int num_levels,
+                              const int num_query,
+                              const int num_point, 
+                              scalar_t* grad_value,
+                              scalar_t* grad_sampling_loc,
+                              scalar_t* grad_attn_weight)
+{
+  const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels;
+  const int num_kernels = batch_size * num_query * num_heads * channels;
+  const int num_actual_kernels = batch_size * num_query * num_heads * channels;
+  if (channels > 1024)
+  {
+    if ((channels & 1023) == 0)
+    {
+      ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks<scalar_t>
+          <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
+              num_threads*3*sizeof(scalar_t), stream>>>(
+                        num_kernels, 
+                        grad_col,
+                        data_value,
+                        data_spatial_shapes,
+                        data_level_start_index, 
+                        data_sampling_loc,
+                        data_attn_weight,
+                        batch_size, 
+                        spatial_size, 
+                        num_heads,
+                        channels, 
+                        num_levels,
+                        num_query,
+                        num_point,
+                        grad_value,
+                        grad_sampling_loc,
+                        grad_attn_weight);
+    }
+    else
+    {
+      ms_deformable_col2im_gpu_kernel_gm<scalar_t>
+        <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
+            0, stream>>>(
+                      num_kernels, 
+                      grad_col,
+                      data_value,
+                      data_spatial_shapes,
+                      data_level_start_index, 
+                      data_sampling_loc,
+                      data_attn_weight,
+                      batch_size, 
+                      spatial_size, 
+                      num_heads,
+                      channels, 
+                      num_levels,
+                      num_query,
+                      num_point,
+                      grad_value,
+                      grad_sampling_loc,
+                      grad_attn_weight);
+    }
+  }
+  else{
+    switch(channels)
+    {
+      case 1:
+        ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 1>
+        <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
+            0, stream>>>(
+                      num_kernels, 
+                      grad_col,
+                      data_value,
+                      data_spatial_shapes,
+                      data_level_start_index, 
+                      data_sampling_loc,
+                      data_attn_weight,
+                      batch_size, 
+                      spatial_size, 
+                      num_heads,
+                      channels, 
+                      num_levels,
+                      num_query,
+                      num_point,
+                      grad_value,
+                      grad_sampling_loc,
+                      grad_attn_weight);
+        break;
+      case 2:
+        ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 2>
+        <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
+            0, stream>>>(
+                      num_kernels, 
+                      grad_col,
+                      data_value,
+                      data_spatial_shapes,
+                      data_level_start_index, 
+                      data_sampling_loc,
+                      data_attn_weight,
+                      batch_size, 
+                      spatial_size, 
+                      num_heads,
+                      channels, 
+                      num_levels,
+                      num_query,
+                      num_point,
+                      grad_value,
+                      grad_sampling_loc,
+                      grad_attn_weight);
+        break;
+      case 4:
+        ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 4>
+        <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
+            0, stream>>>(
+                      num_kernels, 
+                      grad_col,
+                      data_value,
+                      data_spatial_shapes,
+                      data_level_start_index, 
+                      data_sampling_loc,
+                      data_attn_weight,
+                      batch_size, 
+                      spatial_size, 
+                      num_heads,
+                      channels, 
+                      num_levels,
+                      num_query,
+                      num_point,
+                      grad_value,
+                      grad_sampling_loc,
+                      grad_attn_weight);
+        break;
+      case 8:
+        ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 8>
+        <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
+            0, stream>>>(
+                      num_kernels, 
+                      grad_col,
+                      data_value,
+                      data_spatial_shapes,
+                      data_level_start_index, 
+                      data_sampling_loc,
+                      data_attn_weight,
+                      batch_size, 
+                      spatial_size, 
+                      num_heads,
+                      channels, 
+                      num_levels,
+                      num_query,
+                      num_point,
+                      grad_value,
+                      grad_sampling_loc,
+                      grad_attn_weight);
+        break;
+      case 16:
+        ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 16>
+        <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
+            0, stream>>>(
+                      num_kernels, 
+                      grad_col,
+                      data_value,
+                      data_spatial_shapes,
+                      data_level_start_index, 
+                      data_sampling_loc,
+                      data_attn_weight,
+                      batch_size, 
+                      spatial_size, 
+                      num_heads,
+                      channels, 
+                      num_levels,
+                      num_query,
+                      num_point,
+                      grad_value,
+                      grad_sampling_loc,
+                      grad_attn_weight);
+        break;
+      case 32:
+        ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 32>
+        <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
+            0, stream>>>(
+                      num_kernels, 
+                      grad_col,
+                      data_value,
+                      data_spatial_shapes,
+                      data_level_start_index, 
+                      data_sampling_loc,
+                      data_attn_weight,
+                      batch_size, 
+                      spatial_size, 
+                      num_heads,
+                      channels, 
+                      num_levels,
+                      num_query,
+                      num_point,
+                      grad_value,
+                      grad_sampling_loc,
+                      grad_attn_weight);
+        break;
+      case 64:
+        ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 64>
+        <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
+            0, stream>>>(
+                      num_kernels, 
+                      grad_col,
+                      data_value,
+                      data_spatial_shapes,
+                      data_level_start_index, 
+                      data_sampling_loc,
+                      data_attn_weight,
+                      batch_size, 
+                      spatial_size, 
+                      num_heads,
+                      channels, 
+                      num_levels,
+                      num_query,
+                      num_point,
+                      grad_value,
+                      grad_sampling_loc,
+                      grad_attn_weight);
+        break;
+      case 128:
+        ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 128>
+        <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
+            0, stream>>>(
+                      num_kernels, 
+                      grad_col,
+                      data_value,
+                      data_spatial_shapes,
+                      data_level_start_index, 
+                      data_sampling_loc,
+                      data_attn_weight,
+                      batch_size, 
+                      spatial_size, 
+                      num_heads,
+                      channels, 
+                      num_levels,
+                      num_query,
+                      num_point,
+                      grad_value,
+                      grad_sampling_loc,
+                      grad_attn_weight);
+        break;
+      case 256:
+        ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 256>
+        <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
+            0, stream>>>(
+                      num_kernels, 
+                      grad_col,
+                      data_value,
+                      data_spatial_shapes,
+                      data_level_start_index, 
+                      data_sampling_loc,
+                      data_attn_weight,
+                      batch_size, 
+                      spatial_size, 
+                      num_heads,
+                      channels, 
+                      num_levels,
+                      num_query,
+                      num_point,
+                      grad_value,
+                      grad_sampling_loc,
+                      grad_attn_weight);
+        break;
+      case 512:
+        ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 512>
+        <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
+            0, stream>>>(
+                      num_kernels, 
+                      grad_col,
+                      data_value,
+                      data_spatial_shapes,
+                      data_level_start_index, 
+                      data_sampling_loc,
+                      data_attn_weight,
+                      batch_size, 
+                      spatial_size, 
+                      num_heads,
+                      channels, 
+                      num_levels,
+                      num_query,
+                      num_point,
+                      grad_value,
+                      grad_sampling_loc,
+                      grad_attn_weight);
+        break;
+      case 1024:
+        ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 1024>
+        <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
+            0, stream>>>(
+                      num_kernels, 
+                      grad_col,
+                      data_value,
+                      data_spatial_shapes,
+                      data_level_start_index, 
+                      data_sampling_loc,
+                      data_attn_weight,
+                      batch_size, 
+                      spatial_size, 
+                      num_heads,
+                      channels, 
+                      num_levels,
+                      num_query,
+                      num_point,
+                      grad_value,
+                      grad_sampling_loc,
+                      grad_attn_weight);
+        break;
+      default:
+        if (channels < 64)
+        {
+          ms_deformable_col2im_gpu_kernel_shm_reduce_v1<scalar_t>
+          <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
+              num_threads*3*sizeof(scalar_t), stream>>>(
+                        num_kernels, 
+                        grad_col,
+                        data_value,
+                        data_spatial_shapes,
+                        data_level_start_index, 
+                        data_sampling_loc,
+                        data_attn_weight,
+                        batch_size, 
+                        spatial_size, 
+                        num_heads,
+                        channels, 
+                        num_levels,
+                        num_query,
+                        num_point,
+                        grad_value,
+                        grad_sampling_loc,
+                        grad_attn_weight);
+        }
+        else
+        {
+          ms_deformable_col2im_gpu_kernel_shm_reduce_v2<scalar_t>
+          <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
+              num_threads*3*sizeof(scalar_t), stream>>>(
+                        num_kernels, 
+                        grad_col,
+                        data_value,
+                        data_spatial_shapes,
+                        data_level_start_index, 
+                        data_sampling_loc,
+                        data_attn_weight,
+                        batch_size, 
+                        spatial_size, 
+                        num_heads,
+                        channels, 
+                        num_levels,
+                        num_query,
+                        num_point,
+                        grad_value,
+                        grad_sampling_loc,
+                        grad_attn_weight);
+        }
+    }
+  }
+  cudaError_t err = cudaGetLastError();
+  if (err != cudaSuccess)
+  {
+    printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
+  }
+
+}
\ No newline at end of file
diff --git a/annotator/entityseg/mask2former/modeling/pixel_decoder/ops/src/ms_deform_attn.h b/annotator/entityseg/mask2former/modeling/pixel_decoder/ops/src/ms_deform_attn.h
new file mode 100644
index 0000000000000000000000000000000000000000..2f80a1b294c55b37d13bb3558ff7aeadba3b37de
--- /dev/null
+++ b/annotator/entityseg/mask2former/modeling/pixel_decoder/ops/src/ms_deform_attn.h
@@ -0,0 +1,67 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+/*!
+* Copyright (c) Facebook, Inc. and its affiliates.
+* Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
+*/
+
+#pragma once
+
+#include "cpu/ms_deform_attn_cpu.h"
+
+#ifdef WITH_CUDA
+#include "cuda/ms_deform_attn_cuda.h"
+#endif
+
+
+at::Tensor
+ms_deform_attn_forward(
+    const at::Tensor &value, 
+    const at::Tensor &spatial_shapes,
+    const at::Tensor &level_start_index,
+    const at::Tensor &sampling_loc,
+    const at::Tensor &attn_weight,
+    const int im2col_step)
+{
+    if (value.type().is_cuda())
+    {
+#ifdef WITH_CUDA
+        return ms_deform_attn_cuda_forward(
+            value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
+#else
+        AT_ERROR("Not compiled with GPU support");
+#endif
+    }
+    AT_ERROR("Not implemented on the CPU");
+}
+
+std::vector<at::Tensor>
+ms_deform_attn_backward(
+    const at::Tensor &value, 
+    const at::Tensor &spatial_shapes,
+    const at::Tensor &level_start_index,
+    const at::Tensor &sampling_loc,
+    const at::Tensor &attn_weight,
+    const at::Tensor &grad_output,
+    const int im2col_step)
+{
+    if (value.type().is_cuda())
+    {
+#ifdef WITH_CUDA
+        return ms_deform_attn_cuda_backward(
+            value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
+#else
+        AT_ERROR("Not compiled with GPU support");
+#endif
+    }
+    AT_ERROR("Not implemented on the CPU");
+}
+
diff --git a/annotator/entityseg/mask2former/modeling/pixel_decoder/ops/src/vision.cpp b/annotator/entityseg/mask2former/modeling/pixel_decoder/ops/src/vision.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..4a08821e0121a77556aa7a263ec8ebfa928b13b6
--- /dev/null
+++ b/annotator/entityseg/mask2former/modeling/pixel_decoder/ops/src/vision.cpp
@@ -0,0 +1,21 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+/*!
+* Copyright (c) Facebook, Inc. and its affiliates.
+* Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
+*/
+
+#include "ms_deform_attn.h"
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+  m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
+  m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
+}
diff --git a/annotator/entityseg/mask2former/modeling/pixel_decoder/ops/test.py b/annotator/entityseg/mask2former/modeling/pixel_decoder/ops/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e1b545459f6fd3235767e721eb5a1090ae14bef
--- /dev/null
+++ b/annotator/entityseg/mask2former/modeling/pixel_decoder/ops/test.py
@@ -0,0 +1,92 @@
+# ------------------------------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------------------------------
+# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+# ------------------------------------------------------------------------------------------------
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
+
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
+
+import time
+import torch
+import torch.nn as nn
+from torch.autograd import gradcheck
+
+from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch
+
+
+N, M, D = 1, 2, 2
+Lq, L, P = 2, 2, 2
+shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
+level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1]))
+S = sum([(H*W).item() for H, W in shapes])
+
+
+torch.manual_seed(3)
+
+
+@torch.no_grad()
+def check_forward_equal_with_pytorch_double():
+    value = torch.rand(N, S, M, D).cuda() * 0.01
+    sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
+    attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
+    attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
+    im2col_step = 2
+    output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu()
+    output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu()
+    fwdok = torch.allclose(output_cuda, output_pytorch)
+    max_abs_err = (output_cuda - output_pytorch).abs().max()
+    max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
+
+    print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
+
+
+@torch.no_grad()
+def check_forward_equal_with_pytorch_float():
+    value = torch.rand(N, S, M, D).cuda() * 0.01
+    sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
+    attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
+    attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
+    im2col_step = 2
+    output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu()
+    output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu()
+    fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
+    max_abs_err = (output_cuda - output_pytorch).abs().max()
+    max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
+
+    print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
+
+
+def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True):
+
+    value = torch.rand(N, S, M, channels).cuda() * 0.01
+    sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
+    attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
+    attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
+    im2col_step = 2
+    func = MSDeformAttnFunction.apply
+
+    value.requires_grad = grad_value
+    sampling_locations.requires_grad = grad_sampling_loc
+    attention_weights.requires_grad = grad_attn_weight
+
+    gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step))
+
+    print(f'* {gradok} check_gradient_numerical(D={channels})')
+
+
+if __name__ == '__main__':
+    check_forward_equal_with_pytorch_double()
+    check_forward_equal_with_pytorch_float()
+
+    for channels in [30, 32, 64, 71, 1025, 2048, 3096]:
+        check_gradient_numerical(channels, True, True, True)
+
+
+
diff --git a/annotator/entityseg/mask2former/modeling/transformer_decoder/__init__.py b/annotator/entityseg/mask2former/modeling/transformer_decoder/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8fc088ec5a2eea03f85b7c95de4c745391dbfaa5
--- /dev/null
+++ b/annotator/entityseg/mask2former/modeling/transformer_decoder/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+from .maskformer_transformer_decoder import StandardTransformerDecoder
+from .mask2former_transformer_decoder import MultiScaleMaskedTransformerDecoder
+from .cropformer_transformer_decoder import CropSharedMultiScaleMaskedTransformerDecoder
+
diff --git a/annotator/entityseg/mask2former/modeling/transformer_decoder/cropformer_transformer_decoder.py b/annotator/entityseg/mask2former/modeling/transformer_decoder/cropformer_transformer_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a27fda6bd5309a1d860ec8affa930bb00d85819
--- /dev/null
+++ b/annotator/entityseg/mask2former/modeling/transformer_decoder/cropformer_transformer_decoder.py
@@ -0,0 +1,595 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/detr.py
+import logging
+import fvcore.nn.weight_init as weight_init
+from typing import Optional
+import torch
+from torch import nn, Tensor
+from torch.nn import functional as F
+
+from detectron2.config import configurable
+from detectron2.layers import Conv2d
+
+from .position_encoding import PositionEmbeddingSine3D2D
+from .maskformer_transformer_decoder import TRANSFORMER_DECODER_REGISTRY
+
+import pdb
+
+class SelfAttentionLayer(nn.Module):
+
+    def __init__(self, d_model, nhead, dropout=0.0,
+                 activation="relu", normalize_before=False):
+        super().__init__()
+        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+
+        self.norm = nn.LayerNorm(d_model)
+        self.dropout = nn.Dropout(dropout)
+
+        self.activation = _get_activation_fn(activation)
+        self.normalize_before = normalize_before
+
+        self._reset_parameters()
+    
+    def _reset_parameters(self):
+        for p in self.parameters():
+            if p.dim() > 1:
+                nn.init.xavier_uniform_(p)
+
+    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+        return tensor if pos is None else tensor + pos
+
+    def forward_post(self, tgt,
+                     tgt_mask: Optional[Tensor] = None,
+                     tgt_key_padding_mask: Optional[Tensor] = None,
+                     query_pos: Optional[Tensor] = None):
+        q = k = self.with_pos_embed(tgt, query_pos)
+        tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
+                              key_padding_mask=tgt_key_padding_mask)[0]
+        tgt = tgt + self.dropout(tgt2)
+        tgt = self.norm(tgt)
+
+        return tgt
+
+    def forward_pre(self, tgt,
+                    tgt_mask: Optional[Tensor] = None,
+                    tgt_key_padding_mask: Optional[Tensor] = None,
+                    query_pos: Optional[Tensor] = None):
+        tgt2 = self.norm(tgt)
+        q = k = self.with_pos_embed(tgt2, query_pos)
+        tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
+                              key_padding_mask=tgt_key_padding_mask)[0]
+        tgt = tgt + self.dropout(tgt2)
+        
+        return tgt
+
+    def forward(self, tgt,
+                tgt_mask: Optional[Tensor] = None,
+                tgt_key_padding_mask: Optional[Tensor] = None,
+                query_pos: Optional[Tensor] = None):
+        if self.normalize_before:
+            return self.forward_pre(tgt, tgt_mask,
+                                    tgt_key_padding_mask, query_pos)
+        return self.forward_post(tgt, tgt_mask,
+                                 tgt_key_padding_mask, query_pos)
+
+
+class CrossAttentionLayer(nn.Module):
+
+    def __init__(self, d_model, nhead, dropout=0.0,
+                 activation="relu", normalize_before=False):
+        super().__init__()
+        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+
+        self.norm = nn.LayerNorm(d_model)
+        self.dropout = nn.Dropout(dropout)
+
+        self.activation = _get_activation_fn(activation)
+        self.normalize_before = normalize_before
+
+        self._reset_parameters()
+    
+    def _reset_parameters(self):
+        for p in self.parameters():
+            if p.dim() > 1:
+                nn.init.xavier_uniform_(p)
+
+    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+        return tensor if pos is None else tensor + pos
+
+    def forward_post(self, tgt, memory,
+                     memory_mask: Optional[Tensor] = None,
+                     memory_key_padding_mask: Optional[Tensor] = None,
+                     pos: Optional[Tensor] = None,
+                     query_pos: Optional[Tensor] = None):
+        tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
+                                   key=self.with_pos_embed(memory, pos),
+                                   value=memory, attn_mask=memory_mask,
+                                   key_padding_mask=memory_key_padding_mask)[0]
+        tgt = tgt + self.dropout(tgt2)
+        tgt = self.norm(tgt)
+        
+        return tgt
+
+    def forward_pre(self, tgt, memory,
+                    memory_mask: Optional[Tensor] = None,
+                    memory_key_padding_mask: Optional[Tensor] = None,
+                    pos: Optional[Tensor] = None,
+                    query_pos: Optional[Tensor] = None):
+        tgt2 = self.norm(tgt)
+        tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
+                                   key=self.with_pos_embed(memory, pos),
+                                   value=memory, attn_mask=memory_mask,
+                                   key_padding_mask=memory_key_padding_mask)[0]
+        tgt = tgt + self.dropout(tgt2)
+
+        return tgt
+
+    def forward(self, tgt, memory,
+                memory_mask: Optional[Tensor] = None,
+                memory_key_padding_mask: Optional[Tensor] = None,
+                pos: Optional[Tensor] = None,
+                query_pos: Optional[Tensor] = None):
+        if self.normalize_before:
+            return self.forward_pre(tgt, memory, memory_mask,
+                                    memory_key_padding_mask, pos, query_pos)
+        return self.forward_post(tgt, memory, memory_mask,
+                                 memory_key_padding_mask, pos, query_pos)
+
+
+class FFNLayer(nn.Module):
+
+    def __init__(self, d_model, dim_feedforward=2048, dropout=0.0,
+                 activation="relu", normalize_before=False):
+        super().__init__()
+        # Implementation of Feedforward model
+        self.linear1 = nn.Linear(d_model, dim_feedforward)
+        self.dropout = nn.Dropout(dropout)
+        self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+        self.norm = nn.LayerNorm(d_model)
+
+        self.activation = _get_activation_fn(activation)
+        self.normalize_before = normalize_before
+
+        self._reset_parameters()
+    
+    def _reset_parameters(self):
+        for p in self.parameters():
+            if p.dim() > 1:
+                nn.init.xavier_uniform_(p)
+
+    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+        return tensor if pos is None else tensor + pos
+
+    def forward_post(self, tgt):
+        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
+        tgt = tgt + self.dropout(tgt2)
+        tgt = self.norm(tgt)
+        return tgt
+
+    def forward_pre(self, tgt):
+        tgt2 = self.norm(tgt)
+        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
+        tgt = tgt + self.dropout(tgt2)
+        return tgt
+
+    def forward(self, tgt):
+        if self.normalize_before:
+            return self.forward_pre(tgt)
+        return self.forward_post(tgt)
+
+def _get_activation_fn(activation):
+    """Return an activation function given a string"""
+    if activation == "relu":
+        return F.relu
+    if activation == "gelu":
+        return F.gelu
+    if activation == "glu":
+        return F.glu
+    raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
+
+
+class MLP(nn.Module):
+    """ Very simple multi-layer perceptron (also called FFN)"""
+
+    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
+        super().__init__()
+        self.num_layers = num_layers
+        h = [hidden_dim] * (num_layers - 1)
+        self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
+
+    def forward(self, x):
+        for i, layer in enumerate(self.layers):
+            x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+        return x
+
+class Make3dQueries(nn.Module):
+    _version = 2
+    def __init__(self, cfg):
+        super().__init__()
+        self.cfg = cfg
+        self.enc_crosattn_3d = nn.ModuleList()
+        self.enc_selfattn_3d = nn.ModuleList()
+        self.enc_ffn_3d      = nn.ModuleList()
+        self.num_layers_3d = cfg.ENTITY.FUSE_NUM_LAYERS
+        for _ in range(self.num_layers_3d):
+            self.enc_crosattn_3d.append(
+                CrossAttentionLayer(
+                    d_model=cfg.ENTITY.FUSE_ENC_HIDDIEN_DIM,
+                    nhead=cfg.ENTITY.FUSE_ENC_NHEADS,
+                    dropout=0.0,
+                    normalize_before=cfg.ENTITY.FUSE_ENC_PRE_NORM)
+            )
+            self.enc_selfattn_3d.append(
+                SelfAttentionLayer(
+                    d_model=cfg.ENTITY.FUSE_ENC_HIDDIEN_DIM,
+                    nhead=cfg.ENTITY.FUSE_ENC_NHEADS,
+                    dropout=0.0,
+                    normalize_before=cfg.ENTITY.FUSE_ENC_PRE_NORM)
+                    )
+            self.enc_ffn_3d.append(
+                FFNLayer(
+                    d_model=cfg.ENTITY.FUSE_ENC_HIDDIEN_DIM,
+                    dim_feedforward=cfg.ENTITY.FUSE_ENC_DIM_FEEDFORWARD,
+                    dropout=0.0,
+                    normalize_before=cfg.ENTITY.FUSE_ENC_PRE_NORM,
+                    )
+            )
+    
+    def forward(self, output_2d, query_embed_2d, query_embed_3d):
+        Q, BT, C = query_embed_2d.shape
+        Q, B, C  = query_embed_3d.shape
+        T = int(BT / B)
+
+        output_3d = output_2d[:,0::T,:]
+        ### (Q, B, T, C)
+        output_2d = output_2d.unflatten(1, (B, T)).permute((0,2,1,3)).flatten(0,1)
+        query_embed_2d = query_embed_2d.unflatten(1, (B, T)).permute((0,2,1,3)).flatten(0,1)
+
+        for i in range(self.num_layers_3d):
+            output_3d = self.enc_crosattn_3d[i](output_3d, output_2d, pos=query_embed_2d, query_pos=query_embed_3d)
+            output_3d = self.enc_selfattn_3d[i](output_3d)
+            output_3d = self.enc_ffn_3d[i](output_3d)
+        
+        return output_3d
+
+
+@TRANSFORMER_DECODER_REGISTRY.register()
+class CropSharedMultiScaleMaskedTransformerDecoder(nn.Module):
+    _version = 2
+
+    def _load_from_state_dict(
+        self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
+    ):
+        version = local_metadata.get("version", None)
+        if version is None or version < 2:
+            # Do not warn if train from scratch
+            scratch = True
+            logger = logging.getLogger(__name__)
+            for k in list(state_dict.keys()):
+                newk = k
+                if "static_query" in k:
+                    newk = k.replace("static_query", "query_feat")
+                if newk != k:
+                    state_dict[newk] = state_dict[k]
+                    del state_dict[k]
+                    scratch = False
+
+            if not scratch:
+                logger.warning(
+                    f"Weight format of {self.__class__.__name__} have changed! "
+                    "Please upgrade your models. Applying automatic conversion now ..."
+                )
+
+    @configurable
+    def __init__(
+        self,
+        cfg,
+        in_channels,
+        mask_classification=True,
+        *,
+        num_classes: int,
+        hidden_dim: int,
+        num_queries: int,
+        nheads: int,
+        dim_feedforward: int,
+        dec_layers: int,
+        pre_norm: bool,
+        mask_dim: int,
+        enforce_input_project: bool,
+    ):
+        """
+        NOTE: this interface is experimental.
+        Args:
+            in_channels: channels of the input features
+            mask_classification: whether to add mask classifier or not
+            num_classes: number of classes
+            hidden_dim: Transformer feature dimension
+            num_queries: number of queries
+            nheads: number of heads
+            dim_feedforward: feature dimension in feedforward network
+            enc_layers: number of Transformer encoder layers
+            dec_layers: number of Transformer decoder layers
+            pre_norm: whether to use pre-LayerNorm or not
+            mask_dim: mask feature dimension
+            enforce_input_project: add input project 1x1 conv even if input
+                channels and hidden dim is identical
+        """
+        super().__init__()
+
+        assert mask_classification, "Only support mask classification model"
+        self.cfg = cfg
+
+        self.mask_classification = mask_classification
+        # positional encoding
+        N_steps = hidden_dim // 2
+        
+        self.pe_layer = PositionEmbeddingSine3D2D(N_steps, normalize=True)
+        
+        # define Transformer decoder here
+        self.num_heads = nheads
+        self.num_layers = dec_layers
+        self.transformer_self_attention_layers = nn.ModuleList()
+        self.transformer_cross_attention_layers = nn.ModuleList()
+        self.transformer_ffn_layers = nn.ModuleList()
+
+        for _ in range(self.num_layers):
+            self.transformer_self_attention_layers.append(
+                SelfAttentionLayer(
+                    d_model=hidden_dim,
+                    nhead=nheads,
+                    dropout=0.0,
+                    normalize_before=pre_norm,
+                )
+            )
+
+            self.transformer_cross_attention_layers.append(
+                CrossAttentionLayer(
+                    d_model=hidden_dim,
+                    nhead=nheads,
+                    dropout=0.0,
+                    normalize_before=pre_norm,
+                )
+            )
+
+            self.transformer_ffn_layers.append(
+                FFNLayer(
+                    d_model=hidden_dim,
+                    dim_feedforward=dim_feedforward,
+                    dropout=0.0,
+                    normalize_before=pre_norm,
+                )
+            )
+
+        self.make_3d = Make3dQueries(cfg)
+        self.decoder_norm = nn.LayerNorm(hidden_dim)
+
+        self.num_queries = num_queries
+        # learnable query features
+        self.query_feat = nn.Embedding(num_queries, hidden_dim)
+        # learnable query p.e.
+        self.query_embed = nn.Embedding(num_queries, hidden_dim)
+
+        # level embedding (we always use 3 scales)
+        self.num_feature_levels = 3
+        self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim)
+        self.input_proj = nn.ModuleList()
+        for _ in range(self.num_feature_levels):
+            if in_channels != hidden_dim or enforce_input_project:
+                self.input_proj.append(Conv2d(in_channels, hidden_dim, kernel_size=1))
+                weight_init.c2_xavier_fill(self.input_proj[-1])
+                weight_init.c2_xavier_fill(self.input_proj_3d[-1])
+            else:
+                self.input_proj.append(nn.Sequential())
+
+        # output FFNs
+        if self.mask_classification:
+            self.class_embed = nn.Linear(hidden_dim, num_classes + 1)
+        self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
+
+    @classmethod
+    def from_config(cls, cfg, in_channels, mask_classification):
+        ret = {}
+        ret["cfg"] = cfg
+        ret["in_channels"] = in_channels
+        ret["mask_classification"] = mask_classification
+        
+        ret["num_classes"] = cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES
+        ret["hidden_dim"] = cfg.MODEL.MASK_FORMER.HIDDEN_DIM
+        ret["num_queries"] = cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES
+        # Transformer parameters:
+        ret["nheads"] = cfg.MODEL.MASK_FORMER.NHEADS
+        ret["dim_feedforward"] = cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD
+
+        # NOTE: because we add learnable query features which requires supervision,
+        # we add minus 1 to decoder layers to be consistent with our loss
+        # implementation: that is, number of auxiliary losses is always
+        # equal to number of decoder layers. With learnable query features, the number of
+        # auxiliary losses equals number of decoders plus 1.
+        assert cfg.MODEL.MASK_FORMER.DEC_LAYERS >= 1
+        ret["dec_layers"] = cfg.MODEL.MASK_FORMER.DEC_LAYERS - 1
+        ret["pre_norm"] = cfg.MODEL.MASK_FORMER.PRE_NORM
+        ret["enforce_input_project"] = cfg.MODEL.MASK_FORMER.ENFORCE_INPUT_PROJ
+
+        ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM
+
+        return ret
+
+    def forward(self, x, mask_features, mask = None):
+        # x is a list of multi-scale feature
+        assert len(x) == self.num_feature_levels
+
+        bt, c_m, h_m, w_m = mask_features.shape
+        bs = bt // (self.cfg.ENTITY.CROP_SAMPLE_NUM_TRAIN+1) if self.training else 1
+        # bs = bt // self.num_views if self.training else 1
+        t_m = bt // bs
+        mask_features_2d = mask_features
+        mask_features_3d = mask_features.view(bs, t_m, c_m, h_m, w_m)
+
+        src_2d, src_3d = [], []
+        pos_2d, pos_3d = [], []
+        size_list = []
+
+        # disable mask, it does not affect performance
+        del mask
+
+        # pdb.set_trace()
+        for i in range(self.num_feature_levels):
+            size_list.append(x[i].shape[-2:])
+            pos_2d_, pos_3d_ = self.pe_layer(x[i].view(bs, t_m, -1, size_list[-1][0], size_list[-1][1]))
+            
+            pos_3d.append(pos_3d_.flatten(3))
+            src_3d.append(self.input_proj[i](x[i]).flatten(2) + self.level_embed.weight[i][None, :, None])
+
+            pos_2d.append(pos_2d_.flatten(2))
+            src_2d.append(self.input_proj[i](x[i]).flatten(2) + self.level_embed.weight[i][None, :, None])
+
+            # NTxCxHW => NxTxCxHW => (TxHW)xNxC
+            _, c, hw = src_3d[-1].shape
+            pos_3d[-1] = pos_3d[-1].view(bs, t_m, c, hw).permute(1, 3, 0, 2).flatten(0, 1)
+            src_3d[-1] = src_3d[-1].view(bs, t_m, c, hw).permute(1, 3, 0, 2).flatten(0, 1)
+
+            pos_2d[-1] = pos_2d[-1].permute(2,0,1)
+            src_2d[-1] = src_2d[-1].permute(2,0,1)
+
+        # QxNxC
+        query_embed_2d = self.query_embed.weight.unsqueeze(1).repeat(1, bt, 1)
+        output_2d = self.query_feat.weight.unsqueeze(1).repeat(1, bt, 1)
+
+        predictions_class_2d = []
+        predictions_mask_2d  = []
+
+        # prediction heads on learnable query features
+        outputs_class_2d, outputs_mask_2d, attn_mask_2d, embedding_2d = self.forward_prediction_heads(output_2d, mask_features_2d, output_type="2d", attn_mask_target_size=size_list[0])
+        predictions_class_2d.append(outputs_class_2d)
+        predictions_mask_2d.append(outputs_mask_2d)
+        
+        # pdb.set_trace()
+        for i in range(self.num_layers):
+            level_index = i % self.num_feature_levels
+            attn_mask_2d[torch.where(attn_mask_2d.sum(-1) == attn_mask_2d.shape[-1])] = False
+            # attention: cross-attention first
+            output_2d = self.transformer_cross_attention_layers[i](
+                output_2d, src_2d[level_index],
+                memory_mask=attn_mask_2d,
+                memory_key_padding_mask=None,  # here we do not apply masking on padded region
+                pos=pos_2d[level_index], query_pos=query_embed_2d
+            )
+
+            output_2d = self.transformer_self_attention_layers[i](
+                output_2d, tgt_mask=None,
+                tgt_key_padding_mask=None,
+                query_pos=query_embed_2d
+            )
+            
+            # FFN
+            output_2d = self.transformer_ffn_layers[i](
+                output_2d
+            )
+
+            outputs_class_2d, outputs_mask_2d, attn_mask_2d, embedding_2d = self.forward_prediction_heads(output_2d, mask_features_2d, output_type="2d", attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels])
+            predictions_class_2d.append(outputs_class_2d)
+            predictions_mask_2d.append(outputs_mask_2d)
+
+        assert len(predictions_class_2d) == self.num_layers + 1
+
+        out_2d = {
+            'pred_logits': predictions_class_2d[-1],
+            'pred_masks': predictions_mask_2d[-1],
+            'aux_outputs': self._set_aux_loss(
+                predictions_class_2d if self.mask_classification else None, predictions_mask_2d
+            )
+        }
+
+        predictions_class_3d = []
+        predictions_mask_3d  = []
+
+        query_embed_3d = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1)
+
+        output_3d = self.make_3d(output_2d, query_embed_2d, query_embed_3d)
+        
+        # self.fused
+        outputs_class_3d, outputs_mask_3d, attn_mask_3d, embedding_3d = self.forward_prediction_heads(output_3d, mask_features_3d, output_type="3d", attn_mask_target_size=size_list[0])
+        predictions_class_3d.append(outputs_class_3d)
+        predictions_mask_3d.append(outputs_mask_3d)
+
+        for i in range(self.num_layers):
+            level_index = i % self.num_feature_levels
+            attn_mask_3d[torch.where(attn_mask_3d.sum(-1) == attn_mask_3d.shape[-1])] = False
+            ################# 3d (unified) #############
+            # attention: cross-attention first
+            output_3d = self.transformer_cross_attention_layers[i](
+                output_3d, src_3d[level_index],
+                memory_mask=attn_mask_3d,
+                memory_key_padding_mask=None,  # here we do not apply masking on padded region
+                pos=pos_3d[level_index], query_pos=query_embed_3d
+            )
+
+            output_3d = self.transformer_self_attention_layers[i](
+                output_3d, tgt_mask=None,
+                tgt_key_padding_mask=None,
+                query_pos=query_embed_3d
+            )
+            
+            output_3d = self.transformer_ffn_layers[i](
+                output_3d
+            )
+
+            outputs_class_3d, outputs_mask_3d, attn_mask_3d, embedding_3d = self.forward_prediction_heads(output_3d, mask_features_3d, output_type="3d", attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels])
+            predictions_class_3d.append(outputs_class_3d)
+            predictions_mask_3d.append(outputs_mask_3d)
+
+        # assert len(predictions_class_3d) == self.num_layers + 1
+
+        out_3d = {
+            'pred_logits': predictions_class_3d[-1],
+            'pred_masks': predictions_mask_3d[-1],
+            'aux_outputs': self._set_aux_loss(
+                predictions_class_3d if self.mask_classification else None, predictions_mask_3d
+            ),
+        }
+
+        return out_2d, out_3d
+
+    def forward_prediction_heads(self, output, mask_features, output_type, attn_mask_target_size):
+        decoder_output = self.decoder_norm(output)
+        decoder_output = decoder_output.transpose(0, 1)
+        outputs_class  = self.class_embed(decoder_output)
+        mask_embed     = self.mask_embed(decoder_output)
+        if output_type == "3d":
+            outputs_mask = torch.einsum("bqc,btchw->bqthw", mask_embed, mask_features)
+            b, q, t, _, _ = outputs_mask.shape
+            # NOTE: prediction is of higher-resolution
+            # [B, Q, T, H, W] -> [B, Q, T*H*W] -> [B, h, Q, T*H*W] -> [B*h, Q, T*HW]
+            attn_mask = F.interpolate(outputs_mask.flatten(0, 1), size=attn_mask_target_size, mode="bilinear", align_corners=False).view(
+            b, q, t, attn_mask_target_size[0], attn_mask_target_size[1])
+            # must use bool type
+            # If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged.
+            attn_mask = (attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1) < 0.5).bool()
+            attn_mask = attn_mask.detach()
+        elif output_type == "2d":
+            outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features)
+            # NOTE: prediction is of higher-resolution
+            # [B, Q, H, W] -> [B, Q, H*W] -> [B, h, Q, H*W] -> [B*h, Q, HW]
+            attn_mask = F.interpolate(outputs_mask, size=attn_mask_target_size, mode="bilinear", align_corners=False)
+            # must use bool type
+            # If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged.
+            attn_mask = (attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1) < 0.5).bool()
+            attn_mask = attn_mask.detach()
+        else:
+            raise "the output_type should be 2d or 3d"
+        
+        return outputs_class, outputs_mask, attn_mask, decoder_output
+
+    @torch.jit.unused
+    def _set_aux_loss(self, outputs_class, outputs_seg_masks):
+        # this is a workaround to make torchscript happy, as torchscript
+        # doesn't support dictionary with non-homogeneous values, such
+        # as a dict having both a Tensor and a list.
+        if self.mask_classification:
+            return [
+                {"pred_logits": a, "pred_masks": b}
+                for a, b in zip(outputs_class[:-1], outputs_seg_masks[:-1])
+            ]
+        else:
+            return [{"pred_masks": b} for b in outputs_seg_masks[:-1]]
\ No newline at end of file
diff --git a/annotator/entityseg/mask2former/modeling/transformer_decoder/mask2former_transformer_decoder.py b/annotator/entityseg/mask2former/modeling/transformer_decoder/mask2former_transformer_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..52594f62693e6bf48a4c140ba2fe7131a0317774
--- /dev/null
+++ b/annotator/entityseg/mask2former/modeling/transformer_decoder/mask2former_transformer_decoder.py
@@ -0,0 +1,461 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/detr.py
+import logging
+import fvcore.nn.weight_init as weight_init
+from typing import Optional
+import torch
+from torch import nn, Tensor
+from torch.nn import functional as F
+
+from detectron2.config import configurable
+from detectron2.layers import Conv2d
+
+from .position_encoding import PositionEmbeddingSine
+from .maskformer_transformer_decoder import TRANSFORMER_DECODER_REGISTRY
+
+
+class SelfAttentionLayer(nn.Module):
+
+    def __init__(self, d_model, nhead, dropout=0.0,
+                 activation="relu", normalize_before=False):
+        super().__init__()
+        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+
+        self.norm = nn.LayerNorm(d_model)
+        self.dropout = nn.Dropout(dropout)
+
+        self.activation = _get_activation_fn(activation)
+        self.normalize_before = normalize_before
+
+        self._reset_parameters()
+    
+    def _reset_parameters(self):
+        for p in self.parameters():
+            if p.dim() > 1:
+                nn.init.xavier_uniform_(p)
+
+    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+        return tensor if pos is None else tensor + pos
+
+    def forward_post(self, tgt,
+                     tgt_mask: Optional[Tensor] = None,
+                     tgt_key_padding_mask: Optional[Tensor] = None,
+                     query_pos: Optional[Tensor] = None):
+        q = k = self.with_pos_embed(tgt, query_pos)
+        tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
+                              key_padding_mask=tgt_key_padding_mask)[0]
+        tgt = tgt + self.dropout(tgt2)
+        tgt = self.norm(tgt)
+
+        return tgt
+
+    def forward_pre(self, tgt,
+                    tgt_mask: Optional[Tensor] = None,
+                    tgt_key_padding_mask: Optional[Tensor] = None,
+                    query_pos: Optional[Tensor] = None):
+        tgt2 = self.norm(tgt)
+        q = k = self.with_pos_embed(tgt2, query_pos)
+        tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
+                              key_padding_mask=tgt_key_padding_mask)[0]
+        tgt = tgt + self.dropout(tgt2)
+        
+        return tgt
+
+    def forward(self, tgt,
+                tgt_mask: Optional[Tensor] = None,
+                tgt_key_padding_mask: Optional[Tensor] = None,
+                query_pos: Optional[Tensor] = None):
+        if self.normalize_before:
+            return self.forward_pre(tgt, tgt_mask,
+                                    tgt_key_padding_mask, query_pos)
+        return self.forward_post(tgt, tgt_mask,
+                                 tgt_key_padding_mask, query_pos)
+
+
+class CrossAttentionLayer(nn.Module):
+
+    def __init__(self, d_model, nhead, dropout=0.0,
+                 activation="relu", normalize_before=False):
+        super().__init__()
+        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+
+        self.norm = nn.LayerNorm(d_model)
+        self.dropout = nn.Dropout(dropout)
+
+        self.activation = _get_activation_fn(activation)
+        self.normalize_before = normalize_before
+
+        self._reset_parameters()
+    
+    def _reset_parameters(self):
+        for p in self.parameters():
+            if p.dim() > 1:
+                nn.init.xavier_uniform_(p)
+
+    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+        return tensor if pos is None else tensor + pos
+
+    def forward_post(self, tgt, memory,
+                     memory_mask: Optional[Tensor] = None,
+                     memory_key_padding_mask: Optional[Tensor] = None,
+                     pos: Optional[Tensor] = None,
+                     query_pos: Optional[Tensor] = None):
+        tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
+                                   key=self.with_pos_embed(memory, pos),
+                                   value=memory, attn_mask=memory_mask,
+                                   key_padding_mask=memory_key_padding_mask)[0]
+        tgt = tgt + self.dropout(tgt2)
+        tgt = self.norm(tgt)
+        
+        return tgt
+
+    def forward_pre(self, tgt, memory,
+                    memory_mask: Optional[Tensor] = None,
+                    memory_key_padding_mask: Optional[Tensor] = None,
+                    pos: Optional[Tensor] = None,
+                    query_pos: Optional[Tensor] = None):
+        tgt2 = self.norm(tgt)
+        tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
+                                   key=self.with_pos_embed(memory, pos),
+                                   value=memory, attn_mask=memory_mask,
+                                   key_padding_mask=memory_key_padding_mask)[0]
+        tgt = tgt + self.dropout(tgt2)
+
+        return tgt
+
+    def forward(self, tgt, memory,
+                memory_mask: Optional[Tensor] = None,
+                memory_key_padding_mask: Optional[Tensor] = None,
+                pos: Optional[Tensor] = None,
+                query_pos: Optional[Tensor] = None):
+        if self.normalize_before:
+            return self.forward_pre(tgt, memory, memory_mask,
+                                    memory_key_padding_mask, pos, query_pos)
+        return self.forward_post(tgt, memory, memory_mask,
+                                 memory_key_padding_mask, pos, query_pos)
+
+
+class FFNLayer(nn.Module):
+
+    def __init__(self, d_model, dim_feedforward=2048, dropout=0.0,
+                 activation="relu", normalize_before=False):
+        super().__init__()
+        # Implementation of Feedforward model
+        self.linear1 = nn.Linear(d_model, dim_feedforward)
+        self.dropout = nn.Dropout(dropout)
+        self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+        self.norm = nn.LayerNorm(d_model)
+
+        self.activation = _get_activation_fn(activation)
+        self.normalize_before = normalize_before
+
+        self._reset_parameters()
+    
+    def _reset_parameters(self):
+        for p in self.parameters():
+            if p.dim() > 1:
+                nn.init.xavier_uniform_(p)
+
+    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+        return tensor if pos is None else tensor + pos
+
+    def forward_post(self, tgt):
+        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
+        tgt = tgt + self.dropout(tgt2)
+        tgt = self.norm(tgt)
+        return tgt
+
+    def forward_pre(self, tgt):
+        tgt2 = self.norm(tgt)
+        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
+        tgt = tgt + self.dropout(tgt2)
+        return tgt
+
+    def forward(self, tgt):
+        if self.normalize_before:
+            return self.forward_pre(tgt)
+        return self.forward_post(tgt)
+
+
+def _get_activation_fn(activation):
+    """Return an activation function given a string"""
+    if activation == "relu":
+        return F.relu
+    if activation == "gelu":
+        return F.gelu
+    if activation == "glu":
+        return F.glu
+    raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
+
+
+class MLP(nn.Module):
+    """ Very simple multi-layer perceptron (also called FFN)"""
+
+    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
+        super().__init__()
+        self.num_layers = num_layers
+        h = [hidden_dim] * (num_layers - 1)
+        self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
+
+    def forward(self, x):
+        for i, layer in enumerate(self.layers):
+            x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+        return x
+
+
+@TRANSFORMER_DECODER_REGISTRY.register()
+class MultiScaleMaskedTransformerDecoder(nn.Module):
+
+    _version = 2
+
+    def _load_from_state_dict(
+        self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
+    ):
+        version = local_metadata.get("version", None)
+        if version is None or version < 2:
+            # Do not warn if train from scratch
+            scratch = True
+            logger = logging.getLogger(__name__)
+            for k in list(state_dict.keys()):
+                newk = k
+                if "static_query" in k:
+                    newk = k.replace("static_query", "query_feat")
+                if newk != k:
+                    state_dict[newk] = state_dict[k]
+                    del state_dict[k]
+                    scratch = False
+
+            if not scratch:
+                logger.warning(
+                    f"Weight format of {self.__class__.__name__} have changed! "
+                    "Please upgrade your models. Applying automatic conversion now ..."
+                )
+
+    @configurable
+    def __init__(
+        self,
+        in_channels,
+        mask_classification=True,
+        *,
+        num_classes: int,
+        hidden_dim: int,
+        num_queries: int,
+        nheads: int,
+        dim_feedforward: int,
+        dec_layers: int,
+        pre_norm: bool,
+        mask_dim: int,
+        enforce_input_project: bool,
+    ):
+        """
+        NOTE: this interface is experimental.
+        Args:
+            in_channels: channels of the input features
+            mask_classification: whether to add mask classifier or not
+            num_classes: number of classes
+            hidden_dim: Transformer feature dimension
+            num_queries: number of queries
+            nheads: number of heads
+            dim_feedforward: feature dimension in feedforward network
+            enc_layers: number of Transformer encoder layers
+            dec_layers: number of Transformer decoder layers
+            pre_norm: whether to use pre-LayerNorm or not
+            mask_dim: mask feature dimension
+            enforce_input_project: add input project 1x1 conv even if input
+                channels and hidden dim is identical
+        """
+        super().__init__()
+
+        assert mask_classification, "Only support mask classification model"
+        self.mask_classification = mask_classification
+
+        # positional encoding
+        N_steps = hidden_dim // 2
+        self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
+        
+        # define Transformer decoder here
+        self.num_heads = nheads
+        self.num_layers = dec_layers
+        self.transformer_self_attention_layers = nn.ModuleList()
+        self.transformer_cross_attention_layers = nn.ModuleList()
+        self.transformer_ffn_layers = nn.ModuleList()
+
+        for _ in range(self.num_layers):
+            self.transformer_self_attention_layers.append(
+                SelfAttentionLayer(
+                    d_model=hidden_dim,
+                    nhead=nheads,
+                    dropout=0.0,
+                    normalize_before=pre_norm,
+                )
+            )
+
+            self.transformer_cross_attention_layers.append(
+                CrossAttentionLayer(
+                    d_model=hidden_dim,
+                    nhead=nheads,
+                    dropout=0.0,
+                    normalize_before=pre_norm,
+                )
+            )
+
+            self.transformer_ffn_layers.append(
+                FFNLayer(
+                    d_model=hidden_dim,
+                    dim_feedforward=dim_feedforward,
+                    dropout=0.0,
+                    normalize_before=pre_norm,
+                )
+            )
+
+        self.decoder_norm = nn.LayerNorm(hidden_dim)
+
+        self.num_queries = num_queries
+        # learnable query features
+        self.query_feat = nn.Embedding(num_queries, hidden_dim)
+        # learnable query p.e.
+        self.query_embed = nn.Embedding(num_queries, hidden_dim)
+
+        # level embedding (we always use 3 scales)
+        self.num_feature_levels = 3
+        self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim)
+        self.input_proj = nn.ModuleList()
+        for _ in range(self.num_feature_levels):
+            if in_channels != hidden_dim or enforce_input_project:
+                self.input_proj.append(Conv2d(in_channels, hidden_dim, kernel_size=1))
+                weight_init.c2_xavier_fill(self.input_proj[-1])
+            else:
+                self.input_proj.append(nn.Sequential())
+
+        # output FFNs
+        if self.mask_classification:
+            self.class_embed = nn.Linear(hidden_dim, num_classes + 1)
+        self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
+
+    @classmethod
+    def from_config(cls, cfg, in_channels, mask_classification):
+        ret = {}
+        ret["in_channels"] = in_channels
+        ret["mask_classification"] = mask_classification
+        
+        ret["num_classes"] = cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES
+        ret["hidden_dim"] = cfg.MODEL.MASK_FORMER.HIDDEN_DIM
+        ret["num_queries"] = cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES
+        # Transformer parameters:
+        ret["nheads"] = cfg.MODEL.MASK_FORMER.NHEADS
+        ret["dim_feedforward"] = cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD
+
+        # NOTE: because we add learnable query features which requires supervision,
+        # we add minus 1 to decoder layers to be consistent with our loss
+        # implementation: that is, number of auxiliary losses is always
+        # equal to number of decoder layers. With learnable query features, the number of
+        # auxiliary losses equals number of decoders plus 1.
+        assert cfg.MODEL.MASK_FORMER.DEC_LAYERS >= 1
+        ret["dec_layers"] = cfg.MODEL.MASK_FORMER.DEC_LAYERS - 1
+        ret["pre_norm"] = cfg.MODEL.MASK_FORMER.PRE_NORM
+        ret["enforce_input_project"] = cfg.MODEL.MASK_FORMER.ENFORCE_INPUT_PROJ
+
+        ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM
+
+        return ret
+
+    def forward(self, x, mask_features, mask = None):
+        # x is a list of multi-scale feature
+        assert len(x) == self.num_feature_levels
+        src = []
+        pos = []
+        size_list = []
+
+        # disable mask, it does not affect performance
+        del mask
+
+        for i in range(self.num_feature_levels):
+            size_list.append(x[i].shape[-2:])
+            pos.append(self.pe_layer(x[i], None).flatten(2))
+            src.append(self.input_proj[i](x[i]).flatten(2) + self.level_embed.weight[i][None, :, None])
+
+            # flatten NxCxHxW to HWxNxC
+            pos[-1] = pos[-1].permute(2, 0, 1)
+            src[-1] = src[-1].permute(2, 0, 1)
+
+        _, bs, _ = src[0].shape
+
+        # QxNxC
+        query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1)
+        output = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1)
+
+        predictions_class = []
+        predictions_mask = []
+
+        # prediction heads on learnable query features
+        outputs_class, outputs_mask, attn_mask = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[0])
+        predictions_class.append(outputs_class)
+        predictions_mask.append(outputs_mask)
+
+        for i in range(self.num_layers):
+            level_index = i % self.num_feature_levels
+            attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False
+            # attention: cross-attention first
+            output = self.transformer_cross_attention_layers[i](
+                output, src[level_index],
+                memory_mask=attn_mask,
+                memory_key_padding_mask=None,  # here we do not apply masking on padded region
+                pos=pos[level_index], query_pos=query_embed
+            )
+
+            output = self.transformer_self_attention_layers[i](
+                output, tgt_mask=None,
+                tgt_key_padding_mask=None,
+                query_pos=query_embed
+            )
+            
+            # FFN
+            output = self.transformer_ffn_layers[i](
+                output
+            )
+
+            outputs_class, outputs_mask, attn_mask = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels])
+            predictions_class.append(outputs_class)
+            predictions_mask.append(outputs_mask)
+
+        assert len(predictions_class) == self.num_layers + 1
+
+        out = {
+            'pred_logits': predictions_class[-1],
+            'pred_masks': predictions_mask[-1],
+            'aux_outputs': self._set_aux_loss(
+                predictions_class if self.mask_classification else None, predictions_mask
+            )
+        }
+        return out
+
+    def forward_prediction_heads(self, output, mask_features, attn_mask_target_size):
+        decoder_output = self.decoder_norm(output)
+        decoder_output = decoder_output.transpose(0, 1)
+        outputs_class = self.class_embed(decoder_output)
+        mask_embed = self.mask_embed(decoder_output)
+        outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features)
+
+        # NOTE: prediction is of higher-resolution
+        # [B, Q, H, W] -> [B, Q, H*W] -> [B, h, Q, H*W] -> [B*h, Q, HW]
+        attn_mask = F.interpolate(outputs_mask, size=attn_mask_target_size, mode="bilinear", align_corners=False)
+        # must use bool type
+        # If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged.
+        attn_mask = (attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1) < 0.5).bool()
+        attn_mask = attn_mask.detach()
+
+        return outputs_class, outputs_mask, attn_mask
+
+    @torch.jit.unused
+    def _set_aux_loss(self, outputs_class, outputs_seg_masks):
+        # this is a workaround to make torchscript happy, as torchscript
+        # doesn't support dictionary with non-homogeneous values, such
+        # as a dict having both a Tensor and a list.
+        if self.mask_classification:
+            return [
+                {"pred_logits": a, "pred_masks": b}
+                for a, b in zip(outputs_class[:-1], outputs_seg_masks[:-1])
+            ]
+        else:
+            return [{"pred_masks": b} for b in outputs_seg_masks[:-1]]
diff --git a/annotator/entityseg/mask2former/modeling/transformer_decoder/maskformer_transformer_decoder.py b/annotator/entityseg/mask2former/modeling/transformer_decoder/maskformer_transformer_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..79f09fa43f2f5a33c3422a6bb999b20763ab8b5e
--- /dev/null
+++ b/annotator/entityseg/mask2former/modeling/transformer_decoder/maskformer_transformer_decoder.py
@@ -0,0 +1,188 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/detr.py
+import fvcore.nn.weight_init as weight_init
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from detectron2.config import configurable
+from detectron2.layers import Conv2d
+from detectron2.utils.registry import Registry
+
+from .position_encoding import PositionEmbeddingSine
+from .transformer import Transformer
+
+
+TRANSFORMER_DECODER_REGISTRY = Registry("TRANSFORMER_MODULE")
+TRANSFORMER_DECODER_REGISTRY.__doc__ = """
+Registry for transformer module in MaskFormer.
+"""
+
+
+def build_transformer_decoder(cfg, in_channels, mask_classification=True):
+    """
+    Build a instance embedding branch from `cfg.MODEL.INS_EMBED_HEAD.NAME`.
+    """
+    name = cfg.MODEL.MASK_FORMER.TRANSFORMER_DECODER_NAME
+    return TRANSFORMER_DECODER_REGISTRY.get(name)(cfg, in_channels, mask_classification)
+
+
+@TRANSFORMER_DECODER_REGISTRY.register()
+class StandardTransformerDecoder(nn.Module):
+    @configurable
+    def __init__(
+        self,
+        in_channels,
+        mask_classification=True,
+        *,
+        num_classes: int,
+        hidden_dim: int,
+        num_queries: int,
+        nheads: int,
+        dropout: float,
+        dim_feedforward: int,
+        enc_layers: int,
+        dec_layers: int,
+        pre_norm: bool,
+        deep_supervision: bool,
+        mask_dim: int,
+        enforce_input_project: bool,
+    ):
+        """
+        NOTE: this interface is experimental.
+        Args:
+            in_channels: channels of the input features
+            mask_classification: whether to add mask classifier or not
+            num_classes: number of classes
+            hidden_dim: Transformer feature dimension
+            num_queries: number of queries
+            nheads: number of heads
+            dropout: dropout in Transformer
+            dim_feedforward: feature dimension in feedforward network
+            enc_layers: number of Transformer encoder layers
+            dec_layers: number of Transformer decoder layers
+            pre_norm: whether to use pre-LayerNorm or not
+            deep_supervision: whether to add supervision to every decoder layers
+            mask_dim: mask feature dimension
+            enforce_input_project: add input project 1x1 conv even if input
+                channels and hidden dim is identical
+        """
+        super().__init__()
+
+        self.mask_classification = mask_classification
+
+        # positional encoding
+        N_steps = hidden_dim // 2
+        self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
+
+        transformer = Transformer(
+            d_model=hidden_dim,
+            dropout=dropout,
+            nhead=nheads,
+            dim_feedforward=dim_feedforward,
+            num_encoder_layers=enc_layers,
+            num_decoder_layers=dec_layers,
+            normalize_before=pre_norm,
+            return_intermediate_dec=deep_supervision,
+        )
+
+        self.num_queries = num_queries
+        self.transformer = transformer
+        hidden_dim = transformer.d_model
+
+        self.query_embed = nn.Embedding(num_queries, hidden_dim)
+
+        if in_channels != hidden_dim or enforce_input_project:
+            self.input_proj = Conv2d(in_channels, hidden_dim, kernel_size=1)
+            weight_init.c2_xavier_fill(self.input_proj)
+        else:
+            self.input_proj = nn.Sequential()
+        self.aux_loss = deep_supervision
+
+        # output FFNs
+        if self.mask_classification:
+            self.class_embed = nn.Linear(hidden_dim, num_classes + 1)
+        self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
+
+    @classmethod
+    def from_config(cls, cfg, in_channels, mask_classification):
+        ret = {}
+        ret["in_channels"] = in_channels
+        ret["mask_classification"] = mask_classification
+
+        ret["num_classes"] = cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES
+        ret["hidden_dim"] = cfg.MODEL.MASK_FORMER.HIDDEN_DIM
+        ret["num_queries"] = cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES
+        # Transformer parameters:
+        ret["nheads"] = cfg.MODEL.MASK_FORMER.NHEADS
+        ret["dropout"] = cfg.MODEL.MASK_FORMER.DROPOUT
+        ret["dim_feedforward"] = cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD
+        ret["enc_layers"] = cfg.MODEL.MASK_FORMER.ENC_LAYERS
+        ret["dec_layers"] = cfg.MODEL.MASK_FORMER.DEC_LAYERS
+        ret["pre_norm"] = cfg.MODEL.MASK_FORMER.PRE_NORM
+        ret["deep_supervision"] = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION
+        ret["enforce_input_project"] = cfg.MODEL.MASK_FORMER.ENFORCE_INPUT_PROJ
+
+        ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM
+
+        return ret
+
+    def forward(self, x, mask_features, mask=None):
+        if mask is not None:
+            mask = F.interpolate(mask[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
+        pos = self.pe_layer(x, mask)
+
+        src = x
+        hs, memory = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos)
+
+        if self.mask_classification:
+            outputs_class = self.class_embed(hs)
+            out = {"pred_logits": outputs_class[-1]}
+        else:
+            out = {}
+
+        if self.aux_loss:
+            # [l, bs, queries, embed]
+            mask_embed = self.mask_embed(hs)
+            outputs_seg_masks = torch.einsum("lbqc,bchw->lbqhw", mask_embed, mask_features)
+            out["pred_masks"] = outputs_seg_masks[-1]
+            out["aux_outputs"] = self._set_aux_loss(
+                outputs_class if self.mask_classification else None, outputs_seg_masks
+            )
+        else:
+            # FIXME h_boxes takes the last one computed, keep this in mind
+            # [bs, queries, embed]
+            mask_embed = self.mask_embed(hs[-1])
+            outputs_seg_masks = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features)
+            out["pred_masks"] = outputs_seg_masks
+        return out
+
+    @torch.jit.unused
+    def _set_aux_loss(self, outputs_class, outputs_seg_masks):
+        # this is a workaround to make torchscript happy, as torchscript
+        # doesn't support dictionary with non-homogeneous values, such
+        # as a dict having both a Tensor and a list.
+        if self.mask_classification:
+            return [
+                {"pred_logits": a, "pred_masks": b}
+                for a, b in zip(outputs_class[:-1], outputs_seg_masks[:-1])
+            ]
+        else:
+            return [{"pred_masks": b} for b in outputs_seg_masks[:-1]]
+
+
+class MLP(nn.Module):
+    """Very simple multi-layer perceptron (also called FFN)"""
+
+    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
+        super().__init__()
+        self.num_layers = num_layers
+        h = [hidden_dim] * (num_layers - 1)
+        self.layers = nn.ModuleList(
+            nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
+        )
+
+    def forward(self, x):
+        for i, layer in enumerate(self.layers):
+            x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+        return x
diff --git a/annotator/entityseg/mask2former/modeling/transformer_decoder/position_encoding.py b/annotator/entityseg/mask2former/modeling/transformer_decoder/position_encoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..3bd923728878b3f0099fdc8ec7b14f253086b6e4
--- /dev/null
+++ b/annotator/entityseg/mask2former/modeling/transformer_decoder/position_encoding.py
@@ -0,0 +1,134 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py
+"""
+Various positional encodings for the transformer.
+"""
+import math
+
+import torch
+from torch import nn
+
+
+class PositionEmbeddingSine(nn.Module):
+    """
+    This is a more standard version of the position embedding, very similar to the one
+    used by the Attention is all you need paper, generalized to work on images.
+    """
+
+    def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
+        super().__init__()
+        self.num_pos_feats = num_pos_feats
+        self.temperature = temperature
+        self.normalize = normalize
+        if scale is not None and normalize is False:
+            raise ValueError("normalize should be True if scale is passed")
+        if scale is None:
+            scale = 2 * math.pi
+        self.scale = scale
+
+    def forward(self, x, mask=None):
+        if mask is None:
+            mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
+        not_mask = ~mask
+        y_embed = not_mask.cumsum(1, dtype=torch.float32)
+        x_embed = not_mask.cumsum(2, dtype=torch.float32)
+        if self.normalize:
+            eps = 1e-6
+            y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+            x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+
+        dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
+        # dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
+        dim_t = self.temperature ** (2 * (torch.div(dim_t, 2, rounding_mode="trunc")) / self.num_pos_feats)
+
+        pos_x = x_embed[:, :, :, None] / dim_t
+        pos_y = y_embed[:, :, :, None] / dim_t
+        pos_x = torch.stack(
+            (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
+        ).flatten(3)
+        pos_y = torch.stack(
+            (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
+        ).flatten(3)
+        pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+        return pos
+    
+    def __repr__(self, _repr_indent=4):
+        head = "Positional encoding " + self.__class__.__name__
+        body = [
+            "num_pos_feats: {}".format(self.num_pos_feats),
+            "temperature: {}".format(self.temperature),
+            "normalize: {}".format(self.normalize),
+            "scale: {}".format(self.scale),
+        ]
+        # _repr_indent = 4
+        lines = [head] + [" " * _repr_indent + line for line in body]
+        return "\n".join(lines)
+
+class PositionEmbeddingSine3D2D(nn.Module):
+    """
+    This is a more standard version of the position embedding, very similar to the one
+    used by the Attention is all you need paper, generalized to work on images.
+    """
+
+    def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
+        super().__init__()
+        self.num_pos_feats = num_pos_feats
+        self.temperature = temperature
+        self.normalize = normalize
+        if scale is not None and normalize is False:
+            raise ValueError("normalize should be True if scale is passed")
+        if scale is None:
+            scale = 2 * math.pi
+        self.scale = scale
+
+    def forward(self, x, mask=None):
+        ## b, t, c, h, w
+        assert x.dim()==5, f"{x.shape} should be a 5-dimensional Tensor, got {x.dim()}-dimensional Tensor instead"
+        if mask is None:
+            mask = torch.zeros((x.size(0), x.size(1), x.size(3), x.size(4)), device=x.device, dtype=torch.bool)
+        not_mask = ~mask
+        z_embed = not_mask.cumsum(1, dtype=torch.float32)
+        y_embed = not_mask.cumsum(2, dtype=torch.float32)
+        x_embed = not_mask.cumsum(3, dtype=torch.float32)
+        if self.normalize:
+            eps = 1e-6
+            z_embed = z_embed / (z_embed[:, -1:, :, :] + eps) * self.scale 
+            y_embed = y_embed / (y_embed[:, :, -1:, :] + eps) * self.scale
+            x_embed = x_embed / (x_embed[:, :, :, -1:] + eps) * self.scale
+
+        dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
+        # dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
+        dim_t = self.temperature ** (2 * (torch.div(dim_t, 2, rounding_mode="trunc")) / self.num_pos_feats)
+
+        dim_t_z = torch.arange((self.num_pos_feats * 2), dtype=torch.float32, device=x.device)
+        # dim_t_z = self.temperature ** (2 * (dim_t_z // 2) / (self.num_pos_feats * 2))
+        dim_t_z = self.temperature ** (2 * (torch.div(dim_t_z, 2, rounding_mode="trunc")) / (self.num_pos_feats*2))
+
+        pos_x = x_embed[:, :, :, :, None] / dim_t
+        pos_y = y_embed[:, :, :, :, None] / dim_t
+        pos_z = z_embed[:, :, :, :, None] / dim_t_z
+
+        pos_x = torch.stack(
+            (pos_x[:, :, :, :, 0::2].sin(), pos_x[:, :, :, :, 1::2].cos()), dim=5
+        ).flatten(4)
+        pos_y = torch.stack(
+            (pos_y[:, :, :, :, 0::2].sin(), pos_y[:, :, :, :, 1::2].cos()), dim=5
+        ).flatten(4)
+        pos_z = torch.stack(
+            (pos_z[:, :, :, :, 0::2].sin(), pos_z[:, :, :, :, 1::2].cos()), dim=5
+        ).flatten(4)
+        pos2d = torch.cat((pos_y, pos_x), dim=4).permute(0, 1, 4, 2, 3).flatten(0,1)
+        pos3d = (torch.cat((pos_y, pos_x), dim=4) + pos_z).permute(0, 1, 4, 2, 3)
+        return pos2d, pos3d
+    
+    def __repr__(self, _repr_indent=4):
+        head = "Positional encoding " + self.__class__.__name__
+        body = [
+            "num_pos_feats: {}".format(self.num_pos_feats),
+            "temperature: {}".format(self.temperature),
+            "normalize: {}".format(self.normalize),
+            "scale: {}".format(self.scale),
+        ]
+        # _repr_indent = 4
+        lines = [head] + [" " * _repr_indent + line for line in body]
+        return "\n".join(lines)
diff --git a/annotator/entityseg/mask2former/modeling/transformer_decoder/transformer.py b/annotator/entityseg/mask2former/modeling/transformer_decoder/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea8caa0108f5e136a9739320ab69a3e1b6f40298
--- /dev/null
+++ b/annotator/entityseg/mask2former/modeling/transformer_decoder/transformer.py
@@ -0,0 +1,369 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/transformer.py
+"""
+Transformer class.
+
+Copy-paste from torch.nn.Transformer with modifications:
+    * positional encodings are passed in MHattention
+    * extra LN at the end of encoder is removed
+    * decoder returns a stack of activations from all decoding layers
+"""
+import copy
+from typing import List, Optional
+
+import torch
+import torch.nn.functional as F
+from torch import Tensor, nn
+
+
+class Transformer(nn.Module):
+    def __init__(
+        self,
+        d_model=512,
+        nhead=8,
+        num_encoder_layers=6,
+        num_decoder_layers=6,
+        dim_feedforward=2048,
+        dropout=0.1,
+        activation="relu",
+        normalize_before=False,
+        return_intermediate_dec=False,
+    ):
+        super().__init__()
+
+        encoder_layer = TransformerEncoderLayer(
+            d_model, nhead, dim_feedforward, dropout, activation, normalize_before
+        )
+        encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
+        self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
+
+        decoder_layer = TransformerDecoderLayer(
+            d_model, nhead, dim_feedforward, dropout, activation, normalize_before
+        )
+        decoder_norm = nn.LayerNorm(d_model)
+        self.decoder = TransformerDecoder(
+            decoder_layer,
+            num_decoder_layers,
+            decoder_norm,
+            return_intermediate=return_intermediate_dec,
+        )
+
+        self._reset_parameters()
+
+        self.d_model = d_model
+        self.nhead = nhead
+
+    def _reset_parameters(self):
+        for p in self.parameters():
+            if p.dim() > 1:
+                nn.init.xavier_uniform_(p)
+
+    def forward(self, src, mask, query_embed, pos_embed):
+        # flatten NxCxHxW to HWxNxC
+        bs, c, h, w = src.shape
+        src = src.flatten(2).permute(2, 0, 1)
+        pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
+        query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
+        if mask is not None:
+            mask = mask.flatten(1)
+
+        tgt = torch.zeros_like(query_embed)
+        memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
+        hs = self.decoder(
+            tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed
+        )
+        return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)
+
+
+class TransformerEncoder(nn.Module):
+    def __init__(self, encoder_layer, num_layers, norm=None):
+        super().__init__()
+        self.layers = _get_clones(encoder_layer, num_layers)
+        self.num_layers = num_layers
+        self.norm = norm
+
+    def forward(
+        self,
+        src,
+        mask: Optional[Tensor] = None,
+        src_key_padding_mask: Optional[Tensor] = None,
+        pos: Optional[Tensor] = None,
+    ):
+        output = src
+
+        for layer in self.layers:
+            output = layer(
+                output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos
+            )
+
+        if self.norm is not None:
+            output = self.norm(output)
+
+        return output
+
+
+class TransformerDecoder(nn.Module):
+    def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
+        super().__init__()
+        self.layers = _get_clones(decoder_layer, num_layers)
+        self.num_layers = num_layers
+        self.norm = norm
+        self.return_intermediate = return_intermediate
+
+    def forward(
+        self,
+        tgt,
+        memory,
+        tgt_mask: Optional[Tensor] = None,
+        memory_mask: Optional[Tensor] = None,
+        tgt_key_padding_mask: Optional[Tensor] = None,
+        memory_key_padding_mask: Optional[Tensor] = None,
+        pos: Optional[Tensor] = None,
+        query_pos: Optional[Tensor] = None,
+    ):
+        output = tgt
+
+        intermediate = []
+
+        for layer in self.layers:
+            output = layer(
+                output,
+                memory,
+                tgt_mask=tgt_mask,
+                memory_mask=memory_mask,
+                tgt_key_padding_mask=tgt_key_padding_mask,
+                memory_key_padding_mask=memory_key_padding_mask,
+                pos=pos,
+                query_pos=query_pos,
+            )
+            if self.return_intermediate:
+                intermediate.append(self.norm(output))
+
+        if self.norm is not None:
+            output = self.norm(output)
+            if self.return_intermediate:
+                intermediate.pop()
+                intermediate.append(output)
+
+        if self.return_intermediate:
+            return torch.stack(intermediate)
+
+        return output.unsqueeze(0)
+
+
+class TransformerEncoderLayer(nn.Module):
+    def __init__(
+        self,
+        d_model,
+        nhead,
+        dim_feedforward=2048,
+        dropout=0.1,
+        activation="relu",
+        normalize_before=False,
+    ):
+        super().__init__()
+        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+        # Implementation of Feedforward model
+        self.linear1 = nn.Linear(d_model, dim_feedforward)
+        self.dropout = nn.Dropout(dropout)
+        self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+        self.norm1 = nn.LayerNorm(d_model)
+        self.norm2 = nn.LayerNorm(d_model)
+        self.dropout1 = nn.Dropout(dropout)
+        self.dropout2 = nn.Dropout(dropout)
+
+        self.activation = _get_activation_fn(activation)
+        self.normalize_before = normalize_before
+
+    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+        return tensor if pos is None else tensor + pos
+
+    def forward_post(
+        self,
+        src,
+        src_mask: Optional[Tensor] = None,
+        src_key_padding_mask: Optional[Tensor] = None,
+        pos: Optional[Tensor] = None,
+    ):
+        q = k = self.with_pos_embed(src, pos)
+        src2 = self.self_attn(
+            q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
+        )[0]
+        src = src + self.dropout1(src2)
+        src = self.norm1(src)
+        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
+        src = src + self.dropout2(src2)
+        src = self.norm2(src)
+        return src
+
+    def forward_pre(
+        self,
+        src,
+        src_mask: Optional[Tensor] = None,
+        src_key_padding_mask: Optional[Tensor] = None,
+        pos: Optional[Tensor] = None,
+    ):
+        src2 = self.norm1(src)
+        q = k = self.with_pos_embed(src2, pos)
+        src2 = self.self_attn(
+            q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
+        )[0]
+        src = src + self.dropout1(src2)
+        src2 = self.norm2(src)
+        src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
+        src = src + self.dropout2(src2)
+        return src
+
+    def forward(
+        self,
+        src,
+        src_mask: Optional[Tensor] = None,
+        src_key_padding_mask: Optional[Tensor] = None,
+        pos: Optional[Tensor] = None,
+    ):
+        if self.normalize_before:
+            return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
+        return self.forward_post(src, src_mask, src_key_padding_mask, pos)
+
+
+class TransformerDecoderLayer(nn.Module):
+    def __init__(
+        self,
+        d_model,
+        nhead,
+        dim_feedforward=2048,
+        dropout=0.1,
+        activation="relu",
+        normalize_before=False,
+    ):
+        super().__init__()
+        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+        # Implementation of Feedforward model
+        self.linear1 = nn.Linear(d_model, dim_feedforward)
+        self.dropout = nn.Dropout(dropout)
+        self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+        self.norm1 = nn.LayerNorm(d_model)
+        self.norm2 = nn.LayerNorm(d_model)
+        self.norm3 = nn.LayerNorm(d_model)
+        self.dropout1 = nn.Dropout(dropout)
+        self.dropout2 = nn.Dropout(dropout)
+        self.dropout3 = nn.Dropout(dropout)
+
+        self.activation = _get_activation_fn(activation)
+        self.normalize_before = normalize_before
+
+    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+        return tensor if pos is None else tensor + pos
+
+    def forward_post(
+        self,
+        tgt,
+        memory,
+        tgt_mask: Optional[Tensor] = None,
+        memory_mask: Optional[Tensor] = None,
+        tgt_key_padding_mask: Optional[Tensor] = None,
+        memory_key_padding_mask: Optional[Tensor] = None,
+        pos: Optional[Tensor] = None,
+        query_pos: Optional[Tensor] = None,
+    ):
+        q = k = self.with_pos_embed(tgt, query_pos)
+        tgt2 = self.self_attn(
+            q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
+        )[0]
+        tgt = tgt + self.dropout1(tgt2)
+        tgt = self.norm1(tgt)
+        tgt2 = self.multihead_attn(
+            query=self.with_pos_embed(tgt, query_pos),
+            key=self.with_pos_embed(memory, pos),
+            value=memory,
+            attn_mask=memory_mask,
+            key_padding_mask=memory_key_padding_mask,
+        )[0]
+        tgt = tgt + self.dropout2(tgt2)
+        tgt = self.norm2(tgt)
+        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
+        tgt = tgt + self.dropout3(tgt2)
+        tgt = self.norm3(tgt)
+        return tgt
+
+    def forward_pre(
+        self,
+        tgt,
+        memory,
+        tgt_mask: Optional[Tensor] = None,
+        memory_mask: Optional[Tensor] = None,
+        tgt_key_padding_mask: Optional[Tensor] = None,
+        memory_key_padding_mask: Optional[Tensor] = None,
+        pos: Optional[Tensor] = None,
+        query_pos: Optional[Tensor] = None,
+    ):
+        tgt2 = self.norm1(tgt)
+        q = k = self.with_pos_embed(tgt2, query_pos)
+        tgt2 = self.self_attn(
+            q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
+        )[0]
+        tgt = tgt + self.dropout1(tgt2)
+        tgt2 = self.norm2(tgt)
+        tgt2 = self.multihead_attn(
+            query=self.with_pos_embed(tgt2, query_pos),
+            key=self.with_pos_embed(memory, pos),
+            value=memory,
+            attn_mask=memory_mask,
+            key_padding_mask=memory_key_padding_mask,
+        )[0]
+        tgt = tgt + self.dropout2(tgt2)
+        tgt2 = self.norm3(tgt)
+        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
+        tgt = tgt + self.dropout3(tgt2)
+        return tgt
+
+    def forward(
+        self,
+        tgt,
+        memory,
+        tgt_mask: Optional[Tensor] = None,
+        memory_mask: Optional[Tensor] = None,
+        tgt_key_padding_mask: Optional[Tensor] = None,
+        memory_key_padding_mask: Optional[Tensor] = None,
+        pos: Optional[Tensor] = None,
+        query_pos: Optional[Tensor] = None,
+    ):
+        if self.normalize_before:
+            return self.forward_pre(
+                tgt,
+                memory,
+                tgt_mask,
+                memory_mask,
+                tgt_key_padding_mask,
+                memory_key_padding_mask,
+                pos,
+                query_pos,
+            )
+        return self.forward_post(
+            tgt,
+            memory,
+            tgt_mask,
+            memory_mask,
+            tgt_key_padding_mask,
+            memory_key_padding_mask,
+            pos,
+            query_pos,
+        )
+
+
+def _get_clones(module, N):
+    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
+
+
+def _get_activation_fn(activation):
+    """Return an activation function given a string"""
+    if activation == "relu":
+        return F.relu
+    if activation == "gelu":
+        return F.gelu
+    if activation == "glu":
+        return F.glu
+    raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
diff --git a/annotator/entityseg/mask2former/test_time_augmentation.py b/annotator/entityseg/mask2former/test_time_augmentation.py
new file mode 100644
index 0000000000000000000000000000000000000000..b02568d1b1ed32efb9316b5c4d53c4d71e5cef78
--- /dev/null
+++ b/annotator/entityseg/mask2former/test_time_augmentation.py
@@ -0,0 +1,103 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import copy
+import logging
+from itertools import count
+
+import numpy as np
+import torch
+from fvcore.transforms import HFlipTransform
+from torch import nn
+from torch.nn.parallel import DistributedDataParallel
+
+from detectron2.data.detection_utils import read_image
+from detectron2.modeling import DatasetMapperTTA
+
+
+__all__ = [
+    "SemanticSegmentorWithTTA",
+]
+
+
+class SemanticSegmentorWithTTA(nn.Module):
+    """
+    A SemanticSegmentor with test-time augmentation enabled.
+    Its :meth:`__call__` method has the same interface as :meth:`SemanticSegmentor.forward`.
+    """
+
+    def __init__(self, cfg, model, tta_mapper=None, batch_size=1):
+        """
+        Args:
+            cfg (CfgNode):
+            model (SemanticSegmentor): a SemanticSegmentor to apply TTA on.
+            tta_mapper (callable): takes a dataset dict and returns a list of
+                augmented versions of the dataset dict. Defaults to
+                `DatasetMapperTTA(cfg)`.
+            batch_size (int): batch the augmented images into this batch size for inference.
+        """
+        super().__init__()
+        if isinstance(model, DistributedDataParallel):
+            model = model.module
+        self.cfg = cfg.clone()
+
+        self.model = model
+
+        if tta_mapper is None:
+            tta_mapper = DatasetMapperTTA(cfg)
+        self.tta_mapper = tta_mapper
+        self.batch_size = batch_size
+
+    def __call__(self, batched_inputs):
+        """
+        Same input/output format as :meth:`SemanticSegmentor.forward`
+        """
+
+        def _maybe_read_image(dataset_dict):
+            ret = copy.copy(dataset_dict)
+            if "image" not in ret:
+                image = read_image(ret.pop("file_name"), self.model.input_format)
+                image = torch.from_numpy(np.ascontiguousarray(image.transpose(2, 0, 1)))  # CHW
+                ret["image"] = image
+            if "height" not in ret and "width" not in ret:
+                ret["height"] = image.shape[1]
+                ret["width"] = image.shape[2]
+            return ret
+
+        processed_results = []
+        for x in batched_inputs:
+            result = self._inference_one_image(_maybe_read_image(x))
+            processed_results.append(result)
+        return processed_results
+
+    def _inference_one_image(self, input):
+        """
+        Args:
+            input (dict): one dataset dict with "image" field being a CHW tensor
+        Returns:
+            dict: one output dict
+        """
+        orig_shape = (input["height"], input["width"])
+        augmented_inputs, tfms = self._get_augmented_inputs(input)
+
+        final_predictions = None
+        count_predictions = 0
+        for input, tfm in zip(augmented_inputs, tfms):
+            count_predictions += 1
+            with torch.no_grad():
+                if final_predictions is None:
+                    if any(isinstance(t, HFlipTransform) for t in tfm.transforms):
+                        final_predictions = self.model([input])[0].pop("sem_seg").flip(dims=[2])
+                    else:
+                        final_predictions = self.model([input])[0].pop("sem_seg")
+                else:
+                    if any(isinstance(t, HFlipTransform) for t in tfm.transforms):
+                        final_predictions += self.model([input])[0].pop("sem_seg").flip(dims=[2])
+                    else:
+                        final_predictions += self.model([input])[0].pop("sem_seg")
+
+        final_predictions = final_predictions / count_predictions
+        return {"sem_seg": final_predictions}
+
+    def _get_augmented_inputs(self, input):
+        augmented_inputs = self.tta_mapper(input)
+        tfms = [x.pop("transforms") for x in augmented_inputs]
+        return augmented_inputs, tfms
diff --git a/annotator/entityseg/mask2former/utils/__init__.py b/annotator/entityseg/mask2former/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9020c2df23e2af280b7bb168b996ae9eaf312eb8
--- /dev/null
+++ b/annotator/entityseg/mask2former/utils/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
diff --git a/annotator/entityseg/mask2former/utils/misc.py b/annotator/entityseg/mask2former/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..874d9805b482f52bbffc1be620e36e0cffc07c46
--- /dev/null
+++ b/annotator/entityseg/mask2former/utils/misc.py
@@ -0,0 +1,111 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/util/misc.py
+"""
+Misc functions, including distributed helpers.
+
+Mostly copy-paste from torchvision references.
+"""
+from typing import List, Optional
+
+import torch
+import torch.distributed as dist
+import torchvision
+from torch import Tensor
+
+
+def _max_by_axis(the_list):
+    # type: (List[List[int]]) -> List[int]
+    maxes = the_list[0]
+    for sublist in the_list[1:]:
+        for index, item in enumerate(sublist):
+            maxes[index] = max(maxes[index], item)
+    return maxes
+
+
+class NestedTensor(object):
+    def __init__(self, tensors, mask: Optional[Tensor]):
+        self.tensors = tensors
+        self.mask = mask
+
+    def to(self, device):
+        # type: (Device) -> NestedTensor # noqa
+        cast_tensor = self.tensors.to(device)
+        mask = self.mask
+        if mask is not None:
+            assert mask is not None
+            cast_mask = mask.to(device)
+        else:
+            cast_mask = None
+        return NestedTensor(cast_tensor, cast_mask)
+
+    def decompose(self):
+        return self.tensors, self.mask
+
+    def __repr__(self):
+        return str(self.tensors)
+
+
+def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
+    # TODO make this more general
+    if tensor_list[0].ndim == 3:
+        if torchvision._is_tracing():
+            # nested_tensor_from_tensor_list() does not export well to ONNX
+            # call _onnx_nested_tensor_from_tensor_list() instead
+            return _onnx_nested_tensor_from_tensor_list(tensor_list)
+
+        # TODO make it support different-sized images
+        max_size = _max_by_axis([list(img.shape) for img in tensor_list])
+        # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
+        batch_shape = [len(tensor_list)] + max_size
+        b, c, h, w = batch_shape
+        dtype = tensor_list[0].dtype
+        device = tensor_list[0].device
+        tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
+        mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
+        for img, pad_img, m in zip(tensor_list, tensor, mask):
+            pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
+            m[: img.shape[1], : img.shape[2]] = False
+    else:
+        raise ValueError("not supported")
+    return NestedTensor(tensor, mask)
+
+
+# _onnx_nested_tensor_from_tensor_list() is an implementation of
+# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
+@torch.jit.unused
+def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
+    max_size = []
+    for i in range(tensor_list[0].dim()):
+        max_size_i = torch.max(
+            torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)
+        ).to(torch.int64)
+        max_size.append(max_size_i)
+    max_size = tuple(max_size)
+
+    # work around for
+    # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
+    # m[: img.shape[1], :img.shape[2]] = False
+    # which is not yet supported in onnx
+    padded_imgs = []
+    padded_masks = []
+    for img in tensor_list:
+        padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
+        padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
+        padded_imgs.append(padded_img)
+
+        m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
+        padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
+        padded_masks.append(padded_mask.to(torch.bool))
+
+    tensor = torch.stack(padded_imgs)
+    mask = torch.stack(padded_masks)
+
+    return NestedTensor(tensor, mask=mask)
+
+
+def is_dist_avail_and_initialized():
+    if not dist.is_available():
+        return False
+    if not dist.is_initialized():
+        return False
+    return True
diff --git a/annotator/entityseg/predictor.py b/annotator/entityseg/predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..92ea056d0e68153b24d68adb4de8c35d2ae58a5c
--- /dev/null
+++ b/annotator/entityseg/predictor.py
@@ -0,0 +1,227 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Copied from: https://github.com/facebookresearch/detectron2/blob/master/demo/predictor.py
+import atexit
+import bisect
+import multiprocessing as mp
+from collections import deque
+
+import pdb
+import cv2
+import copy
+import torch
+import numpy as np
+
+import detectron2.data.transforms as T
+from detectron2.data import MetadataCatalog
+from detectron2.engine.defaults import DefaultPredictor
+from detectron2.utils.video_visualizer import VideoVisualizer
+from detectron2.utils.visualizer import ColorMode, Visualizer
+
+from mask2former.data.dataset_mappers.crop_augmentations import BatchResizeShortestEdge, EntityCrop, EntityCropTransform
+
+
+class VisualizationDemo(object):
+    def __init__(self, cfg, instance_mode=ColorMode.IMAGE, parallel=False):
+        """
+        Args:
+            cfg (CfgNode):
+            instance_mode (ColorMode):
+            parallel (bool): whether to run the model in different processes from visualization.
+                Useful since the visualization logic can be slow.
+        """
+        self.metadata = MetadataCatalog.get(
+            cfg.DATASETS.TEST[0] if len(cfg.DATASETS.TEST) else "__unused"
+        )
+        self.cpu_device = torch.device("cpu")
+        self.instance_mode = instance_mode
+
+        self.parallel = parallel
+        if parallel:
+            num_gpu = torch.cuda.device_count()
+            self.predictor = AsyncPredictor(cfg, num_gpus=num_gpu)
+        else:
+            self.predictor = CropFormerPredictor(cfg)
+
+    def run_on_image(self, image):
+        """
+        Args:
+            image (np.ndarray): an image of shape (H, W, C) (in BGR order).
+                This is the format used by OpenCV.
+        Returns:
+            predictions (dict): the output of the model.
+            vis_output (VisImage): the visualized image output.
+        """
+        predictions = self.predictor(image)
+        return predictions
+
+class CropFormerPredictor(DefaultPredictor):
+    """
+    """
+
+    def __init__(self, cfg):
+        super().__init__(cfg)
+    
+    def generate_img_augs(self):
+        shortest_side = np.random.choice([self.cfg.INPUT.MIN_SIZE_TEST])
+
+        augs = [
+            T.ResizeShortestEdge(
+                (shortest_side,),
+                self.cfg.INPUT.MAX_SIZE_TEST,
+                self.cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING,
+            ),
+            
+        ]
+
+        # Build original image augmentation
+        crop_augs = []
+        entity_crops = EntityCrop(self.cfg.ENTITY.CROP_AREA_RATIO, 
+                                    self.cfg.ENTITY.CROP_STRIDE_RATIO,
+                                    self.cfg.ENTITY.CROP_SAMPLE_NUM_TEST, 
+                                    False)
+        crop_augs.append(entity_crops)
+        
+        entity_resize = BatchResizeShortestEdge((shortest_side,), self.cfg.INPUT.MAX_SIZE_TEST, self.cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING)
+        crop_augs.append(entity_resize)
+
+        # augs      = T.AugmentationList(augs)
+        crop_augs = T.AugmentationList(crop_augs)
+        return augs, crop_augs
+
+    def __call__(self, original_image):
+        """
+        Args:
+            original_image (np.ndarray): an image of shape (H, W, C) (in BGR order).
+
+        Returns:
+            predictions (dict):
+                the output of the model for one image only.
+                See :doc:`/tutorials/models` for details about the format.
+        """
+        with torch.no_grad():  # https://github.com/sphinx-doc/sphinx/issues/4258
+            # Apply pre-processing to image.
+            if self.input_format == "RGB":
+                # whether the model expects BGR inputs or RGB
+                original_image = original_image[:, :, ::-1]
+            
+            # build cropformer augmentations
+            augs, crop_augs = self.generate_img_augs()
+
+            height, width = original_image.shape[:2]
+            aug_input_ori = T.AugInput(copy.deepcopy(original_image))
+
+            aug_input_ori, _ = T.apply_transform_gens(augs, aug_input_ori)
+            image_ori = aug_input_ori.image
+            image_ori = torch.as_tensor(image_ori.astype("float32").transpose(2, 0, 1))
+
+            aug_input_crop = T.AugInput(copy.deepcopy(original_image))
+            transforms_crop = crop_augs(aug_input_crop)
+            image_crop = aug_input_crop.image
+            assert len(image_crop.shape)==4, "the image shape must be [N, H, W, C]"
+            image_crop = torch.as_tensor(image_crop.astype("float32").transpose(0, 3, 1, 2))
+            
+            for transform_type in transforms_crop:
+                if isinstance(transform_type, EntityCropTransform):
+                    crop_axises = transform_type.crop_axises
+                    crop_indexes = transform_type.crop_indexes
+
+            inputs = {"image": image_ori, 
+                      "height": height, 
+                      "width": width,
+                      "image_crop": image_crop,
+                      "crop_region": crop_axises,
+                      "crop_indexes": crop_indexes
+                      }
+            # pdb.set_trace()
+            predictions = self.model([inputs])[0]
+            return predictions
+
+class AsyncPredictor:
+    """
+    A predictor that runs the model asynchronously, possibly on >1 GPUs.
+    Because rendering the visualization takes considerably amount of time,
+    this helps improve throughput a little bit when rendering videos.
+    """
+
+    class _StopToken:
+        pass
+
+    class _PredictWorker(mp.Process):
+        def __init__(self, cfg, task_queue, result_queue):
+            self.cfg = cfg
+            self.task_queue = task_queue
+            self.result_queue = result_queue
+            super().__init__()
+
+        def run(self):
+            predictor = CropFormerPredictor(self.cfg)
+
+            while True:
+                task = self.task_queue.get()
+                if isinstance(task, AsyncPredictor._StopToken):
+                    break
+                idx, data = task
+                result = predictor(data)
+                self.result_queue.put((idx, result))
+
+    def __init__(self, cfg, num_gpus: int = 1):
+        """
+        Args:
+            cfg (CfgNode):
+            num_gpus (int): if 0, will run on CPU
+        """
+        num_workers = max(num_gpus, 1)
+        self.task_queue = mp.Queue(maxsize=num_workers * 3)
+        self.result_queue = mp.Queue(maxsize=num_workers * 3)
+        self.procs = []
+        for gpuid in range(max(num_gpus, 1)):
+            cfg = cfg.clone()
+            cfg.defrost()
+            cfg.MODEL.DEVICE = "cuda:{}".format(gpuid) if num_gpus > 0 else "cpu"
+            self.procs.append(
+                AsyncPredictor._PredictWorker(cfg, self.task_queue, self.result_queue)
+            )
+
+        self.put_idx = 0
+        self.get_idx = 0
+        self.result_rank = []
+        self.result_data = []
+
+        for p in self.procs:
+            p.start()
+        atexit.register(self.shutdown)
+
+    def put(self, image):
+        self.put_idx += 1
+        self.task_queue.put((self.put_idx, image))
+
+    def get(self):
+        self.get_idx += 1  # the index needed for this request
+        if len(self.result_rank) and self.result_rank[0] == self.get_idx:
+            res = self.result_data[0]
+            del self.result_data[0], self.result_rank[0]
+            return res
+
+        while True:
+            # make sure the results are returned in the correct order
+            idx, res = self.result_queue.get()
+            if idx == self.get_idx:
+                return res
+            insert = bisect.bisect(self.result_rank, idx)
+            self.result_rank.insert(insert, idx)
+            self.result_data.insert(insert, res)
+
+    def __len__(self):
+        return self.put_idx - self.get_idx
+
+    def __call__(self, image):
+        self.put(image)
+        return self.get()
+
+    def shutdown(self):
+        for _ in self.procs:
+            self.task_queue.put(AsyncPredictor._StopToken())
+
+    @property
+    def default_buffer_size(self):
+        return len(self.procs) * 5
diff --git a/annotator/hed/__init__.py b/annotator/hed/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..56532c374df5c26f9ec53e2ac0dd924f4534bbdd
--- /dev/null
+++ b/annotator/hed/__init__.py
@@ -0,0 +1,132 @@
+import numpy as np
+import cv2
+import os
+import torch
+from einops import rearrange
+from annotator.util import annotator_ckpts_path
+
+
+class Network(torch.nn.Module):
+    def __init__(self, model_path):
+        super().__init__()
+
+        self.netVggOne = torch.nn.Sequential(
+            torch.nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1),
+            torch.nn.ReLU(inplace=False),
+            torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
+            torch.nn.ReLU(inplace=False)
+        )
+
+        self.netVggTwo = torch.nn.Sequential(
+            torch.nn.MaxPool2d(kernel_size=2, stride=2),
+            torch.nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
+            torch.nn.ReLU(inplace=False),
+            torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
+            torch.nn.ReLU(inplace=False)
+        )
+
+        self.netVggThr = torch.nn.Sequential(
+            torch.nn.MaxPool2d(kernel_size=2, stride=2),
+            torch.nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
+            torch.nn.ReLU(inplace=False),
+            torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
+            torch.nn.ReLU(inplace=False),
+            torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
+            torch.nn.ReLU(inplace=False)
+        )
+
+        self.netVggFou = torch.nn.Sequential(
+            torch.nn.MaxPool2d(kernel_size=2, stride=2),
+            torch.nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
+            torch.nn.ReLU(inplace=False),
+            torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
+            torch.nn.ReLU(inplace=False),
+            torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
+            torch.nn.ReLU(inplace=False)
+        )
+
+        self.netVggFiv = torch.nn.Sequential(
+            torch.nn.MaxPool2d(kernel_size=2, stride=2),
+            torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
+            torch.nn.ReLU(inplace=False),
+            torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
+            torch.nn.ReLU(inplace=False),
+            torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
+            torch.nn.ReLU(inplace=False)
+        )
+
+        self.netScoreOne = torch.nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1, stride=1, padding=0)
+        self.netScoreTwo = torch.nn.Conv2d(in_channels=128, out_channels=1, kernel_size=1, stride=1, padding=0)
+        self.netScoreThr = torch.nn.Conv2d(in_channels=256, out_channels=1, kernel_size=1, stride=1, padding=0)
+        self.netScoreFou = torch.nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0)
+        self.netScoreFiv = torch.nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0)
+
+        self.netCombine = torch.nn.Sequential(
+            torch.nn.Conv2d(in_channels=5, out_channels=1, kernel_size=1, stride=1, padding=0),
+            torch.nn.Sigmoid()
+        )
+
+        self.load_state_dict({strKey.replace('module', 'net'): tenWeight for strKey, tenWeight in torch.load(model_path).items()})
+
+    def forward(self, tenInput):
+        tenInput = tenInput * 255.0
+        tenInput = tenInput - torch.tensor(data=[104.00698793, 116.66876762, 122.67891434], dtype=tenInput.dtype, device=tenInput.device).view(1, 3, 1, 1)
+
+        tenVggOne = self.netVggOne(tenInput)
+        tenVggTwo = self.netVggTwo(tenVggOne)
+        tenVggThr = self.netVggThr(tenVggTwo)
+        tenVggFou = self.netVggFou(tenVggThr)
+        tenVggFiv = self.netVggFiv(tenVggFou)
+
+        tenScoreOne = self.netScoreOne(tenVggOne)
+        tenScoreTwo = self.netScoreTwo(tenVggTwo)
+        tenScoreThr = self.netScoreThr(tenVggThr)
+        tenScoreFou = self.netScoreFou(tenVggFou)
+        tenScoreFiv = self.netScoreFiv(tenVggFiv)
+
+        tenScoreOne = torch.nn.functional.interpolate(input=tenScoreOne, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False)
+        tenScoreTwo = torch.nn.functional.interpolate(input=tenScoreTwo, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False)
+        tenScoreThr = torch.nn.functional.interpolate(input=tenScoreThr, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False)
+        tenScoreFou = torch.nn.functional.interpolate(input=tenScoreFou, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False)
+        tenScoreFiv = torch.nn.functional.interpolate(input=tenScoreFiv, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False)
+
+        return self.netCombine(torch.cat([ tenScoreOne, tenScoreTwo, tenScoreThr, tenScoreFou, tenScoreFiv ], 1))
+
+
+class HEDdetector:
+    def __init__(self):
+        remote_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/network-bsds500.pth"
+        modelpath = os.path.join(annotator_ckpts_path, "network-bsds500.pth")
+        if not os.path.exists(modelpath):
+            from basicsr.utils.download_util import load_file_from_url
+            load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
+        self.netNetwork = Network(modelpath).cuda().eval()
+
+    def __call__(self, input_image):
+        assert input_image.ndim == 3
+        input_image = input_image[:, :, ::-1].copy()
+        with torch.no_grad():
+            image_hed = torch.from_numpy(input_image).float().cuda()
+            image_hed = image_hed / 255.0
+            image_hed = rearrange(image_hed, 'h w c -> 1 c h w')
+            edge = self.netNetwork(image_hed)[0]
+            edge = (edge.cpu().numpy() * 255.0).clip(0, 255).astype(np.uint8)
+            return edge[0]
+
+
+def nms(x, t, s):
+    x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
+
+    f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
+    f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
+    f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
+    f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
+
+    y = np.zeros_like(x)
+
+    for f in [f1, f2, f3, f4]:
+        np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
+
+    z = np.zeros_like(y, dtype=np.uint8)
+    z[y > t] = 255
+    return z
diff --git a/annotator/midas/__init__.py b/annotator/midas/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8181dc8eaedf09e4fb698b350ee46a37dcacfe8
--- /dev/null
+++ b/annotator/midas/__init__.py
@@ -0,0 +1,35 @@
+import cv2
+import numpy as np
+import torch
+
+from einops import rearrange
+from .api import MiDaSInference
+
+
+class MidasDetector:
+    def __init__(self):
+        self.model = MiDaSInference(model_type="dpt_hybrid").cuda()
+
+    def __call__(self, input_image, a=np.pi * 2.0, bg_th=0.1):
+        assert input_image.ndim == 3
+
+        oh, ow = input_image.shape[:2]
+        nh = oh // 32 * 32
+        nw = ow // 32 * 32
+        input_image = cv2.resize(input_image, (nw, nh))
+        image_depth = input_image
+        with torch.no_grad():
+            image_depth = torch.from_numpy(image_depth).float().cuda()
+            image_depth = image_depth / 127.5 - 1.0
+            image_depth = rearrange(image_depth, 'h w c -> 1 c h w')
+            depth = self.model(image_depth)[0]
+
+            depth_pt = depth.clone()
+            depth_pt -= torch.min(depth_pt)
+            depth_pt /= torch.max(depth_pt)
+            depth_pt = depth_pt.cpu().numpy()
+            depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8)
+
+            depth_image = cv2.resize(depth_image, (nw, nh))
+
+            return depth_image
diff --git a/annotator/midas/api.py b/annotator/midas/api.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ab9f15bf96bbaffcee0e3e29fc9d3979d6c32e8
--- /dev/null
+++ b/annotator/midas/api.py
@@ -0,0 +1,169 @@
+# based on https://github.com/isl-org/MiDaS
+
+import cv2
+import os
+import torch
+import torch.nn as nn
+from torchvision.transforms import Compose
+
+from .midas.dpt_depth import DPTDepthModel
+from .midas.midas_net import MidasNet
+from .midas.midas_net_custom import MidasNet_small
+from .midas.transforms import Resize, NormalizeImage, PrepareForNet
+from annotator.util import annotator_ckpts_path
+
+
+ISL_PATHS = {
+    "dpt_large": os.path.join(annotator_ckpts_path, "dpt_large-midas-2f21e586.pt"),
+    "dpt_hybrid": os.path.join(annotator_ckpts_path, "dpt_hybrid-midas-501f0c75.pt"),
+    "midas_v21": "",
+    "midas_v21_small": "",
+}
+
+remote_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/dpt_hybrid-midas-501f0c75.pt"
+
+
+def disabled_train(self, mode=True):
+    """Overwrite model.train with this function to make sure train/eval mode
+    does not change anymore."""
+    return self
+
+
+def load_midas_transform(model_type):
+    # https://github.com/isl-org/MiDaS/blob/master/run.py
+    # load transform only
+    if model_type == "dpt_large":  # DPT-Large
+        net_w, net_h = 384, 384
+        resize_mode = "minimal"
+        normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+    elif model_type == "dpt_hybrid":  # DPT-Hybrid
+        net_w, net_h = 384, 384
+        resize_mode = "minimal"
+        normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+    elif model_type == "midas_v21":
+        net_w, net_h = 384, 384
+        resize_mode = "upper_bound"
+        normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+
+    elif model_type == "midas_v21_small":
+        net_w, net_h = 256, 256
+        resize_mode = "upper_bound"
+        normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+
+    else:
+        assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
+
+    transform = Compose(
+        [
+            Resize(
+                net_w,
+                net_h,
+                resize_target=None,
+                keep_aspect_ratio=True,
+                ensure_multiple_of=32,
+                resize_method=resize_mode,
+                image_interpolation_method=cv2.INTER_CUBIC,
+            ),
+            normalization,
+            PrepareForNet(),
+        ]
+    )
+
+    return transform
+
+
+def load_model(model_type):
+    # https://github.com/isl-org/MiDaS/blob/master/run.py
+    # load network
+    model_path = ISL_PATHS[model_type]
+    if model_type == "dpt_large":  # DPT-Large
+        model = DPTDepthModel(
+            path=model_path,
+            backbone="vitl16_384",
+            non_negative=True,
+        )
+        net_w, net_h = 384, 384
+        resize_mode = "minimal"
+        normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+    elif model_type == "dpt_hybrid":  # DPT-Hybrid
+        if not os.path.exists(model_path):
+            from basicsr.utils.download_util import load_file_from_url
+            load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
+
+        model = DPTDepthModel(
+            path=model_path,
+            backbone="vitb_rn50_384",
+            non_negative=True,
+        )
+        net_w, net_h = 384, 384
+        resize_mode = "minimal"
+        normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+    elif model_type == "midas_v21":
+        model = MidasNet(model_path, non_negative=True)
+        net_w, net_h = 384, 384
+        resize_mode = "upper_bound"
+        normalization = NormalizeImage(
+            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+        )
+
+    elif model_type == "midas_v21_small":
+        model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
+                               non_negative=True, blocks={'expand': True})
+        net_w, net_h = 256, 256
+        resize_mode = "upper_bound"
+        normalization = NormalizeImage(
+            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+        )
+
+    else:
+        print(f"model_type '{model_type}' not implemented, use: --model_type large")
+        assert False
+
+    transform = Compose(
+        [
+            Resize(
+                net_w,
+                net_h,
+                resize_target=None,
+                keep_aspect_ratio=True,
+                ensure_multiple_of=32,
+                resize_method=resize_mode,
+                image_interpolation_method=cv2.INTER_CUBIC,
+            ),
+            normalization,
+            PrepareForNet(),
+        ]
+    )
+
+    return model.eval(), transform
+
+
+class MiDaSInference(nn.Module):
+    MODEL_TYPES_TORCH_HUB = [
+        "DPT_Large",
+        "DPT_Hybrid",
+        "MiDaS_small"
+    ]
+    MODEL_TYPES_ISL = [
+        "dpt_large",
+        "dpt_hybrid",
+        "midas_v21",
+        "midas_v21_small",
+    ]
+
+    def __init__(self, model_type):
+        super().__init__()
+        assert (model_type in self.MODEL_TYPES_ISL)
+        model, _ = load_model(model_type)
+        self.model = model
+        self.model.train = disabled_train
+
+    def forward(self, x):
+        with torch.no_grad():
+            prediction = self.model(x)
+        return prediction
+
diff --git a/annotator/midas/midas/__init__.py b/annotator/midas/midas/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/annotator/midas/midas/base_model.py b/annotator/midas/midas/base_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..5cf430239b47ec5ec07531263f26f5c24a2311cd
--- /dev/null
+++ b/annotator/midas/midas/base_model.py
@@ -0,0 +1,16 @@
+import torch
+
+
+class BaseModel(torch.nn.Module):
+    def load(self, path):
+        """Load model from file.
+
+        Args:
+            path (str): file path
+        """
+        parameters = torch.load(path, map_location=torch.device('cpu'))
+
+        if "optimizer" in parameters:
+            parameters = parameters["model"]
+
+        self.load_state_dict(parameters)
diff --git a/annotator/midas/midas/blocks.py b/annotator/midas/midas/blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..2145d18fa98060a618536d9a64fe6589e9be4f78
--- /dev/null
+++ b/annotator/midas/midas/blocks.py
@@ -0,0 +1,342 @@
+import torch
+import torch.nn as nn
+
+from .vit import (
+    _make_pretrained_vitb_rn50_384,
+    _make_pretrained_vitl16_384,
+    _make_pretrained_vitb16_384,
+    forward_vit,
+)
+
+def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
+    if backbone == "vitl16_384":
+        pretrained = _make_pretrained_vitl16_384(
+            use_pretrained, hooks=hooks, use_readout=use_readout
+        )
+        scratch = _make_scratch(
+            [256, 512, 1024, 1024], features, groups=groups, expand=expand
+        )  # ViT-L/16 - 85.0% Top1 (backbone)
+    elif backbone == "vitb_rn50_384":
+        pretrained = _make_pretrained_vitb_rn50_384(
+            use_pretrained,
+            hooks=hooks,
+            use_vit_only=use_vit_only,
+            use_readout=use_readout,
+        )
+        scratch = _make_scratch(
+            [256, 512, 768, 768], features, groups=groups, expand=expand
+        )  # ViT-H/16 - 85.0% Top1 (backbone)
+    elif backbone == "vitb16_384":
+        pretrained = _make_pretrained_vitb16_384(
+            use_pretrained, hooks=hooks, use_readout=use_readout
+        )
+        scratch = _make_scratch(
+            [96, 192, 384, 768], features, groups=groups, expand=expand
+        )  # ViT-B/16 - 84.6% Top1 (backbone)
+    elif backbone == "resnext101_wsl":
+        pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
+        scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand)     # efficientnet_lite3  
+    elif backbone == "efficientnet_lite3":
+        pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
+        scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand)  # efficientnet_lite3     
+    else:
+        print(f"Backbone '{backbone}' not implemented")
+        assert False
+        
+    return pretrained, scratch
+
+
+def _make_scratch(in_shape, out_shape, groups=1, expand=False):
+    scratch = nn.Module()
+
+    out_shape1 = out_shape
+    out_shape2 = out_shape
+    out_shape3 = out_shape
+    out_shape4 = out_shape
+    if expand==True:
+        out_shape1 = out_shape
+        out_shape2 = out_shape*2
+        out_shape3 = out_shape*4
+        out_shape4 = out_shape*8
+
+    scratch.layer1_rn = nn.Conv2d(
+        in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+    )
+    scratch.layer2_rn = nn.Conv2d(
+        in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+    )
+    scratch.layer3_rn = nn.Conv2d(
+        in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+    )
+    scratch.layer4_rn = nn.Conv2d(
+        in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+    )
+
+    return scratch
+
+
+def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
+    efficientnet = torch.hub.load(
+        "rwightman/gen-efficientnet-pytorch",
+        "tf_efficientnet_lite3",
+        pretrained=use_pretrained,
+        exportable=exportable
+    )
+    return _make_efficientnet_backbone(efficientnet)
+
+
+def _make_efficientnet_backbone(effnet):
+    pretrained = nn.Module()
+
+    pretrained.layer1 = nn.Sequential(
+        effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
+    )
+    pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
+    pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
+    pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
+
+    return pretrained
+    
+
+def _make_resnet_backbone(resnet):
+    pretrained = nn.Module()
+    pretrained.layer1 = nn.Sequential(
+        resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
+    )
+
+    pretrained.layer2 = resnet.layer2
+    pretrained.layer3 = resnet.layer3
+    pretrained.layer4 = resnet.layer4
+
+    return pretrained
+
+
+def _make_pretrained_resnext101_wsl(use_pretrained):
+    resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
+    return _make_resnet_backbone(resnet)
+
+
+
+class Interpolate(nn.Module):
+    """Interpolation module.
+    """
+
+    def __init__(self, scale_factor, mode, align_corners=False):
+        """Init.
+
+        Args:
+            scale_factor (float): scaling
+            mode (str): interpolation mode
+        """
+        super(Interpolate, self).__init__()
+
+        self.interp = nn.functional.interpolate
+        self.scale_factor = scale_factor
+        self.mode = mode
+        self.align_corners = align_corners
+
+    def forward(self, x):
+        """Forward pass.
+
+        Args:
+            x (tensor): input
+
+        Returns:
+            tensor: interpolated data
+        """
+
+        x = self.interp(
+            x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
+        )
+
+        return x
+
+
+class ResidualConvUnit(nn.Module):
+    """Residual convolution module.
+    """
+
+    def __init__(self, features):
+        """Init.
+
+        Args:
+            features (int): number of features
+        """
+        super().__init__()
+
+        self.conv1 = nn.Conv2d(
+            features, features, kernel_size=3, stride=1, padding=1, bias=True
+        )
+
+        self.conv2 = nn.Conv2d(
+            features, features, kernel_size=3, stride=1, padding=1, bias=True
+        )
+
+        self.relu = nn.ReLU(inplace=True)
+
+    def forward(self, x):
+        """Forward pass.
+
+        Args:
+            x (tensor): input
+
+        Returns:
+            tensor: output
+        """
+        out = self.relu(x)
+        out = self.conv1(out)
+        out = self.relu(out)
+        out = self.conv2(out)
+
+        return out + x
+
+
+class FeatureFusionBlock(nn.Module):
+    """Feature fusion block.
+    """
+
+    def __init__(self, features):
+        """Init.
+
+        Args:
+            features (int): number of features
+        """
+        super(FeatureFusionBlock, self).__init__()
+
+        self.resConfUnit1 = ResidualConvUnit(features)
+        self.resConfUnit2 = ResidualConvUnit(features)
+
+    def forward(self, *xs):
+        """Forward pass.
+
+        Returns:
+            tensor: output
+        """
+        output = xs[0]
+
+        if len(xs) == 2:
+            output += self.resConfUnit1(xs[1])
+
+        output = self.resConfUnit2(output)
+
+        output = nn.functional.interpolate(
+            output, scale_factor=2, mode="bilinear", align_corners=True
+        )
+
+        return output
+
+
+
+
+class ResidualConvUnit_custom(nn.Module):
+    """Residual convolution module.
+    """
+
+    def __init__(self, features, activation, bn):
+        """Init.
+
+        Args:
+            features (int): number of features
+        """
+        super().__init__()
+
+        self.bn = bn
+
+        self.groups=1
+
+        self.conv1 = nn.Conv2d(
+            features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
+        )
+        
+        self.conv2 = nn.Conv2d(
+            features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
+        )
+
+        if self.bn==True:
+            self.bn1 = nn.BatchNorm2d(features)
+            self.bn2 = nn.BatchNorm2d(features)
+
+        self.activation = activation
+
+        self.skip_add = nn.quantized.FloatFunctional()
+
+    def forward(self, x):
+        """Forward pass.
+
+        Args:
+            x (tensor): input
+
+        Returns:
+            tensor: output
+        """
+        
+        out = self.activation(x)
+        out = self.conv1(out)
+        if self.bn==True:
+            out = self.bn1(out)
+       
+        out = self.activation(out)
+        out = self.conv2(out)
+        if self.bn==True:
+            out = self.bn2(out)
+
+        if self.groups > 1:
+            out = self.conv_merge(out)
+
+        return self.skip_add.add(out, x)
+
+        # return out + x
+
+
+class FeatureFusionBlock_custom(nn.Module):
+    """Feature fusion block.
+    """
+
+    def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
+        """Init.
+
+        Args:
+            features (int): number of features
+        """
+        super(FeatureFusionBlock_custom, self).__init__()
+
+        self.deconv = deconv
+        self.align_corners = align_corners
+
+        self.groups=1
+
+        self.expand = expand
+        out_features = features
+        if self.expand==True:
+            out_features = features//2
+        
+        self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
+
+        self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
+        self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
+        
+        self.skip_add = nn.quantized.FloatFunctional()
+
+    def forward(self, *xs):
+        """Forward pass.
+
+        Returns:
+            tensor: output
+        """
+        output = xs[0]
+
+        if len(xs) == 2:
+            res = self.resConfUnit1(xs[1])
+            output = self.skip_add.add(output, res)
+            # output += res
+
+        output = self.resConfUnit2(output)
+
+        output = nn.functional.interpolate(
+            output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
+        )
+
+        output = self.out_conv(output)
+
+        return output
+
diff --git a/annotator/midas/midas/dpt_depth.py b/annotator/midas/midas/dpt_depth.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e9aab5d2767dffea39da5b3f30e2798688216f1
--- /dev/null
+++ b/annotator/midas/midas/dpt_depth.py
@@ -0,0 +1,109 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .base_model import BaseModel
+from .blocks import (
+    FeatureFusionBlock,
+    FeatureFusionBlock_custom,
+    Interpolate,
+    _make_encoder,
+    forward_vit,
+)
+
+
+def _make_fusion_block(features, use_bn):
+    return FeatureFusionBlock_custom(
+        features,
+        nn.ReLU(False),
+        deconv=False,
+        bn=use_bn,
+        expand=False,
+        align_corners=True,
+    )
+
+
+class DPT(BaseModel):
+    def __init__(
+        self,
+        head,
+        features=256,
+        backbone="vitb_rn50_384",
+        readout="project",
+        channels_last=False,
+        use_bn=False,
+    ):
+
+        super(DPT, self).__init__()
+
+        self.channels_last = channels_last
+
+        hooks = {
+            "vitb_rn50_384": [0, 1, 8, 11],
+            "vitb16_384": [2, 5, 8, 11],
+            "vitl16_384": [5, 11, 17, 23],
+        }
+
+        # Instantiate backbone and reassemble blocks
+        self.pretrained, self.scratch = _make_encoder(
+            backbone,
+            features,
+            False, # Set to true of you want to train from scratch, uses ImageNet weights
+            groups=1,
+            expand=False,
+            exportable=False,
+            hooks=hooks[backbone],
+            use_readout=readout,
+        )
+
+        self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
+        self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
+        self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
+        self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
+
+        self.scratch.output_conv = head
+
+
+    def forward(self, x):
+        if self.channels_last == True:
+            x.contiguous(memory_format=torch.channels_last)
+
+        layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
+
+        layer_1_rn = self.scratch.layer1_rn(layer_1)
+        layer_2_rn = self.scratch.layer2_rn(layer_2)
+        layer_3_rn = self.scratch.layer3_rn(layer_3)
+        layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+        path_4 = self.scratch.refinenet4(layer_4_rn)
+        path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
+        path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
+        path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+
+        out = self.scratch.output_conv(path_1)
+
+        return out
+
+
+class DPTDepthModel(DPT):
+    def __init__(self, path=None, non_negative=True, **kwargs):
+        features = kwargs["features"] if "features" in kwargs else 256
+
+        head = nn.Sequential(
+            nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
+            Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
+            nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
+            nn.ReLU(True),
+            nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+            nn.ReLU(True) if non_negative else nn.Identity(),
+            nn.Identity(),
+        )
+
+        super().__init__(head, **kwargs)
+
+        if path is not None:
+           self.load(path)
+
+    def forward(self, x):
+        return super().forward(x).squeeze(dim=1)
+
diff --git a/annotator/midas/midas/midas_net.py b/annotator/midas/midas/midas_net.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a954977800b0a0f48807e80fa63041910e33c1f
--- /dev/null
+++ b/annotator/midas/midas/midas_net.py
@@ -0,0 +1,76 @@
+"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
+This file contains code that is adapted from
+https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
+"""
+import torch
+import torch.nn as nn
+
+from .base_model import BaseModel
+from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
+
+
+class MidasNet(BaseModel):
+    """Network for monocular depth estimation.
+    """
+
+    def __init__(self, path=None, features=256, non_negative=True):
+        """Init.
+
+        Args:
+            path (str, optional): Path to saved model. Defaults to None.
+            features (int, optional): Number of features. Defaults to 256.
+            backbone (str, optional): Backbone network for encoder. Defaults to resnet50
+        """
+        print("Loading weights: ", path)
+
+        super(MidasNet, self).__init__()
+
+        use_pretrained = False if path is None else True
+
+        self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
+
+        self.scratch.refinenet4 = FeatureFusionBlock(features)
+        self.scratch.refinenet3 = FeatureFusionBlock(features)
+        self.scratch.refinenet2 = FeatureFusionBlock(features)
+        self.scratch.refinenet1 = FeatureFusionBlock(features)
+
+        self.scratch.output_conv = nn.Sequential(
+            nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
+            Interpolate(scale_factor=2, mode="bilinear"),
+            nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
+            nn.ReLU(True),
+            nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+            nn.ReLU(True) if non_negative else nn.Identity(),
+        )
+
+        if path:
+            self.load(path)
+
+    def forward(self, x):
+        """Forward pass.
+
+        Args:
+            x (tensor): input data (image)
+
+        Returns:
+            tensor: depth
+        """
+
+        layer_1 = self.pretrained.layer1(x)
+        layer_2 = self.pretrained.layer2(layer_1)
+        layer_3 = self.pretrained.layer3(layer_2)
+        layer_4 = self.pretrained.layer4(layer_3)
+
+        layer_1_rn = self.scratch.layer1_rn(layer_1)
+        layer_2_rn = self.scratch.layer2_rn(layer_2)
+        layer_3_rn = self.scratch.layer3_rn(layer_3)
+        layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+        path_4 = self.scratch.refinenet4(layer_4_rn)
+        path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
+        path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
+        path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+
+        out = self.scratch.output_conv(path_1)
+
+        return torch.squeeze(out, dim=1)
diff --git a/annotator/midas/midas/midas_net_custom.py b/annotator/midas/midas/midas_net_custom.py
new file mode 100644
index 0000000000000000000000000000000000000000..50e4acb5e53d5fabefe3dde16ab49c33c2b7797c
--- /dev/null
+++ b/annotator/midas/midas/midas_net_custom.py
@@ -0,0 +1,128 @@
+"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
+This file contains code that is adapted from
+https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
+"""
+import torch
+import torch.nn as nn
+
+from .base_model import BaseModel
+from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
+
+
+class MidasNet_small(BaseModel):
+    """Network for monocular depth estimation.
+    """
+
+    def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
+        blocks={'expand': True}):
+        """Init.
+
+        Args:
+            path (str, optional): Path to saved model. Defaults to None.
+            features (int, optional): Number of features. Defaults to 256.
+            backbone (str, optional): Backbone network for encoder. Defaults to resnet50
+        """
+        print("Loading weights: ", path)
+
+        super(MidasNet_small, self).__init__()
+
+        use_pretrained = False if path else True
+                
+        self.channels_last = channels_last
+        self.blocks = blocks
+        self.backbone = backbone
+
+        self.groups = 1
+
+        features1=features
+        features2=features
+        features3=features
+        features4=features
+        self.expand = False
+        if "expand" in self.blocks and self.blocks['expand'] == True:
+            self.expand = True
+            features1=features
+            features2=features*2
+            features3=features*4
+            features4=features*8
+
+        self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
+  
+        self.scratch.activation = nn.ReLU(False)    
+
+        self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
+        self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
+        self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
+        self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
+
+        
+        self.scratch.output_conv = nn.Sequential(
+            nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
+            Interpolate(scale_factor=2, mode="bilinear"),
+            nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
+            self.scratch.activation,
+            nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+            nn.ReLU(True) if non_negative else nn.Identity(),
+            nn.Identity(),
+        )
+        
+        if path:
+            self.load(path)
+
+
+    def forward(self, x):
+        """Forward pass.
+
+        Args:
+            x (tensor): input data (image)
+
+        Returns:
+            tensor: depth
+        """
+        if self.channels_last==True:
+            print("self.channels_last = ", self.channels_last)
+            x.contiguous(memory_format=torch.channels_last)
+
+
+        layer_1 = self.pretrained.layer1(x)
+        layer_2 = self.pretrained.layer2(layer_1)
+        layer_3 = self.pretrained.layer3(layer_2)
+        layer_4 = self.pretrained.layer4(layer_3)
+        
+        layer_1_rn = self.scratch.layer1_rn(layer_1)
+        layer_2_rn = self.scratch.layer2_rn(layer_2)
+        layer_3_rn = self.scratch.layer3_rn(layer_3)
+        layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+
+        path_4 = self.scratch.refinenet4(layer_4_rn)
+        path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
+        path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
+        path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+        
+        out = self.scratch.output_conv(path_1)
+
+        return torch.squeeze(out, dim=1)
+
+
+
+def fuse_model(m):
+    prev_previous_type = nn.Identity()
+    prev_previous_name = ''
+    previous_type = nn.Identity()
+    previous_name = ''
+    for name, module in m.named_modules():
+        if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
+            # print("FUSED ", prev_previous_name, previous_name, name)
+            torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
+        elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
+            # print("FUSED ", prev_previous_name, previous_name)
+            torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
+        # elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
+        #    print("FUSED ", previous_name, name)
+        #    torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
+
+        prev_previous_type = previous_type
+        prev_previous_name = previous_name
+        previous_type = type(module)
+        previous_name = name
\ No newline at end of file
diff --git a/annotator/midas/midas/transforms.py b/annotator/midas/midas/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..350cbc11662633ad7f8968eb10be2e7de6e384e9
--- /dev/null
+++ b/annotator/midas/midas/transforms.py
@@ -0,0 +1,234 @@
+import numpy as np
+import cv2
+import math
+
+
+def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
+    """Rezise the sample to ensure the given size. Keeps aspect ratio.
+
+    Args:
+        sample (dict): sample
+        size (tuple): image size
+
+    Returns:
+        tuple: new size
+    """
+    shape = list(sample["disparity"].shape)
+
+    if shape[0] >= size[0] and shape[1] >= size[1]:
+        return sample
+
+    scale = [0, 0]
+    scale[0] = size[0] / shape[0]
+    scale[1] = size[1] / shape[1]
+
+    scale = max(scale)
+
+    shape[0] = math.ceil(scale * shape[0])
+    shape[1] = math.ceil(scale * shape[1])
+
+    # resize
+    sample["image"] = cv2.resize(
+        sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
+    )
+
+    sample["disparity"] = cv2.resize(
+        sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
+    )
+    sample["mask"] = cv2.resize(
+        sample["mask"].astype(np.float32),
+        tuple(shape[::-1]),
+        interpolation=cv2.INTER_NEAREST,
+    )
+    sample["mask"] = sample["mask"].astype(bool)
+
+    return tuple(shape)
+
+
+class Resize(object):
+    """Resize sample to given size (width, height).
+    """
+
+    def __init__(
+        self,
+        width,
+        height,
+        resize_target=True,
+        keep_aspect_ratio=False,
+        ensure_multiple_of=1,
+        resize_method="lower_bound",
+        image_interpolation_method=cv2.INTER_AREA,
+    ):
+        """Init.
+
+        Args:
+            width (int): desired output width
+            height (int): desired output height
+            resize_target (bool, optional):
+                True: Resize the full sample (image, mask, target).
+                False: Resize image only.
+                Defaults to True.
+            keep_aspect_ratio (bool, optional):
+                True: Keep the aspect ratio of the input sample.
+                Output sample might not have the given width and height, and
+                resize behaviour depends on the parameter 'resize_method'.
+                Defaults to False.
+            ensure_multiple_of (int, optional):
+                Output width and height is constrained to be multiple of this parameter.
+                Defaults to 1.
+            resize_method (str, optional):
+                "lower_bound": Output will be at least as large as the given size.
+                "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
+                "minimal": Scale as least as possible.  (Output size might be smaller than given size.)
+                Defaults to "lower_bound".
+        """
+        self.__width = width
+        self.__height = height
+
+        self.__resize_target = resize_target
+        self.__keep_aspect_ratio = keep_aspect_ratio
+        self.__multiple_of = ensure_multiple_of
+        self.__resize_method = resize_method
+        self.__image_interpolation_method = image_interpolation_method
+
+    def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
+        y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+        if max_val is not None and y > max_val:
+            y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+        if y < min_val:
+            y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+        return y
+
+    def get_size(self, width, height):
+        # determine new height and width
+        scale_height = self.__height / height
+        scale_width = self.__width / width
+
+        if self.__keep_aspect_ratio:
+            if self.__resize_method == "lower_bound":
+                # scale such that output size is lower bound
+                if scale_width > scale_height:
+                    # fit width
+                    scale_height = scale_width
+                else:
+                    # fit height
+                    scale_width = scale_height
+            elif self.__resize_method == "upper_bound":
+                # scale such that output size is upper bound
+                if scale_width < scale_height:
+                    # fit width
+                    scale_height = scale_width
+                else:
+                    # fit height
+                    scale_width = scale_height
+            elif self.__resize_method == "minimal":
+                # scale as least as possbile
+                if abs(1 - scale_width) < abs(1 - scale_height):
+                    # fit width
+                    scale_height = scale_width
+                else:
+                    # fit height
+                    scale_width = scale_height
+            else:
+                raise ValueError(
+                    f"resize_method {self.__resize_method} not implemented"
+                )
+
+        if self.__resize_method == "lower_bound":
+            new_height = self.constrain_to_multiple_of(
+                scale_height * height, min_val=self.__height
+            )
+            new_width = self.constrain_to_multiple_of(
+                scale_width * width, min_val=self.__width
+            )
+        elif self.__resize_method == "upper_bound":
+            new_height = self.constrain_to_multiple_of(
+                scale_height * height, max_val=self.__height
+            )
+            new_width = self.constrain_to_multiple_of(
+                scale_width * width, max_val=self.__width
+            )
+        elif self.__resize_method == "minimal":
+            new_height = self.constrain_to_multiple_of(scale_height * height)
+            new_width = self.constrain_to_multiple_of(scale_width * width)
+        else:
+            raise ValueError(f"resize_method {self.__resize_method} not implemented")
+
+        return (new_width, new_height)
+
+    def __call__(self, sample):
+        width, height = self.get_size(
+            sample["image"].shape[1], sample["image"].shape[0]
+        )
+
+        # resize sample
+        sample["image"] = cv2.resize(
+            sample["image"],
+            (width, height),
+            interpolation=self.__image_interpolation_method,
+        )
+
+        if self.__resize_target:
+            if "disparity" in sample:
+                sample["disparity"] = cv2.resize(
+                    sample["disparity"],
+                    (width, height),
+                    interpolation=cv2.INTER_NEAREST,
+                )
+
+            if "depth" in sample:
+                sample["depth"] = cv2.resize(
+                    sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
+                )
+
+            sample["mask"] = cv2.resize(
+                sample["mask"].astype(np.float32),
+                (width, height),
+                interpolation=cv2.INTER_NEAREST,
+            )
+            sample["mask"] = sample["mask"].astype(bool)
+
+        return sample
+
+
+class NormalizeImage(object):
+    """Normlize image by given mean and std.
+    """
+
+    def __init__(self, mean, std):
+        self.__mean = mean
+        self.__std = std
+
+    def __call__(self, sample):
+        sample["image"] = (sample["image"] - self.__mean) / self.__std
+
+        return sample
+
+
+class PrepareForNet(object):
+    """Prepare sample for usage as network input.
+    """
+
+    def __init__(self):
+        pass
+
+    def __call__(self, sample):
+        image = np.transpose(sample["image"], (2, 0, 1))
+        sample["image"] = np.ascontiguousarray(image).astype(np.float32)
+
+        if "mask" in sample:
+            sample["mask"] = sample["mask"].astype(np.float32)
+            sample["mask"] = np.ascontiguousarray(sample["mask"])
+
+        if "disparity" in sample:
+            disparity = sample["disparity"].astype(np.float32)
+            sample["disparity"] = np.ascontiguousarray(disparity)
+
+        if "depth" in sample:
+            depth = sample["depth"].astype(np.float32)
+            sample["depth"] = np.ascontiguousarray(depth)
+
+        return sample
diff --git a/annotator/midas/midas/vit.py b/annotator/midas/midas/vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea46b1be88b261b0dec04f3da0256f5f66f88a74
--- /dev/null
+++ b/annotator/midas/midas/vit.py
@@ -0,0 +1,491 @@
+import torch
+import torch.nn as nn
+import timm
+import types
+import math
+import torch.nn.functional as F
+
+
+class Slice(nn.Module):
+    def __init__(self, start_index=1):
+        super(Slice, self).__init__()
+        self.start_index = start_index
+
+    def forward(self, x):
+        return x[:, self.start_index :]
+
+
+class AddReadout(nn.Module):
+    def __init__(self, start_index=1):
+        super(AddReadout, self).__init__()
+        self.start_index = start_index
+
+    def forward(self, x):
+        if self.start_index == 2:
+            readout = (x[:, 0] + x[:, 1]) / 2
+        else:
+            readout = x[:, 0]
+        return x[:, self.start_index :] + readout.unsqueeze(1)
+
+
+class ProjectReadout(nn.Module):
+    def __init__(self, in_features, start_index=1):
+        super(ProjectReadout, self).__init__()
+        self.start_index = start_index
+
+        self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
+
+    def forward(self, x):
+        readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
+        features = torch.cat((x[:, self.start_index :], readout), -1)
+
+        return self.project(features)
+
+
+class Transpose(nn.Module):
+    def __init__(self, dim0, dim1):
+        super(Transpose, self).__init__()
+        self.dim0 = dim0
+        self.dim1 = dim1
+
+    def forward(self, x):
+        x = x.transpose(self.dim0, self.dim1)
+        return x
+
+
+def forward_vit(pretrained, x):
+    b, c, h, w = x.shape
+
+    glob = pretrained.model.forward_flex(x)
+
+    layer_1 = pretrained.activations["1"]
+    layer_2 = pretrained.activations["2"]
+    layer_3 = pretrained.activations["3"]
+    layer_4 = pretrained.activations["4"]
+
+    layer_1 = pretrained.act_postprocess1[0:2](layer_1)
+    layer_2 = pretrained.act_postprocess2[0:2](layer_2)
+    layer_3 = pretrained.act_postprocess3[0:2](layer_3)
+    layer_4 = pretrained.act_postprocess4[0:2](layer_4)
+
+    unflatten = nn.Sequential(
+        nn.Unflatten(
+            2,
+            torch.Size(
+                [
+                    h // pretrained.model.patch_size[1],
+                    w // pretrained.model.patch_size[0],
+                ]
+            ),
+        )
+    )
+
+    if layer_1.ndim == 3:
+        layer_1 = unflatten(layer_1)
+    if layer_2.ndim == 3:
+        layer_2 = unflatten(layer_2)
+    if layer_3.ndim == 3:
+        layer_3 = unflatten(layer_3)
+    if layer_4.ndim == 3:
+        layer_4 = unflatten(layer_4)
+
+    layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
+    layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
+    layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
+    layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
+
+    return layer_1, layer_2, layer_3, layer_4
+
+
+def _resize_pos_embed(self, posemb, gs_h, gs_w):
+    posemb_tok, posemb_grid = (
+        posemb[:, : self.start_index],
+        posemb[0, self.start_index :],
+    )
+
+    gs_old = int(math.sqrt(len(posemb_grid)))
+
+    posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
+    posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
+    posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
+
+    posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
+
+    return posemb
+
+
+def forward_flex(self, x):
+    b, c, h, w = x.shape
+
+    pos_embed = self._resize_pos_embed(
+        self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
+    )
+
+    B = x.shape[0]
+
+    if hasattr(self.patch_embed, "backbone"):
+        x = self.patch_embed.backbone(x)
+        if isinstance(x, (list, tuple)):
+            x = x[-1]  # last feature if backbone outputs list/tuple of features
+
+    x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
+
+    if getattr(self, "dist_token", None) is not None:
+        cls_tokens = self.cls_token.expand(
+            B, -1, -1
+        )  # stole cls_tokens impl from Phil Wang, thanks
+        dist_token = self.dist_token.expand(B, -1, -1)
+        x = torch.cat((cls_tokens, dist_token, x), dim=1)
+    else:
+        cls_tokens = self.cls_token.expand(
+            B, -1, -1
+        )  # stole cls_tokens impl from Phil Wang, thanks
+        x = torch.cat((cls_tokens, x), dim=1)
+
+    x = x + pos_embed
+    x = self.pos_drop(x)
+
+    for blk in self.blocks:
+        x = blk(x)
+
+    x = self.norm(x)
+
+    return x
+
+
+activations = {}
+
+
+def get_activation(name):
+    def hook(model, input, output):
+        activations[name] = output
+
+    return hook
+
+
+def get_readout_oper(vit_features, features, use_readout, start_index=1):
+    if use_readout == "ignore":
+        readout_oper = [Slice(start_index)] * len(features)
+    elif use_readout == "add":
+        readout_oper = [AddReadout(start_index)] * len(features)
+    elif use_readout == "project":
+        readout_oper = [
+            ProjectReadout(vit_features, start_index) for out_feat in features
+        ]
+    else:
+        assert (
+            False
+        ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
+
+    return readout_oper
+
+
+def _make_vit_b16_backbone(
+    model,
+    features=[96, 192, 384, 768],
+    size=[384, 384],
+    hooks=[2, 5, 8, 11],
+    vit_features=768,
+    use_readout="ignore",
+    start_index=1,
+):
+    pretrained = nn.Module()
+
+    pretrained.model = model
+    pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
+    pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
+    pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
+    pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
+
+    pretrained.activations = activations
+
+    readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
+
+    # 32, 48, 136, 384
+    pretrained.act_postprocess1 = nn.Sequential(
+        readout_oper[0],
+        Transpose(1, 2),
+        nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+        nn.Conv2d(
+            in_channels=vit_features,
+            out_channels=features[0],
+            kernel_size=1,
+            stride=1,
+            padding=0,
+        ),
+        nn.ConvTranspose2d(
+            in_channels=features[0],
+            out_channels=features[0],
+            kernel_size=4,
+            stride=4,
+            padding=0,
+            bias=True,
+            dilation=1,
+            groups=1,
+        ),
+    )
+
+    pretrained.act_postprocess2 = nn.Sequential(
+        readout_oper[1],
+        Transpose(1, 2),
+        nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+        nn.Conv2d(
+            in_channels=vit_features,
+            out_channels=features[1],
+            kernel_size=1,
+            stride=1,
+            padding=0,
+        ),
+        nn.ConvTranspose2d(
+            in_channels=features[1],
+            out_channels=features[1],
+            kernel_size=2,
+            stride=2,
+            padding=0,
+            bias=True,
+            dilation=1,
+            groups=1,
+        ),
+    )
+
+    pretrained.act_postprocess3 = nn.Sequential(
+        readout_oper[2],
+        Transpose(1, 2),
+        nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+        nn.Conv2d(
+            in_channels=vit_features,
+            out_channels=features[2],
+            kernel_size=1,
+            stride=1,
+            padding=0,
+        ),
+    )
+
+    pretrained.act_postprocess4 = nn.Sequential(
+        readout_oper[3],
+        Transpose(1, 2),
+        nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+        nn.Conv2d(
+            in_channels=vit_features,
+            out_channels=features[3],
+            kernel_size=1,
+            stride=1,
+            padding=0,
+        ),
+        nn.Conv2d(
+            in_channels=features[3],
+            out_channels=features[3],
+            kernel_size=3,
+            stride=2,
+            padding=1,
+        ),
+    )
+
+    pretrained.model.start_index = start_index
+    pretrained.model.patch_size = [16, 16]
+
+    # We inject this function into the VisionTransformer instances so that
+    # we can use it with interpolated position embeddings without modifying the library source.
+    pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
+    pretrained.model._resize_pos_embed = types.MethodType(
+        _resize_pos_embed, pretrained.model
+    )
+
+    return pretrained
+
+
+def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
+    model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
+
+    hooks = [5, 11, 17, 23] if hooks == None else hooks
+    return _make_vit_b16_backbone(
+        model,
+        features=[256, 512, 1024, 1024],
+        hooks=hooks,
+        vit_features=1024,
+        use_readout=use_readout,
+    )
+
+
+def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
+    model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
+
+    hooks = [2, 5, 8, 11] if hooks == None else hooks
+    return _make_vit_b16_backbone(
+        model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
+    )
+
+
+def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
+    model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
+
+    hooks = [2, 5, 8, 11] if hooks == None else hooks
+    return _make_vit_b16_backbone(
+        model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
+    )
+
+
+def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
+    model = timm.create_model(
+        "vit_deit_base_distilled_patch16_384", pretrained=pretrained
+    )
+
+    hooks = [2, 5, 8, 11] if hooks == None else hooks
+    return _make_vit_b16_backbone(
+        model,
+        features=[96, 192, 384, 768],
+        hooks=hooks,
+        use_readout=use_readout,
+        start_index=2,
+    )
+
+
+def _make_vit_b_rn50_backbone(
+    model,
+    features=[256, 512, 768, 768],
+    size=[384, 384],
+    hooks=[0, 1, 8, 11],
+    vit_features=768,
+    use_vit_only=False,
+    use_readout="ignore",
+    start_index=1,
+):
+    pretrained = nn.Module()
+
+    pretrained.model = model
+
+    if use_vit_only == True:
+        pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
+        pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
+    else:
+        pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
+            get_activation("1")
+        )
+        pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
+            get_activation("2")
+        )
+
+    pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
+    pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
+
+    pretrained.activations = activations
+
+    readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
+
+    if use_vit_only == True:
+        pretrained.act_postprocess1 = nn.Sequential(
+            readout_oper[0],
+            Transpose(1, 2),
+            nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+            nn.Conv2d(
+                in_channels=vit_features,
+                out_channels=features[0],
+                kernel_size=1,
+                stride=1,
+                padding=0,
+            ),
+            nn.ConvTranspose2d(
+                in_channels=features[0],
+                out_channels=features[0],
+                kernel_size=4,
+                stride=4,
+                padding=0,
+                bias=True,
+                dilation=1,
+                groups=1,
+            ),
+        )
+
+        pretrained.act_postprocess2 = nn.Sequential(
+            readout_oper[1],
+            Transpose(1, 2),
+            nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+            nn.Conv2d(
+                in_channels=vit_features,
+                out_channels=features[1],
+                kernel_size=1,
+                stride=1,
+                padding=0,
+            ),
+            nn.ConvTranspose2d(
+                in_channels=features[1],
+                out_channels=features[1],
+                kernel_size=2,
+                stride=2,
+                padding=0,
+                bias=True,
+                dilation=1,
+                groups=1,
+            ),
+        )
+    else:
+        pretrained.act_postprocess1 = nn.Sequential(
+            nn.Identity(), nn.Identity(), nn.Identity()
+        )
+        pretrained.act_postprocess2 = nn.Sequential(
+            nn.Identity(), nn.Identity(), nn.Identity()
+        )
+
+    pretrained.act_postprocess3 = nn.Sequential(
+        readout_oper[2],
+        Transpose(1, 2),
+        nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+        nn.Conv2d(
+            in_channels=vit_features,
+            out_channels=features[2],
+            kernel_size=1,
+            stride=1,
+            padding=0,
+        ),
+    )
+
+    pretrained.act_postprocess4 = nn.Sequential(
+        readout_oper[3],
+        Transpose(1, 2),
+        nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+        nn.Conv2d(
+            in_channels=vit_features,
+            out_channels=features[3],
+            kernel_size=1,
+            stride=1,
+            padding=0,
+        ),
+        nn.Conv2d(
+            in_channels=features[3],
+            out_channels=features[3],
+            kernel_size=3,
+            stride=2,
+            padding=1,
+        ),
+    )
+
+    pretrained.model.start_index = start_index
+    pretrained.model.patch_size = [16, 16]
+
+    # We inject this function into the VisionTransformer instances so that
+    # we can use it with interpolated position embeddings without modifying the library source.
+    pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
+
+    # We inject this function into the VisionTransformer instances so that
+    # we can use it with interpolated position embeddings without modifying the library source.
+    pretrained.model._resize_pos_embed = types.MethodType(
+        _resize_pos_embed, pretrained.model
+    )
+
+    return pretrained
+
+
+def _make_pretrained_vitb_rn50_384(
+    pretrained, use_readout="ignore", hooks=None, use_vit_only=False
+):
+    model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
+
+    hooks = [0, 1, 8, 11] if hooks == None else hooks
+    return _make_vit_b_rn50_backbone(
+        model,
+        features=[256, 512, 768, 768],
+        size=[384, 384],
+        hooks=hooks,
+        use_vit_only=use_vit_only,
+        use_readout=use_readout,
+    )
diff --git a/annotator/midas/utils.py b/annotator/midas/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a9d3b5b66370fa98da9e067ba53ead848ea9a59
--- /dev/null
+++ b/annotator/midas/utils.py
@@ -0,0 +1,189 @@
+"""Utils for monoDepth."""
+import sys
+import re
+import numpy as np
+import cv2
+import torch
+
+
+def read_pfm(path):
+    """Read pfm file.
+
+    Args:
+        path (str): path to file
+
+    Returns:
+        tuple: (data, scale)
+    """
+    with open(path, "rb") as file:
+
+        color = None
+        width = None
+        height = None
+        scale = None
+        endian = None
+
+        header = file.readline().rstrip()
+        if header.decode("ascii") == "PF":
+            color = True
+        elif header.decode("ascii") == "Pf":
+            color = False
+        else:
+            raise Exception("Not a PFM file: " + path)
+
+        dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
+        if dim_match:
+            width, height = list(map(int, dim_match.groups()))
+        else:
+            raise Exception("Malformed PFM header.")
+
+        scale = float(file.readline().decode("ascii").rstrip())
+        if scale < 0:
+            # little-endian
+            endian = "<"
+            scale = -scale
+        else:
+            # big-endian
+            endian = ">"
+
+        data = np.fromfile(file, endian + "f")
+        shape = (height, width, 3) if color else (height, width)
+
+        data = np.reshape(data, shape)
+        data = np.flipud(data)
+
+        return data, scale
+
+
+def write_pfm(path, image, scale=1):
+    """Write pfm file.
+
+    Args:
+        path (str): pathto file
+        image (array): data
+        scale (int, optional): Scale. Defaults to 1.
+    """
+
+    with open(path, "wb") as file:
+        color = None
+
+        if image.dtype.name != "float32":
+            raise Exception("Image dtype must be float32.")
+
+        image = np.flipud(image)
+
+        if len(image.shape) == 3 and image.shape[2] == 3:  # color image
+            color = True
+        elif (
+            len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
+        ):  # greyscale
+            color = False
+        else:
+            raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
+
+        file.write("PF\n" if color else "Pf\n".encode())
+        file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
+
+        endian = image.dtype.byteorder
+
+        if endian == "<" or endian == "=" and sys.byteorder == "little":
+            scale = -scale
+
+        file.write("%f\n".encode() % scale)
+
+        image.tofile(file)
+
+
+def read_image(path):
+    """Read image and output RGB image (0-1).
+
+    Args:
+        path (str): path to file
+
+    Returns:
+        array: RGB image (0-1)
+    """
+    img = cv2.imread(path)
+
+    if img.ndim == 2:
+        img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+
+    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
+
+    return img
+
+
+def resize_image(img):
+    """Resize image and make it fit for network.
+
+    Args:
+        img (array): image
+
+    Returns:
+        tensor: data ready for network
+    """
+    height_orig = img.shape[0]
+    width_orig = img.shape[1]
+
+    if width_orig > height_orig:
+        scale = width_orig / 384
+    else:
+        scale = height_orig / 384
+
+    height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
+    width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
+
+    img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
+
+    img_resized = (
+        torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float()
+    )
+    img_resized = img_resized.unsqueeze(0)
+
+    return img_resized
+
+
+def resize_depth(depth, width, height):
+    """Resize depth map and bring to CPU (numpy).
+
+    Args:
+        depth (tensor): depth
+        width (int): image width
+        height (int): image height
+
+    Returns:
+        array: processed depth
+    """
+    depth = torch.squeeze(depth[0, :, :, :]).to("cpu")
+
+    depth_resized = cv2.resize(
+        depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC
+    )
+
+    return depth_resized
+
+def write_depth(path, depth, bits=1):
+    """Write depth map to pfm and png file.
+
+    Args:
+        path (str): filepath without extension
+        depth (array): depth
+    """
+    write_pfm(path + ".pfm", depth.astype(np.float32))
+
+    depth_min = depth.min()
+    depth_max = depth.max()
+
+    max_val = (2**(8*bits))-1
+
+    if depth_max - depth_min > np.finfo("float").eps:
+        out = max_val * (depth - depth_min) / (depth_max - depth_min)
+    else:
+        out = np.zeros(depth.shape, dtype=depth.type)
+
+    if bits == 1:
+        cv2.imwrite(path + ".png", out.astype("uint8"))
+    elif bits == 2:
+        cv2.imwrite(path + ".png", out.astype("uint16"))
+
+    return
diff --git a/annotator/openpose/__init__.py b/annotator/openpose/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e259c0d5bc9362c83c1b5cbe695f798aa349ed32
--- /dev/null
+++ b/annotator/openpose/__init__.py
@@ -0,0 +1,44 @@
+import os
+os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
+
+import torch
+import numpy as np
+from . import util
+from .body import Body
+from .hand import Hand
+from annotator.util import annotator_ckpts_path
+
+
+body_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/body_pose_model.pth"
+hand_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/hand_pose_model.pth"
+
+
+class OpenposeDetector:
+    def __init__(self):
+        body_modelpath = os.path.join(annotator_ckpts_path, "body_pose_model.pth")
+        hand_modelpath = os.path.join(annotator_ckpts_path, "hand_pose_model.pth")
+
+        if not os.path.exists(hand_modelpath):
+            from basicsr.utils.download_util import load_file_from_url
+            load_file_from_url(body_model_path, model_dir=annotator_ckpts_path)
+            load_file_from_url(hand_model_path, model_dir=annotator_ckpts_path)
+
+        self.body_estimation = Body(body_modelpath)
+        self.hand_estimation = Hand(hand_modelpath)
+
+    def __call__(self, oriImg, hand=False):
+        oriImg = oriImg[:, :, ::-1].copy()
+        with torch.no_grad():
+            candidate, subset = self.body_estimation(oriImg)
+            canvas = np.zeros_like(oriImg)
+            canvas = util.draw_bodypose(canvas, candidate, subset)
+            if hand:
+                hands_list = util.handDetect(candidate, subset, oriImg)
+                all_hand_peaks = []
+                for x, y, w, is_left in hands_list:
+                    peaks = self.hand_estimation(oriImg[y:y+w, x:x+w, :])
+                    peaks[:, 0] = np.where(peaks[:, 0] == 0, peaks[:, 0], peaks[:, 0] + x)
+                    peaks[:, 1] = np.where(peaks[:, 1] == 0, peaks[:, 1], peaks[:, 1] + y)
+                    all_hand_peaks.append(peaks)
+                canvas = util.draw_handpose(canvas, all_hand_peaks)
+            return canvas 
diff --git a/annotator/openpose/body.py b/annotator/openpose/body.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c3cf7a388b4ac81004524e64125e383bdd455bd
--- /dev/null
+++ b/annotator/openpose/body.py
@@ -0,0 +1,219 @@
+import cv2
+import numpy as np
+import math
+import time
+from scipy.ndimage.filters import gaussian_filter
+import matplotlib.pyplot as plt
+import matplotlib
+import torch
+from torchvision import transforms
+
+from . import util
+from .model import bodypose_model
+
+class Body(object):
+    def __init__(self, model_path):
+        self.model = bodypose_model()
+        if torch.cuda.is_available():
+            self.model = self.model.cuda()
+            print('cuda')
+        model_dict = util.transfer(self.model, torch.load(model_path))
+        self.model.load_state_dict(model_dict)
+        self.model.eval()
+
+    def __call__(self, oriImg):
+        # scale_search = [0.5, 1.0, 1.5, 2.0]
+        scale_search = [0.5]
+        boxsize = 368
+        stride = 8
+        padValue = 128
+        thre1 = 0.1
+        thre2 = 0.05
+        multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search]
+        heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 19))
+        paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38))
+
+        for m in range(len(multiplier)):
+            scale = multiplier[m]
+            imageToTest = cv2.resize(oriImg, (0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
+            imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue)
+            im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5
+            im = np.ascontiguousarray(im)
+
+            data = torch.from_numpy(im).float()
+            if torch.cuda.is_available():
+                data = data.cuda()
+            # data = data.permute([2, 0, 1]).unsqueeze(0).float()
+            with torch.no_grad():
+                Mconv7_stage6_L1, Mconv7_stage6_L2 = self.model(data)
+            Mconv7_stage6_L1 = Mconv7_stage6_L1.cpu().numpy()
+            Mconv7_stage6_L2 = Mconv7_stage6_L2.cpu().numpy()
+
+            # extract outputs, resize, and remove padding
+            # heatmap = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[1]].data), (1, 2, 0))  # output 1 is heatmaps
+            heatmap = np.transpose(np.squeeze(Mconv7_stage6_L2), (1, 2, 0))  # output 1 is heatmaps
+            heatmap = cv2.resize(heatmap, (0, 0), fx=stride, fy=stride, interpolation=cv2.INTER_CUBIC)
+            heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :]
+            heatmap = cv2.resize(heatmap, (oriImg.shape[1], oriImg.shape[0]), interpolation=cv2.INTER_CUBIC)
+
+            # paf = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[0]].data), (1, 2, 0))  # output 0 is PAFs
+            paf = np.transpose(np.squeeze(Mconv7_stage6_L1), (1, 2, 0))  # output 0 is PAFs
+            paf = cv2.resize(paf, (0, 0), fx=stride, fy=stride, interpolation=cv2.INTER_CUBIC)
+            paf = paf[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :]
+            paf = cv2.resize(paf, (oriImg.shape[1], oriImg.shape[0]), interpolation=cv2.INTER_CUBIC)
+
+            heatmap_avg += heatmap_avg + heatmap / len(multiplier)
+            paf_avg += + paf / len(multiplier)
+
+        all_peaks = []
+        peak_counter = 0
+
+        for part in range(18):
+            map_ori = heatmap_avg[:, :, part]
+            one_heatmap = gaussian_filter(map_ori, sigma=3)
+
+            map_left = np.zeros(one_heatmap.shape)
+            map_left[1:, :] = one_heatmap[:-1, :]
+            map_right = np.zeros(one_heatmap.shape)
+            map_right[:-1, :] = one_heatmap[1:, :]
+            map_up = np.zeros(one_heatmap.shape)
+            map_up[:, 1:] = one_heatmap[:, :-1]
+            map_down = np.zeros(one_heatmap.shape)
+            map_down[:, :-1] = one_heatmap[:, 1:]
+
+            peaks_binary = np.logical_and.reduce(
+                (one_heatmap >= map_left, one_heatmap >= map_right, one_heatmap >= map_up, one_heatmap >= map_down, one_heatmap > thre1))
+            peaks = list(zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0]))  # note reverse
+            peaks_with_score = [x + (map_ori[x[1], x[0]],) for x in peaks]
+            peak_id = range(peak_counter, peak_counter + len(peaks))
+            peaks_with_score_and_id = [peaks_with_score[i] + (peak_id[i],) for i in range(len(peak_id))]
+
+            all_peaks.append(peaks_with_score_and_id)
+            peak_counter += len(peaks)
+
+        # find connection in the specified sequence, center 29 is in the position 15
+        limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
+                   [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
+                   [1, 16], [16, 18], [3, 17], [6, 18]]
+        # the middle joints heatmap correpondence
+        mapIdx = [[31, 32], [39, 40], [33, 34], [35, 36], [41, 42], [43, 44], [19, 20], [21, 22], \
+                  [23, 24], [25, 26], [27, 28], [29, 30], [47, 48], [49, 50], [53, 54], [51, 52], \
+                  [55, 56], [37, 38], [45, 46]]
+
+        connection_all = []
+        special_k = []
+        mid_num = 10
+
+        for k in range(len(mapIdx)):
+            score_mid = paf_avg[:, :, [x - 19 for x in mapIdx[k]]]
+            candA = all_peaks[limbSeq[k][0] - 1]
+            candB = all_peaks[limbSeq[k][1] - 1]
+            nA = len(candA)
+            nB = len(candB)
+            indexA, indexB = limbSeq[k]
+            if (nA != 0 and nB != 0):
+                connection_candidate = []
+                for i in range(nA):
+                    for j in range(nB):
+                        vec = np.subtract(candB[j][:2], candA[i][:2])
+                        norm = math.sqrt(vec[0] * vec[0] + vec[1] * vec[1])
+                        norm = max(0.001, norm)
+                        vec = np.divide(vec, norm)
+
+                        startend = list(zip(np.linspace(candA[i][0], candB[j][0], num=mid_num), \
+                                            np.linspace(candA[i][1], candB[j][1], num=mid_num)))
+
+                        vec_x = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 0] \
+                                          for I in range(len(startend))])
+                        vec_y = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 1] \
+                                          for I in range(len(startend))])
+
+                        score_midpts = np.multiply(vec_x, vec[0]) + np.multiply(vec_y, vec[1])
+                        score_with_dist_prior = sum(score_midpts) / len(score_midpts) + min(
+                            0.5 * oriImg.shape[0] / norm - 1, 0)
+                        criterion1 = len(np.nonzero(score_midpts > thre2)[0]) > 0.8 * len(score_midpts)
+                        criterion2 = score_with_dist_prior > 0
+                        if criterion1 and criterion2:
+                            connection_candidate.append(
+                                [i, j, score_with_dist_prior, score_with_dist_prior + candA[i][2] + candB[j][2]])
+
+                connection_candidate = sorted(connection_candidate, key=lambda x: x[2], reverse=True)
+                connection = np.zeros((0, 5))
+                for c in range(len(connection_candidate)):
+                    i, j, s = connection_candidate[c][0:3]
+                    if (i not in connection[:, 3] and j not in connection[:, 4]):
+                        connection = np.vstack([connection, [candA[i][3], candB[j][3], s, i, j]])
+                        if (len(connection) >= min(nA, nB)):
+                            break
+
+                connection_all.append(connection)
+            else:
+                special_k.append(k)
+                connection_all.append([])
+
+        # last number in each row is the total parts number of that person
+        # the second last number in each row is the score of the overall configuration
+        subset = -1 * np.ones((0, 20))
+        candidate = np.array([item for sublist in all_peaks for item in sublist])
+
+        for k in range(len(mapIdx)):
+            if k not in special_k:
+                partAs = connection_all[k][:, 0]
+                partBs = connection_all[k][:, 1]
+                indexA, indexB = np.array(limbSeq[k]) - 1
+
+                for i in range(len(connection_all[k])):  # = 1:size(temp,1)
+                    found = 0
+                    subset_idx = [-1, -1]
+                    for j in range(len(subset)):  # 1:size(subset,1):
+                        if subset[j][indexA] == partAs[i] or subset[j][indexB] == partBs[i]:
+                            subset_idx[found] = j
+                            found += 1
+
+                    if found == 1:
+                        j = subset_idx[0]
+                        if subset[j][indexB] != partBs[i]:
+                            subset[j][indexB] = partBs[i]
+                            subset[j][-1] += 1
+                            subset[j][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2]
+                    elif found == 2:  # if found 2 and disjoint, merge them
+                        j1, j2 = subset_idx
+                        membership = ((subset[j1] >= 0).astype(int) + (subset[j2] >= 0).astype(int))[:-2]
+                        if len(np.nonzero(membership == 2)[0]) == 0:  # merge
+                            subset[j1][:-2] += (subset[j2][:-2] + 1)
+                            subset[j1][-2:] += subset[j2][-2:]
+                            subset[j1][-2] += connection_all[k][i][2]
+                            subset = np.delete(subset, j2, 0)
+                        else:  # as like found == 1
+                            subset[j1][indexB] = partBs[i]
+                            subset[j1][-1] += 1
+                            subset[j1][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2]
+
+                    # if find no partA in the subset, create a new subset
+                    elif not found and k < 17:
+                        row = -1 * np.ones(20)
+                        row[indexA] = partAs[i]
+                        row[indexB] = partBs[i]
+                        row[-1] = 2
+                        row[-2] = sum(candidate[connection_all[k][i, :2].astype(int), 2]) + connection_all[k][i][2]
+                        subset = np.vstack([subset, row])
+        # delete some rows of subset which has few parts occur
+        deleteIdx = []
+        for i in range(len(subset)):
+            if subset[i][-1] < 4 or subset[i][-2] / subset[i][-1] < 0.4:
+                deleteIdx.append(i)
+        subset = np.delete(subset, deleteIdx, axis=0)
+
+        # subset: n*20 array, 0-17 is the index in candidate, 18 is the total score, 19 is the total parts
+        # candidate: x, y, score, id
+        return candidate, subset
+
+if __name__ == "__main__":
+    body_estimation = Body('../model/body_pose_model.pth')
+
+    test_image = '../images/ski.jpg'
+    oriImg = cv2.imread(test_image)  # B,G,R order
+    candidate, subset = body_estimation(oriImg)
+    canvas = util.draw_bodypose(oriImg, candidate, subset)
+    plt.imshow(canvas[:, :, [2, 1, 0]])
+    plt.show()
diff --git a/annotator/openpose/hand.py b/annotator/openpose/hand.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d0bf17165ad7eb225332b51f4a2aa16718664b2
--- /dev/null
+++ b/annotator/openpose/hand.py
@@ -0,0 +1,86 @@
+import cv2
+import json
+import numpy as np
+import math
+import time
+from scipy.ndimage.filters import gaussian_filter
+import matplotlib.pyplot as plt
+import matplotlib
+import torch
+from skimage.measure import label
+
+from .model import handpose_model
+from . import util
+
+class Hand(object):
+    def __init__(self, model_path):
+        self.model = handpose_model()
+        if torch.cuda.is_available():
+            self.model = self.model.cuda()
+            print('cuda')
+        model_dict = util.transfer(self.model, torch.load(model_path))
+        self.model.load_state_dict(model_dict)
+        self.model.eval()
+
+    def __call__(self, oriImg):
+        scale_search = [0.5, 1.0, 1.5, 2.0]
+        # scale_search = [0.5]
+        boxsize = 368
+        stride = 8
+        padValue = 128
+        thre = 0.05
+        multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search]
+        heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 22))
+        # paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38))
+
+        for m in range(len(multiplier)):
+            scale = multiplier[m]
+            imageToTest = cv2.resize(oriImg, (0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
+            imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue)
+            im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5
+            im = np.ascontiguousarray(im)
+
+            data = torch.from_numpy(im).float()
+            if torch.cuda.is_available():
+                data = data.cuda()
+            # data = data.permute([2, 0, 1]).unsqueeze(0).float()
+            with torch.no_grad():
+                output = self.model(data).cpu().numpy()
+                # output = self.model(data).numpy()q
+
+            # extract outputs, resize, and remove padding
+            heatmap = np.transpose(np.squeeze(output), (1, 2, 0))  # output 1 is heatmaps
+            heatmap = cv2.resize(heatmap, (0, 0), fx=stride, fy=stride, interpolation=cv2.INTER_CUBIC)
+            heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :]
+            heatmap = cv2.resize(heatmap, (oriImg.shape[1], oriImg.shape[0]), interpolation=cv2.INTER_CUBIC)
+
+            heatmap_avg += heatmap / len(multiplier)
+
+        all_peaks = []
+        for part in range(21):
+            map_ori = heatmap_avg[:, :, part]
+            one_heatmap = gaussian_filter(map_ori, sigma=3)
+            binary = np.ascontiguousarray(one_heatmap > thre, dtype=np.uint8)
+            # 全部小于阈值
+            if np.sum(binary) == 0:
+                all_peaks.append([0, 0])
+                continue
+            label_img, label_numbers = label(binary, return_num=True, connectivity=binary.ndim)
+            max_index = np.argmax([np.sum(map_ori[label_img == i]) for i in range(1, label_numbers + 1)]) + 1
+            label_img[label_img != max_index] = 0
+            map_ori[label_img == 0] = 0
+
+            y, x = util.npmax(map_ori)
+            all_peaks.append([x, y])
+        return np.array(all_peaks)
+
+if __name__ == "__main__":
+    hand_estimation = Hand('../model/hand_pose_model.pth')
+
+    # test_image = '../images/hand.jpg'
+    test_image = '../images/hand.jpg'
+    oriImg = cv2.imread(test_image)  # B,G,R order
+    peaks = hand_estimation(oriImg)
+    canvas = util.draw_handpose(oriImg, peaks, True)
+    cv2.imshow('', canvas)
+    cv2.waitKey(0)
\ No newline at end of file
diff --git a/annotator/openpose/model.py b/annotator/openpose/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..5dfc80de827a17beccb9b0f3f7588545be78c9de
--- /dev/null
+++ b/annotator/openpose/model.py
@@ -0,0 +1,219 @@
+import torch
+from collections import OrderedDict
+
+import torch
+import torch.nn as nn
+
+def make_layers(block, no_relu_layers):
+    layers = []
+    for layer_name, v in block.items():
+        if 'pool' in layer_name:
+            layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1],
+                                    padding=v[2])
+            layers.append((layer_name, layer))
+        else:
+            conv2d = nn.Conv2d(in_channels=v[0], out_channels=v[1],
+                               kernel_size=v[2], stride=v[3],
+                               padding=v[4])
+            layers.append((layer_name, conv2d))
+            if layer_name not in no_relu_layers:
+                layers.append(('relu_'+layer_name, nn.ReLU(inplace=True)))
+
+    return nn.Sequential(OrderedDict(layers))
+
+class bodypose_model(nn.Module):
+    def __init__(self):
+        super(bodypose_model, self).__init__()
+
+        # these layers have no relu layer
+        no_relu_layers = ['conv5_5_CPM_L1', 'conv5_5_CPM_L2', 'Mconv7_stage2_L1',\
+                          'Mconv7_stage2_L2', 'Mconv7_stage3_L1', 'Mconv7_stage3_L2',\
+                          'Mconv7_stage4_L1', 'Mconv7_stage4_L2', 'Mconv7_stage5_L1',\
+                          'Mconv7_stage5_L2', 'Mconv7_stage6_L1', 'Mconv7_stage6_L1']
+        blocks = {}
+        block0 = OrderedDict([
+                      ('conv1_1', [3, 64, 3, 1, 1]),
+                      ('conv1_2', [64, 64, 3, 1, 1]),
+                      ('pool1_stage1', [2, 2, 0]),
+                      ('conv2_1', [64, 128, 3, 1, 1]),
+                      ('conv2_2', [128, 128, 3, 1, 1]),
+                      ('pool2_stage1', [2, 2, 0]),
+                      ('conv3_1', [128, 256, 3, 1, 1]),
+                      ('conv3_2', [256, 256, 3, 1, 1]),
+                      ('conv3_3', [256, 256, 3, 1, 1]),
+                      ('conv3_4', [256, 256, 3, 1, 1]),
+                      ('pool3_stage1', [2, 2, 0]),
+                      ('conv4_1', [256, 512, 3, 1, 1]),
+                      ('conv4_2', [512, 512, 3, 1, 1]),
+                      ('conv4_3_CPM', [512, 256, 3, 1, 1]),
+                      ('conv4_4_CPM', [256, 128, 3, 1, 1])
+                  ])
+
+
+        # Stage 1
+        block1_1 = OrderedDict([
+                        ('conv5_1_CPM_L1', [128, 128, 3, 1, 1]),
+                        ('conv5_2_CPM_L1', [128, 128, 3, 1, 1]),
+                        ('conv5_3_CPM_L1', [128, 128, 3, 1, 1]),
+                        ('conv5_4_CPM_L1', [128, 512, 1, 1, 0]),
+                        ('conv5_5_CPM_L1', [512, 38, 1, 1, 0])
+                    ])
+
+        block1_2 = OrderedDict([
+                        ('conv5_1_CPM_L2', [128, 128, 3, 1, 1]),
+                        ('conv5_2_CPM_L2', [128, 128, 3, 1, 1]),
+                        ('conv5_3_CPM_L2', [128, 128, 3, 1, 1]),
+                        ('conv5_4_CPM_L2', [128, 512, 1, 1, 0]),
+                        ('conv5_5_CPM_L2', [512, 19, 1, 1, 0])
+                    ])
+        blocks['block1_1'] = block1_1
+        blocks['block1_2'] = block1_2
+
+        self.model0 = make_layers(block0, no_relu_layers)
+
+        # Stages 2 - 6
+        for i in range(2, 7):
+            blocks['block%d_1' % i] = OrderedDict([
+                    ('Mconv1_stage%d_L1' % i, [185, 128, 7, 1, 3]),
+                    ('Mconv2_stage%d_L1' % i, [128, 128, 7, 1, 3]),
+                    ('Mconv3_stage%d_L1' % i, [128, 128, 7, 1, 3]),
+                    ('Mconv4_stage%d_L1' % i, [128, 128, 7, 1, 3]),
+                    ('Mconv5_stage%d_L1' % i, [128, 128, 7, 1, 3]),
+                    ('Mconv6_stage%d_L1' % i, [128, 128, 1, 1, 0]),
+                    ('Mconv7_stage%d_L1' % i, [128, 38, 1, 1, 0])
+                ])
+
+            blocks['block%d_2' % i] = OrderedDict([
+                    ('Mconv1_stage%d_L2' % i, [185, 128, 7, 1, 3]),
+                    ('Mconv2_stage%d_L2' % i, [128, 128, 7, 1, 3]),
+                    ('Mconv3_stage%d_L2' % i, [128, 128, 7, 1, 3]),
+                    ('Mconv4_stage%d_L2' % i, [128, 128, 7, 1, 3]),
+                    ('Mconv5_stage%d_L2' % i, [128, 128, 7, 1, 3]),
+                    ('Mconv6_stage%d_L2' % i, [128, 128, 1, 1, 0]),
+                    ('Mconv7_stage%d_L2' % i, [128, 19, 1, 1, 0])
+                ])
+
+        for k in blocks.keys():
+            blocks[k] = make_layers(blocks[k], no_relu_layers)
+
+        self.model1_1 = blocks['block1_1']
+        self.model2_1 = blocks['block2_1']
+        self.model3_1 = blocks['block3_1']
+        self.model4_1 = blocks['block4_1']
+        self.model5_1 = blocks['block5_1']
+        self.model6_1 = blocks['block6_1']
+
+        self.model1_2 = blocks['block1_2']
+        self.model2_2 = blocks['block2_2']
+        self.model3_2 = blocks['block3_2']
+        self.model4_2 = blocks['block4_2']
+        self.model5_2 = blocks['block5_2']
+        self.model6_2 = blocks['block6_2']
+
+
+    def forward(self, x):
+
+        out1 = self.model0(x)
+
+        out1_1 = self.model1_1(out1)
+        out1_2 = self.model1_2(out1)
+        out2 = torch.cat([out1_1, out1_2, out1], 1)
+
+        out2_1 = self.model2_1(out2)
+        out2_2 = self.model2_2(out2)
+        out3 = torch.cat([out2_1, out2_2, out1], 1)
+
+        out3_1 = self.model3_1(out3)
+        out3_2 = self.model3_2(out3)
+        out4 = torch.cat([out3_1, out3_2, out1], 1)
+
+        out4_1 = self.model4_1(out4)
+        out4_2 = self.model4_2(out4)
+        out5 = torch.cat([out4_1, out4_2, out1], 1)
+
+        out5_1 = self.model5_1(out5)
+        out5_2 = self.model5_2(out5)
+        out6 = torch.cat([out5_1, out5_2, out1], 1)
+
+        out6_1 = self.model6_1(out6)
+        out6_2 = self.model6_2(out6)
+
+        return out6_1, out6_2
+
+class handpose_model(nn.Module):
+    def __init__(self):
+        super(handpose_model, self).__init__()
+
+        # these layers have no relu layer
+        no_relu_layers = ['conv6_2_CPM', 'Mconv7_stage2', 'Mconv7_stage3',\
+                          'Mconv7_stage4', 'Mconv7_stage5', 'Mconv7_stage6']
+        # stage 1
+        block1_0 = OrderedDict([
+                ('conv1_1', [3, 64, 3, 1, 1]),
+                ('conv1_2', [64, 64, 3, 1, 1]),
+                ('pool1_stage1', [2, 2, 0]),
+                ('conv2_1', [64, 128, 3, 1, 1]),
+                ('conv2_2', [128, 128, 3, 1, 1]),
+                ('pool2_stage1', [2, 2, 0]),
+                ('conv3_1', [128, 256, 3, 1, 1]),
+                ('conv3_2', [256, 256, 3, 1, 1]),
+                ('conv3_3', [256, 256, 3, 1, 1]),
+                ('conv3_4', [256, 256, 3, 1, 1]),
+                ('pool3_stage1', [2, 2, 0]),
+                ('conv4_1', [256, 512, 3, 1, 1]),
+                ('conv4_2', [512, 512, 3, 1, 1]),
+                ('conv4_3', [512, 512, 3, 1, 1]),
+                ('conv4_4', [512, 512, 3, 1, 1]),
+                ('conv5_1', [512, 512, 3, 1, 1]),
+                ('conv5_2', [512, 512, 3, 1, 1]),
+                ('conv5_3_CPM', [512, 128, 3, 1, 1])
+            ])
+
+        block1_1 = OrderedDict([
+            ('conv6_1_CPM', [128, 512, 1, 1, 0]),
+            ('conv6_2_CPM', [512, 22, 1, 1, 0])
+        ])
+
+        blocks = {}
+        blocks['block1_0'] = block1_0
+        blocks['block1_1'] = block1_1
+
+        # stage 2-6
+        for i in range(2, 7):
+            blocks['block%d' % i] = OrderedDict([
+                    ('Mconv1_stage%d' % i, [150, 128, 7, 1, 3]),
+                    ('Mconv2_stage%d' % i, [128, 128, 7, 1, 3]),
+                    ('Mconv3_stage%d' % i, [128, 128, 7, 1, 3]),
+                    ('Mconv4_stage%d' % i, [128, 128, 7, 1, 3]),
+                    ('Mconv5_stage%d' % i, [128, 128, 7, 1, 3]),
+                    ('Mconv6_stage%d' % i, [128, 128, 1, 1, 0]),
+                    ('Mconv7_stage%d' % i, [128, 22, 1, 1, 0])
+                ])
+
+        for k in blocks.keys():
+            blocks[k] = make_layers(blocks[k], no_relu_layers)
+
+        self.model1_0 = blocks['block1_0']
+        self.model1_1 = blocks['block1_1']
+        self.model2 = blocks['block2']
+        self.model3 = blocks['block3']
+        self.model4 = blocks['block4']
+        self.model5 = blocks['block5']
+        self.model6 = blocks['block6']
+
+    def forward(self, x):
+        out1_0 = self.model1_0(x)
+        out1_1 = self.model1_1(out1_0)
+        concat_stage2 = torch.cat([out1_1, out1_0], 1)
+        out_stage2 = self.model2(concat_stage2)
+        concat_stage3 = torch.cat([out_stage2, out1_0], 1)
+        out_stage3 = self.model3(concat_stage3)
+        concat_stage4 = torch.cat([out_stage3, out1_0], 1)
+        out_stage4 = self.model4(concat_stage4)
+        concat_stage5 = torch.cat([out_stage4, out1_0], 1)
+        out_stage5 = self.model5(concat_stage5)
+        concat_stage6 = torch.cat([out_stage5, out1_0], 1)
+        out_stage6 = self.model6(concat_stage6)
+        return out_stage6
+
+
diff --git a/annotator/openpose/util.py b/annotator/openpose/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f91ae0e65abaf0cbd62d803f56498991141e61b
--- /dev/null
+++ b/annotator/openpose/util.py
@@ -0,0 +1,164 @@
+import math
+import numpy as np
+import matplotlib
+import cv2
+
+
+def padRightDownCorner(img, stride, padValue):
+    h = img.shape[0]
+    w = img.shape[1]
+
+    pad = 4 * [None]
+    pad[0] = 0 # up
+    pad[1] = 0 # left
+    pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down
+    pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right
+
+    img_padded = img
+    pad_up = np.tile(img_padded[0:1, :, :]*0 + padValue, (pad[0], 1, 1))
+    img_padded = np.concatenate((pad_up, img_padded), axis=0)
+    pad_left = np.tile(img_padded[:, 0:1, :]*0 + padValue, (1, pad[1], 1))
+    img_padded = np.concatenate((pad_left, img_padded), axis=1)
+    pad_down = np.tile(img_padded[-2:-1, :, :]*0 + padValue, (pad[2], 1, 1))
+    img_padded = np.concatenate((img_padded, pad_down), axis=0)
+    pad_right = np.tile(img_padded[:, -2:-1, :]*0 + padValue, (1, pad[3], 1))
+    img_padded = np.concatenate((img_padded, pad_right), axis=1)
+
+    return img_padded, pad
+
+# transfer caffe model to pytorch which will match the layer name
+def transfer(model, model_weights):
+    transfered_model_weights = {}
+    for weights_name in model.state_dict().keys():
+        transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])]
+    return transfered_model_weights
+
+# draw the body keypoint and lims
+def draw_bodypose(canvas, candidate, subset):
+    stickwidth = 4
+    limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
+               [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
+               [1, 16], [16, 18], [3, 17], [6, 18]]
+
+    colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
+              [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
+              [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
+    for i in range(18):
+        for n in range(len(subset)):
+            index = int(subset[n][i])
+            if index == -1:
+                continue
+            x, y = candidate[index][0:2]
+            cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1)
+    for i in range(17):
+        for n in range(len(subset)):
+            index = subset[n][np.array(limbSeq[i]) - 1]
+            if -1 in index:
+                continue
+            cur_canvas = canvas.copy()
+            Y = candidate[index.astype(int), 0]
+            X = candidate[index.astype(int), 1]
+            mX = np.mean(X)
+            mY = np.mean(Y)
+            length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
+            angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
+            polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
+            cv2.fillConvexPoly(cur_canvas, polygon, colors[i])
+            canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
+    # plt.imsave("preview.jpg", canvas[:, :, [2, 1, 0]])
+    # plt.imshow(canvas[:, :, [2, 1, 0]])
+    return canvas
+
+
+# image drawed by opencv is not good.
+def draw_handpose(canvas, all_hand_peaks, show_number=False):
+    edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \
+             [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]]
+
+    for peaks in all_hand_peaks:
+        for ie, e in enumerate(edges):
+            if np.sum(np.all(peaks[e], axis=1)==0)==0:
+                x1, y1 = peaks[e[0]]
+                x2, y2 = peaks[e[1]]
+                cv2.line(canvas, (x1, y1), (x2, y2), matplotlib.colors.hsv_to_rgb([ie/float(len(edges)), 1.0, 1.0])*255, thickness=2)
+
+        for i, keyponit in enumerate(peaks):
+            x, y = keyponit
+            cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
+            if show_number:
+                cv2.putText(canvas, str(i), (x, y), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (0, 0, 0), lineType=cv2.LINE_AA)
+    return canvas
+
+# detect hand according to body pose keypoints
+# please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp
+def handDetect(candidate, subset, oriImg):
+    # right hand: wrist 4, elbow 3, shoulder 2
+    # left hand: wrist 7, elbow 6, shoulder 5
+    ratioWristElbow = 0.33
+    detect_result = []
+    image_height, image_width = oriImg.shape[0:2]
+    for person in subset.astype(int):
+        # if any of three not detected
+        has_left = np.sum(person[[5, 6, 7]] == -1) == 0
+        has_right = np.sum(person[[2, 3, 4]] == -1) == 0
+        if not (has_left or has_right):
+            continue
+        hands = []
+        #left hand
+        if has_left:
+            left_shoulder_index, left_elbow_index, left_wrist_index = person[[5, 6, 7]]
+            x1, y1 = candidate[left_shoulder_index][:2]
+            x2, y2 = candidate[left_elbow_index][:2]
+            x3, y3 = candidate[left_wrist_index][:2]
+            hands.append([x1, y1, x2, y2, x3, y3, True])
+        # right hand
+        if has_right:
+            right_shoulder_index, right_elbow_index, right_wrist_index = person[[2, 3, 4]]
+            x1, y1 = candidate[right_shoulder_index][:2]
+            x2, y2 = candidate[right_elbow_index][:2]
+            x3, y3 = candidate[right_wrist_index][:2]
+            hands.append([x1, y1, x2, y2, x3, y3, False])
+
+        for x1, y1, x2, y2, x3, y3, is_left in hands:
+            # pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox
+            # handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]);
+            # handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]);
+            # const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow);
+            # const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder);
+            # handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder);
+            x = x3 + ratioWristElbow * (x3 - x2)
+            y = y3 + ratioWristElbow * (y3 - y2)
+            distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2)
+            distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
+            width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder)
+            # x-y refers to the center --> offset to topLeft point
+            # handRectangle.x -= handRectangle.width / 2.f;
+            # handRectangle.y -= handRectangle.height / 2.f;
+            x -= width / 2
+            y -= width / 2  # width = height
+            # overflow the image
+            if x < 0: x = 0
+            if y < 0: y = 0
+            width1 = width
+            width2 = width
+            if x + width > image_width: width1 = image_width - x
+            if y + width > image_height: width2 = image_height - y
+            width = min(width1, width2)
+            # the max hand box value is 20 pixels
+            if width >= 20:
+                detect_result.append([int(x), int(y), int(width), is_left])
+
+    '''
+    return value: [[x, y, w, True if left hand else False]].
+    width=height since the network require squared input.
+    x, y is the coordinate of top left 
+    '''
+    return detect_result
+
+# get max index of 2d array
+def npmax(array):
+    arrayindex = array.argmax(1)
+    arrayvalue = array.max(1)
+    i = arrayvalue.argmax()
+    j = arrayindex[i]
+    return i, j
diff --git a/annotator/util.py b/annotator/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..90831643d19cc1b9b0940df3d4fd4d846ba74a05
--- /dev/null
+++ b/annotator/util.py
@@ -0,0 +1,38 @@
+import numpy as np
+import cv2
+import os
+
+
+annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts')
+
+
+def HWC3(x):
+    assert x.dtype == np.uint8
+    if x.ndim == 2:
+        x = x[:, :, None]
+    assert x.ndim == 3
+    H, W, C = x.shape
+    assert C == 1 or C == 3 or C == 4
+    if C == 3:
+        return x
+    if C == 1:
+        return np.concatenate([x, x, x], axis=2)
+    if C == 4:
+        color = x[:, :, 0:3].astype(np.float32)
+        alpha = x[:, :, 3:4].astype(np.float32) / 255.0
+        y = color * alpha + 255.0 * (1.0 - alpha)
+        y = y.clip(0, 255).astype(np.uint8)
+        return y
+
+
+def resize_image(input_image, resolution):
+    H, W, C = input_image.shape
+    H = float(H)
+    W = float(W)
+    k = float(resolution) / min(H, W)
+    H *= k
+    W *= k
+    H = int(np.round(H / 64.0)) * 64
+    W = int(np.round(W / 64.0)) * 64
+    img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
+    return img
diff --git a/configs/anycontrol.yaml b/configs/anycontrol.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..fcd9b1b6f77d4fe5220fcbb2288b68d6e7d82b28
--- /dev/null
+++ b/configs/anycontrol.yaml
@@ -0,0 +1,109 @@
+model:
+  target: models.anycontrol.AnyControlNet
+  params:
+    linear_start: 0.00085
+    linear_end: 0.0120
+    num_timesteps_cond: 1
+    log_every_t: 200
+    timesteps: 1000
+    first_stage_key: "jpg"
+    cond_stage_key: "txt"
+    image_size: 64
+    channels: 4
+    cond_stage_trainable: false
+    conditioning_key: crossattn
+    monitor: val/loss_simple_ema
+    scale_factor: 0.18215
+    use_ema: False
+    mode: uni
+
+    qformer_config:
+      target: models.q_formers.blip2_qformer.Blip2Qformer
+      qformer_enabled: true
+      model_name: "blip2"
+      model_type: "pretrain"
+      pretrained: "ckpts/blip2_pretrained.pth"
+      params:
+        img_size: 224
+        drop_path_rate: 0
+        use_grad_checkpoint: False
+        vit_precision: "fp16"
+        num_query_token: 256
+        max_txt_len: 32
+        query_token_init_type: "uniform"
+        max_position_embeddings: 512
+        multilevels: [3, 10, 17, 24, 31, 38]
+
+    local_control_config:
+      target: models.local_adapter.LocalAdapter
+      params:
+        in_channels: 4
+        model_channels: 320
+        local_channels: 3
+        inject_channels: [192, 256, 384, 512]
+        inject_layers: [1, 4, 7, 10]
+        query_channels: [768, 768, 768, 768]
+        query_layers: [4, 6, 8, 12]
+        query_scales: [4, 2, 1, 0.5] 
+        num_res_blocks: 2
+        attention_resolutions: [4, 2, 1]
+        channel_mult: [1, 2, 4, 4]
+        use_checkpoint: True
+        num_heads: 8
+        use_spatial_transformer: True
+        transformer_depth: 1
+        context_dim: 768
+        legacy: False
+
+    global_control_config:
+      target: models.global_adapter.GlobalAdapter
+      params:
+        cross_attention_dim: 768
+        clip_embeddings_dim: 768
+        context_tokens: 4
+        color_in_dim: 180
+
+    unet_config:
+      target: models.local_adapter.LocalControlUNetModel
+      params:
+        image_size: 32
+        in_channels: 4
+        model_channels: 320
+        out_channels: 4
+        num_res_blocks: 2
+        attention_resolutions: [4, 2, 1]
+        channel_mult: [1, 2, 4, 4]
+        use_checkpoint: True
+        num_heads: 8
+        use_spatial_transformer: True
+        transformer_depth: 1
+        context_dim: 768
+        content_dim: 768
+        color_dim: 768
+        legacy: False
+
+    first_stage_config:
+      target: ldm.models.autoencoder.AutoencoderKL
+      params:
+        embed_dim: 4
+        monitor: val/rec_loss
+        ddconfig:
+          double_z: true
+          z_channels: 4
+          resolution: 256
+          in_channels: 3
+          out_ch: 3
+          ch: 128
+          ch_mult:
+          - 1
+          - 2
+          - 4
+          - 4
+          num_res_blocks: 2
+          attn_resolutions: []
+          dropout: 0.0
+        lossconfig:
+          target: torch.nn.Identity
+
+    cond_stage_config:
+      target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
diff --git a/configs/anycontrol_local.yaml b/configs/anycontrol_local.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8fddd659cc5c4555a9e8f914b661bcfdbec97ed0
--- /dev/null
+++ b/configs/anycontrol_local.yaml
@@ -0,0 +1,150 @@
+model:
+  target: models.anycontrol.AnyControlNet
+  params:
+    linear_start: 0.00085
+    linear_end: 0.0120
+    num_timesteps_cond: 1
+    log_every_t: 200
+    timesteps: 1000
+    first_stage_key: "jpg"
+    cond_stage_key: "txt"
+    image_size: 64
+    channels: 4
+    cond_stage_trainable: false
+    conditioning_key: crossattn
+    monitor: val/loss_simple_ema
+    scale_factor: 0.18215
+    use_ema: False
+    mode: local
+
+    qformer_config:
+      target: models.q_formers.blip2_qformer.Blip2Qformer
+      model_name: "blip2"
+      model_type: "pretrain"
+      pretrained: "ckpts/blip2_pretrained.pth"
+      params:
+        img_size: 224
+        drop_path_rate: 0
+        use_grad_checkpoint: False
+        vit_precision: "fp16"
+        num_query_token: 256
+        max_txt_len: 32
+        query_token_init_type: "uniform"
+        max_position_embeddings: 512
+        multilevels: [3, 10, 17, 24, 31, 38]
+
+    local_control_config:
+      target: models.local_adapter.LocalAdapter
+      params:
+        in_channels: 4
+        model_channels: 320
+        local_channels: 3
+        inject_channels: [192, 256, 384, 512]
+        inject_layers: [1, 4, 7, 10]
+        query_channels: [768, 768, 768, 768]
+        query_layers: [4, 6, 8, 12]
+        query_scales: [4, 2, 1, 0.5] 
+        num_res_blocks: 2
+        attention_resolutions: [4, 2, 1]
+        channel_mult: [1, 2, 4, 4]
+        use_checkpoint: False
+        num_heads: 8
+        use_spatial_transformer: True
+        transformer_depth: 1
+        context_dim: 768
+        legacy: False
+
+    global_control_config:
+      target: models.global_adapter.GlobalAdapter
+      params:
+        cross_attention_dim: 768
+        clip_embeddings_dim: 768
+        context_tokens: 4
+        color_in_dim: 180
+
+    unet_config:
+      target: models.local_adapter.LocalControlUNetModel
+      params:
+        image_size: 32
+        in_channels: 4
+        model_channels: 320
+        out_channels: 4
+        num_res_blocks: 2
+        attention_resolutions: [4, 2, 1]
+        channel_mult: [1, 2, 4, 4]
+        use_checkpoint: False
+        num_heads: 8
+        use_spatial_transformer: True
+        transformer_depth: 1
+        context_dim: 768
+        legacy: False
+
+    first_stage_config:
+      target: ldm.models.autoencoder.AutoencoderKL
+      params:
+        embed_dim: 4
+        monitor: val/rec_loss
+        ddconfig:
+          double_z: true
+          z_channels: 4
+          resolution: 256
+          in_channels: 3
+          out_ch: 3
+          ch: 128
+          ch_mult:
+          - 1
+          - 2
+          - 4
+          - 4
+          num_res_blocks: 2
+          attn_resolutions: []
+          dropout: 0.0
+        lossconfig:
+          target: torch.nn.Identity
+
+    cond_stage_config:
+      target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
+
+
+data:
+  target: src.train.dataset.CustomDataset
+  local_tasks: [canny, hed, depth, seg, openpose]
+  datasets: [multigen, coco, openimages]
+  json_files: 
+    multigen: "./datasets/MultiGen-20M/anycontrol_annotations.jsonl"
+    coco: "./datasets/MSCOCO/anycontrol_annotations.jsonl"
+    openimages: "./datasets/OpenImages/anycontrol_annotations.jsonl" 
+  params:
+    data_root:
+        multigen: ./datasets/MultiGen-20M
+        coco: ./datasets/MSCOCO
+        openimages: ./datasets/OpenImages
+    image_dir:
+        multigen: ./datasets/MultiGen-20M/images 
+        coco: ./datasets/MSCOCO/train2017
+        openimages: ./datasets/OpenImages/train
+    condition_root:
+        multigen: conditions
+        coco: conditions
+        openimages: conditions
+    resolution: 512
+    drop_txt_prob: 0.05
+    drop_all_prob: 0.05
+    keep_all_local_prob: 0.0
+    drop_all_local_prob: 0.0
+    drop_each_cond_prob: 
+      canny: 0.0
+      hed: 0.0
+      depth: 0.0
+      seg: 0.0
+      openpose: 0.0
+
+logger:
+    sample: false
+    N: 4
+    n_row: 4
+    ddim_steps: 50
+    ddim_eta: 0.0
+    plot_denoise_rows: false
+    plot_diffusion_rows: false
+    unconditional_guidance_scale: 7.5
diff --git a/environment.yaml b/environment.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1e9b7c4dfbfe11c0f18d98b750984a11ea05ef8f
--- /dev/null
+++ b/environment.yaml
@@ -0,0 +1,32 @@
+name: anycontrol 
+channels:
+  - pytorch
+  - nvidia
+dependencies:
+  - python=3.10
+  - pytorch=2.0.1
+  - torchvision=0.15.2
+  - pytorch-cuda=11.8
+  - pip
+  - pip:
+    - numpy==1.23.5
+    - pillow==9.5.0
+    - scipy==1.13.0
+    - scikit-image==0.21.0
+    - scikit-learn==1.3.1
+    - pycocotools==2.0.7
+    - nltk==3.8.1
+    - mmagic
+    - salesforce-lavis
+    - einops==0.4.1
+    - pytorch-lightning==1.9.4
+    - accelerate==0.21.0
+    - diffusers==0.22.3
+    - mmcv==2.0.0
+    - gradio==4.37.1
+    - spacy
+    - opencv-python==4.9.0.80
+    - transformers==4.30.2
+    - basicsr
+    - clip
+    - open_clip_torch
diff --git a/ldm/models/autoencoder.py b/ldm/models/autoencoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..d122549995ce2cd64092c81a58419ed4a15a02fd
--- /dev/null
+++ b/ldm/models/autoencoder.py
@@ -0,0 +1,219 @@
+import torch
+import pytorch_lightning as pl
+import torch.nn.functional as F
+from contextlib import contextmanager
+
+from ldm.modules.diffusionmodules.model import Encoder, Decoder
+from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
+
+from ldm.util import instantiate_from_config
+from ldm.modules.ema import LitEma
+
+
+class AutoencoderKL(pl.LightningModule):
+    def __init__(self,
+                 ddconfig,
+                 lossconfig,
+                 embed_dim,
+                 ckpt_path=None,
+                 ignore_keys=[],
+                 image_key="image",
+                 colorize_nlabels=None,
+                 monitor=None,
+                 ema_decay=None,
+                 learn_logvar=False
+                 ):
+        super().__init__()
+        self.learn_logvar = learn_logvar
+        self.image_key = image_key
+        self.encoder = Encoder(**ddconfig)
+        self.decoder = Decoder(**ddconfig)
+        self.loss = instantiate_from_config(lossconfig)
+        assert ddconfig["double_z"]
+        self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
+        self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
+        self.embed_dim = embed_dim
+        if colorize_nlabels is not None:
+            assert type(colorize_nlabels)==int
+            self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
+        if monitor is not None:
+            self.monitor = monitor
+
+        self.use_ema = ema_decay is not None
+        if self.use_ema:
+            self.ema_decay = ema_decay
+            assert 0. < ema_decay < 1.
+            self.model_ema = LitEma(self, decay=ema_decay)
+            print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+
+        if ckpt_path is not None:
+            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+
+    def init_from_ckpt(self, path, ignore_keys=list()):
+        sd = torch.load(path, map_location="cpu")["state_dict"]
+        keys = list(sd.keys())
+        for k in keys:
+            for ik in ignore_keys:
+                if k.startswith(ik):
+                    print("Deleting key {} from state_dict.".format(k))
+                    del sd[k]
+        self.load_state_dict(sd, strict=False)
+        print(f"Restored from {path}")
+
+    @contextmanager
+    def ema_scope(self, context=None):
+        if self.use_ema:
+            self.model_ema.store(self.parameters())
+            self.model_ema.copy_to(self)
+            if context is not None:
+                print(f"{context}: Switched to EMA weights")
+        try:
+            yield None
+        finally:
+            if self.use_ema:
+                self.model_ema.restore(self.parameters())
+                if context is not None:
+                    print(f"{context}: Restored training weights")
+
+    def on_train_batch_end(self, *args, **kwargs):
+        if self.use_ema:
+            self.model_ema(self)
+
+    def encode(self, x):
+        h = self.encoder(x)
+        moments = self.quant_conv(h)
+        posterior = DiagonalGaussianDistribution(moments)
+        return posterior
+
+    def decode(self, z):
+        z = self.post_quant_conv(z)
+        dec = self.decoder(z)
+        return dec
+
+    def forward(self, input, sample_posterior=True):
+        posterior = self.encode(input)
+        if sample_posterior:
+            z = posterior.sample()
+        else:
+            z = posterior.mode()
+        dec = self.decode(z)
+        return dec, posterior
+
+    def get_input(self, batch, k):
+        x = batch[k]
+        if len(x.shape) == 3:
+            x = x[..., None]
+        x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
+        return x
+
+    def training_step(self, batch, batch_idx, optimizer_idx):
+        inputs = self.get_input(batch, self.image_key)
+        reconstructions, posterior = self(inputs)
+
+        if optimizer_idx == 0:
+            # train encoder+decoder+logvar
+            aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
+                                            last_layer=self.get_last_layer(), split="train")
+            self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+            self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
+            return aeloss
+
+        if optimizer_idx == 1:
+            # train the discriminator
+            discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
+                                                last_layer=self.get_last_layer(), split="train")
+
+            self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+            self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
+            return discloss
+
+    def validation_step(self, batch, batch_idx):
+        log_dict = self._validation_step(batch, batch_idx)
+        with self.ema_scope():
+            log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
+        return log_dict
+
+    def _validation_step(self, batch, batch_idx, postfix=""):
+        inputs = self.get_input(batch, self.image_key)
+        reconstructions, posterior = self(inputs)
+        aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
+                                        last_layer=self.get_last_layer(), split="val"+postfix)
+
+        discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
+                                            last_layer=self.get_last_layer(), split="val"+postfix)
+
+        self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
+        self.log_dict(log_dict_ae)
+        self.log_dict(log_dict_disc)
+        return self.log_dict
+
+    def configure_optimizers(self):
+        lr = self.learning_rate
+        ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(
+            self.quant_conv.parameters()) + list(self.post_quant_conv.parameters())
+        if self.learn_logvar:
+            print(f"{self.__class__.__name__}: Learning logvar")
+            ae_params_list.append(self.loss.logvar)
+        opt_ae = torch.optim.Adam(ae_params_list,
+                                  lr=lr, betas=(0.5, 0.9))
+        opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
+                                    lr=lr, betas=(0.5, 0.9))
+        return [opt_ae, opt_disc], []
+
+    def get_last_layer(self):
+        return self.decoder.conv_out.weight
+
+    @torch.no_grad()
+    def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
+        log = dict()
+        x = self.get_input(batch, self.image_key)
+        x = x.to(self.device)
+        if not only_inputs:
+            xrec, posterior = self(x)
+            if x.shape[1] > 3:
+                # colorize with random projection
+                assert xrec.shape[1] > 3
+                x = self.to_rgb(x)
+                xrec = self.to_rgb(xrec)
+            log["samples"] = self.decode(torch.randn_like(posterior.sample()))
+            log["reconstructions"] = xrec
+            if log_ema or self.use_ema:
+                with self.ema_scope():
+                    xrec_ema, posterior_ema = self(x)
+                    if x.shape[1] > 3:
+                        # colorize with random projection
+                        assert xrec_ema.shape[1] > 3
+                        xrec_ema = self.to_rgb(xrec_ema)
+                    log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample()))
+                    log["reconstructions_ema"] = xrec_ema
+        log["inputs"] = x
+        return log
+
+    def to_rgb(self, x):
+        assert self.image_key == "segmentation"
+        if not hasattr(self, "colorize"):
+            self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
+        x = F.conv2d(x, weight=self.colorize)
+        x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
+        return x
+
+
+class IdentityFirstStage(torch.nn.Module):
+    def __init__(self, *args, vq_interface=False, **kwargs):
+        self.vq_interface = vq_interface
+        super().__init__()
+
+    def encode(self, x, *args, **kwargs):
+        return x
+
+    def decode(self, x, *args, **kwargs):
+        return x
+
+    def quantize(self, x, *args, **kwargs):
+        if self.vq_interface:
+            return x, None, [None, None, None]
+        return x
+
+    def forward(self, x, *args, **kwargs):
+        return x
+
diff --git a/ldm/models/diffusion/__init__.py b/ldm/models/diffusion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py
new file mode 100644
index 0000000000000000000000000000000000000000..5578dbd361288620a8c91c8f591ffb9f15048e13
--- /dev/null
+++ b/ldm/models/diffusion/ddim.py
@@ -0,0 +1,354 @@
+"""SAMPLING ONLY."""
+
+import torch
+import numpy as np
+from tqdm import tqdm
+
+from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
+
+
+class DDIMSampler(object):
+    def __init__(self, model, schedule="linear", **kwargs):
+        super().__init__()
+        self.model = model
+        self.ddpm_num_timesteps = model.num_timesteps
+        self.schedule = schedule
+
+    def register_buffer(self, name, attr):
+        if type(attr) == torch.Tensor:
+            if attr.device != torch.device("cuda"):
+                attr = attr.to(torch.device("cuda"))
+        setattr(self, name, attr)
+
+    def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
+        self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
+                                                  num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
+        alphas_cumprod = self.model.alphas_cumprod
+        assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
+        to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
+
+        self.register_buffer('betas', to_torch(self.model.betas))
+        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+        self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
+
+        # calculations for diffusion q(x_t | x_{t-1}) and others
+        self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
+        self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
+        self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
+        self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
+        self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
+
+        # ddim sampling parameters
+        ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
+                                                                                   ddim_timesteps=self.ddim_timesteps,
+                                                                                   eta=ddim_eta,verbose=verbose)
+        self.register_buffer('ddim_sigmas', ddim_sigmas)
+        self.register_buffer('ddim_alphas', ddim_alphas)
+        self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
+        self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
+        sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
+            (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
+                        1 - self.alphas_cumprod / self.alphas_cumprod_prev))
+        self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
+
+    @torch.no_grad()
+    def sample(self,
+               S,
+               batch_size,
+               shape,
+               conditioning=None,
+               callback=None,
+               normals_sequence=None,
+               img_callback=None,
+               quantize_x0=False,
+               eta=0.,
+               mask=None,
+               x0=None,
+               temperature=1.,
+               noise_dropout=0.,
+               score_corrector=None,
+               corrector_kwargs=None,
+               verbose=True,
+               x_T=None,
+               log_every_t=100,
+               unconditional_guidance_scale=1.,
+               unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+               dynamic_threshold=None,
+               ucg_schedule=None,
+               local_strength=1,
+               global_strength=1,
+               color_strength=1,
+               **kwargs
+               ):
+        if conditioning is not None:
+            if isinstance(conditioning, dict):
+                ctmp = conditioning[list(conditioning.keys())[0]]
+                while isinstance(ctmp, list): ctmp = ctmp[0]
+                cbs = ctmp.shape[0]
+                if cbs != batch_size:
+                    print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+
+            elif isinstance(conditioning, list):
+                for ctmp in conditioning:
+                    if ctmp.shape[0] != batch_size:
+                        print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+
+            else:
+                if conditioning.shape[0] != batch_size:
+                    print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+
+        self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
+        # sampling
+        C, H, W = shape
+        size = (batch_size, C, H, W)
+        print(f'Data shape for DDIM sampling is {size}, eta {eta}')
+
+        samples, intermediates = self.ddim_sampling(conditioning, size,
+                                                    callback=callback,
+                                                    img_callback=img_callback,
+                                                    quantize_denoised=quantize_x0,
+                                                    mask=mask, x0=x0,
+                                                    ddim_use_original_steps=False,
+                                                    noise_dropout=noise_dropout,
+                                                    temperature=temperature,
+                                                    score_corrector=score_corrector,
+                                                    corrector_kwargs=corrector_kwargs,
+                                                    x_T=x_T,
+                                                    log_every_t=log_every_t,
+                                                    unconditional_guidance_scale=unconditional_guidance_scale,
+                                                    unconditional_conditioning=unconditional_conditioning,
+                                                    dynamic_threshold=dynamic_threshold,
+                                                    ucg_schedule=ucg_schedule,
+                                                    local_strength=local_strength,
+                                                    global_strength=global_strength,
+                                                    color_strength=color_strength,
+                                                    )
+        return samples, intermediates
+
+    @torch.no_grad()
+    def ddim_sampling(self, cond, shape,
+                      x_T=None, ddim_use_original_steps=False,
+                      callback=None, timesteps=None, quantize_denoised=False,
+                      mask=None, x0=None, img_callback=None, log_every_t=100,
+                      temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+                      unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
+                      ucg_schedule=None, local_strength=1, global_strength=1, color_strength=1):
+        device = self.model.betas.device
+        b = shape[0]
+        if x_T is None:
+            img = torch.randn(shape, device=device)
+        else:
+            img = x_T
+
+        if timesteps is None:
+            timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
+        elif timesteps is not None and not ddim_use_original_steps:
+            subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
+            timesteps = self.ddim_timesteps[:subset_end]
+
+        intermediates = {'x_inter': [img], 'pred_x0': [img]}
+        time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
+        total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
+        print(f"Running DDIM Sampling with {total_steps} timesteps")
+
+        iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
+
+        for i, step in enumerate(iterator):
+            index = total_steps - i - 1
+            ts = torch.full((b,), step, device=device, dtype=torch.long)
+
+            if mask is not None:
+                assert x0 is not None
+                img_orig = self.model.q_sample(x0, ts)  # TODO: deterministic forward pass?
+                img = img_orig * mask + (1. - mask) * img
+
+            if ucg_schedule is not None:
+                assert len(ucg_schedule) == len(time_range)
+                unconditional_guidance_scale = ucg_schedule[i]
+
+            outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
+                                      quantize_denoised=quantize_denoised, temperature=temperature,
+                                      noise_dropout=noise_dropout, score_corrector=score_corrector,
+                                      corrector_kwargs=corrector_kwargs,
+                                      unconditional_guidance_scale=unconditional_guidance_scale,
+                                      unconditional_conditioning=unconditional_conditioning,
+                                      dynamic_threshold=dynamic_threshold,
+                                      local_strength=local_strength,
+                                      global_strength=global_strength,
+                                      color_strength=color_strength)
+            img, pred_x0 = outs
+            if callback: callback(i)
+            if img_callback: img_callback(pred_x0, i)
+
+            if index % log_every_t == 0 or index == total_steps - 1:
+                intermediates['x_inter'].append(img)
+                intermediates['pred_x0'].append(pred_x0)
+
+        return img, intermediates
+
+    @torch.no_grad()
+    def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
+                      temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+                      unconditional_guidance_scale=1., unconditional_conditioning=None,
+                      dynamic_threshold=None,local_strength=1,global_strength=1,color_strength=1):
+        b, *_, device = *x.shape, x.device
+
+        if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
+            model_output = self.model.apply_model(x, t, c)
+        else:
+            x_in = torch.cat([x] * 2)
+            t_in = torch.cat([t] * 2)
+            if isinstance(c, dict):
+                assert isinstance(unconditional_conditioning, dict)
+                c_in = dict()
+                for k in c:
+                    if isinstance(c[k], list):
+                        if isinstance(c[k][0], torch.Tensor):
+                            c_in[k] = [torch.cat([
+                                unconditional_conditioning[k][i],
+                                c[k][i]]) for i in range(len(c[k]))]
+                        elif isinstance(c[k][0], dict):
+                            c_in[k] = [{key:torch.cat([
+                                unconditional_conditioning[k][i][key],
+                                c[k][i][key]]) for key in c[k][i].keys()} for i in range(len(c[k]))]
+                        else:
+                            c_in[k] = [
+                                unconditional_conditioning[k][i] + c[k][i] for i in range(len(c[k]))]
+                    else:
+                        c_in[k] = torch.cat([
+                                unconditional_conditioning[k],
+                                c[k]])
+            elif isinstance(c, list):
+                c_in = list()
+                assert isinstance(unconditional_conditioning, list)
+                for i in range(len(c)):
+                    c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
+            else:
+                c_in = torch.cat([unconditional_conditioning, c])
+            model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in, 
+                local_strength=local_strength, global_strength=global_strength, color_strength=color_strength).chunk(2)
+            model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
+
+        if self.model.parameterization == "v":
+            e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
+        else:
+            e_t = model_output
+
+        if score_corrector is not None:
+            assert self.model.parameterization == "eps", 'not implemented'
+            e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
+
+        alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+        alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
+        sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
+        sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
+        # select parameters corresponding to the currently considered timestep
+        a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+        a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+        sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+        sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
+
+        # current prediction for x_0
+        if self.model.parameterization != "v":
+            pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+        else:
+            pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
+
+        if quantize_denoised:
+            pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+
+        if dynamic_threshold is not None:
+            raise NotImplementedError()
+
+        # direction pointing to x_t
+        dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+        noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+        if noise_dropout > 0.:
+            noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+        x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+        return x_prev, pred_x0
+
+    @torch.no_grad()
+    def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
+               unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
+        num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]
+
+        assert t_enc <= num_reference_steps
+        num_steps = t_enc
+
+        if use_original_steps:
+            alphas_next = self.alphas_cumprod[:num_steps]
+            alphas = self.alphas_cumprod_prev[:num_steps]
+        else:
+            alphas_next = self.ddim_alphas[:num_steps]
+            alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
+
+        x_next = x0
+        intermediates = []
+        inter_steps = []
+        for i in tqdm(range(num_steps), desc='Encoding Image'):
+            t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
+            if unconditional_guidance_scale == 1.:
+                noise_pred = self.model.apply_model(x_next, t, c)
+            else:
+                assert unconditional_conditioning is not None
+                e_t_uncond, noise_pred = torch.chunk(
+                    self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
+                                           torch.cat((unconditional_conditioning, c))), 2)
+                noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
+
+            xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
+            weighted_noise_pred = alphas_next[i].sqrt() * (
+                    (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
+            x_next = xt_weighted + weighted_noise_pred
+            if return_intermediates and i % (
+                    num_steps // return_intermediates) == 0 and i < num_steps - 1:
+                intermediates.append(x_next)
+                inter_steps.append(i)
+            elif return_intermediates and i >= num_steps - 2:
+                intermediates.append(x_next)
+                inter_steps.append(i)
+            if callback: callback(i)
+
+        out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
+        if return_intermediates:
+            out.update({'intermediates': intermediates})
+        return x_next, out
+
+    @torch.no_grad()
+    def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
+        # fast, but does not allow for exact reconstruction
+        # t serves as an index to gather the correct alphas
+        if use_original_steps:
+            sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
+            sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
+        else:
+            sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
+            sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
+
+        if noise is None:
+            noise = torch.randn_like(x0)
+        return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
+                extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
+
+    @torch.no_grad()
+    def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
+               use_original_steps=False, callback=None):
+
+        timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
+        timesteps = timesteps[:t_start]
+
+        time_range = np.flip(timesteps)
+        total_steps = timesteps.shape[0]
+        print(f"Running DDIM Sampling with {total_steps} timesteps")
+
+        iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
+        x_dec = x_latent
+        for i, step in enumerate(iterator):
+            index = total_steps - i - 1
+            ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
+            x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
+                                          unconditional_guidance_scale=unconditional_guidance_scale,
+                                          unconditional_conditioning=unconditional_conditioning)
+            if callback: callback(i)
+        return x_dec
diff --git a/ldm/models/diffusion/ddpm.py b/ldm/models/diffusion/ddpm.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee17aa0cb7a966ac5bdde16b21b69f29b162b2a0
--- /dev/null
+++ b/ldm/models/diffusion/ddpm.py
@@ -0,0 +1,1806 @@
+"""
+wild mixture of
+https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
+https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
+https://github.com/CompVis/taming-transformers
+-- merci
+"""
+
+import torch
+import torch.nn as nn
+import numpy as np
+import pytorch_lightning as pl
+from torch.optim.lr_scheduler import LambdaLR
+from einops import rearrange, repeat
+from contextlib import contextmanager, nullcontext
+from functools import partial
+import itertools
+from tqdm import tqdm
+from torchvision.utils import make_grid
+try:
+    from pytorch_lightning.utilities.distributed import rank_zero_only
+except:
+    from pytorch_lightning.utilities.rank_zero import rank_zero_only
+from omegaconf import ListConfig
+
+from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
+from ldm.modules.ema import LitEma
+from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
+from ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL
+from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
+from ldm.models.diffusion.ddim import DDIMSampler
+
+
+__conditioning_keys__ = {'concat': 'c_concat',
+                         'crossattn': 'c_crossattn',
+                         'adm': 'y'}
+
+
+def disabled_train(self, mode=True):
+    """Overwrite model.train with this function to make sure train/eval mode
+    does not change anymore."""
+    return self
+
+
+def uniform_on_device(r1, r2, shape, device):
+    return (r1 - r2) * torch.rand(*shape, device=device) + r2
+
+
+class DDPM(pl.LightningModule):
+    # classic DDPM with Gaussian diffusion, in image space
+    def __init__(self,
+                 unet_config,
+                 timesteps=1000,
+                 beta_schedule="linear",
+                 loss_type="l2",
+                 ckpt_path=None,
+                 ignore_keys=[],
+                 load_only_unet=False,
+                 monitor="val/loss",
+                 use_ema=True,
+                 first_stage_key="image",
+                 image_size=256,
+                 channels=3,
+                 log_every_t=100,
+                 clip_denoised=True,
+                 linear_start=1e-4,
+                 linear_end=2e-2,
+                 cosine_s=8e-3,
+                 given_betas=None,
+                 original_elbo_weight=0.,
+                 v_posterior=0.,  # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
+                 l_simple_weight=1.,
+                 conditioning_key=None,
+                 parameterization="eps",  # all assuming fixed variance schedules
+                 scheduler_config=None,
+                 use_positional_encodings=False,
+                 learn_logvar=False,
+                 logvar_init=0.,
+                 make_it_fit=False,
+                 ucg_training=None,
+                 reset_ema=False,
+                 reset_num_ema_updates=False,
+                 ):
+        super().__init__()
+        assert parameterization in ["eps", "x0", "v"], 'currently only supporting "eps" and "x0" and "v"'
+        self.parameterization = parameterization
+        print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
+        self.cond_stage_model = None
+        self.clip_denoised = clip_denoised
+        self.log_every_t = log_every_t
+        self.first_stage_key = first_stage_key
+        self.image_size = image_size  # try conv?
+        self.channels = channels
+        self.use_positional_encodings = use_positional_encodings
+        self.model = DiffusionWrapper(unet_config, conditioning_key)
+        count_params(self.model, verbose=True)
+        self.use_ema = use_ema
+        if self.use_ema:
+            self.model_ema = LitEma(self.model)
+            print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+
+        self.use_scheduler = scheduler_config is not None
+        if self.use_scheduler:
+            self.scheduler_config = scheduler_config
+
+        self.v_posterior = v_posterior
+        self.original_elbo_weight = original_elbo_weight
+        self.l_simple_weight = l_simple_weight
+
+        if monitor is not None:
+            self.monitor = monitor
+        self.make_it_fit = make_it_fit
+        if reset_ema: assert exists(ckpt_path)
+        if ckpt_path is not None:
+            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
+            if reset_ema:
+                assert self.use_ema
+                print(f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
+                self.model_ema = LitEma(self.model)
+        if reset_num_ema_updates:
+            print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
+            assert self.use_ema
+            self.model_ema.reset_num_updates()
+
+        self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
+                               linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
+
+        self.loss_type = loss_type
+
+        self.learn_logvar = learn_logvar
+        logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
+        if self.learn_logvar:
+            self.logvar = nn.Parameter(self.logvar, requires_grad=True)
+        else:
+            self.register_buffer('logvar', logvar)
+
+        self.ucg_training = ucg_training or dict()
+        if self.ucg_training:
+            self.ucg_prng = np.random.RandomState()
+
+    def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
+                          linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+        if exists(given_betas):
+            betas = given_betas
+        else:
+            betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
+                                       cosine_s=cosine_s)
+        alphas = 1. - betas
+        alphas_cumprod = np.cumprod(alphas, axis=0)
+        alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
+
+        timesteps, = betas.shape
+        self.num_timesteps = int(timesteps)
+        self.linear_start = linear_start
+        self.linear_end = linear_end
+        assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
+
+        to_torch = partial(torch.tensor, dtype=torch.float32)
+
+        self.register_buffer('betas', to_torch(betas))
+        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+        self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
+
+        # calculations for diffusion q(x_t | x_{t-1}) and others
+        self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
+        self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
+        self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
+        self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
+        self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
+
+        # calculations for posterior q(x_{t-1} | x_t, x_0)
+        posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
+                1. - alphas_cumprod) + self.v_posterior * betas
+        # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
+        self.register_buffer('posterior_variance', to_torch(posterior_variance))
+        # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
+        self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
+        self.register_buffer('posterior_mean_coef1', to_torch(
+            betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
+        self.register_buffer('posterior_mean_coef2', to_torch(
+            (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
+
+        if self.parameterization == "eps":
+            lvlb_weights = self.betas ** 2 / (
+                    2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
+        elif self.parameterization == "x0":
+            lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
+        elif self.parameterization == "v":
+            lvlb_weights = torch.ones_like(self.betas ** 2 / (
+                    2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)))
+        else:
+            raise NotImplementedError("mu not supported")
+        lvlb_weights[0] = lvlb_weights[1]
+        self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
+        assert not torch.isnan(self.lvlb_weights).all()
+
+    @contextmanager
+    def ema_scope(self, context=None):
+        if self.use_ema:
+            self.model_ema.store(self.model.parameters())
+            self.model_ema.copy_to(self.model)
+            if context is not None:
+                print(f"{context}: Switched to EMA weights")
+        try:
+            yield None
+        finally:
+            if self.use_ema:
+                self.model_ema.restore(self.model.parameters())
+                if context is not None:
+                    print(f"{context}: Restored training weights")
+
+    @torch.no_grad()
+    def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
+        sd = torch.load(path, map_location="cpu")
+        if "state_dict" in list(sd.keys()):
+            sd = sd["state_dict"]
+        keys = list(sd.keys())
+        for k in keys:
+            for ik in ignore_keys:
+                if k.startswith(ik):
+                    print("Deleting key {} from state_dict.".format(k))
+                    del sd[k]
+        if self.make_it_fit:
+            n_params = len([name for name, _ in
+                            itertools.chain(self.named_parameters(),
+                                            self.named_buffers())])
+            for name, param in tqdm(
+                    itertools.chain(self.named_parameters(),
+                                    self.named_buffers()),
+                    desc="Fitting old weights to new weights",
+                    total=n_params
+            ):
+                if not name in sd:
+                    continue
+                old_shape = sd[name].shape
+                new_shape = param.shape
+                assert len(old_shape) == len(new_shape)
+                if len(new_shape) > 2:
+                    # we only modify first two axes
+                    assert new_shape[2:] == old_shape[2:]
+                # assumes first axis corresponds to output dim
+                if not new_shape == old_shape:
+                    new_param = param.clone()
+                    old_param = sd[name]
+                    if len(new_shape) == 1:
+                        for i in range(new_param.shape[0]):
+                            new_param[i] = old_param[i % old_shape[0]]
+                    elif len(new_shape) >= 2:
+                        for i in range(new_param.shape[0]):
+                            for j in range(new_param.shape[1]):
+                                new_param[i, j] = old_param[i % old_shape[0], j % old_shape[1]]
+
+                        n_used_old = torch.ones(old_shape[1])
+                        for j in range(new_param.shape[1]):
+                            n_used_old[j % old_shape[1]] += 1
+                        n_used_new = torch.zeros(new_shape[1])
+                        for j in range(new_param.shape[1]):
+                            n_used_new[j] = n_used_old[j % old_shape[1]]
+
+                        n_used_new = n_used_new[None, :]
+                        while len(n_used_new.shape) < len(new_shape):
+                            n_used_new = n_used_new.unsqueeze(-1)
+                        new_param /= n_used_new
+
+                    sd[name] = new_param
+
+        missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
+            sd, strict=False)
+        print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
+        if len(missing) > 0:
+            print(f"Missing Keys:\n {missing}")
+        if len(unexpected) > 0:
+            print(f"\nUnexpected Keys:\n {unexpected}")
+
+    def q_mean_variance(self, x_start, t):
+        """
+        Get the distribution q(x_t | x_0).
+        :param x_start: the [N x C x ...] tensor of noiseless inputs.
+        :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+        :return: A tuple (mean, variance, log_variance), all of x_start's shape.
+        """
+        mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
+        variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
+        log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
+        return mean, variance, log_variance
+
+    def predict_start_from_noise(self, x_t, t, noise):
+        return (
+                extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
+                extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
+        )
+
+    def predict_start_from_z_and_v(self, x_t, t, v):
+        # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
+        # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
+        return (
+                extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
+                extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
+        )
+
+    def predict_eps_from_z_and_v(self, x_t, t, v):
+        return (
+                extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v +
+                extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t
+        )
+
+    def q_posterior(self, x_start, x_t, t):
+        posterior_mean = (
+                extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
+                extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
+        )
+        posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
+        posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
+        return posterior_mean, posterior_variance, posterior_log_variance_clipped
+
+    def p_mean_variance(self, x, t, clip_denoised: bool):
+        model_out = self.model(x, t)
+        if self.parameterization == "eps":
+            x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
+        elif self.parameterization == "x0":
+            x_recon = model_out
+        if clip_denoised:
+            x_recon.clamp_(-1., 1.)
+
+        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
+        return model_mean, posterior_variance, posterior_log_variance
+
+    @torch.no_grad()
+    def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
+        b, *_, device = *x.shape, x.device
+        model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
+        noise = noise_like(x.shape, device, repeat_noise)
+        # no noise when t == 0
+        nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
+        return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
+
+    @torch.no_grad()
+    def p_sample_loop(self, shape, return_intermediates=False):
+        device = self.betas.device
+        b = shape[0]
+        img = torch.randn(shape, device=device)
+        intermediates = [img]
+        for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
+            img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
+                                clip_denoised=self.clip_denoised)
+            if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
+                intermediates.append(img)
+        if return_intermediates:
+            return img, intermediates
+        return img
+
+    @torch.no_grad()
+    def sample(self, batch_size=16, return_intermediates=False):
+        image_size = self.image_size
+        channels = self.channels
+        return self.p_sample_loop((batch_size, channels, image_size, image_size),
+                                  return_intermediates=return_intermediates)
+
+    def q_sample(self, x_start, t, noise=None):
+        noise = default(noise, lambda: torch.randn_like(x_start))
+        return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
+                extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
+
+    def get_v(self, x, noise, t):
+        return (
+                extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise -
+                extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
+        )
+
+    def get_loss(self, pred, target, mean=True):
+        if self.loss_type == 'l1':
+            loss = (target - pred).abs()
+            if mean:
+                loss = loss.mean()
+        elif self.loss_type == 'l2':
+            if mean:
+                loss = torch.nn.functional.mse_loss(target, pred)
+            else:
+                loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
+        else:
+            raise NotImplementedError("unknown loss type '{loss_type}'")
+
+        return loss
+
+    def p_losses(self, x_start, t, noise=None):
+        noise = default(noise, lambda: torch.randn_like(x_start))
+        x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+        model_out = self.model(x_noisy, t)
+
+        loss_dict = {}
+        if self.parameterization == "eps":
+            target = noise
+        elif self.parameterization == "x0":
+            target = x_start
+        elif self.parameterization == "v":
+            target = self.get_v(x_start, noise, t)
+        else:
+            raise NotImplementedError(f"Parameterization {self.parameterization} not yet supported")
+
+        loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
+
+        log_prefix = 'train' if self.training else 'val'
+
+        loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
+        loss_simple = loss.mean() * self.l_simple_weight
+
+        loss_vlb = (self.lvlb_weights[t] * loss).mean()
+        loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
+
+        loss = loss_simple + self.original_elbo_weight * loss_vlb
+
+        loss_dict.update({f'{log_prefix}/loss': loss})
+
+        return loss, loss_dict
+
+    def forward(self, x, *args, **kwargs):
+        # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
+        # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
+        t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
+        return self.p_losses(x, t, *args, **kwargs)
+
+    def get_input(self, batch, k):
+        x = batch[k]
+        if len(x.shape) == 3:
+            x = x[..., None]
+        x = rearrange(x, 'b h w c -> b c h w')
+        x = x.to(memory_format=torch.contiguous_format).float()
+        return x
+
+    def shared_step(self, batch):
+        x = self.get_input(batch, self.first_stage_key)
+        loss, loss_dict = self(x)
+        return loss, loss_dict
+
+    def training_step(self, batch, batch_idx):
+        for k in self.ucg_training:
+            p = self.ucg_training[k]["p"]
+            val = self.ucg_training[k]["val"]
+            if val is None:
+                val = ""
+            for i in range(len(batch[k])):
+                if self.ucg_prng.choice(2, p=[1 - p, p]):
+                    batch[k][i] = val
+
+        loss, loss_dict = self.shared_step(batch)
+
+        self.log_dict(loss_dict, prog_bar=True,
+                      logger=True, on_step=True, on_epoch=True)
+
+        self.log("global_step", self.global_step,
+                 prog_bar=True, logger=True, on_step=True, on_epoch=False)
+
+        if self.use_scheduler:
+            lr = self.optimizers().param_groups[0]['lr']
+            self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
+
+        return loss
+
+    @torch.no_grad()
+    def validation_step(self, batch, batch_idx):
+        _, loss_dict_no_ema = self.shared_step(batch)
+        with self.ema_scope():
+            _, loss_dict_ema = self.shared_step(batch)
+            loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
+        self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
+        self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
+
+    def on_train_batch_end(self, *args, **kwargs):
+        if self.use_ema:
+            self.model_ema(self.model)
+
+    def _get_rows_from_list(self, samples):
+        n_imgs_per_row = len(samples)
+        denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
+        denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
+        denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
+        return denoise_grid
+
+    @torch.no_grad()
+    def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
+        log = dict()
+        x = self.get_input(batch, self.first_stage_key)
+        N = min(x.shape[0], N)
+        n_row = min(x.shape[0], n_row)
+        x = x.to(self.device)[:N]
+        log["inputs"] = x
+
+        # get diffusion row
+        diffusion_row = list()
+        x_start = x[:n_row]
+
+        for t in range(self.num_timesteps):
+            if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+                t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+                t = t.to(self.device).long()
+                noise = torch.randn_like(x_start)
+                x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+                diffusion_row.append(x_noisy)
+
+        log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
+
+        if sample:
+            # get denoise row
+            with self.ema_scope("Plotting"):
+                samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
+
+            log["samples"] = samples
+            log["denoise_row"] = self._get_rows_from_list(denoise_row)
+
+        if return_keys:
+            if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
+                return log
+            else:
+                return {key: log[key] for key in return_keys}
+        return log
+
+    def configure_optimizers(self):
+        lr = self.learning_rate
+        params = list(self.model.parameters())
+        if self.learn_logvar:
+            params = params + [self.logvar]
+        opt = torch.optim.AdamW(params, lr=lr)
+        return opt
+
+
+class LatentDiffusion(DDPM):
+    """main class"""
+
+    def __init__(self,
+                 first_stage_config,
+                 cond_stage_config,
+                 num_timesteps_cond=None,
+                 cond_stage_key="image",
+                 cond_stage_trainable=False,
+                 concat_mode=True,
+                 cond_stage_forward=None,
+                 conditioning_key=None,
+                 scale_factor=1.0,
+                 scale_by_std=False,
+                 force_null_conditioning=False,
+                 *args, **kwargs):
+        self.force_null_conditioning = force_null_conditioning
+        self.num_timesteps_cond = default(num_timesteps_cond, 1)
+        self.scale_by_std = scale_by_std
+        assert self.num_timesteps_cond <= kwargs['timesteps']
+        # for backwards compatibility after implementation of DiffusionWrapper
+        if conditioning_key is None:
+            conditioning_key = 'concat' if concat_mode else 'crossattn'
+        if cond_stage_config == '__is_unconditional__' and not self.force_null_conditioning:
+            conditioning_key = None
+        ckpt_path = kwargs.pop("ckpt_path", None)
+        reset_ema = kwargs.pop("reset_ema", False)
+        reset_num_ema_updates = kwargs.pop("reset_num_ema_updates", False)
+        ignore_keys = kwargs.pop("ignore_keys", [])
+        super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
+        self.concat_mode = concat_mode
+        self.cond_stage_trainable = cond_stage_trainable
+        self.cond_stage_key = cond_stage_key
+        try:
+            self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
+        except:
+            self.num_downs = 0
+        if not scale_by_std:
+            self.scale_factor = scale_factor
+        else:
+            self.register_buffer('scale_factor', torch.tensor(scale_factor))
+        self.instantiate_first_stage(first_stage_config)
+        self.instantiate_cond_stage(cond_stage_config)
+        self.cond_stage_forward = cond_stage_forward
+        self.clip_denoised = False
+        self.bbox_tokenizer = None
+
+        self.restarted_from_ckpt = False
+        if ckpt_path is not None:
+            self.init_from_ckpt(ckpt_path, ignore_keys)
+            self.restarted_from_ckpt = True
+            if reset_ema:
+                assert self.use_ema
+                print(
+                    f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
+                self.model_ema = LitEma(self.model)
+        if reset_num_ema_updates:
+            print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
+            assert self.use_ema
+            self.model_ema.reset_num_updates()
+
+    def make_cond_schedule(self, ):
+        self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
+        ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
+        self.cond_ids[:self.num_timesteps_cond] = ids
+
+    @rank_zero_only
+    @torch.no_grad()
+    def on_train_batch_start(self, batch, batch_idx):
+        # only for very first batch
+        if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
+            assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
+            # set rescale weight to 1./std of encodings
+            print("### USING STD-RESCALING ###")
+            x = super().get_input(batch, self.first_stage_key)
+            x = x.to(self.device)
+            encoder_posterior = self.encode_first_stage(x)
+            z = self.get_first_stage_encoding(encoder_posterior).detach()
+            del self.scale_factor
+            self.register_buffer('scale_factor', 1. / z.flatten().std())
+            print(f"setting self.scale_factor to {self.scale_factor}")
+            print("### USING STD-RESCALING ###")
+
+    def register_schedule(self,
+                          given_betas=None, beta_schedule="linear", timesteps=1000,
+                          linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+        super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
+
+        self.shorten_cond_schedule = self.num_timesteps_cond > 1
+        if self.shorten_cond_schedule:
+            self.make_cond_schedule()
+
+    def instantiate_first_stage(self, config):
+        model = instantiate_from_config(config)
+        self.first_stage_model = model.eval()
+        self.first_stage_model.train = disabled_train
+        for param in self.first_stage_model.parameters():
+            param.requires_grad = False
+
+    def instantiate_cond_stage(self, config):
+        if not self.cond_stage_trainable:
+            if config == "__is_first_stage__":
+                print("Using first stage also as cond stage.")
+                self.cond_stage_model = self.first_stage_model
+            elif config == "__is_unconditional__":
+                print(f"Training {self.__class__.__name__} as an unconditional model.")
+                self.cond_stage_model = None
+                # self.be_unconditional = True
+            else:
+                model = instantiate_from_config(config)
+                self.cond_stage_model = model.eval()
+                self.cond_stage_model.train = disabled_train
+                for param in self.cond_stage_model.parameters():
+                    param.requires_grad = False
+        else:
+            assert config != '__is_first_stage__'
+            assert config != '__is_unconditional__'
+            model = instantiate_from_config(config)
+            self.cond_stage_model = model
+
+    def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
+        denoise_row = []
+        for zd in tqdm(samples, desc=desc):
+            denoise_row.append(self.decode_first_stage(zd.to(self.device),
+                                                       force_not_quantize=force_no_decoder_quantization))
+        n_imgs_per_row = len(denoise_row)
+        denoise_row = torch.stack(denoise_row)  # n_log_step, n_row, C, H, W
+        denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
+        denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
+        denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
+        return denoise_grid
+
+    def get_first_stage_encoding(self, encoder_posterior):
+        if isinstance(encoder_posterior, DiagonalGaussianDistribution):
+            z = encoder_posterior.sample()
+        elif isinstance(encoder_posterior, torch.Tensor):
+            z = encoder_posterior
+        else:
+            raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
+        return self.scale_factor * z
+
+    def get_learned_conditioning(self, c):
+        if self.cond_stage_forward is None:
+            if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
+                c = self.cond_stage_model.encode(c)
+                if isinstance(c, DiagonalGaussianDistribution):
+                    c = c.mode()
+            else:
+                c = self.cond_stage_model(c)
+        else:
+            assert hasattr(self.cond_stage_model, self.cond_stage_forward)
+            c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
+        return c
+
+    def meshgrid(self, h, w):
+        y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
+        x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
+
+        arr = torch.cat([y, x], dim=-1)
+        return arr
+
+    def delta_border(self, h, w):
+        """
+        :param h: height
+        :param w: width
+        :return: normalized distance to image border,
+         wtith min distance = 0 at border and max dist = 0.5 at image center
+        """
+        lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
+        arr = self.meshgrid(h, w) / lower_right_corner
+        dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
+        dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
+        edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
+        return edge_dist
+
+    def get_weighting(self, h, w, Ly, Lx, device):
+        weighting = self.delta_border(h, w)
+        weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
+                               self.split_input_params["clip_max_weight"], )
+        weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
+
+        if self.split_input_params["tie_braker"]:
+            L_weighting = self.delta_border(Ly, Lx)
+            L_weighting = torch.clip(L_weighting,
+                                     self.split_input_params["clip_min_tie_weight"],
+                                     self.split_input_params["clip_max_tie_weight"])
+
+            L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
+            weighting = weighting * L_weighting
+        return weighting
+
+    def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1):  # todo load once not every time, shorten code
+        """
+        :param x: img of size (bs, c, h, w)
+        :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
+        """
+        bs, nc, h, w = x.shape
+
+        # number of crops in image
+        Ly = (h - kernel_size[0]) // stride[0] + 1
+        Lx = (w - kernel_size[1]) // stride[1] + 1
+
+        if uf == 1 and df == 1:
+            fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
+            unfold = torch.nn.Unfold(**fold_params)
+
+            fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
+
+            weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
+            normalization = fold(weighting).view(1, 1, h, w)  # normalizes the overlap
+            weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
+
+        elif uf > 1 and df == 1:
+            fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
+            unfold = torch.nn.Unfold(**fold_params)
+
+            fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
+                                dilation=1, padding=0,
+                                stride=(stride[0] * uf, stride[1] * uf))
+            fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
+
+            weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
+            normalization = fold(weighting).view(1, 1, h * uf, w * uf)  # normalizes the overlap
+            weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
+
+        elif df > 1 and uf == 1:
+            fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
+            unfold = torch.nn.Unfold(**fold_params)
+
+            fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
+                                dilation=1, padding=0,
+                                stride=(stride[0] // df, stride[1] // df))
+            fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
+
+            weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
+            normalization = fold(weighting).view(1, 1, h // df, w // df)  # normalizes the overlap
+            weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
+
+        else:
+            raise NotImplementedError
+
+        return fold, unfold, normalization, weighting
+
+    @torch.no_grad()
+    def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
+                  cond_key=None, return_original_cond=False, bs=None, return_x=False):
+        # get image from batch
+        x = super().get_input(batch, k)
+        if bs is not None:
+            x = x[:bs]
+        x = x.to(self.device)
+        # encode image to latent
+        encoder_posterior = self.encode_first_stage(x)
+        # sample (from vae)
+        z = self.get_first_stage_encoding(encoder_posterior).detach()
+
+        # encode condition caption or class labels
+        if self.model.conditioning_key is not None and not self.force_null_conditioning:
+            if cond_key is None:
+                cond_key = self.cond_stage_key
+            if cond_key != self.first_stage_key:
+                if cond_key in ['caption', 'coordinates_bbox', "txt"]:
+                    xc = batch[cond_key]
+                elif cond_key in ['class_label', 'cls']:
+                    xc = batch
+                else:
+                    xc = super().get_input(batch, cond_key).to(self.device)
+            else:
+                xc = x
+            if not self.cond_stage_trainable or force_c_encode:
+                if isinstance(xc, dict) or isinstance(xc, list):
+                    c = self.get_learned_conditioning(xc)
+                else:
+                    c = self.get_learned_conditioning(xc.to(self.device))
+            else:
+                c = xc
+            if bs is not None:
+                c = c[:bs]
+
+            if self.use_positional_encodings:
+                pos_x, pos_y = self.compute_latent_shifts(batch)
+                ckey = __conditioning_keys__[self.model.conditioning_key]
+                c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y}
+
+        else:
+            c = None
+            xc = None
+            if self.use_positional_encodings:
+                pos_x, pos_y = self.compute_latent_shifts(batch)
+                c = {'pos_x': pos_x, 'pos_y': pos_y}
+
+        # return
+        out = [z, c]
+        if return_first_stage_outputs:
+            xrec = self.decode_first_stage(z)
+            out.extend([x, xrec])
+        if return_x:
+            out.extend([x])
+        if return_original_cond:
+            out.append(xc)
+        return out
+
+    @torch.no_grad()
+    def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
+        if predict_cids:
+            if z.dim() == 4:
+                z = torch.argmax(z.exp(), dim=1).long()
+            z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
+            z = rearrange(z, 'b h w c -> b c h w').contiguous()
+
+        z = 1. / self.scale_factor * z
+        return self.first_stage_model.decode(z)
+
+    @torch.no_grad()
+    def encode_first_stage(self, x):
+        return self.first_stage_model.encode(x)
+
+    def shared_step(self, batch, **kwargs):
+        x, c = self.get_input(batch, self.first_stage_key)
+        loss = self(x, c)
+        return loss
+
+    def forward(self, x, c, *args, **kwargs):
+        t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
+        if self.model.conditioning_key is not None:
+            assert c is not None
+            if self.cond_stage_trainable:
+                c = self.get_learned_conditioning(c)
+            if self.shorten_cond_schedule:  # TODO: drop this option
+                tc = self.cond_ids[t].to(self.device)
+                c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
+        return self.p_losses(x, c, t, *args, **kwargs)
+
+    def apply_model(self, x_noisy, t, cond, return_ids=False):
+        if isinstance(cond, dict):
+            # hybrid case, cond is expected to be a dict
+            pass
+        else:
+            if not isinstance(cond, list):
+                cond = [cond]
+            key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
+            cond = {key: cond}
+
+        x_recon = self.model(x_noisy, t, **cond)
+
+        if isinstance(x_recon, tuple) and not return_ids:
+            return x_recon[0]
+        else:
+            return x_recon
+
+    def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
+        return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
+               extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
+
+    def _prior_bpd(self, x_start):
+        """
+        Get the prior KL term for the variational lower-bound, measured in
+        bits-per-dim.
+        This term can't be optimized, as it only depends on the encoder.
+        :param x_start: the [N x C x ...] tensor of inputs.
+        :return: a batch of [N] KL values (in bits), one per batch element.
+        """
+        batch_size = x_start.shape[0]
+        t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
+        qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
+        kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
+        return mean_flat(kl_prior) / np.log(2.0)
+
+    def p_losses(self, x_start, cond, t, noise=None):
+        noise = default(noise, lambda: torch.randn_like(x_start))
+        x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+        model_output = self.apply_model(x_noisy, t, cond)
+
+        loss_dict = {}
+        prefix = 'train' if self.training else 'val'
+
+        if self.parameterization == "x0":
+            target = x_start
+        elif self.parameterization == "eps":
+            target = noise
+        elif self.parameterization == "v":
+            target = self.get_v(x_start, noise, t)
+        else:
+            raise NotImplementedError()
+
+        loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
+        loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
+
+        logvar_t = self.logvar[t].to(self.device)
+        loss = loss_simple / torch.exp(logvar_t) + logvar_t
+        # loss = loss_simple / torch.exp(self.logvar) + self.logvar
+        if self.learn_logvar:
+            loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
+            loss_dict.update({'logvar': self.logvar.data.mean()})
+
+        loss = self.l_simple_weight * loss.mean()
+
+        loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
+        loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
+        loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
+        loss += (self.original_elbo_weight * loss_vlb)
+        loss_dict.update({f'{prefix}/loss': loss})
+
+        return loss, loss_dict
+
+    def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
+                        return_x0=False, score_corrector=None, corrector_kwargs=None):
+        t_in = t
+        model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
+
+        if score_corrector is not None:
+            assert self.parameterization == "eps"
+            model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
+
+        if return_codebook_ids:
+            model_out, logits = model_out
+
+        if self.parameterization == "eps":
+            x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
+        elif self.parameterization == "x0":
+            x_recon = model_out
+        else:
+            raise NotImplementedError()
+
+        if clip_denoised:
+            x_recon.clamp_(-1., 1.)
+        if quantize_denoised:
+            x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
+        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
+        if return_codebook_ids:
+            return model_mean, posterior_variance, posterior_log_variance, logits
+        elif return_x0:
+            return model_mean, posterior_variance, posterior_log_variance, x_recon
+        else:
+            return model_mean, posterior_variance, posterior_log_variance
+
+    @torch.no_grad()
+    def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
+                 return_codebook_ids=False, quantize_denoised=False, return_x0=False,
+                 temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
+        b, *_, device = *x.shape, x.device
+        outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
+                                       return_codebook_ids=return_codebook_ids,
+                                       quantize_denoised=quantize_denoised,
+                                       return_x0=return_x0,
+                                       score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
+        if return_codebook_ids:
+            raise DeprecationWarning("Support dropped.")
+            model_mean, _, model_log_variance, logits = outputs
+        elif return_x0:
+            model_mean, _, model_log_variance, x0 = outputs
+        else:
+            model_mean, _, model_log_variance = outputs
+
+        noise = noise_like(x.shape, device, repeat_noise) * temperature
+        if noise_dropout > 0.:
+            noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+        # no noise when t == 0
+        nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
+
+        if return_codebook_ids:
+            return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
+        if return_x0:
+            return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
+        else:
+            return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
+
+    @torch.no_grad()
+    def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
+                              img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
+                              score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
+                              log_every_t=None):
+        if not log_every_t:
+            log_every_t = self.log_every_t
+        timesteps = self.num_timesteps
+        if batch_size is not None:
+            b = batch_size if batch_size is not None else shape[0]
+            shape = [batch_size] + list(shape)
+        else:
+            b = batch_size = shape[0]
+        if x_T is None:
+            img = torch.randn(shape, device=self.device)
+        else:
+            img = x_T
+        intermediates = []
+        if cond is not None:
+            if isinstance(cond, dict):
+                cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
+                list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
+            else:
+                cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
+
+        if start_T is not None:
+            timesteps = min(timesteps, start_T)
+        iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
+                        total=timesteps) if verbose else reversed(
+            range(0, timesteps))
+        if type(temperature) == float:
+            temperature = [temperature] * timesteps
+
+        for i in iterator:
+            ts = torch.full((b,), i, device=self.device, dtype=torch.long)
+            if self.shorten_cond_schedule:
+                assert self.model.conditioning_key != 'hybrid'
+                tc = self.cond_ids[ts].to(cond.device)
+                cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
+
+            img, x0_partial = self.p_sample(img, cond, ts,
+                                            clip_denoised=self.clip_denoised,
+                                            quantize_denoised=quantize_denoised, return_x0=True,
+                                            temperature=temperature[i], noise_dropout=noise_dropout,
+                                            score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
+            if mask is not None:
+                assert x0 is not None
+                img_orig = self.q_sample(x0, ts)
+                img = img_orig * mask + (1. - mask) * img
+
+            if i % log_every_t == 0 or i == timesteps - 1:
+                intermediates.append(x0_partial)
+            if callback: callback(i)
+            if img_callback: img_callback(img, i)
+        return img, intermediates
+
+    @torch.no_grad()
+    def p_sample_loop(self, cond, shape, return_intermediates=False,
+                      x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
+                      mask=None, x0=None, img_callback=None, start_T=None,
+                      log_every_t=None):
+
+        if not log_every_t:
+            log_every_t = self.log_every_t
+        device = self.betas.device
+        b = shape[0]
+        if x_T is None:
+            img = torch.randn(shape, device=device)
+        else:
+            img = x_T
+
+        intermediates = [img]
+        if timesteps is None:
+            timesteps = self.num_timesteps
+
+        if start_T is not None:
+            timesteps = min(timesteps, start_T)
+        iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
+            range(0, timesteps))
+
+        if mask is not None:
+            assert x0 is not None
+            assert x0.shape[2:3] == mask.shape[2:3]  # spatial size has to match
+
+        for i in iterator:
+            ts = torch.full((b,), i, device=device, dtype=torch.long)
+            if self.shorten_cond_schedule:
+                assert self.model.conditioning_key != 'hybrid'
+                tc = self.cond_ids[ts].to(cond.device)
+                cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
+
+            img = self.p_sample(img, cond, ts,
+                                clip_denoised=self.clip_denoised,
+                                quantize_denoised=quantize_denoised)
+            if mask is not None:
+                img_orig = self.q_sample(x0, ts)
+                img = img_orig * mask + (1. - mask) * img
+
+            if i % log_every_t == 0 or i == timesteps - 1:
+                intermediates.append(img)
+            if callback: callback(i)
+            if img_callback: img_callback(img, i)
+
+        if return_intermediates:
+            return img, intermediates
+        return img
+
+    @torch.no_grad()
+    def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
+               verbose=True, timesteps=None, quantize_denoised=False,
+               mask=None, x0=None, shape=None, **kwargs):
+        if shape is None:
+            shape = (batch_size, self.channels, self.image_size, self.image_size)
+        if cond is not None:
+            if isinstance(cond, dict):
+                cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
+                list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
+            else:
+                cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
+        return self.p_sample_loop(cond,
+                                  shape,
+                                  return_intermediates=return_intermediates, x_T=x_T,
+                                  verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
+                                  mask=mask, x0=x0)
+
+    @torch.no_grad()
+    def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
+        if ddim:
+            ddim_sampler = DDIMSampler(self)
+            shape = (self.channels, self.image_size, self.image_size)
+            samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size,
+                                                         shape, cond, verbose=False, **kwargs)
+
+        else:
+            samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
+                                                 return_intermediates=True, **kwargs)
+
+        return samples, intermediates
+
+    @torch.no_grad()
+    def get_unconditional_conditioning(self, batch_size, null_label=None):
+        if null_label is not None:
+            xc = null_label
+            if isinstance(xc, ListConfig):
+                xc = list(xc)
+            if isinstance(xc, dict) or isinstance(xc, list):
+                c = self.get_learned_conditioning(xc)
+            else:
+                if hasattr(xc, "to"):
+                    xc = xc.to(self.device)
+                c = self.get_learned_conditioning(xc)
+        else:
+            if self.cond_stage_key in ["class_label", "cls"]:
+                xc = self.cond_stage_model.get_unconditional_conditioning(batch_size, device=self.device)
+                return self.get_learned_conditioning(xc)
+            else:
+                raise NotImplementedError("todo")
+        if isinstance(c, list):  # in case the encoder gives us a list
+            for i in range(len(c)):
+                c[i] = repeat(c[i], '1 ... -> b ...', b=batch_size).to(self.device)
+        else:
+            c = repeat(c, '1 ... -> b ...', b=batch_size).to(self.device)
+        return c
+
+    @torch.no_grad()
+    def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=50, ddim_eta=0., return_keys=None,
+                   quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
+                   plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
+                   use_ema_scope=True,
+                   **kwargs):
+        ema_scope = self.ema_scope if use_ema_scope else nullcontext
+        use_ddim = ddim_steps is not None
+
+        log = dict()
+        z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
+                                           return_first_stage_outputs=True,
+                                           force_c_encode=True,
+                                           return_original_cond=True,
+                                           bs=N)
+        N = min(x.shape[0], N)
+        n_row = min(x.shape[0], n_row)
+        log["inputs"] = x
+        log["reconstruction"] = xrec
+        if self.model.conditioning_key is not None:
+            if hasattr(self.cond_stage_model, "decode"):
+                xc = self.cond_stage_model.decode(c)
+                log["conditioning"] = xc
+            elif self.cond_stage_key in ["caption", "txt"]:
+                xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
+                log["conditioning"] = xc
+            elif self.cond_stage_key in ['class_label', "cls"]:
+                try:
+                    xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
+                    log['conditioning'] = xc
+                except KeyError:
+                    # probably no "human_label" in batch
+                    pass
+            elif isimage(xc):
+                log["conditioning"] = xc
+            if ismap(xc):
+                log["original_conditioning"] = self.to_rgb(xc)
+
+        if plot_diffusion_rows:
+            # get diffusion row
+            diffusion_row = list()
+            z_start = z[:n_row]
+            for t in range(self.num_timesteps):
+                if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+                    t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+                    t = t.to(self.device).long()
+                    noise = torch.randn_like(z_start)
+                    z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
+                    diffusion_row.append(self.decode_first_stage(z_noisy))
+
+            diffusion_row = torch.stack(diffusion_row)  # n_log_step, n_row, C, H, W
+            diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
+            diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
+            diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
+            log["diffusion_row"] = diffusion_grid
+
+        if sample:
+            # get denoise row
+            with ema_scope("Sampling"):
+                samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
+                                                         ddim_steps=ddim_steps, eta=ddim_eta)
+                # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
+            x_samples = self.decode_first_stage(samples)
+            log["samples"] = x_samples
+            if plot_denoise_rows:
+                denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
+                log["denoise_row"] = denoise_grid
+
+            if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
+                    self.first_stage_model, IdentityFirstStage):
+                # also display when quantizing x0 while sampling
+                with ema_scope("Plotting Quantized Denoised"):
+                    samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
+                                                             ddim_steps=ddim_steps, eta=ddim_eta,
+                                                             quantize_denoised=True)
+                    # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
+                    #                                      quantize_denoised=True)
+                x_samples = self.decode_first_stage(samples.to(self.device))
+                log["samples_x0_quantized"] = x_samples
+
+        if unconditional_guidance_scale > 1.0:
+            uc = self.get_unconditional_conditioning(N, unconditional_guidance_label)
+            if self.model.conditioning_key == "crossattn-adm":
+                uc = {"c_crossattn": [uc], "c_adm": c["c_adm"]}
+            with ema_scope("Sampling with classifier-free guidance"):
+                samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
+                                                 ddim_steps=ddim_steps, eta=ddim_eta,
+                                                 unconditional_guidance_scale=unconditional_guidance_scale,
+                                                 unconditional_conditioning=uc,
+                                                 )
+                x_samples_cfg = self.decode_first_stage(samples_cfg)
+                log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
+
+        if inpaint:
+            # make a simple center square
+            b, h, w = z.shape[0], z.shape[2], z.shape[3]
+            mask = torch.ones(N, h, w).to(self.device)
+            # zeros will be filled in
+            mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
+            mask = mask[:, None, ...]
+            with ema_scope("Plotting Inpaint"):
+                samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta,
+                                             ddim_steps=ddim_steps, x0=z[:N], mask=mask)
+            x_samples = self.decode_first_stage(samples.to(self.device))
+            log["samples_inpainting"] = x_samples
+            log["mask"] = mask
+
+            # outpaint
+            mask = 1. - mask
+            with ema_scope("Plotting Outpaint"):
+                samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta,
+                                             ddim_steps=ddim_steps, x0=z[:N], mask=mask)
+            x_samples = self.decode_first_stage(samples.to(self.device))
+            log["samples_outpainting"] = x_samples
+
+        if plot_progressive_rows:
+            with ema_scope("Plotting Progressives"):
+                img, progressives = self.progressive_denoising(c,
+                                                               shape=(self.channels, self.image_size, self.image_size),
+                                                               batch_size=N)
+            prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
+            log["progressive_row"] = prog_row
+
+        if return_keys:
+            if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
+                return log
+            else:
+                return {key: log[key] for key in return_keys}
+        return log
+
+    def configure_optimizers(self):
+        lr = self.learning_rate
+        params = list(self.model.parameters())
+        if self.cond_stage_trainable:
+            print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
+            params = params + list(self.cond_stage_model.parameters())
+        if self.learn_logvar:
+            print('Diffusion model optimizing logvar')
+            params.append(self.logvar)
+        opt = torch.optim.AdamW(params, lr=lr)
+        if self.use_scheduler:
+            assert 'target' in self.scheduler_config
+            scheduler = instantiate_from_config(self.scheduler_config)
+
+            print("Setting up LambdaLR scheduler...")
+            scheduler = [
+                {
+                    'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
+                    'interval': 'step',
+                    'frequency': 1
+                }]
+            return [opt], scheduler
+        return opt
+
+    @torch.no_grad()
+    def to_rgb(self, x):
+        x = x.float()
+        if not hasattr(self, "colorize"):
+            self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
+        x = nn.functional.conv2d(x, weight=self.colorize)
+        x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
+        return x
+
+
+class DiffusionWrapper(pl.LightningModule):
+    def __init__(self, diff_model_config, conditioning_key):
+        super().__init__()
+        self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False)
+        self.diffusion_model = instantiate_from_config(diff_model_config)
+        self.conditioning_key = conditioning_key
+        assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm']
+
+    def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None):
+        if self.conditioning_key is None:
+            out = self.diffusion_model(x, t)
+        elif self.conditioning_key == 'concat':
+            xc = torch.cat([x] + c_concat, dim=1)
+            out = self.diffusion_model(xc, t)
+        elif self.conditioning_key == 'crossattn':
+            if not self.sequential_cross_attn:
+                cc = torch.cat(c_crossattn, 1)
+            else:
+                cc = c_crossattn
+            out = self.diffusion_model(x, t, context=cc)
+        elif self.conditioning_key == 'hybrid':
+            xc = torch.cat([x] + c_concat, dim=1)
+            cc = torch.cat(c_crossattn, 1)
+            out = self.diffusion_model(xc, t, context=cc)
+        elif self.conditioning_key == 'hybrid-adm':
+            assert c_adm is not None
+            xc = torch.cat([x] + c_concat, dim=1)
+            cc = torch.cat(c_crossattn, 1)
+            out = self.diffusion_model(xc, t, context=cc, y=c_adm)
+        elif self.conditioning_key == 'crossattn-adm':
+            assert c_adm is not None
+            cc = torch.cat(c_crossattn, 1)
+            out = self.diffusion_model(x, t, context=cc, y=c_adm)
+        elif self.conditioning_key == 'adm':
+            cc = c_crossattn[0]
+            out = self.diffusion_model(x, t, y=cc)
+        else:
+            raise NotImplementedError()
+
+        return out
+
+
+class LatentUpscaleDiffusion(LatentDiffusion):
+    def __init__(self, *args, low_scale_config, low_scale_key="LR", noise_level_key=None, **kwargs):
+        super().__init__(*args, **kwargs)
+        # assumes that neither the cond_stage nor the low_scale_model contain trainable params
+        assert not self.cond_stage_trainable
+        self.instantiate_low_stage(low_scale_config)
+        self.low_scale_key = low_scale_key
+        self.noise_level_key = noise_level_key
+
+    def instantiate_low_stage(self, config):
+        model = instantiate_from_config(config)
+        self.low_scale_model = model.eval()
+        self.low_scale_model.train = disabled_train
+        for param in self.low_scale_model.parameters():
+            param.requires_grad = False
+
+    @torch.no_grad()
+    def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False):
+        if not log_mode:
+            z, c = super().get_input(batch, k, force_c_encode=True, bs=bs)
+        else:
+            z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
+                                                  force_c_encode=True, return_original_cond=True, bs=bs)
+        x_low = batch[self.low_scale_key][:bs]
+        x_low = rearrange(x_low, 'b h w c -> b c h w')
+        x_low = x_low.to(memory_format=torch.contiguous_format).float()
+        zx, noise_level = self.low_scale_model(x_low)
+        if self.noise_level_key is not None:
+            # get noise level from batch instead, e.g. when extracting a custom noise level for bsr
+            raise NotImplementedError('TODO')
+
+        all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level}
+        if log_mode:
+            # TODO: maybe disable if too expensive
+            x_low_rec = self.low_scale_model.decode(zx)
+            return z, all_conds, x, xrec, xc, x_low, x_low_rec, noise_level
+        return z, all_conds
+
+    @torch.no_grad()
+    def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
+                   plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True,
+                   unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True,
+                   **kwargs):
+        ema_scope = self.ema_scope if use_ema_scope else nullcontext
+        use_ddim = ddim_steps is not None
+
+        log = dict()
+        z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input(batch, self.first_stage_key, bs=N,
+                                                                          log_mode=True)
+        N = min(x.shape[0], N)
+        n_row = min(x.shape[0], n_row)
+        log["inputs"] = x
+        log["reconstruction"] = xrec
+        log["x_lr"] = x_low
+        log[f"x_lr_rec_@noise_levels{'-'.join(map(lambda x: str(x), list(noise_level.cpu().numpy())))}"] = x_low_rec
+        if self.model.conditioning_key is not None:
+            if hasattr(self.cond_stage_model, "decode"):
+                xc = self.cond_stage_model.decode(c)
+                log["conditioning"] = xc
+            elif self.cond_stage_key in ["caption", "txt"]:
+                xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
+                log["conditioning"] = xc
+            elif self.cond_stage_key in ['class_label', 'cls']:
+                xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
+                log['conditioning'] = xc
+            elif isimage(xc):
+                log["conditioning"] = xc
+            if ismap(xc):
+                log["original_conditioning"] = self.to_rgb(xc)
+
+        if plot_diffusion_rows:
+            # get diffusion row
+            diffusion_row = list()
+            z_start = z[:n_row]
+            for t in range(self.num_timesteps):
+                if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+                    t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+                    t = t.to(self.device).long()
+                    noise = torch.randn_like(z_start)
+                    z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
+                    diffusion_row.append(self.decode_first_stage(z_noisy))
+
+            diffusion_row = torch.stack(diffusion_row)  # n_log_step, n_row, C, H, W
+            diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
+            diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
+            diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
+            log["diffusion_row"] = diffusion_grid
+
+        if sample:
+            # get denoise row
+            with ema_scope("Sampling"):
+                samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
+                                                         ddim_steps=ddim_steps, eta=ddim_eta)
+                # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
+            x_samples = self.decode_first_stage(samples)
+            log["samples"] = x_samples
+            if plot_denoise_rows:
+                denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
+                log["denoise_row"] = denoise_grid
+
+        if unconditional_guidance_scale > 1.0:
+            uc_tmp = self.get_unconditional_conditioning(N, unconditional_guidance_label)
+            # TODO explore better "unconditional" choices for the other keys
+            # maybe guide away from empty text label and highest noise level and maximally degraded zx?
+            uc = dict()
+            for k in c:
+                if k == "c_crossattn":
+                    assert isinstance(c[k], list) and len(c[k]) == 1
+                    uc[k] = [uc_tmp]
+                elif k == "c_adm":  # todo: only run with text-based guidance?
+                    assert isinstance(c[k], torch.Tensor)
+                    #uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level
+                    uc[k] = c[k]
+                elif isinstance(c[k], list):
+                    uc[k] = [c[k][i] for i in range(len(c[k]))]
+                else:
+                    uc[k] = c[k]
+
+            with ema_scope("Sampling with classifier-free guidance"):
+                samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
+                                                 ddim_steps=ddim_steps, eta=ddim_eta,
+                                                 unconditional_guidance_scale=unconditional_guidance_scale,
+                                                 unconditional_conditioning=uc,
+                                                 )
+                x_samples_cfg = self.decode_first_stage(samples_cfg)
+                log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
+
+        if plot_progressive_rows:
+            with ema_scope("Plotting Progressives"):
+                img, progressives = self.progressive_denoising(c,
+                                                               shape=(self.channels, self.image_size, self.image_size),
+                                                               batch_size=N)
+            prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
+            log["progressive_row"] = prog_row
+
+        return log
+
+
+class LatentFinetuneDiffusion(LatentDiffusion):
+    """
+         Basis for different finetunas, such as inpainting or depth2image
+         To disable finetuning mode, set finetune_keys to None
+    """
+
+    def __init__(self,
+                 concat_keys: tuple,
+                 finetune_keys=("model.diffusion_model.input_blocks.0.0.weight",
+                                "model_ema.diffusion_modelinput_blocks00weight"
+                                ),
+                 keep_finetune_dims=4,
+                 # if model was trained without concat mode before and we would like to keep these channels
+                 c_concat_log_start=None,  # to log reconstruction of c_concat codes
+                 c_concat_log_end=None,
+                 *args, **kwargs
+                 ):
+        ckpt_path = kwargs.pop("ckpt_path", None)
+        ignore_keys = kwargs.pop("ignore_keys", list())
+        super().__init__(*args, **kwargs)
+        self.finetune_keys = finetune_keys
+        self.concat_keys = concat_keys
+        self.keep_dims = keep_finetune_dims
+        self.c_concat_log_start = c_concat_log_start
+        self.c_concat_log_end = c_concat_log_end
+        if exists(self.finetune_keys): assert exists(ckpt_path), 'can only finetune from a given checkpoint'
+        if exists(ckpt_path):
+            self.init_from_ckpt(ckpt_path, ignore_keys)
+
+    def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
+        sd = torch.load(path, map_location="cpu")
+        if "state_dict" in list(sd.keys()):
+            sd = sd["state_dict"]
+        keys = list(sd.keys())
+        for k in keys:
+            for ik in ignore_keys:
+                if k.startswith(ik):
+                    print("Deleting key {} from state_dict.".format(k))
+                    del sd[k]
+
+            # make it explicit, finetune by including extra input channels
+            if exists(self.finetune_keys) and k in self.finetune_keys:
+                new_entry = None
+                for name, param in self.named_parameters():
+                    if name in self.finetune_keys:
+                        print(
+                            f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only")
+                        new_entry = torch.zeros_like(param)  # zero init
+                assert exists(new_entry), 'did not find matching parameter to modify'
+                new_entry[:, :self.keep_dims, ...] = sd[k]
+                sd[k] = new_entry
+
+        missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
+            sd, strict=False)
+        print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
+        if len(missing) > 0:
+            print(f"Missing Keys: {missing}")
+        if len(unexpected) > 0:
+            print(f"Unexpected Keys: {unexpected}")
+
+    @torch.no_grad()
+    def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
+                   quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
+                   plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
+                   use_ema_scope=True,
+                   **kwargs):
+        ema_scope = self.ema_scope if use_ema_scope else nullcontext
+        use_ddim = ddim_steps is not None
+
+        log = dict()
+        z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True)
+        c_cat, c = c["c_concat"][0], c["c_crossattn"][0]
+        N = min(x.shape[0], N)
+        n_row = min(x.shape[0], n_row)
+        log["inputs"] = x
+        log["reconstruction"] = xrec
+        if self.model.conditioning_key is not None:
+            if hasattr(self.cond_stage_model, "decode"):
+                xc = self.cond_stage_model.decode(c)
+                log["conditioning"] = xc
+            elif self.cond_stage_key in ["caption", "txt"]:
+                xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
+                log["conditioning"] = xc
+            elif self.cond_stage_key in ['class_label', 'cls']:
+                xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
+                log['conditioning'] = xc
+            elif isimage(xc):
+                log["conditioning"] = xc
+            if ismap(xc):
+                log["original_conditioning"] = self.to_rgb(xc)
+
+        if not (self.c_concat_log_start is None and self.c_concat_log_end is None):
+            log["c_concat_decoded"] = self.decode_first_stage(c_cat[:, self.c_concat_log_start:self.c_concat_log_end])
+
+        if plot_diffusion_rows:
+            # get diffusion row
+            diffusion_row = list()
+            z_start = z[:n_row]
+            for t in range(self.num_timesteps):
+                if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+                    t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+                    t = t.to(self.device).long()
+                    noise = torch.randn_like(z_start)
+                    z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
+                    diffusion_row.append(self.decode_first_stage(z_noisy))
+
+            diffusion_row = torch.stack(diffusion_row)  # n_log_step, n_row, C, H, W
+            diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
+            diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
+            diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
+            log["diffusion_row"] = diffusion_grid
+
+        if sample:
+            # get denoise row
+            with ema_scope("Sampling"):
+                samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
+                                                         batch_size=N, ddim=use_ddim,
+                                                         ddim_steps=ddim_steps, eta=ddim_eta)
+                # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
+            x_samples = self.decode_first_stage(samples)
+            log["samples"] = x_samples
+            if plot_denoise_rows:
+                denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
+                log["denoise_row"] = denoise_grid
+
+        if unconditional_guidance_scale > 1.0:
+            uc_cross = self.get_unconditional_conditioning(N, unconditional_guidance_label)
+            uc_cat = c_cat
+            uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
+            with ema_scope("Sampling with classifier-free guidance"):
+                samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
+                                                 batch_size=N, ddim=use_ddim,
+                                                 ddim_steps=ddim_steps, eta=ddim_eta,
+                                                 unconditional_guidance_scale=unconditional_guidance_scale,
+                                                 unconditional_conditioning=uc_full,
+                                                 )
+                x_samples_cfg = self.decode_first_stage(samples_cfg)
+                log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
+
+        return log
+
+
+class LatentInpaintDiffusion(LatentFinetuneDiffusion):
+    """
+    can either run as pure inpainting model (only concat mode) or with mixed conditionings,
+    e.g. mask as concat and text via cross-attn.
+    To disable finetuning mode, set finetune_keys to None
+     """
+
+    def __init__(self,
+                 concat_keys=("mask", "masked_image"),
+                 masked_image_key="masked_image",
+                 *args, **kwargs
+                 ):
+        super().__init__(concat_keys, *args, **kwargs)
+        self.masked_image_key = masked_image_key
+        assert self.masked_image_key in concat_keys
+
+    @torch.no_grad()
+    def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
+        # note: restricted to non-trainable encoders currently
+        assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for inpainting'
+        z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
+                                              force_c_encode=True, return_original_cond=True, bs=bs)
+
+        assert exists(self.concat_keys)
+        c_cat = list()
+        for ck in self.concat_keys:
+            cc = rearrange(batch[ck], 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float()
+            if bs is not None:
+                cc = cc[:bs]
+                cc = cc.to(self.device)
+            bchw = z.shape
+            if ck != self.masked_image_key:
+                cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
+            else:
+                cc = self.get_first_stage_encoding(self.encode_first_stage(cc))
+            c_cat.append(cc)
+        c_cat = torch.cat(c_cat, dim=1)
+        all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
+        if return_first_stage_outputs:
+            return z, all_conds, x, xrec, xc
+        return z, all_conds
+
+    @torch.no_grad()
+    def log_images(self, *args, **kwargs):
+        log = super(LatentInpaintDiffusion, self).log_images(*args, **kwargs)
+        log["masked_image"] = rearrange(args[0]["masked_image"],
+                                        'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float()
+        return log
+
+
+class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion):
+    """
+    condition on monocular depth estimation
+    """
+
+    def __init__(self, depth_stage_config, concat_keys=("midas_in",), *args, **kwargs):
+        super().__init__(concat_keys=concat_keys, *args, **kwargs)
+        self.depth_model = instantiate_from_config(depth_stage_config)
+        self.depth_stage_key = concat_keys[0]
+
+    @torch.no_grad()
+    def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
+        # note: restricted to non-trainable encoders currently
+        assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for depth2img'
+        z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
+                                              force_c_encode=True, return_original_cond=True, bs=bs)
+
+        assert exists(self.concat_keys)
+        assert len(self.concat_keys) == 1
+        c_cat = list()
+        for ck in self.concat_keys:
+            cc = batch[ck]
+            if bs is not None:
+                cc = cc[:bs]
+                cc = cc.to(self.device)
+            cc = self.depth_model(cc)
+            cc = torch.nn.functional.interpolate(
+                cc,
+                size=z.shape[2:],
+                mode="bicubic",
+                align_corners=False,
+            )
+
+            depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, dim=[1, 2, 3],
+                                                                                           keepdim=True)
+            cc = 2. * (cc - depth_min) / (depth_max - depth_min + 0.001) - 1.
+            c_cat.append(cc)
+        c_cat = torch.cat(c_cat, dim=1)
+        all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
+        if return_first_stage_outputs:
+            return z, all_conds, x, xrec, xc
+        return z, all_conds
+
+    @torch.no_grad()
+    def log_images(self, *args, **kwargs):
+        log = super().log_images(*args, **kwargs)
+        depth = self.depth_model(args[0][self.depth_stage_key])
+        depth_min, depth_max = torch.amin(depth, dim=[1, 2, 3], keepdim=True), \
+                               torch.amax(depth, dim=[1, 2, 3], keepdim=True)
+        log["depth"] = 2. * (depth - depth_min) / (depth_max - depth_min) - 1.
+        return log
+
+
+class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
+    """
+        condition on low-res image (and optionally on some spatial noise augmentation)
+    """
+    def __init__(self, concat_keys=("lr",), reshuffle_patch_size=None,
+                 low_scale_config=None, low_scale_key=None, *args, **kwargs):
+        super().__init__(concat_keys=concat_keys, *args, **kwargs)
+        self.reshuffle_patch_size = reshuffle_patch_size
+        self.low_scale_model = None
+        if low_scale_config is not None:
+            print("Initializing a low-scale model")
+            assert exists(low_scale_key)
+            self.instantiate_low_stage(low_scale_config)
+            self.low_scale_key = low_scale_key
+
+    def instantiate_low_stage(self, config):
+        model = instantiate_from_config(config)
+        self.low_scale_model = model.eval()
+        self.low_scale_model.train = disabled_train
+        for param in self.low_scale_model.parameters():
+            param.requires_grad = False
+
+    @torch.no_grad()
+    def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
+        # note: restricted to non-trainable encoders currently
+        assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for upscaling-ft'
+        z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
+                                              force_c_encode=True, return_original_cond=True, bs=bs)
+
+        assert exists(self.concat_keys)
+        assert len(self.concat_keys) == 1
+        # optionally make spatial noise_level here
+        c_cat = list()
+        noise_level = None
+        for ck in self.concat_keys:
+            cc = batch[ck]
+            cc = rearrange(cc, 'b h w c -> b c h w')
+            if exists(self.reshuffle_patch_size):
+                assert isinstance(self.reshuffle_patch_size, int)
+                cc = rearrange(cc, 'b c (p1 h) (p2 w) -> b (p1 p2 c) h w',
+                               p1=self.reshuffle_patch_size, p2=self.reshuffle_patch_size)
+            if bs is not None:
+                cc = cc[:bs]
+                cc = cc.to(self.device)
+            if exists(self.low_scale_model) and ck == self.low_scale_key:
+                cc, noise_level = self.low_scale_model(cc)
+            c_cat.append(cc)
+        c_cat = torch.cat(c_cat, dim=1)
+        if exists(noise_level):
+            all_conds = {"c_concat": [c_cat], "c_crossattn": [c], "c_adm": noise_level}
+        else:
+            all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
+        if return_first_stage_outputs:
+            return z, all_conds, x, xrec, xc
+        return z, all_conds
+
+    @torch.no_grad()
+    def log_images(self, *args, **kwargs):
+        log = super().log_images(*args, **kwargs)
+        log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w')
+        return log
diff --git a/ldm/models/diffusion/dpm_solver/__init__.py b/ldm/models/diffusion/dpm_solver/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7427f38c07530afbab79154ea8aaf88c4bf70a08
--- /dev/null
+++ b/ldm/models/diffusion/dpm_solver/__init__.py
@@ -0,0 +1 @@
+from .sampler import DPMSolverSampler
\ No newline at end of file
diff --git a/ldm/models/diffusion/dpm_solver/dpm_solver.py b/ldm/models/diffusion/dpm_solver/dpm_solver.py
new file mode 100644
index 0000000000000000000000000000000000000000..095e5ba3ce0b1aa7f4b3f1e2e5d8fff7cfe6dc8c
--- /dev/null
+++ b/ldm/models/diffusion/dpm_solver/dpm_solver.py
@@ -0,0 +1,1154 @@
+import torch
+import torch.nn.functional as F
+import math
+from tqdm import tqdm
+
+
+class NoiseScheduleVP:
+    def __init__(
+            self,
+            schedule='discrete',
+            betas=None,
+            alphas_cumprod=None,
+            continuous_beta_0=0.1,
+            continuous_beta_1=20.,
+    ):
+        """Create a wrapper class for the forward SDE (VP type).
+        ***
+        Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
+                We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
+        ***
+        The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
+        We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
+        Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
+            log_alpha_t = self.marginal_log_mean_coeff(t)
+            sigma_t = self.marginal_std(t)
+            lambda_t = self.marginal_lambda(t)
+        Moreover, as lambda(t) is an invertible function, we also support its inverse function:
+            t = self.inverse_lambda(lambda_t)
+        ===============================================================
+        We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
+        1. For discrete-time DPMs:
+            For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
+                t_i = (i + 1) / N
+            e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
+            We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
+            Args:
+                betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
+                alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
+            Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
+            **Important**:  Please pay special attention for the args for `alphas_cumprod`:
+                The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
+                    q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
+                Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
+                    alpha_{t_n} = \sqrt{\hat{alpha_n}},
+                and
+                    log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
+        2. For continuous-time DPMs:
+            We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
+            schedule are the default settings in DDPM and improved-DDPM:
+            Args:
+                beta_min: A `float` number. The smallest beta for the linear schedule.
+                beta_max: A `float` number. The largest beta for the linear schedule.
+                cosine_s: A `float` number. The hyperparameter in the cosine schedule.
+                cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
+                T: A `float` number. The ending time of the forward process.
+        ===============================================================
+        Args:
+            schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
+                    'linear' or 'cosine' for continuous-time DPMs.
+        Returns:
+            A wrapper object of the forward SDE (VP type).
+
+        ===============================================================
+        Example:
+        # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
+        >>> ns = NoiseScheduleVP('discrete', betas=betas)
+        # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
+        >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
+        # For continuous-time DPMs (VPSDE), linear schedule:
+        >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
+        """
+
+        if schedule not in ['discrete', 'linear', 'cosine']:
+            raise ValueError(
+                "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(
+                    schedule))
+
+        self.schedule = schedule
+        if schedule == 'discrete':
+            if betas is not None:
+                log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
+            else:
+                assert alphas_cumprod is not None
+                log_alphas = 0.5 * torch.log(alphas_cumprod)
+            self.total_N = len(log_alphas)
+            self.T = 1.
+            self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1))
+            self.log_alpha_array = log_alphas.reshape((1, -1,))
+        else:
+            self.total_N = 1000
+            self.beta_0 = continuous_beta_0
+            self.beta_1 = continuous_beta_1
+            self.cosine_s = 0.008
+            self.cosine_beta_max = 999.
+            self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (
+                        1. + self.cosine_s) / math.pi - self.cosine_s
+            self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
+            self.schedule = schedule
+            if schedule == 'cosine':
+                # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
+                # Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
+                self.T = 0.9946
+            else:
+                self.T = 1.
+
+    def marginal_log_mean_coeff(self, t):
+        """
+        Compute log(alpha_t) of a given continuous-time label t in [0, T].
+        """
+        if self.schedule == 'discrete':
+            return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device),
+                                  self.log_alpha_array.to(t.device)).reshape((-1))
+        elif self.schedule == 'linear':
+            return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
+        elif self.schedule == 'cosine':
+            log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
+            log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
+            return log_alpha_t
+
+    def marginal_alpha(self, t):
+        """
+        Compute alpha_t of a given continuous-time label t in [0, T].
+        """
+        return torch.exp(self.marginal_log_mean_coeff(t))
+
+    def marginal_std(self, t):
+        """
+        Compute sigma_t of a given continuous-time label t in [0, T].
+        """
+        return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
+
+    def marginal_lambda(self, t):
+        """
+        Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
+        """
+        log_mean_coeff = self.marginal_log_mean_coeff(t)
+        log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
+        return log_mean_coeff - log_std
+
+    def inverse_lambda(self, lamb):
+        """
+        Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
+        """
+        if self.schedule == 'linear':
+            tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
+            Delta = self.beta_0 ** 2 + tmp
+            return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
+        elif self.schedule == 'discrete':
+            log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
+            t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]),
+                               torch.flip(self.t_array.to(lamb.device), [1]))
+            return t.reshape((-1,))
+        else:
+            log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
+            t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (
+                        1. + self.cosine_s) / math.pi - self.cosine_s
+            t = t_fn(log_alpha)
+            return t
+
+
+def model_wrapper(
+        model,
+        noise_schedule,
+        model_type="noise",
+        model_kwargs={},
+        guidance_type="uncond",
+        condition=None,
+        unconditional_condition=None,
+        guidance_scale=1.,
+        classifier_fn=None,
+        classifier_kwargs={},
+):
+    """Create a wrapper function for the noise prediction model.
+    DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
+    firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
+    We support four types of the diffusion model by setting `model_type`:
+        1. "noise": noise prediction model. (Trained by predicting noise).
+        2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
+        3. "v": velocity prediction model. (Trained by predicting the velocity).
+            The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
+            [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
+                arXiv preprint arXiv:2202.00512 (2022).
+            [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
+                arXiv preprint arXiv:2210.02303 (2022).
+
+        4. "score": marginal score function. (Trained by denoising score matching).
+            Note that the score function and the noise prediction model follows a simple relationship:
+            ```
+                noise(x_t, t) = -sigma_t * score(x_t, t)
+            ```
+    We support three types of guided sampling by DPMs by setting `guidance_type`:
+        1. "uncond": unconditional sampling by DPMs.
+            The input `model` has the following format:
+            ``
+                model(x, t_input, **model_kwargs) -> noise | x_start | v | score
+            ``
+        2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
+            The input `model` has the following format:
+            ``
+                model(x, t_input, **model_kwargs) -> noise | x_start | v | score
+            ``
+            The input `classifier_fn` has the following format:
+            ``
+                classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
+            ``
+            [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
+                in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
+        3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
+            The input `model` has the following format:
+            ``
+                model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
+            ``
+            And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
+            [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
+                arXiv preprint arXiv:2207.12598 (2022).
+
+    The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
+    or continuous-time labels (i.e. epsilon to T).
+    We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
+    ``
+        def model_fn(x, t_continuous) -> noise:
+            t_input = get_model_input_time(t_continuous)
+            return noise_pred(model, x, t_input, **model_kwargs)
+    ``
+    where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
+    ===============================================================
+    Args:
+        model: A diffusion model with the corresponding format described above.
+        noise_schedule: A noise schedule object, such as NoiseScheduleVP.
+        model_type: A `str`. The parameterization type of the diffusion model.
+                    "noise" or "x_start" or "v" or "score".
+        model_kwargs: A `dict`. A dict for the other inputs of the model function.
+        guidance_type: A `str`. The type of the guidance for sampling.
+                    "uncond" or "classifier" or "classifier-free".
+        condition: A pytorch tensor. The condition for the guided sampling.
+                    Only used for "classifier" or "classifier-free" guidance type.
+        unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
+                    Only used for "classifier-free" guidance type.
+        guidance_scale: A `float`. The scale for the guided sampling.
+        classifier_fn: A classifier function. Only used for the classifier guidance.
+        classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
+    Returns:
+        A noise prediction model that accepts the noised data and the continuous time as the inputs.
+    """
+
+    def get_model_input_time(t_continuous):
+        """
+        Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
+        For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
+        For continuous-time DPMs, we just use `t_continuous`.
+        """
+        if noise_schedule.schedule == 'discrete':
+            return (t_continuous - 1. / noise_schedule.total_N) * 1000.
+        else:
+            return t_continuous
+
+    def noise_pred_fn(x, t_continuous, cond=None):
+        if t_continuous.reshape((-1,)).shape[0] == 1:
+            t_continuous = t_continuous.expand((x.shape[0]))
+        t_input = get_model_input_time(t_continuous)
+        if cond is None:
+            output = model(x, t_input, **model_kwargs)
+        else:
+            output = model(x, t_input, cond, **model_kwargs)
+        if model_type == "noise":
+            return output
+        elif model_type == "x_start":
+            alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
+            dims = x.dim()
+            return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
+        elif model_type == "v":
+            alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
+            dims = x.dim()
+            return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
+        elif model_type == "score":
+            sigma_t = noise_schedule.marginal_std(t_continuous)
+            dims = x.dim()
+            return -expand_dims(sigma_t, dims) * output
+
+    def cond_grad_fn(x, t_input):
+        """
+        Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
+        """
+        with torch.enable_grad():
+            x_in = x.detach().requires_grad_(True)
+            log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
+            return torch.autograd.grad(log_prob.sum(), x_in)[0]
+
+    def model_fn(x, t_continuous):
+        """
+        The noise predicition model function that is used for DPM-Solver.
+        """
+        if t_continuous.reshape((-1,)).shape[0] == 1:
+            t_continuous = t_continuous.expand((x.shape[0]))
+        if guidance_type == "uncond":
+            return noise_pred_fn(x, t_continuous)
+        elif guidance_type == "classifier":
+            assert classifier_fn is not None
+            t_input = get_model_input_time(t_continuous)
+            cond_grad = cond_grad_fn(x, t_input)
+            sigma_t = noise_schedule.marginal_std(t_continuous)
+            noise = noise_pred_fn(x, t_continuous)
+            return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
+        elif guidance_type == "classifier-free":
+            if guidance_scale == 1. or unconditional_condition is None:
+                return noise_pred_fn(x, t_continuous, cond=condition)
+            else:
+                x_in = torch.cat([x] * 2)
+                t_in = torch.cat([t_continuous] * 2)
+                c_in = torch.cat([unconditional_condition, condition])
+                noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
+                return noise_uncond + guidance_scale * (noise - noise_uncond)
+
+    assert model_type in ["noise", "x_start", "v"]
+    assert guidance_type in ["uncond", "classifier", "classifier-free"]
+    return model_fn
+
+
+class DPM_Solver:
+    def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.):
+        """Construct a DPM-Solver.
+        We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0").
+        If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver).
+        If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++).
+            In such case, we further support the "dynamic thresholding" in [1] when `thresholding` is True.
+            The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs with large guidance scales.
+        Args:
+            model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):
+                ``
+                def model_fn(x, t_continuous):
+                    return noise
+                ``
+            noise_schedule: A noise schedule object, such as NoiseScheduleVP.
+            predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model.
+            thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1].
+            max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. The max value for thresholding.
+
+        [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.
+        """
+        self.model = model_fn
+        self.noise_schedule = noise_schedule
+        self.predict_x0 = predict_x0
+        self.thresholding = thresholding
+        self.max_val = max_val
+
+    def noise_prediction_fn(self, x, t):
+        """
+        Return the noise prediction model.
+        """
+        return self.model(x, t)
+
+    def data_prediction_fn(self, x, t):
+        """
+        Return the data prediction model (with thresholding).
+        """
+        noise = self.noise_prediction_fn(x, t)
+        dims = x.dim()
+        alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
+        x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
+        if self.thresholding:
+            p = 0.995  # A hyperparameter in the paper of "Imagen" [1].
+            s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
+            s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
+            x0 = torch.clamp(x0, -s, s) / s
+        return x0
+
+    def model_fn(self, x, t):
+        """
+        Convert the model to the noise prediction model or the data prediction model.
+        """
+        if self.predict_x0:
+            return self.data_prediction_fn(x, t)
+        else:
+            return self.noise_prediction_fn(x, t)
+
+    def get_time_steps(self, skip_type, t_T, t_0, N, device):
+        """Compute the intermediate time steps for sampling.
+        Args:
+            skip_type: A `str`. The type for the spacing of the time steps. We support three types:
+                - 'logSNR': uniform logSNR for the time steps.
+                - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
+                - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
+            t_T: A `float`. The starting time of the sampling (default is T).
+            t_0: A `float`. The ending time of the sampling (default is epsilon).
+            N: A `int`. The total number of the spacing of the time steps.
+            device: A torch device.
+        Returns:
+            A pytorch tensor of the time steps, with the shape (N + 1,).
+        """
+        if skip_type == 'logSNR':
+            lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
+            lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
+            logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
+            return self.noise_schedule.inverse_lambda(logSNR_steps)
+        elif skip_type == 'time_uniform':
+            return torch.linspace(t_T, t_0, N + 1).to(device)
+        elif skip_type == 'time_quadratic':
+            t_order = 2
+            t = torch.linspace(t_T ** (1. / t_order), t_0 ** (1. / t_order), N + 1).pow(t_order).to(device)
+            return t
+        else:
+            raise ValueError(
+                "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
+
+    def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
+        """
+        Get the order of each step for sampling by the singlestep DPM-Solver.
+        We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast".
+        Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is:
+            - If order == 1:
+                We take `steps` of DPM-Solver-1 (i.e. DDIM).
+            - If order == 2:
+                - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling.
+                - If steps % 2 == 0, we use K steps of DPM-Solver-2.
+                - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1.
+            - If order == 3:
+                - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
+                - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
+                - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
+                - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
+        ============================================
+        Args:
+            order: A `int`. The max order for the solver (2 or 3).
+            steps: A `int`. The total number of function evaluations (NFE).
+            skip_type: A `str`. The type for the spacing of the time steps. We support three types:
+                - 'logSNR': uniform logSNR for the time steps.
+                - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
+                - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
+            t_T: A `float`. The starting time of the sampling (default is T).
+            t_0: A `float`. The ending time of the sampling (default is epsilon).
+            device: A torch device.
+        Returns:
+            orders: A list of the solver order of each step.
+        """
+        if order == 3:
+            K = steps // 3 + 1
+            if steps % 3 == 0:
+                orders = [3, ] * (K - 2) + [2, 1]
+            elif steps % 3 == 1:
+                orders = [3, ] * (K - 1) + [1]
+            else:
+                orders = [3, ] * (K - 1) + [2]
+        elif order == 2:
+            if steps % 2 == 0:
+                K = steps // 2
+                orders = [2, ] * K
+            else:
+                K = steps // 2 + 1
+                orders = [2, ] * (K - 1) + [1]
+        elif order == 1:
+            K = 1
+            orders = [1, ] * steps
+        else:
+            raise ValueError("'order' must be '1' or '2' or '3'.")
+        if skip_type == 'logSNR':
+            # To reproduce the results in DPM-Solver paper
+            timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
+        else:
+            timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[
+                torch.cumsum(torch.tensor([0, ] + orders)).to(device)]
+        return timesteps_outer, orders
+
+    def denoise_to_zero_fn(self, x, s):
+        """
+        Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
+        """
+        return self.data_prediction_fn(x, s)
+
+    def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False):
+        """
+        DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`.
+        Args:
+            x: A pytorch tensor. The initial value at time `s`.
+            s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+            t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+            model_s: A pytorch tensor. The model function evaluated at time `s`.
+                If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
+            return_intermediate: A `bool`. If true, also return the model value at time `s`.
+        Returns:
+            x_t: A pytorch tensor. The approximated solution at time `t`.
+        """
+        ns = self.noise_schedule
+        dims = x.dim()
+        lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
+        h = lambda_t - lambda_s
+        log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
+        sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
+        alpha_t = torch.exp(log_alpha_t)
+
+        if self.predict_x0:
+            phi_1 = torch.expm1(-h)
+            if model_s is None:
+                model_s = self.model_fn(x, s)
+            x_t = (
+                    expand_dims(sigma_t / sigma_s, dims) * x
+                    - expand_dims(alpha_t * phi_1, dims) * model_s
+            )
+            if return_intermediate:
+                return x_t, {'model_s': model_s}
+            else:
+                return x_t
+        else:
+            phi_1 = torch.expm1(h)
+            if model_s is None:
+                model_s = self.model_fn(x, s)
+            x_t = (
+                    expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+                    - expand_dims(sigma_t * phi_1, dims) * model_s
+            )
+            if return_intermediate:
+                return x_t, {'model_s': model_s}
+            else:
+                return x_t
+
+    def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False,
+                                            solver_type='dpm_solver'):
+        """
+        Singlestep solver DPM-Solver-2 from time `s` to time `t`.
+        Args:
+            x: A pytorch tensor. The initial value at time `s`.
+            s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+            t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+            r1: A `float`. The hyperparameter of the second-order solver.
+            model_s: A pytorch tensor. The model function evaluated at time `s`.
+                If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
+            return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time).
+            solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+                The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+        Returns:
+            x_t: A pytorch tensor. The approximated solution at time `t`.
+        """
+        if solver_type not in ['dpm_solver', 'taylor']:
+            raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
+        if r1 is None:
+            r1 = 0.5
+        ns = self.noise_schedule
+        dims = x.dim()
+        lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
+        h = lambda_t - lambda_s
+        lambda_s1 = lambda_s + r1 * h
+        s1 = ns.inverse_lambda(lambda_s1)
+        log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(
+            s1), ns.marginal_log_mean_coeff(t)
+        sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t)
+        alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
+
+        if self.predict_x0:
+            phi_11 = torch.expm1(-r1 * h)
+            phi_1 = torch.expm1(-h)
+
+            if model_s is None:
+                model_s = self.model_fn(x, s)
+            x_s1 = (
+                    expand_dims(sigma_s1 / sigma_s, dims) * x
+                    - expand_dims(alpha_s1 * phi_11, dims) * model_s
+            )
+            model_s1 = self.model_fn(x_s1, s1)
+            if solver_type == 'dpm_solver':
+                x_t = (
+                        expand_dims(sigma_t / sigma_s, dims) * x
+                        - expand_dims(alpha_t * phi_1, dims) * model_s
+                        - (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s)
+                )
+            elif solver_type == 'taylor':
+                x_t = (
+                        expand_dims(sigma_t / sigma_s, dims) * x
+                        - expand_dims(alpha_t * phi_1, dims) * model_s
+                        + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * (
+                                    model_s1 - model_s)
+                )
+        else:
+            phi_11 = torch.expm1(r1 * h)
+            phi_1 = torch.expm1(h)
+
+            if model_s is None:
+                model_s = self.model_fn(x, s)
+            x_s1 = (
+                    expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
+                    - expand_dims(sigma_s1 * phi_11, dims) * model_s
+            )
+            model_s1 = self.model_fn(x_s1, s1)
+            if solver_type == 'dpm_solver':
+                x_t = (
+                        expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+                        - expand_dims(sigma_t * phi_1, dims) * model_s
+                        - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s)
+                )
+            elif solver_type == 'taylor':
+                x_t = (
+                        expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+                        - expand_dims(sigma_t * phi_1, dims) * model_s
+                        - (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * (model_s1 - model_s)
+                )
+        if return_intermediate:
+            return x_t, {'model_s': model_s, 'model_s1': model_s1}
+        else:
+            return x_t
+
+    def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., model_s=None, model_s1=None,
+                                           return_intermediate=False, solver_type='dpm_solver'):
+        """
+        Singlestep solver DPM-Solver-3 from time `s` to time `t`.
+        Args:
+            x: A pytorch tensor. The initial value at time `s`.
+            s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+            t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+            r1: A `float`. The hyperparameter of the third-order solver.
+            r2: A `float`. The hyperparameter of the third-order solver.
+            model_s: A pytorch tensor. The model function evaluated at time `s`.
+                If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
+            model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`).
+                If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it.
+            return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
+            solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+                The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+        Returns:
+            x_t: A pytorch tensor. The approximated solution at time `t`.
+        """
+        if solver_type not in ['dpm_solver', 'taylor']:
+            raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
+        if r1 is None:
+            r1 = 1. / 3.
+        if r2 is None:
+            r2 = 2. / 3.
+        ns = self.noise_schedule
+        dims = x.dim()
+        lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
+        h = lambda_t - lambda_s
+        lambda_s1 = lambda_s + r1 * h
+        lambda_s2 = lambda_s + r2 * h
+        s1 = ns.inverse_lambda(lambda_s1)
+        s2 = ns.inverse_lambda(lambda_s2)
+        log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(
+            s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t)
+        sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(
+            s2), ns.marginal_std(t)
+        alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t)
+
+        if self.predict_x0:
+            phi_11 = torch.expm1(-r1 * h)
+            phi_12 = torch.expm1(-r2 * h)
+            phi_1 = torch.expm1(-h)
+            phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.
+            phi_2 = phi_1 / h + 1.
+            phi_3 = phi_2 / h - 0.5
+
+            if model_s is None:
+                model_s = self.model_fn(x, s)
+            if model_s1 is None:
+                x_s1 = (
+                        expand_dims(sigma_s1 / sigma_s, dims) * x
+                        - expand_dims(alpha_s1 * phi_11, dims) * model_s
+                )
+                model_s1 = self.model_fn(x_s1, s1)
+            x_s2 = (
+                    expand_dims(sigma_s2 / sigma_s, dims) * x
+                    - expand_dims(alpha_s2 * phi_12, dims) * model_s
+                    + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s)
+            )
+            model_s2 = self.model_fn(x_s2, s2)
+            if solver_type == 'dpm_solver':
+                x_t = (
+                        expand_dims(sigma_t / sigma_s, dims) * x
+                        - expand_dims(alpha_t * phi_1, dims) * model_s
+                        + (1. / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s)
+                )
+            elif solver_type == 'taylor':
+                D1_0 = (1. / r1) * (model_s1 - model_s)
+                D1_1 = (1. / r2) * (model_s2 - model_s)
+                D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
+                D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
+                x_t = (
+                        expand_dims(sigma_t / sigma_s, dims) * x
+                        - expand_dims(alpha_t * phi_1, dims) * model_s
+                        + expand_dims(alpha_t * phi_2, dims) * D1
+                        - expand_dims(alpha_t * phi_3, dims) * D2
+                )
+        else:
+            phi_11 = torch.expm1(r1 * h)
+            phi_12 = torch.expm1(r2 * h)
+            phi_1 = torch.expm1(h)
+            phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.
+            phi_2 = phi_1 / h - 1.
+            phi_3 = phi_2 / h - 0.5
+
+            if model_s is None:
+                model_s = self.model_fn(x, s)
+            if model_s1 is None:
+                x_s1 = (
+                        expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
+                        - expand_dims(sigma_s1 * phi_11, dims) * model_s
+                )
+                model_s1 = self.model_fn(x_s1, s1)
+            x_s2 = (
+                    expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x
+                    - expand_dims(sigma_s2 * phi_12, dims) * model_s
+                    - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s)
+            )
+            model_s2 = self.model_fn(x_s2, s2)
+            if solver_type == 'dpm_solver':
+                x_t = (
+                        expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+                        - expand_dims(sigma_t * phi_1, dims) * model_s
+                        - (1. / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s)
+                )
+            elif solver_type == 'taylor':
+                D1_0 = (1. / r1) * (model_s1 - model_s)
+                D1_1 = (1. / r2) * (model_s2 - model_s)
+                D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
+                D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
+                x_t = (
+                        expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+                        - expand_dims(sigma_t * phi_1, dims) * model_s
+                        - expand_dims(sigma_t * phi_2, dims) * D1
+                        - expand_dims(sigma_t * phi_3, dims) * D2
+                )
+
+        if return_intermediate:
+            return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2}
+        else:
+            return x_t
+
+    def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"):
+        """
+        Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
+        Args:
+            x: A pytorch tensor. The initial value at time `s`.
+            model_prev_list: A list of pytorch tensor. The previous computed model values.
+            t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
+            t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+            solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+                The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+        Returns:
+            x_t: A pytorch tensor. The approximated solution at time `t`.
+        """
+        if solver_type not in ['dpm_solver', 'taylor']:
+            raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
+        ns = self.noise_schedule
+        dims = x.dim()
+        model_prev_1, model_prev_0 = model_prev_list
+        t_prev_1, t_prev_0 = t_prev_list
+        lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(
+            t_prev_0), ns.marginal_lambda(t)
+        log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
+        sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
+        alpha_t = torch.exp(log_alpha_t)
+
+        h_0 = lambda_prev_0 - lambda_prev_1
+        h = lambda_t - lambda_prev_0
+        r0 = h_0 / h
+        D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
+        if self.predict_x0:
+            if solver_type == 'dpm_solver':
+                x_t = (
+                        expand_dims(sigma_t / sigma_prev_0, dims) * x
+                        - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
+                        - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * D1_0
+                )
+            elif solver_type == 'taylor':
+                x_t = (
+                        expand_dims(sigma_t / sigma_prev_0, dims) * x
+                        - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
+                        + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0
+                )
+        else:
+            if solver_type == 'dpm_solver':
+                x_t = (
+                        expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
+                        - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
+                        - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * D1_0
+                )
+            elif solver_type == 'taylor':
+                x_t = (
+                        expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
+                        - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
+                        - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1_0
+                )
+        return x_t
+
+    def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpm_solver'):
+        """
+        Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
+        Args:
+            x: A pytorch tensor. The initial value at time `s`.
+            model_prev_list: A list of pytorch tensor. The previous computed model values.
+            t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
+            t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+            solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+                The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+        Returns:
+            x_t: A pytorch tensor. The approximated solution at time `t`.
+        """
+        ns = self.noise_schedule
+        dims = x.dim()
+        model_prev_2, model_prev_1, model_prev_0 = model_prev_list
+        t_prev_2, t_prev_1, t_prev_0 = t_prev_list
+        lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(
+            t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
+        log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
+        sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
+        alpha_t = torch.exp(log_alpha_t)
+
+        h_1 = lambda_prev_1 - lambda_prev_2
+        h_0 = lambda_prev_0 - lambda_prev_1
+        h = lambda_t - lambda_prev_0
+        r0, r1 = h_0 / h, h_1 / h
+        D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
+        D1_1 = expand_dims(1. / r1, dims) * (model_prev_1 - model_prev_2)
+        D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1)
+        D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1)
+        if self.predict_x0:
+            x_t = (
+                    expand_dims(sigma_t / sigma_prev_0, dims) * x
+                    - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
+                    + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1
+                    - expand_dims(alpha_t * ((torch.exp(-h) - 1. + h) / h ** 2 - 0.5), dims) * D2
+            )
+        else:
+            x_t = (
+                    expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
+                    - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
+                    - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1
+                    - expand_dims(sigma_t * ((torch.exp(h) - 1. - h) / h ** 2 - 0.5), dims) * D2
+            )
+        return x_t
+
+    def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None,
+                                     r2=None):
+        """
+        Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
+        Args:
+            x: A pytorch tensor. The initial value at time `s`.
+            s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+            t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+            order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
+            return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
+            solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+                The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+            r1: A `float`. The hyperparameter of the second-order or third-order solver.
+            r2: A `float`. The hyperparameter of the third-order solver.
+        Returns:
+            x_t: A pytorch tensor. The approximated solution at time `t`.
+        """
+        if order == 1:
+            return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate)
+        elif order == 2:
+            return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate,
+                                                            solver_type=solver_type, r1=r1)
+        elif order == 3:
+            return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate,
+                                                           solver_type=solver_type, r1=r1, r2=r2)
+        else:
+            raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
+
+    def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpm_solver'):
+        """
+        Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
+        Args:
+            x: A pytorch tensor. The initial value at time `s`.
+            model_prev_list: A list of pytorch tensor. The previous computed model values.
+            t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
+            t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+            order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
+            solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+                The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+        Returns:
+            x_t: A pytorch tensor. The approximated solution at time `t`.
+        """
+        if order == 1:
+            return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1])
+        elif order == 2:
+            return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
+        elif order == 3:
+            return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
+        else:
+            raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
+
+    def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5,
+                            solver_type='dpm_solver'):
+        """
+        The adaptive step size solver based on singlestep DPM-Solver.
+        Args:
+            x: A pytorch tensor. The initial value at time `t_T`.
+            order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
+            t_T: A `float`. The starting time of the sampling (default is T).
+            t_0: A `float`. The ending time of the sampling (default is epsilon).
+            h_init: A `float`. The initial step size (for logSNR).
+            atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].
+            rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
+            theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
+            t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
+                current time and `t_0` is less than `t_err`. The default setting is 1e-5.
+            solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+                The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+        Returns:
+            x_0: A pytorch tensor. The approximated solution at time `t_0`.
+        [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
+        """
+        ns = self.noise_schedule
+        s = t_T * torch.ones((x.shape[0],)).to(x)
+        lambda_s = ns.marginal_lambda(s)
+        lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
+        h = h_init * torch.ones_like(s).to(x)
+        x_prev = x
+        nfe = 0
+        if order == 2:
+            r1 = 0.5
+            lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True)
+            higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
+                                                                                               solver_type=solver_type,
+                                                                                               **kwargs)
+        elif order == 3:
+            r1, r2 = 1. / 3., 2. / 3.
+            lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
+                                                                                    return_intermediate=True,
+                                                                                    solver_type=solver_type)
+            higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2,
+                                                                                              solver_type=solver_type,
+                                                                                              **kwargs)
+        else:
+            raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order))
+        while torch.abs((s - t_0)).mean() > t_err:
+            t = ns.inverse_lambda(lambda_s + h)
+            x_lower, lower_noise_kwargs = lower_update(x, s, t)
+            x_higher = higher_update(x, s, t, **lower_noise_kwargs)
+            delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)))
+            norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
+            E = norm_fn((x_higher - x_lower) / delta).max()
+            if torch.all(E <= 1.):
+                x = x_higher
+                s = t
+                x_prev = x_lower
+                lambda_s = ns.marginal_lambda(s)
+            h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s)
+            nfe += order
+        print('adaptive solver nfe', nfe)
+        return x
+
+    def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform',
+               method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
+               atol=0.0078, rtol=0.05,
+               ):
+        """
+        Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
+        =====================================================
+        We support the following algorithms for both noise prediction model and data prediction model:
+            - 'singlestep':
+                Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver.
+                We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps).
+                The total number of function evaluations (NFE) == `steps`.
+                Given a fixed NFE == `steps`, the sampling procedure is:
+                    - If `order` == 1:
+                        - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM).
+                    - If `order` == 2:
+                        - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling.
+                        - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2.
+                        - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
+                    - If `order` == 3:
+                        - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
+                        - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
+                        - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1.
+                        - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2.
+            - 'multistep':
+                Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`.
+                We initialize the first `order` values by lower order multistep solvers.
+                Given a fixed NFE == `steps`, the sampling procedure is:
+                    Denote K = steps.
+                    - If `order` == 1:
+                        - We use K steps of DPM-Solver-1 (i.e. DDIM).
+                    - If `order` == 2:
+                        - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2.
+                    - If `order` == 3:
+                        - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3.
+            - 'singlestep_fixed':
+                Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3).
+                We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE.
+            - 'adaptive':
+                Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper).
+                We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.
+                You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs
+                (NFE) and the sample quality.
+                    - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2.
+                    - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3.
+        =====================================================
+        Some advices for choosing the algorithm:
+            - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs:
+                Use singlestep DPM-Solver ("DPM-Solver-fast" in the paper) with `order = 3`.
+                e.g.
+                    >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False)
+                    >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
+                            skip_type='time_uniform', method='singlestep')
+            - For **guided sampling with large guidance scale** by DPMs:
+                Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`.
+                e.g.
+                    >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True)
+                    >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2,
+                            skip_type='time_uniform', method='multistep')
+        We support three types of `skip_type`:
+            - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images**
+            - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**.
+            - 'time_quadratic': quadratic time for the time steps.
+        =====================================================
+        Args:
+            x: A pytorch tensor. The initial value at time `t_start`
+                e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution.
+            steps: A `int`. The total number of function evaluations (NFE).
+            t_start: A `float`. The starting time of the sampling.
+                If `T` is None, we use self.noise_schedule.T (default is 1.0).
+            t_end: A `float`. The ending time of the sampling.
+                If `t_end` is None, we use 1. / self.noise_schedule.total_N.
+                e.g. if total_N == 1000, we have `t_end` == 1e-3.
+                For discrete-time DPMs:
+                    - We recommend `t_end` == 1. / self.noise_schedule.total_N.
+                For continuous-time DPMs:
+                    - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15.
+            order: A `int`. The order of DPM-Solver.
+            skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
+            method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
+            denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step.
+                Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1).
+                This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and
+                score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID
+                for diffusion models sampling by diffusion SDEs for low-resolutional images
+                (such as CIFAR-10). However, we observed that such trick does not matter for
+                high-resolutional images. As it needs an additional NFE, we do not recommend
+                it for high-resolutional images.
+            lower_order_final: A `bool`. Whether to use lower order solvers at the final steps.
+                Only valid for `method=multistep` and `steps < 15`. We empirically find that
+                this trick is a key to stabilizing the sampling by DPM-Solver with very few steps
+                (especially for steps <= 10). So we recommend to set it to be `True`.
+            solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`.
+            atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
+            rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
+        Returns:
+            x_end: A pytorch tensor. The approximated solution at time `t_end`.
+        """
+        t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
+        t_T = self.noise_schedule.T if t_start is None else t_start
+        device = x.device
+        if method == 'adaptive':
+            with torch.no_grad():
+                x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol,
+                                             solver_type=solver_type)
+        elif method == 'multistep':
+            assert steps >= order
+            timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
+            assert timesteps.shape[0] - 1 == steps
+            with torch.no_grad():
+                vec_t = timesteps[0].expand((x.shape[0]))
+                model_prev_list = [self.model_fn(x, vec_t)]
+                t_prev_list = [vec_t]
+                # Init the first `order` values by lower order multistep DPM-Solver.
+                for init_order in tqdm(range(1, order), desc="DPM init order"):
+                    vec_t = timesteps[init_order].expand(x.shape[0])
+                    x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order,
+                                                         solver_type=solver_type)
+                    model_prev_list.append(self.model_fn(x, vec_t))
+                    t_prev_list.append(vec_t)
+                # Compute the remaining values by `order`-th order multistep DPM-Solver.
+                for step in tqdm(range(order, steps + 1), desc="DPM multistep"):
+                    vec_t = timesteps[step].expand(x.shape[0])
+                    if lower_order_final and steps < 15:
+                        step_order = min(order, steps + 1 - step)
+                    else:
+                        step_order = order
+                    x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order,
+                                                         solver_type=solver_type)
+                    for i in range(order - 1):
+                        t_prev_list[i] = t_prev_list[i + 1]
+                        model_prev_list[i] = model_prev_list[i + 1]
+                    t_prev_list[-1] = vec_t
+                    # We do not need to evaluate the final model value.
+                    if step < steps:
+                        model_prev_list[-1] = self.model_fn(x, vec_t)
+        elif method in ['singlestep', 'singlestep_fixed']:
+            if method == 'singlestep':
+                timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order,
+                                                                                              skip_type=skip_type,
+                                                                                              t_T=t_T, t_0=t_0,
+                                                                                              device=device)
+            elif method == 'singlestep_fixed':
+                K = steps // order
+                orders = [order, ] * K
+                timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device)
+            for i, order in enumerate(orders):
+                t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1]
+                timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(),
+                                                      N=order, device=device)
+                lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner)
+                vec_s, vec_t = t_T_inner.tile(x.shape[0]), t_0_inner.tile(x.shape[0])
+                h = lambda_inner[-1] - lambda_inner[0]
+                r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
+                r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
+                x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2)
+        if denoise_to_zero:
+            x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
+        return x
+
+
+#############################################################
+# other utility functions
+#############################################################
+
+def interpolate_fn(x, xp, yp):
+    """
+    A piecewise linear function y = f(x), using xp and yp as keypoints.
+    We implement f(x) in a differentiable way (i.e. applicable for autograd).
+    The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
+    Args:
+        x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
+        xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
+        yp: PyTorch tensor with shape [C, K].
+    Returns:
+        The function values f(x), with shape [N, C].
+    """
+    N, K = x.shape[0], xp.shape[1]
+    all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
+    sorted_all_x, x_indices = torch.sort(all_x, dim=2)
+    x_idx = torch.argmin(x_indices, dim=2)
+    cand_start_idx = x_idx - 1
+    start_idx = torch.where(
+        torch.eq(x_idx, 0),
+        torch.tensor(1, device=x.device),
+        torch.where(
+            torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
+        ),
+    )
+    end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
+    start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
+    end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
+    start_idx2 = torch.where(
+        torch.eq(x_idx, 0),
+        torch.tensor(0, device=x.device),
+        torch.where(
+            torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
+        ),
+    )
+    y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
+    start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
+    end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
+    cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
+    return cand
+
+
+def expand_dims(v, dims):
+    """
+    Expand the tensor `v` to the dim `dims`.
+    Args:
+        `v`: a PyTorch tensor with shape [N].
+        `dim`: a `int`.
+    Returns:
+        a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
+    """
+    return v[(...,) + (None,) * (dims - 1)]
\ No newline at end of file
diff --git a/ldm/models/diffusion/dpm_solver/sampler.py b/ldm/models/diffusion/dpm_solver/sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d137b8cf36718c1c58faa09f9dd919e5fb2977b
--- /dev/null
+++ b/ldm/models/diffusion/dpm_solver/sampler.py
@@ -0,0 +1,87 @@
+"""SAMPLING ONLY."""
+import torch
+
+from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
+
+
+MODEL_TYPES = {
+    "eps": "noise",
+    "v": "v"
+}
+
+
+class DPMSolverSampler(object):
+    def __init__(self, model, **kwargs):
+        super().__init__()
+        self.model = model
+        to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
+        self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
+
+    def register_buffer(self, name, attr):
+        if type(attr) == torch.Tensor:
+            if attr.device != torch.device("cuda"):
+                attr = attr.to(torch.device("cuda"))
+        setattr(self, name, attr)
+
+    @torch.no_grad()
+    def sample(self,
+               S,
+               batch_size,
+               shape,
+               conditioning=None,
+               callback=None,
+               normals_sequence=None,
+               img_callback=None,
+               quantize_x0=False,
+               eta=0.,
+               mask=None,
+               x0=None,
+               temperature=1.,
+               noise_dropout=0.,
+               score_corrector=None,
+               corrector_kwargs=None,
+               verbose=True,
+               x_T=None,
+               log_every_t=100,
+               unconditional_guidance_scale=1.,
+               unconditional_conditioning=None,
+               # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+               **kwargs
+               ):
+        if conditioning is not None:
+            if isinstance(conditioning, dict):
+                cbs = conditioning[list(conditioning.keys())[0]].shape[0]
+                if cbs != batch_size:
+                    print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+            else:
+                if conditioning.shape[0] != batch_size:
+                    print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+
+        # sampling
+        C, H, W = shape
+        size = (batch_size, C, H, W)
+
+        print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}')
+
+        device = self.model.betas.device
+        if x_T is None:
+            img = torch.randn(size, device=device)
+        else:
+            img = x_T
+
+        ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
+
+        model_fn = model_wrapper(
+            lambda x, t, c: self.model.apply_model(x, t, c),
+            ns,
+            model_type=MODEL_TYPES[self.model.parameterization],
+            guidance_type="classifier-free",
+            condition=conditioning,
+            unconditional_condition=unconditional_conditioning,
+            guidance_scale=unconditional_guidance_scale,
+        )
+
+        dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
+        x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True)
+
+        return x.to(device), None
\ No newline at end of file
diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py
new file mode 100644
index 0000000000000000000000000000000000000000..7002a365d27168ced0a04e9a4d83e088f8284eae
--- /dev/null
+++ b/ldm/models/diffusion/plms.py
@@ -0,0 +1,244 @@
+"""SAMPLING ONLY."""
+
+import torch
+import numpy as np
+from tqdm import tqdm
+from functools import partial
+
+from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
+from ldm.models.diffusion.sampling_util import norm_thresholding
+
+
+class PLMSSampler(object):
+    def __init__(self, model, schedule="linear", **kwargs):
+        super().__init__()
+        self.model = model
+        self.ddpm_num_timesteps = model.num_timesteps
+        self.schedule = schedule
+
+    def register_buffer(self, name, attr):
+        if type(attr) == torch.Tensor:
+            if attr.device != torch.device("cuda"):
+                attr = attr.to(torch.device("cuda"))
+        setattr(self, name, attr)
+
+    def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
+        if ddim_eta != 0:
+            raise ValueError('ddim_eta must be 0 for PLMS')
+        self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
+                                                  num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
+        alphas_cumprod = self.model.alphas_cumprod
+        assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
+        to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
+
+        self.register_buffer('betas', to_torch(self.model.betas))
+        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+        self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
+
+        # calculations for diffusion q(x_t | x_{t-1}) and others
+        self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
+        self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
+        self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
+        self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
+        self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
+
+        # ddim sampling parameters
+        ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
+                                                                                   ddim_timesteps=self.ddim_timesteps,
+                                                                                   eta=ddim_eta,verbose=verbose)
+        self.register_buffer('ddim_sigmas', ddim_sigmas)
+        self.register_buffer('ddim_alphas', ddim_alphas)
+        self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
+        self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
+        sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
+            (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
+                        1 - self.alphas_cumprod / self.alphas_cumprod_prev))
+        self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
+
+    @torch.no_grad()
+    def sample(self,
+               S,
+               batch_size,
+               shape,
+               conditioning=None,
+               callback=None,
+               normals_sequence=None,
+               img_callback=None,
+               quantize_x0=False,
+               eta=0.,
+               mask=None,
+               x0=None,
+               temperature=1.,
+               noise_dropout=0.,
+               score_corrector=None,
+               corrector_kwargs=None,
+               verbose=True,
+               x_T=None,
+               log_every_t=100,
+               unconditional_guidance_scale=1.,
+               unconditional_conditioning=None,
+               # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+               dynamic_threshold=None,
+               **kwargs
+               ):
+        if conditioning is not None:
+            if isinstance(conditioning, dict):
+                cbs = conditioning[list(conditioning.keys())[0]].shape[0]
+                if cbs != batch_size:
+                    print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+            else:
+                if conditioning.shape[0] != batch_size:
+                    print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+
+        self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
+        # sampling
+        C, H, W = shape
+        size = (batch_size, C, H, W)
+        print(f'Data shape for PLMS sampling is {size}')
+
+        samples, intermediates = self.plms_sampling(conditioning, size,
+                                                    callback=callback,
+                                                    img_callback=img_callback,
+                                                    quantize_denoised=quantize_x0,
+                                                    mask=mask, x0=x0,
+                                                    ddim_use_original_steps=False,
+                                                    noise_dropout=noise_dropout,
+                                                    temperature=temperature,
+                                                    score_corrector=score_corrector,
+                                                    corrector_kwargs=corrector_kwargs,
+                                                    x_T=x_T,
+                                                    log_every_t=log_every_t,
+                                                    unconditional_guidance_scale=unconditional_guidance_scale,
+                                                    unconditional_conditioning=unconditional_conditioning,
+                                                    dynamic_threshold=dynamic_threshold,
+                                                    )
+        return samples, intermediates
+
+    @torch.no_grad()
+    def plms_sampling(self, cond, shape,
+                      x_T=None, ddim_use_original_steps=False,
+                      callback=None, timesteps=None, quantize_denoised=False,
+                      mask=None, x0=None, img_callback=None, log_every_t=100,
+                      temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+                      unconditional_guidance_scale=1., unconditional_conditioning=None,
+                      dynamic_threshold=None):
+        device = self.model.betas.device
+        b = shape[0]
+        if x_T is None:
+            img = torch.randn(shape, device=device)
+        else:
+            img = x_T
+
+        if timesteps is None:
+            timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
+        elif timesteps is not None and not ddim_use_original_steps:
+            subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
+            timesteps = self.ddim_timesteps[:subset_end]
+
+        intermediates = {'x_inter': [img], 'pred_x0': [img]}
+        time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
+        total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
+        print(f"Running PLMS Sampling with {total_steps} timesteps")
+
+        iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
+        old_eps = []
+
+        for i, step in enumerate(iterator):
+            index = total_steps - i - 1
+            ts = torch.full((b,), step, device=device, dtype=torch.long)
+            ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
+
+            if mask is not None:
+                assert x0 is not None
+                img_orig = self.model.q_sample(x0, ts)  # TODO: deterministic forward pass?
+                img = img_orig * mask + (1. - mask) * img
+
+            outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
+                                      quantize_denoised=quantize_denoised, temperature=temperature,
+                                      noise_dropout=noise_dropout, score_corrector=score_corrector,
+                                      corrector_kwargs=corrector_kwargs,
+                                      unconditional_guidance_scale=unconditional_guidance_scale,
+                                      unconditional_conditioning=unconditional_conditioning,
+                                      old_eps=old_eps, t_next=ts_next,
+                                      dynamic_threshold=dynamic_threshold)
+            img, pred_x0, e_t = outs
+            old_eps.append(e_t)
+            if len(old_eps) >= 4:
+                old_eps.pop(0)
+            if callback: callback(i)
+            if img_callback: img_callback(pred_x0, i)
+
+            if index % log_every_t == 0 or index == total_steps - 1:
+                intermediates['x_inter'].append(img)
+                intermediates['pred_x0'].append(pred_x0)
+
+        return img, intermediates
+
+    @torch.no_grad()
+    def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
+                      temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+                      unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,
+                      dynamic_threshold=None):
+        b, *_, device = *x.shape, x.device
+
+        def get_model_output(x, t):
+            if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
+                e_t = self.model.apply_model(x, t, c)
+            else:
+                x_in = torch.cat([x] * 2)
+                t_in = torch.cat([t] * 2)
+                c_in = torch.cat([unconditional_conditioning, c])
+                e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
+                e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
+
+            if score_corrector is not None:
+                assert self.model.parameterization == "eps"
+                e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
+
+            return e_t
+
+        alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+        alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
+        sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
+        sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
+
+        def get_x_prev_and_pred_x0(e_t, index):
+            # select parameters corresponding to the currently considered timestep
+            a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+            a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+            sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+            sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
+
+            # current prediction for x_0
+            pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+            if quantize_denoised:
+                pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+            if dynamic_threshold is not None:
+                pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
+            # direction pointing to x_t
+            dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+            noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+            if noise_dropout > 0.:
+                noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+            x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+            return x_prev, pred_x0
+
+        e_t = get_model_output(x, t)
+        if len(old_eps) == 0:
+            # Pseudo Improved Euler (2nd order)
+            x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
+            e_t_next = get_model_output(x_prev, t_next)
+            e_t_prime = (e_t + e_t_next) / 2
+        elif len(old_eps) == 1:
+            # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
+            e_t_prime = (3 * e_t - old_eps[-1]) / 2
+        elif len(old_eps) == 2:
+            # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
+            e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
+        elif len(old_eps) >= 3:
+            # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
+            e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
+
+        x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
+
+        return x_prev, pred_x0, e_t
diff --git a/ldm/models/diffusion/sampling_util.py b/ldm/models/diffusion/sampling_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..7eff02be6d7c54d43ee6680636ac0698dd3b3f33
--- /dev/null
+++ b/ldm/models/diffusion/sampling_util.py
@@ -0,0 +1,22 @@
+import torch
+import numpy as np
+
+
+def append_dims(x, target_dims):
+    """Appends dimensions to the end of a tensor until it has target_dims dimensions.
+    From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
+    dims_to_append = target_dims - x.ndim
+    if dims_to_append < 0:
+        raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
+    return x[(...,) + (None,) * dims_to_append]
+
+
+def norm_thresholding(x0, value):
+    s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
+    return x0 * (value / s)
+
+
+def spatial_norm_thresholding(x0, value):
+    # b c h w
+    s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
+    return x0 * (value / s)
\ No newline at end of file
diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..c2943e1269675c2ed745c3920d6301f2c7f6a01d
--- /dev/null
+++ b/ldm/modules/attention.py
@@ -0,0 +1,408 @@
+from inspect import isfunction
+import math
+import torch
+import torch.nn.functional as F
+from torch import nn, einsum
+from einops import rearrange, repeat
+from typing import Optional, Any
+
+from ldm.modules.diffusionmodules.util import checkpoint
+
+
+try:
+    import xformers
+    import xformers.ops
+    XFORMERS_IS_AVAILBLE = True
+except:
+    XFORMERS_IS_AVAILBLE = False
+
+# CrossAttn precision handling
+import os
+_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
+
+def exists(val):
+    return val is not None
+
+
+def uniq(arr):
+    return{el: True for el in arr}.keys()
+
+
+def default(val, d):
+    if exists(val):
+        return val
+    return d() if isfunction(d) else d
+
+
+def max_neg_value(t):
+    return -torch.finfo(t.dtype).max
+
+
+def init_(tensor):
+    dim = tensor.shape[-1]
+    std = 1 / math.sqrt(dim)
+    tensor.uniform_(-std, std)
+    return tensor
+
+
+# feedforward
+class GEGLU(nn.Module):
+    def __init__(self, dim_in, dim_out):
+        super().__init__()
+        self.proj = nn.Linear(dim_in, dim_out * 2)
+
+    def forward(self, x):
+        x, gate = self.proj(x).chunk(2, dim=-1)
+        return x * F.gelu(gate)
+
+
+class FeedForward(nn.Module):
+    def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
+        super().__init__()
+        inner_dim = int(dim * mult)
+        dim_out = default(dim_out, dim)
+        project_in = nn.Sequential(
+            nn.Linear(dim, inner_dim),
+            nn.GELU()
+        ) if not glu else GEGLU(dim, inner_dim)
+
+        self.net = nn.Sequential(
+            project_in,
+            nn.Dropout(dropout),
+            nn.Linear(inner_dim, dim_out)
+        )
+
+    def forward(self, x):
+        return self.net(x)
+
+
+def zero_module(module):
+    """
+    Zero out the parameters of a module and return it.
+    """
+    for p in module.parameters():
+        p.detach().zero_()
+    return module
+
+
+def Normalize(in_channels):
+    return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class SpatialSelfAttention(nn.Module):
+    def __init__(self, in_channels):
+        super().__init__()
+        self.in_channels = in_channels
+
+        self.norm = Normalize(in_channels)
+        self.q = torch.nn.Conv2d(in_channels,
+                                 in_channels,
+                                 kernel_size=1,
+                                 stride=1,
+                                 padding=0)
+        self.k = torch.nn.Conv2d(in_channels,
+                                 in_channels,
+                                 kernel_size=1,
+                                 stride=1,
+                                 padding=0)
+        self.v = torch.nn.Conv2d(in_channels,
+                                 in_channels,
+                                 kernel_size=1,
+                                 stride=1,
+                                 padding=0)
+        self.proj_out = torch.nn.Conv2d(in_channels,
+                                        in_channels,
+                                        kernel_size=1,
+                                        stride=1,
+                                        padding=0)
+
+    def forward(self, x):
+        h_ = x
+        h_ = self.norm(h_)
+        q = self.q(h_)
+        k = self.k(h_)
+        v = self.v(h_)
+
+        # compute attention
+        b,c,h,w = q.shape
+        q = rearrange(q, 'b c h w -> b (h w) c')
+        k = rearrange(k, 'b c h w -> b c (h w)')
+        w_ = torch.einsum('bij,bjk->bik', q, k)
+
+        w_ = w_ * (int(c)**(-0.5))
+        w_ = torch.nn.functional.softmax(w_, dim=2)
+
+        # attend to values
+        v = rearrange(v, 'b c h w -> b c (h w)')
+        w_ = rearrange(w_, 'b i j -> b j i')
+        h_ = torch.einsum('bij,bjk->bik', v, w_)
+        h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
+        h_ = self.proj_out(h_)
+
+        return x+h_
+
+
+class CrossAttention(nn.Module):
+    def __init__(self, query_dim, context_dim=None, content_dim=0, color_dim=0, heads=8, dim_head=64, dropout=0.):
+        super().__init__()
+        inner_dim = dim_head * heads
+        context_dim = default(context_dim, query_dim)
+
+        self.scale = dim_head ** -0.5
+        self.heads = heads
+
+        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+        self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+        self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+        self.content_dim = content_dim
+        if self.content_dim > 0:
+            self.to_k_extra = nn.Linear(content_dim, inner_dim, bias=False)
+            self.to_v_extra = nn.Linear(content_dim, inner_dim, bias=False)
+
+        self.color_dim = color_dim
+        if self.color_dim > 0:
+            self.to_k_color = nn.Linear(color_dim, inner_dim, bias=False)
+            self.to_v_color = nn.Linear(color_dim, inner_dim, bias=False)
+
+        self.to_out = nn.Sequential(
+            nn.Linear(inner_dim, query_dim),
+            nn.Dropout(dropout)
+        )
+
+    def _forward(self, q, context, mask, content=False, color=False):
+        h = self.heads
+        if content:
+            k = self.to_k_extra(context)
+            v = self.to_v_extra(context)
+        elif color:
+            k = self.to_k_color(context)
+            v = self.to_v_color(context)
+        else:
+            k = self.to_k(context)
+            v = self.to_v(context)
+
+        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+
+        # force cast to fp32 to avoid overflowing
+        if _ATTN_PRECISION =="fp32":
+            with torch.autocast(enabled=False, device_type = 'cuda'):
+                q, k = q.float(), k.float()
+                sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
+        else:
+            sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
+        
+        del q, k
+    
+        if exists(mask):
+            mask = rearrange(mask, 'b ... -> b (...)')
+            max_neg_value = -torch.finfo(sim.dtype).max
+            mask = repeat(mask, 'b j -> (b h) () j', h=h)
+            sim.masked_fill_(~mask, max_neg_value)
+
+        # attention, what we cannot get enough of
+        sim = sim.softmax(dim=-1)
+
+        out = einsum('b i j, b j d -> b i d', sim, v)
+        out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
+        return out
+
+    def forward(self, x, context=None, mask=None, content_control=None, content_mask=None, content_w=1.0, 
+            color_control=None, color_mask=None, color_w=1.0):
+
+        h = self.heads
+
+        q = self.to_q(x)
+        context = default(context, x)
+        out = self._forward(q, context, mask)
+
+        if self.content_dim>0 and content_control is not None:
+            out_content = self._forward(q, content_control, content_mask, content=True)
+            out = out + content_w * out_content
+
+        if self.color_dim>0 and color_control is not None:
+            out_color = self._forward(q, color_control, color_mask, color=True)
+            out = out + color_w * out_color
+        return self.to_out(out)
+
+
+class MemoryEfficientCrossAttention(nn.Module):
+    # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
+    def __init__(self, query_dim, context_dim=None, content_dim=0, color_dim=0, heads=8, dim_head=64, dropout=0.0):
+        super().__init__()
+        print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
+              f"{heads} heads.")
+        inner_dim = dim_head * heads
+        context_dim = default(context_dim, query_dim)
+
+        self.heads = heads
+        self.dim_head = dim_head
+
+        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+        self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+        self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+        self.content_dim = content_dim
+        if self.content_dim > 0:
+            self.to_k_extra = nn.Linear(content_dim, inner_dim, bias=False)
+            self.to_v_extra = nn.Linear(content_dim, inner_dim, bias=False)
+
+        self.color_dim = color_dim
+        if self.color_dim > 0:
+            self.to_k_color = nn.Linear(color_dim, inner_dim, bias=False)
+            self.to_v_color = nn.Linear(color_dim, inner_dim, bias=False)
+
+        self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
+        self.attention_op: Optional[Any] = None
+
+    def _forward(self, q, context, mask, content=False, color=False):
+        if content:
+            k = self.to_k_extra(context)
+            v = self.to_v_extra(context)
+        elif color:
+            k = self.to_k_color(context)
+            v = self.to_v_color(context)
+        else:
+            k = self.to_k(context)
+            v = self.to_v(context)
+
+        b, _, _ = q.shape
+        q, k, v = map(
+            lambda t: t.unsqueeze(3)
+            .reshape(b, t.shape[1], self.heads, self.dim_head)
+            .permute(0, 2, 1, 3)
+            .reshape(b * self.heads, t.shape[1], self.dim_head)
+            .contiguous(),
+            (q, k, v),
+        )
+
+        # actually compute the attention, what we cannot get enough of
+        out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
+
+        if exists(mask):
+            raise NotImplementedError
+        out = (
+            out.unsqueeze(0)
+            .reshape(b, self.heads, out.shape[1], self.dim_head)
+            .permute(0, 2, 1, 3)
+            .reshape(b, out.shape[1], self.heads * self.dim_head)
+        )
+
+        return out
+
+    def forward(self, x, context=None, mask=None, content_control=None, content_mask=None, content_w=1.0, 
+            color_control=None, color_mask=None, color_w=1.0):
+        q = self.to_q(x)
+        context = default(context, x)
+        out = self._forward(q, context, mask)
+        if self.content_dim>0 and content_control is not None:
+            out_content = self._forward(q, content_control, content_mask, content=True)
+            out = out + content_w * out_content
+        if self.color_dim>0 and color_control is not None:
+            out_color = self._forward(q, color_control, color_mask, color=True)
+            out = out + color_w * out_color
+        return self.to_out(out)
+
+
+class BasicTransformerBlock(nn.Module):
+    ATTENTION_MODES = {
+        "softmax": CrossAttention,  # vanilla attention
+        "softmax-xformers": MemoryEfficientCrossAttention
+    }
+    def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, content_dim=0, color_dim=0, gated_ff=True, checkpoint=True,
+                 disable_self_attn=False):
+        super().__init__()
+        attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
+        assert attn_mode in self.ATTENTION_MODES
+        attn_cls = self.ATTENTION_MODES[attn_mode]
+        self.disable_self_attn = disable_self_attn
+        self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
+                              context_dim=context_dim if self.disable_self_attn else None, 
+                              content_dim=content_dim if self.disable_self_attn else 0, 
+                              color_dim=color_dim if self.disable_self_attn else 0)  # is a self-attention if not self.disable_self_attn
+        self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+        self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim, content_dim=content_dim, color_dim=color_dim,
+                              heads=n_heads, dim_head=d_head, dropout=dropout)  # is self-attn if context is none
+        self.norm1 = nn.LayerNorm(dim)
+        self.norm2 = nn.LayerNorm(dim)
+        self.norm3 = nn.LayerNorm(dim)
+        self.checkpoint = checkpoint
+
+    def forward(self, x, context=None, content_control=None, content_w=1.0, color_control=None, color_w=1.0):
+        return checkpoint(self._forward, (x, context, content_control, content_w, color_control, color_w), self.parameters(), self.checkpoint)
+
+    def _forward(self, x, context=None, content_control=None, content_w=1.0, color_control=None, color_w=1.0):
+        x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None, 
+                                      content_control=content_control if self.disable_self_attn else None, content_w=content_w, 
+                                      color_control=color_control if self.disable_self_attn else None, color_w=color_w) + x
+        x = self.attn2(self.norm2(x), context=context, content_control=content_control, content_w=content_w, color_control=color_control, color_w=color_w) + x
+        x = self.ff(self.norm3(x)) + x
+        return x
+
+
+class SpatialTransformer(nn.Module):
+    """
+    Transformer block for image-like data.
+    First, project the input (aka embedding)
+    and reshape to b, t, d.
+    Then apply standard transformer action.
+    Finally, reshape to image
+    NEW: use_linear for more efficiency instead of the 1x1 convs
+    """
+    def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0., context_dim=None, content_dim=0, color_dim=0,
+                 disable_self_attn=False, use_linear=False, use_checkpoint=True):
+        super().__init__()
+        if exists(context_dim) and not isinstance(context_dim, list):
+            context_dim = [context_dim]
+        self.in_channels = in_channels
+        inner_dim = n_heads * d_head
+        self.norm = Normalize(in_channels)
+        if not use_linear:
+            self.proj_in = nn.Conv2d(in_channels,
+                                     inner_dim,
+                                     kernel_size=1,
+                                     stride=1,
+                                     padding=0)
+        else:
+            self.proj_in = nn.Linear(in_channels, inner_dim)
+
+        self.transformer_blocks = nn.ModuleList(
+            [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], 
+                                   content_dim=content_dim, color_dim=color_dim,
+                                   disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
+                for d in range(depth)]
+        )
+        if not use_linear:
+            self.proj_out = zero_module(nn.Conv2d(inner_dim,
+                                                  in_channels,
+                                                  kernel_size=1,
+                                                  stride=1,
+                                                  padding=0))
+        else:
+            self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
+        self.use_linear = use_linear
+
+    def forward(self, x, context=None, content_control=None, content_w=1.0, color_control=None, color_w=1.0):
+        # note: if no context is given, cross-attention defaults to self-attention
+        if not isinstance(context, list):
+            context = [context]
+        if not isinstance(content_control, list):
+            content_control = [content_control]
+        if not isinstance(color_control, list):
+            color_control = [color_control]
+        b, c, h, w = x.shape
+        x_in = x
+        x = self.norm(x)
+        if not self.use_linear:
+            x = self.proj_in(x)
+        x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
+        if self.use_linear:
+            x = self.proj_in(x)
+        for i, block in enumerate(self.transformer_blocks):
+            x = block(x, context=context[i], content_control=content_control[i], content_w=content_w, color_control=color_control[i], color_w=color_w)
+        if self.use_linear:
+            x = self.proj_out(x)
+        x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
+        if not self.use_linear:
+            x = self.proj_out(x)
+        return x + x_in
diff --git a/ldm/modules/diffusionmodules/__init__.py b/ldm/modules/diffusionmodules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..b089eebbe1676d8249005bb9def002ff5180715b
--- /dev/null
+++ b/ldm/modules/diffusionmodules/model.py
@@ -0,0 +1,852 @@
+# pytorch_diffusion + derived encoder decoder
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+from einops import rearrange
+from typing import Optional, Any
+
+from ldm.modules.attention import MemoryEfficientCrossAttention
+
+try:
+    import xformers
+    import xformers.ops
+    XFORMERS_IS_AVAILBLE = True
+except:
+    XFORMERS_IS_AVAILBLE = False
+    print("No module 'xformers'. Proceeding without it.")
+
+
+def get_timestep_embedding(timesteps, embedding_dim):
+    """
+    This matches the implementation in Denoising Diffusion Probabilistic Models:
+    From Fairseq.
+    Build sinusoidal embeddings.
+    This matches the implementation in tensor2tensor, but differs slightly
+    from the description in Section 3.5 of "Attention Is All You Need".
+    """
+    assert len(timesteps.shape) == 1
+
+    half_dim = embedding_dim // 2
+    emb = math.log(10000) / (half_dim - 1)
+    emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
+    emb = emb.to(device=timesteps.device)
+    emb = timesteps.float()[:, None] * emb[None, :]
+    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+    if embedding_dim % 2 == 1:  # zero pad
+        emb = torch.nn.functional.pad(emb, (0,1,0,0))
+    return emb
+
+
+def nonlinearity(x):
+    # swish
+    return x*torch.sigmoid(x)
+
+
+def Normalize(in_channels, num_groups=32):
+    return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class Upsample(nn.Module):
+    def __init__(self, in_channels, with_conv):
+        super().__init__()
+        self.with_conv = with_conv
+        if self.with_conv:
+            self.conv = torch.nn.Conv2d(in_channels,
+                                        in_channels,
+                                        kernel_size=3,
+                                        stride=1,
+                                        padding=1)
+
+    def forward(self, x):
+        x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+        if self.with_conv:
+            x = self.conv(x)
+        return x
+
+
+class Downsample(nn.Module):
+    def __init__(self, in_channels, with_conv):
+        super().__init__()
+        self.with_conv = with_conv
+        if self.with_conv:
+            # no asymmetric padding in torch conv, must do it ourselves
+            self.conv = torch.nn.Conv2d(in_channels,
+                                        in_channels,
+                                        kernel_size=3,
+                                        stride=2,
+                                        padding=0)
+
+    def forward(self, x):
+        if self.with_conv:
+            pad = (0,1,0,1)
+            x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+            x = self.conv(x)
+        else:
+            x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
+        return x
+
+
+class ResnetBlock(nn.Module):
+    def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
+                 dropout, temb_channels=512):
+        super().__init__()
+        self.in_channels = in_channels
+        out_channels = in_channels if out_channels is None else out_channels
+        self.out_channels = out_channels
+        self.use_conv_shortcut = conv_shortcut
+
+        self.norm1 = Normalize(in_channels)
+        self.conv1 = torch.nn.Conv2d(in_channels,
+                                     out_channels,
+                                     kernel_size=3,
+                                     stride=1,
+                                     padding=1)
+        if temb_channels > 0:
+            self.temb_proj = torch.nn.Linear(temb_channels,
+                                             out_channels)
+        self.norm2 = Normalize(out_channels)
+        self.dropout = torch.nn.Dropout(dropout)
+        self.conv2 = torch.nn.Conv2d(out_channels,
+                                     out_channels,
+                                     kernel_size=3,
+                                     stride=1,
+                                     padding=1)
+        if self.in_channels != self.out_channels:
+            if self.use_conv_shortcut:
+                self.conv_shortcut = torch.nn.Conv2d(in_channels,
+                                                     out_channels,
+                                                     kernel_size=3,
+                                                     stride=1,
+                                                     padding=1)
+            else:
+                self.nin_shortcut = torch.nn.Conv2d(in_channels,
+                                                    out_channels,
+                                                    kernel_size=1,
+                                                    stride=1,
+                                                    padding=0)
+
+    def forward(self, x, temb):
+        h = x
+        h = self.norm1(h)
+        h = nonlinearity(h)
+        h = self.conv1(h)
+
+        if temb is not None:
+            h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
+
+        h = self.norm2(h)
+        h = nonlinearity(h)
+        h = self.dropout(h)
+        h = self.conv2(h)
+
+        if self.in_channels != self.out_channels:
+            if self.use_conv_shortcut:
+                x = self.conv_shortcut(x)
+            else:
+                x = self.nin_shortcut(x)
+
+        return x+h
+
+
+class AttnBlock(nn.Module):
+    def __init__(self, in_channels):
+        super().__init__()
+        self.in_channels = in_channels
+
+        self.norm = Normalize(in_channels)
+        self.q = torch.nn.Conv2d(in_channels,
+                                 in_channels,
+                                 kernel_size=1,
+                                 stride=1,
+                                 padding=0)
+        self.k = torch.nn.Conv2d(in_channels,
+                                 in_channels,
+                                 kernel_size=1,
+                                 stride=1,
+                                 padding=0)
+        self.v = torch.nn.Conv2d(in_channels,
+                                 in_channels,
+                                 kernel_size=1,
+                                 stride=1,
+                                 padding=0)
+        self.proj_out = torch.nn.Conv2d(in_channels,
+                                        in_channels,
+                                        kernel_size=1,
+                                        stride=1,
+                                        padding=0)
+
+    def forward(self, x):
+        h_ = x
+        h_ = self.norm(h_)
+        q = self.q(h_)
+        k = self.k(h_)
+        v = self.v(h_)
+
+        # compute attention
+        b,c,h,w = q.shape
+        q = q.reshape(b,c,h*w)
+        q = q.permute(0,2,1)   # b,hw,c
+        k = k.reshape(b,c,h*w) # b,c,hw
+        w_ = torch.bmm(q,k)     # b,hw,hw    w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+        w_ = w_ * (int(c)**(-0.5))
+        w_ = torch.nn.functional.softmax(w_, dim=2)
+
+        # attend to values
+        v = v.reshape(b,c,h*w)
+        w_ = w_.permute(0,2,1)   # b,hw,hw (first hw of k, second of q)
+        h_ = torch.bmm(v,w_)     # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+        h_ = h_.reshape(b,c,h,w)
+
+        h_ = self.proj_out(h_)
+
+        return x+h_
+
+class MemoryEfficientAttnBlock(nn.Module):
+    """
+        Uses xformers efficient implementation,
+        see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
+        Note: this is a single-head self-attention operation
+    """
+    #
+    def __init__(self, in_channels):
+        super().__init__()
+        self.in_channels = in_channels
+
+        self.norm = Normalize(in_channels)
+        self.q = torch.nn.Conv2d(in_channels,
+                                 in_channels,
+                                 kernel_size=1,
+                                 stride=1,
+                                 padding=0)
+        self.k = torch.nn.Conv2d(in_channels,
+                                 in_channels,
+                                 kernel_size=1,
+                                 stride=1,
+                                 padding=0)
+        self.v = torch.nn.Conv2d(in_channels,
+                                 in_channels,
+                                 kernel_size=1,
+                                 stride=1,
+                                 padding=0)
+        self.proj_out = torch.nn.Conv2d(in_channels,
+                                        in_channels,
+                                        kernel_size=1,
+                                        stride=1,
+                                        padding=0)
+        self.attention_op: Optional[Any] = None
+
+    def forward(self, x):
+        h_ = x
+        h_ = self.norm(h_)
+        q = self.q(h_)
+        k = self.k(h_)
+        v = self.v(h_)
+
+        # compute attention
+        B, C, H, W = q.shape
+        q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
+
+        q, k, v = map(
+            lambda t: t.unsqueeze(3)
+            .reshape(B, t.shape[1], 1, C)
+            .permute(0, 2, 1, 3)
+            .reshape(B * 1, t.shape[1], C)
+            .contiguous(),
+            (q, k, v),
+        )
+        out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
+
+        out = (
+            out.unsqueeze(0)
+            .reshape(B, 1, out.shape[1], C)
+            .permute(0, 2, 1, 3)
+            .reshape(B, out.shape[1], C)
+        )
+        out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
+        out = self.proj_out(out)
+        return x+out
+
+
+class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
+    def forward(self, x, context=None, mask=None):
+        b, c, h, w = x.shape
+        x = rearrange(x, 'b c h w -> b (h w) c')
+        out = super().forward(x, context=context, mask=mask)
+        out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w, c=c)
+        return x + out
+
+
+def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
+    assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
+    if XFORMERS_IS_AVAILBLE and attn_type == "vanilla":
+        attn_type = "vanilla-xformers"
+    print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
+    if attn_type == "vanilla":
+        assert attn_kwargs is None
+        return AttnBlock(in_channels)
+    elif attn_type == "vanilla-xformers":
+        print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
+        return MemoryEfficientAttnBlock(in_channels)
+    elif type == "memory-efficient-cross-attn":
+        attn_kwargs["query_dim"] = in_channels
+        return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
+    elif attn_type == "none":
+        return nn.Identity(in_channels)
+    else:
+        raise NotImplementedError()
+
+
+class Model(nn.Module):
+    def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+                 attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+                 resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
+        super().__init__()
+        if use_linear_attn: attn_type = "linear"
+        self.ch = ch
+        self.temb_ch = self.ch*4
+        self.num_resolutions = len(ch_mult)
+        self.num_res_blocks = num_res_blocks
+        self.resolution = resolution
+        self.in_channels = in_channels
+
+        self.use_timestep = use_timestep
+        if self.use_timestep:
+            # timestep embedding
+            self.temb = nn.Module()
+            self.temb.dense = nn.ModuleList([
+                torch.nn.Linear(self.ch,
+                                self.temb_ch),
+                torch.nn.Linear(self.temb_ch,
+                                self.temb_ch),
+            ])
+
+        # downsampling
+        self.conv_in = torch.nn.Conv2d(in_channels,
+                                       self.ch,
+                                       kernel_size=3,
+                                       stride=1,
+                                       padding=1)
+
+        curr_res = resolution
+        in_ch_mult = (1,)+tuple(ch_mult)
+        self.down = nn.ModuleList()
+        for i_level in range(self.num_resolutions):
+            block = nn.ModuleList()
+            attn = nn.ModuleList()
+            block_in = ch*in_ch_mult[i_level]
+            block_out = ch*ch_mult[i_level]
+            for i_block in range(self.num_res_blocks):
+                block.append(ResnetBlock(in_channels=block_in,
+                                         out_channels=block_out,
+                                         temb_channels=self.temb_ch,
+                                         dropout=dropout))
+                block_in = block_out
+                if curr_res in attn_resolutions:
+                    attn.append(make_attn(block_in, attn_type=attn_type))
+            down = nn.Module()
+            down.block = block
+            down.attn = attn
+            if i_level != self.num_resolutions-1:
+                down.downsample = Downsample(block_in, resamp_with_conv)
+                curr_res = curr_res // 2
+            self.down.append(down)
+
+        # middle
+        self.mid = nn.Module()
+        self.mid.block_1 = ResnetBlock(in_channels=block_in,
+                                       out_channels=block_in,
+                                       temb_channels=self.temb_ch,
+                                       dropout=dropout)
+        self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+        self.mid.block_2 = ResnetBlock(in_channels=block_in,
+                                       out_channels=block_in,
+                                       temb_channels=self.temb_ch,
+                                       dropout=dropout)
+
+        # upsampling
+        self.up = nn.ModuleList()
+        for i_level in reversed(range(self.num_resolutions)):
+            block = nn.ModuleList()
+            attn = nn.ModuleList()
+            block_out = ch*ch_mult[i_level]
+            skip_in = ch*ch_mult[i_level]
+            for i_block in range(self.num_res_blocks+1):
+                if i_block == self.num_res_blocks:
+                    skip_in = ch*in_ch_mult[i_level]
+                block.append(ResnetBlock(in_channels=block_in+skip_in,
+                                         out_channels=block_out,
+                                         temb_channels=self.temb_ch,
+                                         dropout=dropout))
+                block_in = block_out
+                if curr_res in attn_resolutions:
+                    attn.append(make_attn(block_in, attn_type=attn_type))
+            up = nn.Module()
+            up.block = block
+            up.attn = attn
+            if i_level != 0:
+                up.upsample = Upsample(block_in, resamp_with_conv)
+                curr_res = curr_res * 2
+            self.up.insert(0, up) # prepend to get consistent order
+
+        # end
+        self.norm_out = Normalize(block_in)
+        self.conv_out = torch.nn.Conv2d(block_in,
+                                        out_ch,
+                                        kernel_size=3,
+                                        stride=1,
+                                        padding=1)
+
+    def forward(self, x, t=None, context=None):
+        #assert x.shape[2] == x.shape[3] == self.resolution
+        if context is not None:
+            # assume aligned context, cat along channel axis
+            x = torch.cat((x, context), dim=1)
+        if self.use_timestep:
+            # timestep embedding
+            assert t is not None
+            temb = get_timestep_embedding(t, self.ch)
+            temb = self.temb.dense[0](temb)
+            temb = nonlinearity(temb)
+            temb = self.temb.dense[1](temb)
+        else:
+            temb = None
+
+        # downsampling
+        hs = [self.conv_in(x)]
+        for i_level in range(self.num_resolutions):
+            for i_block in range(self.num_res_blocks):
+                h = self.down[i_level].block[i_block](hs[-1], temb)
+                if len(self.down[i_level].attn) > 0:
+                    h = self.down[i_level].attn[i_block](h)
+                hs.append(h)
+            if i_level != self.num_resolutions-1:
+                hs.append(self.down[i_level].downsample(hs[-1]))
+
+        # middle
+        h = hs[-1]
+        h = self.mid.block_1(h, temb)
+        h = self.mid.attn_1(h)
+        h = self.mid.block_2(h, temb)
+
+        # upsampling
+        for i_level in reversed(range(self.num_resolutions)):
+            for i_block in range(self.num_res_blocks+1):
+                h = self.up[i_level].block[i_block](
+                    torch.cat([h, hs.pop()], dim=1), temb)
+                if len(self.up[i_level].attn) > 0:
+                    h = self.up[i_level].attn[i_block](h)
+            if i_level != 0:
+                h = self.up[i_level].upsample(h)
+
+        # end
+        h = self.norm_out(h)
+        h = nonlinearity(h)
+        h = self.conv_out(h)
+        return h
+
+    def get_last_layer(self):
+        return self.conv_out.weight
+
+
+class Encoder(nn.Module):
+    def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+                 attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+                 resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
+                 **ignore_kwargs):
+        super().__init__()
+        if use_linear_attn: attn_type = "linear"
+        self.ch = ch
+        self.temb_ch = 0
+        self.num_resolutions = len(ch_mult)
+        self.num_res_blocks = num_res_blocks
+        self.resolution = resolution
+        self.in_channels = in_channels
+
+        # downsampling
+        self.conv_in = torch.nn.Conv2d(in_channels,
+                                       self.ch,
+                                       kernel_size=3,
+                                       stride=1,
+                                       padding=1)
+
+        curr_res = resolution
+        in_ch_mult = (1,)+tuple(ch_mult)
+        self.in_ch_mult = in_ch_mult
+        self.down = nn.ModuleList()
+        for i_level in range(self.num_resolutions):
+            block = nn.ModuleList()
+            attn = nn.ModuleList()
+            block_in = ch*in_ch_mult[i_level]
+            block_out = ch*ch_mult[i_level]
+            for i_block in range(self.num_res_blocks):
+                block.append(ResnetBlock(in_channels=block_in,
+                                         out_channels=block_out,
+                                         temb_channels=self.temb_ch,
+                                         dropout=dropout))
+                block_in = block_out
+                if curr_res in attn_resolutions:
+                    attn.append(make_attn(block_in, attn_type=attn_type))
+            down = nn.Module()
+            down.block = block
+            down.attn = attn
+            if i_level != self.num_resolutions-1:
+                down.downsample = Downsample(block_in, resamp_with_conv)
+                curr_res = curr_res // 2
+            self.down.append(down)
+
+        # middle
+        self.mid = nn.Module()
+        self.mid.block_1 = ResnetBlock(in_channels=block_in,
+                                       out_channels=block_in,
+                                       temb_channels=self.temb_ch,
+                                       dropout=dropout)
+        self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+        self.mid.block_2 = ResnetBlock(in_channels=block_in,
+                                       out_channels=block_in,
+                                       temb_channels=self.temb_ch,
+                                       dropout=dropout)
+
+        # end
+        self.norm_out = Normalize(block_in)
+        self.conv_out = torch.nn.Conv2d(block_in,
+                                        2*z_channels if double_z else z_channels,
+                                        kernel_size=3,
+                                        stride=1,
+                                        padding=1)
+
+    def forward(self, x):
+        # timestep embedding
+        temb = None
+
+        # downsampling
+        hs = [self.conv_in(x)]
+        for i_level in range(self.num_resolutions):
+            for i_block in range(self.num_res_blocks):
+                h = self.down[i_level].block[i_block](hs[-1], temb)
+                if len(self.down[i_level].attn) > 0:
+                    h = self.down[i_level].attn[i_block](h)
+                hs.append(h)
+            if i_level != self.num_resolutions-1:
+                hs.append(self.down[i_level].downsample(hs[-1]))
+
+        # middle
+        h = hs[-1]
+        h = self.mid.block_1(h, temb)
+        h = self.mid.attn_1(h)
+        h = self.mid.block_2(h, temb)
+
+        # end
+        h = self.norm_out(h)
+        h = nonlinearity(h)
+        h = self.conv_out(h)
+        return h
+
+
+class Decoder(nn.Module):
+    def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+                 attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+                 resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
+                 attn_type="vanilla", **ignorekwargs):
+        super().__init__()
+        if use_linear_attn: attn_type = "linear"
+        self.ch = ch
+        self.temb_ch = 0
+        self.num_resolutions = len(ch_mult)
+        self.num_res_blocks = num_res_blocks
+        self.resolution = resolution
+        self.in_channels = in_channels
+        self.give_pre_end = give_pre_end
+        self.tanh_out = tanh_out
+
+        # compute in_ch_mult, block_in and curr_res at lowest res
+        in_ch_mult = (1,)+tuple(ch_mult)
+        block_in = ch*ch_mult[self.num_resolutions-1]
+        curr_res = resolution // 2**(self.num_resolutions-1)
+        self.z_shape = (1,z_channels,curr_res,curr_res)
+        print("Working with z of shape {} = {} dimensions.".format(
+            self.z_shape, np.prod(self.z_shape)))
+
+        # z to block_in
+        self.conv_in = torch.nn.Conv2d(z_channels,
+                                       block_in,
+                                       kernel_size=3,
+                                       stride=1,
+                                       padding=1)
+
+        # middle
+        self.mid = nn.Module()
+        self.mid.block_1 = ResnetBlock(in_channels=block_in,
+                                       out_channels=block_in,
+                                       temb_channels=self.temb_ch,
+                                       dropout=dropout)
+        self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+        self.mid.block_2 = ResnetBlock(in_channels=block_in,
+                                       out_channels=block_in,
+                                       temb_channels=self.temb_ch,
+                                       dropout=dropout)
+
+        # upsampling
+        self.up = nn.ModuleList()
+        for i_level in reversed(range(self.num_resolutions)):
+            block = nn.ModuleList()
+            attn = nn.ModuleList()
+            block_out = ch*ch_mult[i_level]
+            for i_block in range(self.num_res_blocks+1):
+                block.append(ResnetBlock(in_channels=block_in,
+                                         out_channels=block_out,
+                                         temb_channels=self.temb_ch,
+                                         dropout=dropout))
+                block_in = block_out
+                if curr_res in attn_resolutions:
+                    attn.append(make_attn(block_in, attn_type=attn_type))
+            up = nn.Module()
+            up.block = block
+            up.attn = attn
+            if i_level != 0:
+                up.upsample = Upsample(block_in, resamp_with_conv)
+                curr_res = curr_res * 2
+            self.up.insert(0, up) # prepend to get consistent order
+
+        # end
+        self.norm_out = Normalize(block_in)
+        self.conv_out = torch.nn.Conv2d(block_in,
+                                        out_ch,
+                                        kernel_size=3,
+                                        stride=1,
+                                        padding=1)
+
+    def forward(self, z):
+        #assert z.shape[1:] == self.z_shape[1:]
+        self.last_z_shape = z.shape
+
+        # timestep embedding
+        temb = None
+
+        # z to block_in
+        h = self.conv_in(z)
+
+        # middle
+        h = self.mid.block_1(h, temb)
+        h = self.mid.attn_1(h)
+        h = self.mid.block_2(h, temb)
+
+        # upsampling
+        for i_level in reversed(range(self.num_resolutions)):
+            for i_block in range(self.num_res_blocks+1):
+                h = self.up[i_level].block[i_block](h, temb)
+                if len(self.up[i_level].attn) > 0:
+                    h = self.up[i_level].attn[i_block](h)
+            if i_level != 0:
+                h = self.up[i_level].upsample(h)
+
+        # end
+        if self.give_pre_end:
+            return h
+
+        h = self.norm_out(h)
+        h = nonlinearity(h)
+        h = self.conv_out(h)
+        if self.tanh_out:
+            h = torch.tanh(h)
+        return h
+
+
+class SimpleDecoder(nn.Module):
+    def __init__(self, in_channels, out_channels, *args, **kwargs):
+        super().__init__()
+        self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
+                                     ResnetBlock(in_channels=in_channels,
+                                                 out_channels=2 * in_channels,
+                                                 temb_channels=0, dropout=0.0),
+                                     ResnetBlock(in_channels=2 * in_channels,
+                                                out_channels=4 * in_channels,
+                                                temb_channels=0, dropout=0.0),
+                                     ResnetBlock(in_channels=4 * in_channels,
+                                                out_channels=2 * in_channels,
+                                                temb_channels=0, dropout=0.0),
+                                     nn.Conv2d(2*in_channels, in_channels, 1),
+                                     Upsample(in_channels, with_conv=True)])
+        # end
+        self.norm_out = Normalize(in_channels)
+        self.conv_out = torch.nn.Conv2d(in_channels,
+                                        out_channels,
+                                        kernel_size=3,
+                                        stride=1,
+                                        padding=1)
+
+    def forward(self, x):
+        for i, layer in enumerate(self.model):
+            if i in [1,2,3]:
+                x = layer(x, None)
+            else:
+                x = layer(x)
+
+        h = self.norm_out(x)
+        h = nonlinearity(h)
+        x = self.conv_out(h)
+        return x
+
+
+class UpsampleDecoder(nn.Module):
+    def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
+                 ch_mult=(2,2), dropout=0.0):
+        super().__init__()
+        # upsampling
+        self.temb_ch = 0
+        self.num_resolutions = len(ch_mult)
+        self.num_res_blocks = num_res_blocks
+        block_in = in_channels
+        curr_res = resolution // 2 ** (self.num_resolutions - 1)
+        self.res_blocks = nn.ModuleList()
+        self.upsample_blocks = nn.ModuleList()
+        for i_level in range(self.num_resolutions):
+            res_block = []
+            block_out = ch * ch_mult[i_level]
+            for i_block in range(self.num_res_blocks + 1):
+                res_block.append(ResnetBlock(in_channels=block_in,
+                                         out_channels=block_out,
+                                         temb_channels=self.temb_ch,
+                                         dropout=dropout))
+                block_in = block_out
+            self.res_blocks.append(nn.ModuleList(res_block))
+            if i_level != self.num_resolutions - 1:
+                self.upsample_blocks.append(Upsample(block_in, True))
+                curr_res = curr_res * 2
+
+        # end
+        self.norm_out = Normalize(block_in)
+        self.conv_out = torch.nn.Conv2d(block_in,
+                                        out_channels,
+                                        kernel_size=3,
+                                        stride=1,
+                                        padding=1)
+
+    def forward(self, x):
+        # upsampling
+        h = x
+        for k, i_level in enumerate(range(self.num_resolutions)):
+            for i_block in range(self.num_res_blocks + 1):
+                h = self.res_blocks[i_level][i_block](h, None)
+            if i_level != self.num_resolutions - 1:
+                h = self.upsample_blocks[k](h)
+        h = self.norm_out(h)
+        h = nonlinearity(h)
+        h = self.conv_out(h)
+        return h
+
+
+class LatentRescaler(nn.Module):
+    def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
+        super().__init__()
+        # residual block, interpolate, residual block
+        self.factor = factor
+        self.conv_in = nn.Conv2d(in_channels,
+                                 mid_channels,
+                                 kernel_size=3,
+                                 stride=1,
+                                 padding=1)
+        self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
+                                                     out_channels=mid_channels,
+                                                     temb_channels=0,
+                                                     dropout=0.0) for _ in range(depth)])
+        self.attn = AttnBlock(mid_channels)
+        self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
+                                                     out_channels=mid_channels,
+                                                     temb_channels=0,
+                                                     dropout=0.0) for _ in range(depth)])
+
+        self.conv_out = nn.Conv2d(mid_channels,
+                                  out_channels,
+                                  kernel_size=1,
+                                  )
+
+    def forward(self, x):
+        x = self.conv_in(x)
+        for block in self.res_block1:
+            x = block(x, None)
+        x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
+        x = self.attn(x)
+        for block in self.res_block2:
+            x = block(x, None)
+        x = self.conv_out(x)
+        return x
+
+
+class MergedRescaleEncoder(nn.Module):
+    def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
+                 attn_resolutions, dropout=0.0, resamp_with_conv=True,
+                 ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
+        super().__init__()
+        intermediate_chn = ch * ch_mult[-1]
+        self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
+                               z_channels=intermediate_chn, double_z=False, resolution=resolution,
+                               attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
+                               out_ch=None)
+        self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
+                                       mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
+
+    def forward(self, x):
+        x = self.encoder(x)
+        x = self.rescaler(x)
+        return x
+
+
+class MergedRescaleDecoder(nn.Module):
+    def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
+                 dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
+        super().__init__()
+        tmp_chn = z_channels*ch_mult[-1]
+        self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
+                               resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
+                               ch_mult=ch_mult, resolution=resolution, ch=ch)
+        self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
+                                       out_channels=tmp_chn, depth=rescale_module_depth)
+
+    def forward(self, x):
+        x = self.rescaler(x)
+        x = self.decoder(x)
+        return x
+
+
+class Upsampler(nn.Module):
+    def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
+        super().__init__()
+        assert out_size >= in_size
+        num_blocks = int(np.log2(out_size//in_size))+1
+        factor_up = 1.+ (out_size % in_size)
+        print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
+        self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
+                                       out_channels=in_channels)
+        self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
+                               attn_resolutions=[], in_channels=None, ch=in_channels,
+                               ch_mult=[ch_mult for _ in range(num_blocks)])
+
+    def forward(self, x):
+        x = self.rescaler(x)
+        x = self.decoder(x)
+        return x
+
+
+class Resize(nn.Module):
+    def __init__(self, in_channels=None, learned=False, mode="bilinear"):
+        super().__init__()
+        self.with_conv = learned
+        self.mode = mode
+        if self.with_conv:
+            print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
+            raise NotImplementedError()
+            assert in_channels is not None
+            # no asymmetric padding in torch conv, must do it ourselves
+            self.conv = torch.nn.Conv2d(in_channels,
+                                        in_channels,
+                                        kernel_size=4,
+                                        stride=2,
+                                        padding=1)
+
+    def forward(self, x, scale_factor=1.0):
+        if scale_factor==1.0:
+            return x
+        else:
+            x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
+        return x
diff --git a/ldm/modules/diffusionmodules/openaimodel.py b/ldm/modules/diffusionmodules/openaimodel.py
new file mode 100644
index 0000000000000000000000000000000000000000..71444590f7bbb56a757031ff6843da41059a902b
--- /dev/null
+++ b/ldm/modules/diffusionmodules/openaimodel.py
@@ -0,0 +1,788 @@
+from abc import abstractmethod
+import math
+
+import numpy as np
+import torch as th
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ldm.modules.diffusionmodules.util import (
+    checkpoint,
+    conv_nd,
+    linear,
+    avg_pool_nd,
+    zero_module,
+    normalization,
+    timestep_embedding,
+)
+from ldm.modules.attention import SpatialTransformer
+from ldm.util import exists
+
+
+# dummy replace
+def convert_module_to_f16(x):
+    pass
+
+def convert_module_to_f32(x):
+    pass
+
+
+## go
+class AttentionPool2d(nn.Module):
+    """
+    Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
+    """
+
+    def __init__(
+        self,
+        spacial_dim: int,
+        embed_dim: int,
+        num_heads_channels: int,
+        output_dim: int = None,
+    ):
+        super().__init__()
+        self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
+        self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
+        self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
+        self.num_heads = embed_dim // num_heads_channels
+        self.attention = QKVAttention(self.num_heads)
+
+    def forward(self, x):
+        b, c, *_spatial = x.shape
+        x = x.reshape(b, c, -1)  # NC(HW)
+        x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1)  # NC(HW+1)
+        x = x + self.positional_embedding[None, :, :].to(x.dtype)  # NC(HW+1)
+        x = self.qkv_proj(x)
+        x = self.attention(x)
+        x = self.c_proj(x)
+        return x[:, :, 0]
+
+
+class TimestepBlock(nn.Module):
+    """
+    Any module where forward() takes timestep embeddings as a second argument.
+    """
+
+    @abstractmethod
+    def forward(self, x, emb):
+        """
+        Apply the module to `x` given `emb` timestep embeddings.
+        """
+
+
+class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
+    """
+    A sequential module that passes timestep embeddings to the children that
+    support it as an extra input.
+    """
+
+    def forward(self, x, emb, context=None, content_control=None, color_control=None, content_w=1.0, color_w=1.0):
+        for layer in self:
+            if isinstance(layer, TimestepBlock):
+                x = layer(x, emb)
+            elif isinstance(layer, SpatialTransformer):
+                x = layer(x, context, content_control=content_control, color_control=color_control, content_w=content_w, color_w=color_w)
+            else:
+                x = layer(x)
+        return x
+
+
+class Upsample(nn.Module):
+    """
+    An upsampling layer with an optional convolution.
+    :param channels: channels in the inputs and outputs.
+    :param use_conv: a bool determining if a convolution is applied.
+    :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+                 upsampling occurs in the inner-two dimensions.
+    """
+
+    def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
+        super().__init__()
+        self.channels = channels
+        self.out_channels = out_channels or channels
+        self.use_conv = use_conv
+        self.dims = dims
+        if use_conv:
+            self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
+
+    def forward(self, x):
+        assert x.shape[1] == self.channels
+        if self.dims == 3:
+            x = F.interpolate(
+                x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
+            )
+        else:
+            x = F.interpolate(x, scale_factor=2, mode="nearest")
+        if self.use_conv:
+            x = self.conv(x)
+        return x
+
+class TransposedUpsample(nn.Module):
+    'Learned 2x upsampling without padding'
+    def __init__(self, channels, out_channels=None, ks=5):
+        super().__init__()
+        self.channels = channels
+        self.out_channels = out_channels or channels
+
+        self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
+
+    def forward(self,x):
+        return self.up(x)
+
+
+class Downsample(nn.Module):
+    """
+    A downsampling layer with an optional convolution.
+    :param channels: channels in the inputs and outputs.
+    :param use_conv: a bool determining if a convolution is applied.
+    :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+                 downsampling occurs in the inner-two dimensions.
+    """
+
+    def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
+        super().__init__()
+        self.channels = channels
+        self.out_channels = out_channels or channels
+        self.use_conv = use_conv
+        self.dims = dims
+        stride = 2 if dims != 3 else (1, 2, 2)
+        if use_conv:
+            self.op = conv_nd(
+                dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
+            )
+        else:
+            assert self.channels == self.out_channels
+            self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
+
+    def forward(self, x):
+        assert x.shape[1] == self.channels
+        return self.op(x)
+    
+
+class ResBlock(TimestepBlock):
+    """
+    A residual block that can optionally change the number of channels.
+    :param channels: the number of input channels.
+    :param emb_channels: the number of timestep embedding channels.
+    :param dropout: the rate of dropout.
+    :param out_channels: if specified, the number of out channels.
+    :param use_conv: if True and out_channels is specified, use a spatial
+        convolution instead of a smaller 1x1 convolution to change the
+        channels in the skip connection.
+    :param dims: determines if the signal is 1D, 2D, or 3D.
+    :param use_checkpoint: if True, use gradient checkpointing on this module.
+    :param up: if True, use this block for upsampling.
+    :param down: if True, use this block for downsampling.
+    """
+
+    def __init__(
+        self,
+        channels,
+        emb_channels,
+        dropout,
+        out_channels=None,
+        use_conv=False,
+        use_scale_shift_norm=False,
+        dims=2,
+        use_checkpoint=False,
+        up=False,
+        down=False,
+    ):
+        super().__init__()
+        self.channels = channels
+        self.emb_channels = emb_channels
+        self.dropout = dropout
+        self.out_channels = out_channels or channels
+        self.use_conv = use_conv
+        self.use_checkpoint = use_checkpoint
+        self.use_scale_shift_norm = use_scale_shift_norm
+
+        self.in_layers = nn.Sequential(
+            normalization(channels),
+            nn.SiLU(),
+            conv_nd(dims, channels, self.out_channels, 3, padding=1),
+        )
+
+        self.updown = up or down
+
+        if up:
+            self.h_upd = Upsample(channels, False, dims)
+            self.x_upd = Upsample(channels, False, dims)
+        elif down:
+            self.h_upd = Downsample(channels, False, dims)
+            self.x_upd = Downsample(channels, False, dims)
+        else:
+            self.h_upd = self.x_upd = nn.Identity()
+
+        self.emb_layers = nn.Sequential(
+            nn.SiLU(),
+            linear(
+                emb_channels,
+                2 * self.out_channels if use_scale_shift_norm else self.out_channels,
+            ),
+        )
+        self.out_layers = nn.Sequential(
+            normalization(self.out_channels),
+            nn.SiLU(),
+            nn.Dropout(p=dropout),
+            zero_module(
+                conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
+            ),
+        )
+
+        if self.out_channels == channels:
+            self.skip_connection = nn.Identity()
+        elif use_conv:
+            self.skip_connection = conv_nd(
+                dims, channels, self.out_channels, 3, padding=1
+            )
+        else:
+            self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
+
+    def forward(self, x, emb):
+        """
+        Apply the block to a Tensor, conditioned on a timestep embedding.
+        :param x: an [N x C x ...] Tensor of features.
+        :param emb: an [N x emb_channels] Tensor of timestep embeddings.
+        :return: an [N x C x ...] Tensor of outputs.
+        """
+        return checkpoint(
+            self._forward, (x, emb), self.parameters(), self.use_checkpoint
+        )
+
+
+    def _forward(self, x, emb):
+        if self.updown:
+            in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
+            h = in_rest(x)
+            h = self.h_upd(h)
+            x = self.x_upd(x)
+            h = in_conv(h)
+        else:
+            h = self.in_layers(x)
+        emb_out = self.emb_layers(emb).type(h.dtype)
+        while len(emb_out.shape) < len(h.shape):
+            emb_out = emb_out[..., None]
+        if self.use_scale_shift_norm:
+            out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
+            scale, shift = th.chunk(emb_out, 2, dim=1)
+            h = out_norm(h) * (1 + scale) + shift
+            h = out_rest(h)
+        else:
+            h = h + emb_out
+            h = self.out_layers(h)
+        return self.skip_connection(x) + h
+
+
+class AttentionBlock(nn.Module):
+    """
+    An attention block that allows spatial positions to attend to each other.
+    Originally ported from here, but adapted to the N-d case.
+    https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
+    """
+
+    def __init__(
+        self,
+        channels,
+        num_heads=1,
+        num_head_channels=-1,
+        use_checkpoint=False,
+        use_new_attention_order=False,
+    ):
+        super().__init__()
+        self.channels = channels
+        if num_head_channels == -1:
+            self.num_heads = num_heads
+        else:
+            assert (
+                channels % num_head_channels == 0
+            ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
+            self.num_heads = channels // num_head_channels
+        self.use_checkpoint = use_checkpoint
+        self.norm = normalization(channels)
+        self.qkv = conv_nd(1, channels, channels * 3, 1)
+        if use_new_attention_order:
+            # split qkv before split heads
+            self.attention = QKVAttention(self.num_heads)
+        else:
+            # split heads before split qkv
+            self.attention = QKVAttentionLegacy(self.num_heads)
+
+        self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
+
+    def forward(self, x):
+        return checkpoint(self._forward, (x,), self.parameters(), True)   # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
+        #return pt_checkpoint(self._forward, x)  # pytorch
+
+    def _forward(self, x):
+        b, c, *spatial = x.shape
+        x = x.reshape(b, c, -1)
+        qkv = self.qkv(self.norm(x))
+        h = self.attention(qkv)
+        h = self.proj_out(h)
+        return (x + h).reshape(b, c, *spatial)
+
+
+def count_flops_attn(model, _x, y):
+    """
+    A counter for the `thop` package to count the operations in an
+    attention operation.
+    Meant to be used like:
+        macs, params = thop.profile(
+            model,
+            inputs=(inputs, timestamps),
+            custom_ops={QKVAttention: QKVAttention.count_flops},
+        )
+    """
+    b, c, *spatial = y[0].shape
+    num_spatial = int(np.prod(spatial))
+    # We perform two matmuls with the same number of ops.
+    # The first computes the weight matrix, the second computes
+    # the combination of the value vectors.
+    matmul_ops = 2 * b * (num_spatial ** 2) * c
+    model.total_ops += th.DoubleTensor([matmul_ops])
+
+
+class QKVAttentionLegacy(nn.Module):
+    """
+    A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
+    """
+
+    def __init__(self, n_heads):
+        super().__init__()
+        self.n_heads = n_heads
+
+    def forward(self, qkv):
+        """
+        Apply QKV attention.
+        :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
+        :return: an [N x (H * C) x T] tensor after attention.
+        """
+        bs, width, length = qkv.shape
+        assert width % (3 * self.n_heads) == 0
+        ch = width // (3 * self.n_heads)
+        q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
+        scale = 1 / math.sqrt(math.sqrt(ch))
+        weight = th.einsum(
+            "bct,bcs->bts", q * scale, k * scale
+        )  # More stable with f16 than dividing afterwards
+        weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+        a = th.einsum("bts,bcs->bct", weight, v)
+        return a.reshape(bs, -1, length)
+
+    @staticmethod
+    def count_flops(model, _x, y):
+        return count_flops_attn(model, _x, y)
+
+
+class QKVAttention(nn.Module):
+    """
+    A module which performs QKV attention and splits in a different order.
+    """
+
+    def __init__(self, n_heads):
+        super().__init__()
+        self.n_heads = n_heads
+
+    def forward(self, qkv):
+        """
+        Apply QKV attention.
+        :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
+        :return: an [N x (H * C) x T] tensor after attention.
+        """
+        bs, width, length = qkv.shape
+        assert width % (3 * self.n_heads) == 0
+        ch = width // (3 * self.n_heads)
+        q, k, v = qkv.chunk(3, dim=1)
+        scale = 1 / math.sqrt(math.sqrt(ch))
+        weight = th.einsum(
+            "bct,bcs->bts",
+            (q * scale).view(bs * self.n_heads, ch, length),
+            (k * scale).view(bs * self.n_heads, ch, length),
+        )  # More stable with f16 than dividing afterwards
+        weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+        a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
+        return a.reshape(bs, -1, length)
+
+    @staticmethod
+    def count_flops(model, _x, y):
+        return count_flops_attn(model, _x, y)
+
+
+class UNetModel(nn.Module):
+    """
+    The full UNet model with attention and timestep embedding.
+    :param in_channels: channels in the input Tensor.
+    :param model_channels: base channel count for the model.
+    :param out_channels: channels in the output Tensor.
+    :param num_res_blocks: number of residual blocks per downsample.
+    :param attention_resolutions: a collection of downsample rates at which
+        attention will take place. May be a set, list, or tuple.
+        For example, if this contains 4, then at 4x downsampling, attention
+        will be used.
+    :param dropout: the dropout probability.
+    :param channel_mult: channel multiplier for each level of the UNet.
+    :param conv_resample: if True, use learned convolutions for upsampling and
+        downsampling.
+    :param dims: determines if the signal is 1D, 2D, or 3D.
+    :param num_classes: if specified (as an int), then this model will be
+        class-conditional with `num_classes` classes.
+    :param use_checkpoint: use gradient checkpointing to reduce memory usage.
+    :param num_heads: the number of attention heads in each attention layer.
+    :param num_heads_channels: if specified, ignore num_heads and instead use
+                               a fixed channel width per attention head.
+    :param num_heads_upsample: works with num_heads to set a different number
+                               of heads for upsampling. Deprecated.
+    :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
+    :param resblock_updown: use residual blocks for up/downsampling.
+    :param use_new_attention_order: use a different attention pattern for potentially
+                                    increased efficiency.
+    """
+
+    def __init__(
+        self,
+        image_size,
+        in_channels,
+        model_channels,
+        out_channels,
+        num_res_blocks,
+        attention_resolutions,
+        dropout=0,
+        channel_mult=(1, 2, 4, 8),
+        conv_resample=True,
+        dims=2,
+        num_classes=None,
+        use_checkpoint=False,
+        use_fp16=False,
+        num_heads=-1,
+        num_head_channels=-1,
+        num_heads_upsample=-1,
+        use_scale_shift_norm=False,
+        resblock_updown=False,
+        use_new_attention_order=False,
+        use_spatial_transformer=False,    # custom transformer support
+        transformer_depth=1,              # custom transformer support
+        context_dim=None,                 # custom transformer support
+        content_dim=0,
+        color_dim=0,
+        n_embed=None,                     # custom support for prediction of discrete ids into codebook of first stage vq model
+        legacy=True,
+        disable_self_attentions=None,
+        num_attention_blocks=None,
+        disable_middle_self_attn=False,
+        use_linear_in_transformer=False,
+    ):
+        super().__init__()
+        if use_spatial_transformer:
+            assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
+
+        if context_dim is not None:
+            assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
+            from omegaconf.listconfig import ListConfig
+            if type(context_dim) == ListConfig:
+                context_dim = list(context_dim)
+
+        if num_heads_upsample == -1:
+            num_heads_upsample = num_heads
+
+        if num_heads == -1:
+            assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
+
+        if num_head_channels == -1:
+            assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
+
+        self.image_size = image_size
+        self.in_channels = in_channels
+        self.model_channels = model_channels
+        self.out_channels = out_channels
+        if isinstance(num_res_blocks, int):
+            self.num_res_blocks = len(channel_mult) * [num_res_blocks]
+        else:
+            if len(num_res_blocks) != len(channel_mult):
+                raise ValueError("provide num_res_blocks either as an int (globally constant) or "
+                                 "as a list/tuple (per-level) with the same length as channel_mult")
+            self.num_res_blocks = num_res_blocks
+        if disable_self_attentions is not None:
+            # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
+            assert len(disable_self_attentions) == len(channel_mult)
+        if num_attention_blocks is not None:
+            assert len(num_attention_blocks) == len(self.num_res_blocks)
+            assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
+            print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
+                  f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
+                  f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
+                  f"attention will still not be set.")
+
+        self.attention_resolutions = attention_resolutions
+        self.dropout = dropout
+        self.channel_mult = channel_mult
+        self.conv_resample = conv_resample
+        self.num_classes = num_classes
+        self.use_checkpoint = use_checkpoint
+        self.dtype = th.float16 if use_fp16 else th.float32
+        self.num_heads = num_heads
+        self.num_head_channels = num_head_channels
+        self.num_heads_upsample = num_heads_upsample
+        self.predict_codebook_ids = n_embed is not None
+
+        time_embed_dim = model_channels * 4
+        self.time_embed = nn.Sequential(
+            linear(model_channels, time_embed_dim),
+            nn.SiLU(),
+            linear(time_embed_dim, time_embed_dim),
+        )
+
+        if self.num_classes is not None:
+            if isinstance(self.num_classes, int):
+                self.label_emb = nn.Embedding(num_classes, time_embed_dim)
+            elif self.num_classes == "continuous":
+                print("setting up linear c_adm embedding layer")
+                self.label_emb = nn.Linear(1, time_embed_dim)
+            else:
+                raise ValueError()
+
+        self.input_blocks = nn.ModuleList(
+            [
+                TimestepEmbedSequential(
+                    conv_nd(dims, in_channels, model_channels, 3, padding=1)
+                )
+            ]
+        )
+        self._feature_size = model_channels
+        input_block_chans = [model_channels]
+        ch = model_channels
+        ds = 1
+        for level, mult in enumerate(channel_mult):
+            for nr in range(self.num_res_blocks[level]):
+                layers = [
+                    ResBlock(
+                        ch,
+                        time_embed_dim,
+                        dropout,
+                        out_channels=mult * model_channels,
+                        dims=dims,
+                        use_checkpoint=use_checkpoint,
+                        use_scale_shift_norm=use_scale_shift_norm,
+                    )
+                ]
+                ch = mult * model_channels
+                if ds in attention_resolutions:
+                    if num_head_channels == -1:
+                        dim_head = ch // num_heads
+                    else:
+                        num_heads = ch // num_head_channels
+                        dim_head = num_head_channels
+                    if legacy:
+                        #num_heads = 1
+                        dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+                    if exists(disable_self_attentions):
+                        disabled_sa = disable_self_attentions[level]
+                    else:
+                        disabled_sa = False
+
+                    if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
+                        layers.append(
+                            AttentionBlock(
+                                ch,
+                                use_checkpoint=use_checkpoint,
+                                num_heads=num_heads,
+                                num_head_channels=dim_head,
+                                use_new_attention_order=use_new_attention_order,
+                            ) if not use_spatial_transformer else SpatialTransformer(
+                                ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, content_dim=content_dim, color_dim=color_dim,
+                                disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
+                                use_checkpoint=use_checkpoint
+                            )
+                        )
+                self.input_blocks.append(TimestepEmbedSequential(*layers))
+                self._feature_size += ch
+                input_block_chans.append(ch)
+            if level != len(channel_mult) - 1:
+                out_ch = ch
+                self.input_blocks.append(
+                    TimestepEmbedSequential(
+                        ResBlock(
+                            ch,
+                            time_embed_dim,
+                            dropout,
+                            out_channels=out_ch,
+                            dims=dims,
+                            use_checkpoint=use_checkpoint,
+                            use_scale_shift_norm=use_scale_shift_norm,
+                            down=True,
+                        )
+                        if resblock_updown
+                        else Downsample(
+                            ch, conv_resample, dims=dims, out_channels=out_ch
+                        )
+                    )
+                )
+                ch = out_ch
+                input_block_chans.append(ch)
+                ds *= 2
+                self._feature_size += ch
+
+        if num_head_channels == -1:
+            dim_head = ch // num_heads
+        else:
+            num_heads = ch // num_head_channels
+            dim_head = num_head_channels
+        if legacy:
+            #num_heads = 1
+            dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+        self.middle_block = TimestepEmbedSequential(
+            ResBlock(
+                ch,
+                time_embed_dim,
+                dropout,
+                dims=dims,
+                use_checkpoint=use_checkpoint,
+                use_scale_shift_norm=use_scale_shift_norm,
+            ),
+            AttentionBlock(
+                ch,
+                use_checkpoint=use_checkpoint,
+                num_heads=num_heads,
+                num_head_channels=dim_head,
+                use_new_attention_order=use_new_attention_order,
+            ) if not use_spatial_transformer else SpatialTransformer(  # always uses a self-attn
+                            ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, content_dim=content_dim, color_dim=color_dim,
+                            disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
+                            use_checkpoint=use_checkpoint
+                        ),
+            ResBlock(
+                ch,
+                time_embed_dim,
+                dropout,
+                dims=dims,
+                use_checkpoint=use_checkpoint,
+                use_scale_shift_norm=use_scale_shift_norm,
+            ),
+        )
+        self._feature_size += ch
+
+        self.output_blocks = nn.ModuleList([])
+        for level, mult in list(enumerate(channel_mult))[::-1]:
+            for i in range(self.num_res_blocks[level] + 1):
+                ich = input_block_chans.pop()
+                layers = [
+                    ResBlock(
+                        ch + ich,
+                        time_embed_dim,
+                        dropout,
+                        out_channels=model_channels * mult,
+                        dims=dims,
+                        use_checkpoint=use_checkpoint,
+                        use_scale_shift_norm=use_scale_shift_norm,
+                    )
+                ]
+                ch = model_channels * mult
+                if ds in attention_resolutions:
+                    if num_head_channels == -1:
+                        dim_head = ch // num_heads
+                    else:
+                        num_heads = ch // num_head_channels
+                        dim_head = num_head_channels
+                    if legacy:
+                        #num_heads = 1
+                        dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+                    if exists(disable_self_attentions):
+                        disabled_sa = disable_self_attentions[level]
+                    else:
+                        disabled_sa = False
+
+                    if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
+                        layers.append(
+                            AttentionBlock(
+                                ch,
+                                use_checkpoint=use_checkpoint,
+                                num_heads=num_heads_upsample,
+                                num_head_channels=dim_head,
+                                use_new_attention_order=use_new_attention_order,
+                            ) if not use_spatial_transformer else SpatialTransformer(
+                                ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, content_dim=content_dim, color_dim=color_dim,
+                                disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
+                                use_checkpoint=use_checkpoint
+                            )
+                        )
+                if level and i == self.num_res_blocks[level]:
+                    out_ch = ch
+                    layers.append(
+                        ResBlock(
+                            ch,
+                            time_embed_dim,
+                            dropout,
+                            out_channels=out_ch,
+                            dims=dims,
+                            use_checkpoint=use_checkpoint,
+                            use_scale_shift_norm=use_scale_shift_norm,
+                            up=True,
+                        )
+                        if resblock_updown
+                        else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
+                    )
+                    ds //= 2
+                self.output_blocks.append(TimestepEmbedSequential(*layers))
+                self._feature_size += ch
+
+        self.out = nn.Sequential(
+            normalization(ch),
+            nn.SiLU(),
+            zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
+        )
+        if self.predict_codebook_ids:
+            self.id_predictor = nn.Sequential(
+            normalization(ch),
+            conv_nd(dims, model_channels, n_embed, 1),
+            #nn.LogSoftmax(dim=1)  # change to cross_entropy and produce non-normalized logits
+        )
+
+    def convert_to_fp16(self):
+        """
+        Convert the torso of the model to float16.
+        """
+        self.input_blocks.apply(convert_module_to_f16)
+        self.middle_block.apply(convert_module_to_f16)
+        self.output_blocks.apply(convert_module_to_f16)
+
+    def convert_to_fp32(self):
+        """
+        Convert the torso of the model to float32.
+        """
+        self.input_blocks.apply(convert_module_to_f32)
+        self.middle_block.apply(convert_module_to_f32)
+        self.output_blocks.apply(convert_module_to_f32)
+
+    def forward(self, x, timesteps=None, context=None, content_control=None, content_w=1.0, color_control=None, color_w=1.0, y=None, **kwargs):
+        """
+        Apply the model to an input batch.
+        :param x: an [N x C x ...] Tensor of inputs.
+        :param timesteps: a 1-D batch of timesteps.
+        :param context: conditioning plugged in via crossattn
+        :param y: an [N] Tensor of labels, if class-conditional.
+        :return: an [N x C x ...] Tensor of outputs.
+        """
+        assert (y is not None) == (
+            self.num_classes is not None
+        ), "must specify y if and only if the model is class-conditional"
+        hs = []
+        t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+        emb = self.time_embed(t_emb)
+
+        if self.num_classes is not None:
+            assert y.shape[0] == x.shape[0]
+            emb = emb + self.label_emb(y)
+
+        h = x.type(self.dtype)
+        for module in self.input_blocks:
+            h = module(h, emb, context=context, content_control=content_control, content_w=content_w, color_control=color_control, color_w=color_w)
+            hs.append(h)
+        h = self.middle_block(h, emb, context=context, content_control=content_control, content_w=content_w, color_control=color_control, color_w=color_w)
+        for module in self.output_blocks:
+            h = th.cat([h, hs.pop()], dim=1)
+            h = module(h, emb, context=context, content_control=content_control, content_w=content_w, color_control=color_control, color_w=color_w)
+        h = h.type(x.dtype)
+        if self.predict_codebook_ids:
+            return self.id_predictor(h)
+        else:
+            return self.out(h)
diff --git a/ldm/modules/diffusionmodules/upscaling.py b/ldm/modules/diffusionmodules/upscaling.py
new file mode 100644
index 0000000000000000000000000000000000000000..03816662098ce1ffac79bd939b892e867ab91988
--- /dev/null
+++ b/ldm/modules/diffusionmodules/upscaling.py
@@ -0,0 +1,81 @@
+import torch
+import torch.nn as nn
+import numpy as np
+from functools import partial
+
+from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule
+from ldm.util import default
+
+
+class AbstractLowScaleModel(nn.Module):
+    # for concatenating a downsampled image to the latent representation
+    def __init__(self, noise_schedule_config=None):
+        super(AbstractLowScaleModel, self).__init__()
+        if noise_schedule_config is not None:
+            self.register_schedule(**noise_schedule_config)
+
+    def register_schedule(self, beta_schedule="linear", timesteps=1000,
+                          linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+        betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
+                                   cosine_s=cosine_s)
+        alphas = 1. - betas
+        alphas_cumprod = np.cumprod(alphas, axis=0)
+        alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
+
+        timesteps, = betas.shape
+        self.num_timesteps = int(timesteps)
+        self.linear_start = linear_start
+        self.linear_end = linear_end
+        assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
+
+        to_torch = partial(torch.tensor, dtype=torch.float32)
+
+        self.register_buffer('betas', to_torch(betas))
+        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+        self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
+
+        # calculations for diffusion q(x_t | x_{t-1}) and others
+        self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
+        self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
+        self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
+        self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
+        self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
+
+    def q_sample(self, x_start, t, noise=None):
+        noise = default(noise, lambda: torch.randn_like(x_start))
+        return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
+                extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
+
+    def forward(self, x):
+        return x, None
+
+    def decode(self, x):
+        return x
+
+
+class SimpleImageConcat(AbstractLowScaleModel):
+    # no noise level conditioning
+    def __init__(self):
+        super(SimpleImageConcat, self).__init__(noise_schedule_config=None)
+        self.max_noise_level = 0
+
+    def forward(self, x):
+        # fix to constant noise level
+        return x, torch.zeros(x.shape[0], device=x.device).long()
+
+
+class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
+    def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False):
+        super().__init__(noise_schedule_config=noise_schedule_config)
+        self.max_noise_level = max_noise_level
+
+    def forward(self, x, noise_level=None):
+        if noise_level is None:
+            noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
+        else:
+            assert isinstance(noise_level, torch.Tensor)
+        z = self.q_sample(x, noise_level)
+        return z, noise_level
+
+
+
diff --git a/ldm/modules/diffusionmodules/util.py b/ldm/modules/diffusionmodules/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b847f872dc9df7cc978e4efb73a0421f90a9c08
--- /dev/null
+++ b/ldm/modules/diffusionmodules/util.py
@@ -0,0 +1,270 @@
+# adopted from
+# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+# and
+# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
+# and
+# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
+#
+# thanks!
+
+
+import os
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+from einops import repeat
+
+from ldm.util import instantiate_from_config
+
+
+def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+    if schedule == "linear":
+        betas = (
+                torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
+        )
+
+    elif schedule == "cosine":
+        timesteps = (
+                torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
+        )
+        alphas = timesteps / (1 + cosine_s) * np.pi / 2
+        alphas = torch.cos(alphas).pow(2)
+        alphas = alphas / alphas[0]
+        betas = 1 - alphas[1:] / alphas[:-1]
+        betas = np.clip(betas, a_min=0, a_max=0.999)
+
+    elif schedule == "sqrt_linear":
+        betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
+    elif schedule == "sqrt":
+        betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
+    else:
+        raise ValueError(f"schedule '{schedule}' unknown.")
+    return betas.numpy()
+
+
+def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
+    if ddim_discr_method == 'uniform':
+        c = num_ddpm_timesteps // num_ddim_timesteps
+        ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
+    elif ddim_discr_method == 'quad':
+        ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
+    else:
+        raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
+
+    # assert ddim_timesteps.shape[0] == num_ddim_timesteps
+    # add one to get the final alpha values right (the ones from first scale to data during sampling)
+    steps_out = ddim_timesteps + 1
+    if verbose:
+        print(f'Selected timesteps for ddim sampler: {steps_out}')
+    return steps_out
+
+
+def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
+    # select alphas for computing the variance schedule
+    alphas = alphacums[ddim_timesteps]
+    alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
+
+    # according the the formula provided in https://arxiv.org/abs/2010.02502
+    sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
+    if verbose:
+        print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
+        print(f'For the chosen value of eta, which is {eta}, '
+              f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
+    return sigmas, alphas, alphas_prev
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
+    """
+    Create a beta schedule that discretizes the given alpha_t_bar function,
+    which defines the cumulative product of (1-beta) over time from t = [0,1].
+    :param num_diffusion_timesteps: the number of betas to produce.
+    :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
+                      produces the cumulative product of (1-beta) up to that
+                      part of the diffusion process.
+    :param max_beta: the maximum beta to use; use values lower than 1 to
+                     prevent singularities.
+    """
+    betas = []
+    for i in range(num_diffusion_timesteps):
+        t1 = i / num_diffusion_timesteps
+        t2 = (i + 1) / num_diffusion_timesteps
+        betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+    return np.array(betas)
+
+
+def extract_into_tensor(a, t, x_shape):
+    b, *_ = t.shape
+    out = a.gather(-1, t)
+    return out.reshape(b, *((1,) * (len(x_shape) - 1)))
+
+
+def checkpoint(func, inputs, params, flag):
+    """
+    Evaluate a function without caching intermediate activations, allowing for
+    reduced memory at the expense of extra compute in the backward pass.
+    :param func: the function to evaluate.
+    :param inputs: the argument sequence to pass to `func`.
+    :param params: a sequence of parameters `func` depends on but does not
+                   explicitly take as arguments.
+    :param flag: if False, disable gradient checkpointing.
+    """
+    if flag:
+        args = tuple(inputs) + tuple(params)
+        return CheckpointFunction.apply(func, len(inputs), *args)
+    else:
+        return func(*inputs)
+
+
+class CheckpointFunction(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, run_function, length, *args):
+        ctx.run_function = run_function
+        ctx.input_tensors = list(args[:length])
+        ctx.input_params = list(args[length:])
+        ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
+                                   "dtype": torch.get_autocast_gpu_dtype(),
+                                   "cache_enabled": torch.is_autocast_cache_enabled()}
+        with torch.no_grad():
+            output_tensors = ctx.run_function(*ctx.input_tensors)
+        return output_tensors
+
+    @staticmethod
+    def backward(ctx, *output_grads):
+        ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
+        with torch.enable_grad(), \
+                torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
+            # Fixes a bug where the first op in run_function modifies the
+            # Tensor storage in place, which is not allowed for detach()'d
+            # Tensors.
+            shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
+            output_tensors = ctx.run_function(*shallow_copies)
+        input_grads = torch.autograd.grad(
+            output_tensors,
+            ctx.input_tensors + ctx.input_params,
+            output_grads,
+            allow_unused=True,
+        )
+        del ctx.input_tensors
+        del ctx.input_params
+        del output_tensors
+        return (None, None) + input_grads
+
+
+def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
+    """
+    Create sinusoidal timestep embeddings.
+    :param timesteps: a 1-D Tensor of N indices, one per batch element.
+                      These may be fractional.
+    :param dim: the dimension of the output.
+    :param max_period: controls the minimum frequency of the embeddings.
+    :return: an [N x dim] Tensor of positional embeddings.
+    """
+    if not repeat_only:
+        half = dim // 2
+        freqs = torch.exp(
+            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
+        ).to(device=timesteps.device)
+        args = timesteps[:, None].float() * freqs[None]
+        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+        if dim % 2:
+            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+    else:
+        embedding = repeat(timesteps, 'b -> b d', d=dim)
+    return embedding
+
+
+def zero_module(module):
+    """
+    Zero out the parameters of a module and return it.
+    """
+    for p in module.parameters():
+        p.detach().zero_()
+    return module
+
+
+def scale_module(module, scale):
+    """
+    Scale the parameters of a module and return it.
+    """
+    for p in module.parameters():
+        p.detach().mul_(scale)
+    return module
+
+
+def mean_flat(tensor):
+    """
+    Take the mean over all non-batch dimensions.
+    """
+    return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def normalization(channels):
+    """
+    Make a standard normalization layer.
+    :param channels: number of input channels.
+    :return: an nn.Module for normalization.
+    """
+    return GroupNorm32(32, channels)
+
+
+# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
+class SiLU(nn.Module):
+    def forward(self, x):
+        return x * torch.sigmoid(x)
+
+
+class GroupNorm32(nn.GroupNorm):
+    def forward(self, x):
+        return super().forward(x.float()).type(x.dtype)
+
+def conv_nd(dims, *args, **kwargs):
+    """
+    Create a 1D, 2D, or 3D convolution module.
+    """
+    if dims == 1:
+        return nn.Conv1d(*args, **kwargs)
+    elif dims == 2:
+        return nn.Conv2d(*args, **kwargs)
+    elif dims == 3:
+        return nn.Conv3d(*args, **kwargs)
+    raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def linear(*args, **kwargs):
+    """
+    Create a linear module.
+    """
+    return nn.Linear(*args, **kwargs)
+
+
+def avg_pool_nd(dims, *args, **kwargs):
+    """
+    Create a 1D, 2D, or 3D average pooling module.
+    """
+    if dims == 1:
+        return nn.AvgPool1d(*args, **kwargs)
+    elif dims == 2:
+        return nn.AvgPool2d(*args, **kwargs)
+    elif dims == 3:
+        return nn.AvgPool3d(*args, **kwargs)
+    raise ValueError(f"unsupported dimensions: {dims}")
+
+
+class HybridConditioner(nn.Module):
+
+    def __init__(self, c_concat_config, c_crossattn_config):
+        super().__init__()
+        self.concat_conditioner = instantiate_from_config(c_concat_config)
+        self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
+
+    def forward(self, c_concat, c_crossattn):
+        c_concat = self.concat_conditioner(c_concat)
+        c_crossattn = self.crossattn_conditioner(c_crossattn)
+        return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
+
+
+def noise_like(shape, device, repeat=False):
+    repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
+    noise = lambda: torch.randn(shape, device=device)
+    return repeat_noise() if repeat else noise()
diff --git a/ldm/modules/distributions/__init__.py b/ldm/modules/distributions/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ldm/modules/distributions/distributions.py b/ldm/modules/distributions/distributions.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2b8ef901130efc171aa69742ca0244d94d3f2e9
--- /dev/null
+++ b/ldm/modules/distributions/distributions.py
@@ -0,0 +1,92 @@
+import torch
+import numpy as np
+
+
+class AbstractDistribution:
+    def sample(self):
+        raise NotImplementedError()
+
+    def mode(self):
+        raise NotImplementedError()
+
+
+class DiracDistribution(AbstractDistribution):
+    def __init__(self, value):
+        self.value = value
+
+    def sample(self):
+        return self.value
+
+    def mode(self):
+        return self.value
+
+
+class DiagonalGaussianDistribution(object):
+    def __init__(self, parameters, deterministic=False):
+        self.parameters = parameters
+        self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
+        self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+        self.deterministic = deterministic
+        self.std = torch.exp(0.5 * self.logvar)
+        self.var = torch.exp(self.logvar)
+        if self.deterministic:
+            self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
+
+    def sample(self):
+        x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
+        return x
+
+    def kl(self, other=None):
+        if self.deterministic:
+            return torch.Tensor([0.])
+        else:
+            if other is None:
+                return 0.5 * torch.sum(torch.pow(self.mean, 2)
+                                       + self.var - 1.0 - self.logvar,
+                                       dim=[1, 2, 3])
+            else:
+                return 0.5 * torch.sum(
+                    torch.pow(self.mean - other.mean, 2) / other.var
+                    + self.var / other.var - 1.0 - self.logvar + other.logvar,
+                    dim=[1, 2, 3])
+
+    def nll(self, sample, dims=[1,2,3]):
+        if self.deterministic:
+            return torch.Tensor([0.])
+        logtwopi = np.log(2.0 * np.pi)
+        return 0.5 * torch.sum(
+            logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
+            dim=dims)
+
+    def mode(self):
+        return self.mean
+
+
+def normal_kl(mean1, logvar1, mean2, logvar2):
+    """
+    source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
+    Compute the KL divergence between two gaussians.
+    Shapes are automatically broadcasted, so batches can be compared to
+    scalars, among other use cases.
+    """
+    tensor = None
+    for obj in (mean1, logvar1, mean2, logvar2):
+        if isinstance(obj, torch.Tensor):
+            tensor = obj
+            break
+    assert tensor is not None, "at least one argument must be a Tensor"
+
+    # Force variances to be Tensors. Broadcasting helps convert scalars to
+    # Tensors, but it does not work for torch.exp().
+    logvar1, logvar2 = [
+        x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
+        for x in (logvar1, logvar2)
+    ]
+
+    return 0.5 * (
+        -1.0
+        + logvar2
+        - logvar1
+        + torch.exp(logvar1 - logvar2)
+        + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
+    )
diff --git a/ldm/modules/ema.py b/ldm/modules/ema.py
new file mode 100644
index 0000000000000000000000000000000000000000..bded25019b9bcbcd0260f0b8185f8c7859ca58c4
--- /dev/null
+++ b/ldm/modules/ema.py
@@ -0,0 +1,80 @@
+import torch
+from torch import nn
+
+
+class LitEma(nn.Module):
+    def __init__(self, model, decay=0.9999, use_num_upates=True):
+        super().__init__()
+        if decay < 0.0 or decay > 1.0:
+            raise ValueError('Decay must be between 0 and 1')
+
+        self.m_name2s_name = {}
+        self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
+        self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates
+        else torch.tensor(-1, dtype=torch.int))
+
+        for name, p in model.named_parameters():
+            if p.requires_grad:
+                # remove as '.'-character is not allowed in buffers
+                s_name = name.replace('.', '')
+                self.m_name2s_name.update({name: s_name})
+                self.register_buffer(s_name, p.clone().detach().data)
+
+        self.collected_params = []
+
+    def reset_num_updates(self):
+        del self.num_updates
+        self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))
+
+    def forward(self, model):
+        decay = self.decay
+
+        if self.num_updates >= 0:
+            self.num_updates += 1
+            decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
+
+        one_minus_decay = 1.0 - decay
+
+        with torch.no_grad():
+            m_param = dict(model.named_parameters())
+            shadow_params = dict(self.named_buffers())
+
+            for key in m_param:
+                if m_param[key].requires_grad:
+                    sname = self.m_name2s_name[key]
+                    shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
+                    shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
+                else:
+                    assert not key in self.m_name2s_name
+
+    def copy_to(self, model):
+        m_param = dict(model.named_parameters())
+        shadow_params = dict(self.named_buffers())
+        for key in m_param:
+            if m_param[key].requires_grad:
+                m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
+            else:
+                assert not key in self.m_name2s_name
+
+    def store(self, parameters):
+        """
+        Save the current parameters for restoring later.
+        Args:
+          parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+            temporarily stored.
+        """
+        self.collected_params = [param.clone() for param in parameters]
+
+    def restore(self, parameters):
+        """
+        Restore the parameters stored with the `store` method.
+        Useful to validate the model with EMA parameters without affecting the
+        original optimization process. Store the parameters before the
+        `copy_to` method. After validation (or model saving), use this to
+        restore the former parameters.
+        Args:
+          parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+            updated with the stored parameters.
+        """
+        for c_param, param in zip(self.collected_params, parameters):
+            param.data.copy_(c_param.data)
diff --git a/ldm/modules/encoders/__init__.py b/ldm/modules/encoders/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..4edd5496b9e668ea72a5be39db9cca94b6a42f9b
--- /dev/null
+++ b/ldm/modules/encoders/modules.py
@@ -0,0 +1,213 @@
+import torch
+import torch.nn as nn
+from torch.utils.checkpoint import checkpoint
+
+from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
+
+import open_clip
+from ldm.util import default, count_params
+
+
+class AbstractEncoder(nn.Module):
+    def __init__(self):
+        super().__init__()
+
+    def encode(self, *args, **kwargs):
+        raise NotImplementedError
+
+
+class IdentityEncoder(AbstractEncoder):
+
+    def encode(self, x):
+        return x
+
+
+class ClassEmbedder(nn.Module):
+    def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1):
+        super().__init__()
+        self.key = key
+        self.embedding = nn.Embedding(n_classes, embed_dim)
+        self.n_classes = n_classes
+        self.ucg_rate = ucg_rate
+
+    def forward(self, batch, key=None, disable_dropout=False):
+        if key is None:
+            key = self.key
+        # this is for use in crossattn
+        c = batch[key][:, None]
+        if self.ucg_rate > 0. and not disable_dropout:
+            mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
+            c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1)
+            c = c.long()
+        c = self.embedding(c)
+        return c
+
+    def get_unconditional_conditioning(self, bs, device="cuda"):
+        uc_class = self.n_classes - 1  # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
+        uc = torch.ones((bs,), device=device) * uc_class
+        uc = {self.key: uc}
+        return uc
+
+
+def disabled_train(self, mode=True):
+    """Overwrite model.train with this function to make sure train/eval mode
+    does not change anymore."""
+    return self
+
+
+class FrozenT5Embedder(AbstractEncoder):
+    """Uses the T5 transformer encoder for text"""
+    def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True):  # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
+        super().__init__()
+        self.tokenizer = T5Tokenizer.from_pretrained(version)
+        self.transformer = T5EncoderModel.from_pretrained(version)
+        self.device = device
+        self.max_length = max_length   # TODO: typical value?
+        if freeze:
+            self.freeze()
+
+    def freeze(self):
+        self.transformer = self.transformer.eval()
+        #self.train = disabled_train
+        for param in self.parameters():
+            param.requires_grad = False
+
+    def forward(self, text):
+        batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
+                                        return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
+        tokens = batch_encoding["input_ids"].to(self.device)
+        outputs = self.transformer(input_ids=tokens)
+
+        z = outputs.last_hidden_state
+        return z
+
+    def encode(self, text):
+        return self(text)
+
+
+class FrozenCLIPEmbedder(AbstractEncoder):
+    """Uses the CLIP transformer encoder for text (from huggingface)"""
+    LAYERS = [
+        "last",
+        "pooled",
+        "hidden"
+    ]
+    def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
+                 freeze=True, layer="last", layer_idx=None):  # clip-vit-base-patch32
+        super().__init__()
+        assert layer in self.LAYERS
+        self.tokenizer = CLIPTokenizer.from_pretrained(version)
+        self.transformer = CLIPTextModel.from_pretrained(version)
+        self.device = device
+        self.max_length = max_length
+        if freeze:
+            self.freeze()
+        self.layer = layer
+        self.layer_idx = layer_idx
+        if layer == "hidden":
+            assert layer_idx is not None
+            assert 0 <= abs(layer_idx) <= 12
+
+    def freeze(self):
+        self.transformer = self.transformer.eval()
+        #self.train = disabled_train
+        for param in self.parameters():
+            param.requires_grad = False
+
+    def forward(self, text):
+        batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
+                                        return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
+        tokens = batch_encoding["input_ids"].to(self.device)
+        outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
+        if self.layer == "last":
+            z = outputs.last_hidden_state
+        elif self.layer == "pooled":
+            z = outputs.pooler_output[:, None, :]
+        else:
+            z = outputs.hidden_states[self.layer_idx]
+        return z
+
+    def encode(self, text):
+        return self(text)
+
+
+class FrozenOpenCLIPEmbedder(AbstractEncoder):
+    """
+    Uses the OpenCLIP transformer encoder for text
+    """
+    LAYERS = [
+        #"pooled",
+        "last",
+        "penultimate"
+    ]
+    def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
+                 freeze=True, layer="last"):
+        super().__init__()
+        assert layer in self.LAYERS
+        model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version)
+        del model.visual
+        self.model = model
+
+        self.device = device
+        self.max_length = max_length
+        if freeze:
+            self.freeze()
+        self.layer = layer
+        if self.layer == "last":
+            self.layer_idx = 0
+        elif self.layer == "penultimate":
+            self.layer_idx = 1
+        else:
+            raise NotImplementedError()
+
+    def freeze(self):
+        self.model = self.model.eval()
+        for param in self.parameters():
+            param.requires_grad = False
+
+    def forward(self, text):
+        tokens = open_clip.tokenize(text)
+        z = self.encode_with_transformer(tokens.to(self.device))
+        return z
+
+    def encode_with_transformer(self, text):
+        x = self.model.token_embedding(text)  # [batch_size, n_ctx, d_model]
+        x = x + self.model.positional_embedding
+        x = x.permute(1, 0, 2)  # NLD -> LND
+        x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
+        x = x.permute(1, 0, 2)  # LND -> NLD
+        x = self.model.ln_final(x)
+        return x
+
+    def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
+        for i, r in enumerate(self.model.transformer.resblocks):
+            if i == len(self.model.transformer.resblocks) - self.layer_idx:
+                break
+            if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
+                x = checkpoint(r, x, attn_mask)
+            else:
+                x = r(x, attn_mask=attn_mask)
+        return x
+
+    def encode(self, text):
+        return self(text)
+
+
+class FrozenCLIPT5Encoder(AbstractEncoder):
+    def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda",
+                 clip_max_length=77, t5_max_length=77):
+        super().__init__()
+        self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)
+        self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
+        print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, "
+              f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.")
+
+    def encode(self, text):
+        return self(text)
+
+    def forward(self, text):
+        clip_z = self.clip_encoder.encode(text)
+        t5_z = self.t5_encoder.encode(text)
+        return [clip_z, t5_z]
+
+
diff --git a/ldm/modules/image_degradation/__init__.py b/ldm/modules/image_degradation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7836cada81f90ded99c58d5942eea4c3477f58fc
--- /dev/null
+++ b/ldm/modules/image_degradation/__init__.py
@@ -0,0 +1,2 @@
+from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
+from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
diff --git a/ldm/modules/image_degradation/bsrgan.py b/ldm/modules/image_degradation/bsrgan.py
new file mode 100644
index 0000000000000000000000000000000000000000..32ef56169978e550090261cddbcf5eb611a6173b
--- /dev/null
+++ b/ldm/modules/image_degradation/bsrgan.py
@@ -0,0 +1,730 @@
+# -*- coding: utf-8 -*-
+"""
+# --------------------------------------------
+# Super-Resolution
+# --------------------------------------------
+#
+# Kai Zhang (cskaizhang@gmail.com)
+# https://github.com/cszn
+# From 2019/03--2021/08
+# --------------------------------------------
+"""
+
+import numpy as np
+import cv2
+import torch
+
+from functools import partial
+import random
+from scipy import ndimage
+import scipy
+import scipy.stats as ss
+from scipy.interpolate import interp2d
+from scipy.linalg import orth
+import albumentations
+
+import ldm.modules.image_degradation.utils_image as util
+
+
+def modcrop_np(img, sf):
+    '''
+    Args:
+        img: numpy image, WxH or WxHxC
+        sf: scale factor
+    Return:
+        cropped image
+    '''
+    w, h = img.shape[:2]
+    im = np.copy(img)
+    return im[:w - w % sf, :h - h % sf, ...]
+
+
+"""
+# --------------------------------------------
+# anisotropic Gaussian kernels
+# --------------------------------------------
+"""
+
+
+def analytic_kernel(k):
+    """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
+    k_size = k.shape[0]
+    # Calculate the big kernels size
+    big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
+    # Loop over the small kernel to fill the big one
+    for r in range(k_size):
+        for c in range(k_size):
+            big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
+    # Crop the edges of the big kernel to ignore very small values and increase run time of SR
+    crop = k_size // 2
+    cropped_big_k = big_k[crop:-crop, crop:-crop]
+    # Normalize to 1
+    return cropped_big_k / cropped_big_k.sum()
+
+
+def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
+    """ generate an anisotropic Gaussian kernel
+    Args:
+        ksize : e.g., 15, kernel size
+        theta : [0,  pi], rotation angle range
+        l1    : [0.1,50], scaling of eigenvalues
+        l2    : [0.1,l1], scaling of eigenvalues
+        If l1 = l2, will get an isotropic Gaussian kernel.
+    Returns:
+        k     : kernel
+    """
+
+    v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
+    V = np.array([[v[0], v[1]], [v[1], -v[0]]])
+    D = np.array([[l1, 0], [0, l2]])
+    Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
+    k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
+
+    return k
+
+
+def gm_blur_kernel(mean, cov, size=15):
+    center = size / 2.0 + 0.5
+    k = np.zeros([size, size])
+    for y in range(size):
+        for x in range(size):
+            cy = y - center + 1
+            cx = x - center + 1
+            k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
+
+    k = k / np.sum(k)
+    return k
+
+
+def shift_pixel(x, sf, upper_left=True):
+    """shift pixel for super-resolution with different scale factors
+    Args:
+        x: WxHxC or WxH
+        sf: scale factor
+        upper_left: shift direction
+    """
+    h, w = x.shape[:2]
+    shift = (sf - 1) * 0.5
+    xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
+    if upper_left:
+        x1 = xv + shift
+        y1 = yv + shift
+    else:
+        x1 = xv - shift
+        y1 = yv - shift
+
+    x1 = np.clip(x1, 0, w - 1)
+    y1 = np.clip(y1, 0, h - 1)
+
+    if x.ndim == 2:
+        x = interp2d(xv, yv, x)(x1, y1)
+    if x.ndim == 3:
+        for i in range(x.shape[-1]):
+            x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
+
+    return x
+
+
+def blur(x, k):
+    '''
+    x: image, NxcxHxW
+    k: kernel, Nx1xhxw
+    '''
+    n, c = x.shape[:2]
+    p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
+    x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
+    k = k.repeat(1, c, 1, 1)
+    k = k.view(-1, 1, k.shape[2], k.shape[3])
+    x = x.view(1, -1, x.shape[2], x.shape[3])
+    x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
+    x = x.view(n, c, x.shape[2], x.shape[3])
+
+    return x
+
+
+def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
+    """"
+    # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
+    # Kai Zhang
+    # min_var = 0.175 * sf  # variance of the gaussian kernel will be sampled between min_var and max_var
+    # max_var = 2.5 * sf
+    """
+    # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
+    lambda_1 = min_var + np.random.rand() * (max_var - min_var)
+    lambda_2 = min_var + np.random.rand() * (max_var - min_var)
+    theta = np.random.rand() * np.pi  # random theta
+    noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
+
+    # Set COV matrix using Lambdas and Theta
+    LAMBDA = np.diag([lambda_1, lambda_2])
+    Q = np.array([[np.cos(theta), -np.sin(theta)],
+                  [np.sin(theta), np.cos(theta)]])
+    SIGMA = Q @ LAMBDA @ Q.T
+    INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
+
+    # Set expectation position (shifting kernel for aligned image)
+    MU = k_size // 2 - 0.5 * (scale_factor - 1)  # - 0.5 * (scale_factor - k_size % 2)
+    MU = MU[None, None, :, None]
+
+    # Create meshgrid for Gaussian
+    [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
+    Z = np.stack([X, Y], 2)[:, :, :, None]
+
+    # Calcualte Gaussian for every pixel of the kernel
+    ZZ = Z - MU
+    ZZ_t = ZZ.transpose(0, 1, 3, 2)
+    raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
+
+    # shift the kernel so it will be centered
+    # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
+
+    # Normalize the kernel and return
+    # kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
+    kernel = raw_kernel / np.sum(raw_kernel)
+    return kernel
+
+
+def fspecial_gaussian(hsize, sigma):
+    hsize = [hsize, hsize]
+    siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
+    std = sigma
+    [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
+    arg = -(x * x + y * y) / (2 * std * std)
+    h = np.exp(arg)
+    h[h < scipy.finfo(float).eps * h.max()] = 0
+    sumh = h.sum()
+    if sumh != 0:
+        h = h / sumh
+    return h
+
+
+def fspecial_laplacian(alpha):
+    alpha = max([0, min([alpha, 1])])
+    h1 = alpha / (alpha + 1)
+    h2 = (1 - alpha) / (alpha + 1)
+    h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
+    h = np.array(h)
+    return h
+
+
+def fspecial(filter_type, *args, **kwargs):
+    '''
+    python code from:
+    https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
+    '''
+    if filter_type == 'gaussian':
+        return fspecial_gaussian(*args, **kwargs)
+    if filter_type == 'laplacian':
+        return fspecial_laplacian(*args, **kwargs)
+
+
+"""
+# --------------------------------------------
+# degradation models
+# --------------------------------------------
+"""
+
+
+def bicubic_degradation(x, sf=3):
+    '''
+    Args:
+        x: HxWxC image, [0, 1]
+        sf: down-scale factor
+    Return:
+        bicubicly downsampled LR image
+    '''
+    x = util.imresize_np(x, scale=1 / sf)
+    return x
+
+
+def srmd_degradation(x, k, sf=3):
+    ''' blur + bicubic downsampling
+    Args:
+        x: HxWxC image, [0, 1]
+        k: hxw, double
+        sf: down-scale factor
+    Return:
+        downsampled LR image
+    Reference:
+        @inproceedings{zhang2018learning,
+          title={Learning a single convolutional super-resolution network for multiple degradations},
+          author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+          booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+          pages={3262--3271},
+          year={2018}
+        }
+    '''
+    x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')  # 'nearest' | 'mirror'
+    x = bicubic_degradation(x, sf=sf)
+    return x
+
+
+def dpsr_degradation(x, k, sf=3):
+    ''' bicubic downsampling + blur
+    Args:
+        x: HxWxC image, [0, 1]
+        k: hxw, double
+        sf: down-scale factor
+    Return:
+        downsampled LR image
+    Reference:
+        @inproceedings{zhang2019deep,
+          title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
+          author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+          booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+          pages={1671--1681},
+          year={2019}
+        }
+    '''
+    x = bicubic_degradation(x, sf=sf)
+    x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+    return x
+
+
+def classical_degradation(x, k, sf=3):
+    ''' blur + downsampling
+    Args:
+        x: HxWxC image, [0, 1]/[0, 255]
+        k: hxw, double
+        sf: down-scale factor
+    Return:
+        downsampled LR image
+    '''
+    x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+    # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
+    st = 0
+    return x[st::sf, st::sf, ...]
+
+
+def add_sharpening(img, weight=0.5, radius=50, threshold=10):
+    """USM sharpening. borrowed from real-ESRGAN
+    Input image: I; Blurry image: B.
+    1. K = I + weight * (I - B)
+    2. Mask = 1 if abs(I - B) > threshold, else: 0
+    3. Blur mask:
+    4. Out = Mask * K + (1 - Mask) * I
+    Args:
+        img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
+        weight (float): Sharp weight. Default: 1.
+        radius (float): Kernel size of Gaussian blur. Default: 50.
+        threshold (int):
+    """
+    if radius % 2 == 0:
+        radius += 1
+    blur = cv2.GaussianBlur(img, (radius, radius), 0)
+    residual = img - blur
+    mask = np.abs(residual) * 255 > threshold
+    mask = mask.astype('float32')
+    soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
+
+    K = img + weight * residual
+    K = np.clip(K, 0, 1)
+    return soft_mask * K + (1 - soft_mask) * img
+
+
+def add_blur(img, sf=4):
+    wd2 = 4.0 + sf
+    wd = 2.0 + 0.2 * sf
+    if random.random() < 0.5:
+        l1 = wd2 * random.random()
+        l2 = wd2 * random.random()
+        k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
+    else:
+        k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random())
+    img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
+
+    return img
+
+
+def add_resize(img, sf=4):
+    rnum = np.random.rand()
+    if rnum > 0.8:  # up
+        sf1 = random.uniform(1, 2)
+    elif rnum < 0.7:  # down
+        sf1 = random.uniform(0.5 / sf, 1)
+    else:
+        sf1 = 1.0
+    img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
+    img = np.clip(img, 0.0, 1.0)
+
+    return img
+
+
+# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+#     noise_level = random.randint(noise_level1, noise_level2)
+#     rnum = np.random.rand()
+#     if rnum > 0.6:  # add color Gaussian noise
+#         img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+#     elif rnum < 0.4:  # add grayscale Gaussian noise
+#         img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+#     else:  # add  noise
+#         L = noise_level2 / 255.
+#         D = np.diag(np.random.rand(3))
+#         U = orth(np.random.rand(3, 3))
+#         conv = np.dot(np.dot(np.transpose(U), D), U)
+#         img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+#     img = np.clip(img, 0.0, 1.0)
+#     return img
+
+def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+    noise_level = random.randint(noise_level1, noise_level2)
+    rnum = np.random.rand()
+    if rnum > 0.6:  # add color Gaussian noise
+        img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+    elif rnum < 0.4:  # add grayscale Gaussian noise
+        img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+    else:  # add  noise
+        L = noise_level2 / 255.
+        D = np.diag(np.random.rand(3))
+        U = orth(np.random.rand(3, 3))
+        conv = np.dot(np.dot(np.transpose(U), D), U)
+        img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+    img = np.clip(img, 0.0, 1.0)
+    return img
+
+
+def add_speckle_noise(img, noise_level1=2, noise_level2=25):
+    noise_level = random.randint(noise_level1, noise_level2)
+    img = np.clip(img, 0.0, 1.0)
+    rnum = random.random()
+    if rnum > 0.6:
+        img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+    elif rnum < 0.4:
+        img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+    else:
+        L = noise_level2 / 255.
+        D = np.diag(np.random.rand(3))
+        U = orth(np.random.rand(3, 3))
+        conv = np.dot(np.dot(np.transpose(U), D), U)
+        img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+    img = np.clip(img, 0.0, 1.0)
+    return img
+
+
+def add_Poisson_noise(img):
+    img = np.clip((img * 255.0).round(), 0, 255) / 255.
+    vals = 10 ** (2 * random.random() + 2.0)  # [2, 4]
+    if random.random() < 0.5:
+        img = np.random.poisson(img * vals).astype(np.float32) / vals
+    else:
+        img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
+        img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
+        noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
+        img += noise_gray[:, :, np.newaxis]
+    img = np.clip(img, 0.0, 1.0)
+    return img
+
+
+def add_JPEG_noise(img):
+    quality_factor = random.randint(30, 95)
+    img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
+    result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
+    img = cv2.imdecode(encimg, 1)
+    img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
+    return img
+
+
+def random_crop(lq, hq, sf=4, lq_patchsize=64):
+    h, w = lq.shape[:2]
+    rnd_h = random.randint(0, h - lq_patchsize)
+    rnd_w = random.randint(0, w - lq_patchsize)
+    lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
+
+    rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
+    hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
+    return lq, hq
+
+
+def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
+    """
+    This is the degradation model of BSRGAN from the paper
+    "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+    ----------
+    img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
+    sf: scale factor
+    isp_model: camera ISP model
+    Returns
+    -------
+    img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+    hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+    """
+    isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+    sf_ori = sf
+
+    h1, w1 = img.shape[:2]
+    img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...]  # mod crop
+    h, w = img.shape[:2]
+
+    if h < lq_patchsize * sf or w < lq_patchsize * sf:
+        raise ValueError(f'img size ({h1}X{w1}) is too small!')
+
+    hq = img.copy()
+
+    if sf == 4 and random.random() < scale2_prob:  # downsample1
+        if np.random.rand() < 0.5:
+            img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
+                             interpolation=random.choice([1, 2, 3]))
+        else:
+            img = util.imresize_np(img, 1 / 2, True)
+        img = np.clip(img, 0.0, 1.0)
+        sf = 2
+
+    shuffle_order = random.sample(range(7), 7)
+    idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+    if idx1 > idx2:  # keep downsample3 last
+        shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+
+    for i in shuffle_order:
+
+        if i == 0:
+            img = add_blur(img, sf=sf)
+
+        elif i == 1:
+            img = add_blur(img, sf=sf)
+
+        elif i == 2:
+            a, b = img.shape[1], img.shape[0]
+            # downsample2
+            if random.random() < 0.75:
+                sf1 = random.uniform(1, 2 * sf)
+                img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
+                                 interpolation=random.choice([1, 2, 3]))
+            else:
+                k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+                k_shifted = shift_pixel(k, sf)
+                k_shifted = k_shifted / k_shifted.sum()  # blur with shifted kernel
+                img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
+                img = img[0::sf, 0::sf, ...]  # nearest downsampling
+            img = np.clip(img, 0.0, 1.0)
+
+        elif i == 3:
+            # downsample3
+            img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+            img = np.clip(img, 0.0, 1.0)
+
+        elif i == 4:
+            # add Gaussian noise
+            img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
+
+        elif i == 5:
+            # add JPEG noise
+            if random.random() < jpeg_prob:
+                img = add_JPEG_noise(img)
+
+        elif i == 6:
+            # add processed camera sensor noise
+            if random.random() < isp_prob and isp_model is not None:
+                with torch.no_grad():
+                    img, hq = isp_model.forward(img.copy(), hq)
+
+    # add final JPEG compression noise
+    img = add_JPEG_noise(img)
+
+    # random crop
+    img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
+
+    return img, hq
+
+
+# todo no isp_model?
+def degradation_bsrgan_variant(image, sf=4, isp_model=None):
+    """
+    This is the degradation model of BSRGAN from the paper
+    "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+    ----------
+    sf: scale factor
+    isp_model: camera ISP model
+    Returns
+    -------
+    img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+    hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+    """
+    image = util.uint2single(image)
+    isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+    sf_ori = sf
+
+    h1, w1 = image.shape[:2]
+    image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...]  # mod crop
+    h, w = image.shape[:2]
+
+    hq = image.copy()
+
+    if sf == 4 and random.random() < scale2_prob:  # downsample1
+        if np.random.rand() < 0.5:
+            image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
+                               interpolation=random.choice([1, 2, 3]))
+        else:
+            image = util.imresize_np(image, 1 / 2, True)
+        image = np.clip(image, 0.0, 1.0)
+        sf = 2
+
+    shuffle_order = random.sample(range(7), 7)
+    idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+    if idx1 > idx2:  # keep downsample3 last
+        shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+
+    for i in shuffle_order:
+
+        if i == 0:
+            image = add_blur(image, sf=sf)
+
+        elif i == 1:
+            image = add_blur(image, sf=sf)
+
+        elif i == 2:
+            a, b = image.shape[1], image.shape[0]
+            # downsample2
+            if random.random() < 0.75:
+                sf1 = random.uniform(1, 2 * sf)
+                image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
+                                   interpolation=random.choice([1, 2, 3]))
+            else:
+                k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+                k_shifted = shift_pixel(k, sf)
+                k_shifted = k_shifted / k_shifted.sum()  # blur with shifted kernel
+                image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
+                image = image[0::sf, 0::sf, ...]  # nearest downsampling
+            image = np.clip(image, 0.0, 1.0)
+
+        elif i == 3:
+            # downsample3
+            image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+            image = np.clip(image, 0.0, 1.0)
+
+        elif i == 4:
+            # add Gaussian noise
+            image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25)
+
+        elif i == 5:
+            # add JPEG noise
+            if random.random() < jpeg_prob:
+                image = add_JPEG_noise(image)
+
+        # elif i == 6:
+        #     # add processed camera sensor noise
+        #     if random.random() < isp_prob and isp_model is not None:
+        #         with torch.no_grad():
+        #             img, hq = isp_model.forward(img.copy(), hq)
+
+    # add final JPEG compression noise
+    image = add_JPEG_noise(image)
+    image = util.single2uint(image)
+    example = {"image":image}
+    return example
+
+
+# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc...
+def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None):
+    """
+    This is an extended degradation model by combining
+    the degradation models of BSRGAN and Real-ESRGAN
+    ----------
+    img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
+    sf: scale factor
+    use_shuffle: the degradation shuffle
+    use_sharp: sharpening the img
+    Returns
+    -------
+    img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+    hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+    """
+
+    h1, w1 = img.shape[:2]
+    img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...]  # mod crop
+    h, w = img.shape[:2]
+
+    if h < lq_patchsize * sf or w < lq_patchsize * sf:
+        raise ValueError(f'img size ({h1}X{w1}) is too small!')
+
+    if use_sharp:
+        img = add_sharpening(img)
+    hq = img.copy()
+
+    if random.random() < shuffle_prob:
+        shuffle_order = random.sample(range(13), 13)
+    else:
+        shuffle_order = list(range(13))
+        # local shuffle for noise, JPEG is always the last one
+        shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6)))
+        shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13)))
+
+    poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1
+
+    for i in shuffle_order:
+        if i == 0:
+            img = add_blur(img, sf=sf)
+        elif i == 1:
+            img = add_resize(img, sf=sf)
+        elif i == 2:
+            img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
+        elif i == 3:
+            if random.random() < poisson_prob:
+                img = add_Poisson_noise(img)
+        elif i == 4:
+            if random.random() < speckle_prob:
+                img = add_speckle_noise(img)
+        elif i == 5:
+            if random.random() < isp_prob and isp_model is not None:
+                with torch.no_grad():
+                    img, hq = isp_model.forward(img.copy(), hq)
+        elif i == 6:
+            img = add_JPEG_noise(img)
+        elif i == 7:
+            img = add_blur(img, sf=sf)
+        elif i == 8:
+            img = add_resize(img, sf=sf)
+        elif i == 9:
+            img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
+        elif i == 10:
+            if random.random() < poisson_prob:
+                img = add_Poisson_noise(img)
+        elif i == 11:
+            if random.random() < speckle_prob:
+                img = add_speckle_noise(img)
+        elif i == 12:
+            if random.random() < isp_prob and isp_model is not None:
+                with torch.no_grad():
+                    img, hq = isp_model.forward(img.copy(), hq)
+        else:
+            print('check the shuffle!')
+
+    # resize to desired size
+    img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])),
+                     interpolation=random.choice([1, 2, 3]))
+
+    # add final JPEG compression noise
+    img = add_JPEG_noise(img)
+
+    # random crop
+    img, hq = random_crop(img, hq, sf, lq_patchsize)
+
+    return img, hq
+
+
+if __name__ == '__main__':
+	print("hey")
+	img = util.imread_uint('utils/test.png', 3)
+	print(img)
+	img = util.uint2single(img)
+	print(img)
+	img = img[:448, :448]
+	h = img.shape[0] // 4
+	print("resizing to", h)
+	sf = 4
+	deg_fn = partial(degradation_bsrgan_variant, sf=sf)
+	for i in range(20):
+		print(i)
+		img_lq = deg_fn(img)
+		print(img_lq)
+		img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"]
+		print(img_lq.shape)
+		print("bicubic", img_lq_bicubic.shape)
+		print(img_hq.shape)
+		lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+		                        interpolation=0)
+		lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+		                        interpolation=0)
+		img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
+		util.imsave(img_concat, str(i) + '.png')
+
+
diff --git a/ldm/modules/image_degradation/bsrgan_light.py b/ldm/modules/image_degradation/bsrgan_light.py
new file mode 100644
index 0000000000000000000000000000000000000000..808c7f882cb75e2ba2340d5b55881d11927351f0
--- /dev/null
+++ b/ldm/modules/image_degradation/bsrgan_light.py
@@ -0,0 +1,651 @@
+# -*- coding: utf-8 -*-
+import numpy as np
+import cv2
+import torch
+
+from functools import partial
+import random
+from scipy import ndimage
+import scipy
+import scipy.stats as ss
+from scipy.interpolate import interp2d
+from scipy.linalg import orth
+import albumentations
+
+import ldm.modules.image_degradation.utils_image as util
+
+"""
+# --------------------------------------------
+# Super-Resolution
+# --------------------------------------------
+#
+# Kai Zhang (cskaizhang@gmail.com)
+# https://github.com/cszn
+# From 2019/03--2021/08
+# --------------------------------------------
+"""
+
+def modcrop_np(img, sf):
+    '''
+    Args:
+        img: numpy image, WxH or WxHxC
+        sf: scale factor
+    Return:
+        cropped image
+    '''
+    w, h = img.shape[:2]
+    im = np.copy(img)
+    return im[:w - w % sf, :h - h % sf, ...]
+
+
+"""
+# --------------------------------------------
+# anisotropic Gaussian kernels
+# --------------------------------------------
+"""
+
+
+def analytic_kernel(k):
+    """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
+    k_size = k.shape[0]
+    # Calculate the big kernels size
+    big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
+    # Loop over the small kernel to fill the big one
+    for r in range(k_size):
+        for c in range(k_size):
+            big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
+    # Crop the edges of the big kernel to ignore very small values and increase run time of SR
+    crop = k_size // 2
+    cropped_big_k = big_k[crop:-crop, crop:-crop]
+    # Normalize to 1
+    return cropped_big_k / cropped_big_k.sum()
+
+
+def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
+    """ generate an anisotropic Gaussian kernel
+    Args:
+        ksize : e.g., 15, kernel size
+        theta : [0,  pi], rotation angle range
+        l1    : [0.1,50], scaling of eigenvalues
+        l2    : [0.1,l1], scaling of eigenvalues
+        If l1 = l2, will get an isotropic Gaussian kernel.
+    Returns:
+        k     : kernel
+    """
+
+    v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
+    V = np.array([[v[0], v[1]], [v[1], -v[0]]])
+    D = np.array([[l1, 0], [0, l2]])
+    Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
+    k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
+
+    return k
+
+
+def gm_blur_kernel(mean, cov, size=15):
+    center = size / 2.0 + 0.5
+    k = np.zeros([size, size])
+    for y in range(size):
+        for x in range(size):
+            cy = y - center + 1
+            cx = x - center + 1
+            k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
+
+    k = k / np.sum(k)
+    return k
+
+
+def shift_pixel(x, sf, upper_left=True):
+    """shift pixel for super-resolution with different scale factors
+    Args:
+        x: WxHxC or WxH
+        sf: scale factor
+        upper_left: shift direction
+    """
+    h, w = x.shape[:2]
+    shift = (sf - 1) * 0.5
+    xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
+    if upper_left:
+        x1 = xv + shift
+        y1 = yv + shift
+    else:
+        x1 = xv - shift
+        y1 = yv - shift
+
+    x1 = np.clip(x1, 0, w - 1)
+    y1 = np.clip(y1, 0, h - 1)
+
+    if x.ndim == 2:
+        x = interp2d(xv, yv, x)(x1, y1)
+    if x.ndim == 3:
+        for i in range(x.shape[-1]):
+            x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
+
+    return x
+
+
+def blur(x, k):
+    '''
+    x: image, NxcxHxW
+    k: kernel, Nx1xhxw
+    '''
+    n, c = x.shape[:2]
+    p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
+    x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
+    k = k.repeat(1, c, 1, 1)
+    k = k.view(-1, 1, k.shape[2], k.shape[3])
+    x = x.view(1, -1, x.shape[2], x.shape[3])
+    x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
+    x = x.view(n, c, x.shape[2], x.shape[3])
+
+    return x
+
+
+def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
+    """"
+    # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
+    # Kai Zhang
+    # min_var = 0.175 * sf  # variance of the gaussian kernel will be sampled between min_var and max_var
+    # max_var = 2.5 * sf
+    """
+    # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
+    lambda_1 = min_var + np.random.rand() * (max_var - min_var)
+    lambda_2 = min_var + np.random.rand() * (max_var - min_var)
+    theta = np.random.rand() * np.pi  # random theta
+    noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
+
+    # Set COV matrix using Lambdas and Theta
+    LAMBDA = np.diag([lambda_1, lambda_2])
+    Q = np.array([[np.cos(theta), -np.sin(theta)],
+                  [np.sin(theta), np.cos(theta)]])
+    SIGMA = Q @ LAMBDA @ Q.T
+    INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
+
+    # Set expectation position (shifting kernel for aligned image)
+    MU = k_size // 2 - 0.5 * (scale_factor - 1)  # - 0.5 * (scale_factor - k_size % 2)
+    MU = MU[None, None, :, None]
+
+    # Create meshgrid for Gaussian
+    [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
+    Z = np.stack([X, Y], 2)[:, :, :, None]
+
+    # Calcualte Gaussian for every pixel of the kernel
+    ZZ = Z - MU
+    ZZ_t = ZZ.transpose(0, 1, 3, 2)
+    raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
+
+    # shift the kernel so it will be centered
+    # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
+
+    # Normalize the kernel and return
+    # kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
+    kernel = raw_kernel / np.sum(raw_kernel)
+    return kernel
+
+
+def fspecial_gaussian(hsize, sigma):
+    hsize = [hsize, hsize]
+    siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
+    std = sigma
+    [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
+    arg = -(x * x + y * y) / (2 * std * std)
+    h = np.exp(arg)
+    h[h < scipy.finfo(float).eps * h.max()] = 0
+    sumh = h.sum()
+    if sumh != 0:
+        h = h / sumh
+    return h
+
+
+def fspecial_laplacian(alpha):
+    alpha = max([0, min([alpha, 1])])
+    h1 = alpha / (alpha + 1)
+    h2 = (1 - alpha) / (alpha + 1)
+    h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
+    h = np.array(h)
+    return h
+
+
+def fspecial(filter_type, *args, **kwargs):
+    '''
+    python code from:
+    https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
+    '''
+    if filter_type == 'gaussian':
+        return fspecial_gaussian(*args, **kwargs)
+    if filter_type == 'laplacian':
+        return fspecial_laplacian(*args, **kwargs)
+
+
+"""
+# --------------------------------------------
+# degradation models
+# --------------------------------------------
+"""
+
+
+def bicubic_degradation(x, sf=3):
+    '''
+    Args:
+        x: HxWxC image, [0, 1]
+        sf: down-scale factor
+    Return:
+        bicubicly downsampled LR image
+    '''
+    x = util.imresize_np(x, scale=1 / sf)
+    return x
+
+
+def srmd_degradation(x, k, sf=3):
+    ''' blur + bicubic downsampling
+    Args:
+        x: HxWxC image, [0, 1]
+        k: hxw, double
+        sf: down-scale factor
+    Return:
+        downsampled LR image
+    Reference:
+        @inproceedings{zhang2018learning,
+          title={Learning a single convolutional super-resolution network for multiple degradations},
+          author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+          booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+          pages={3262--3271},
+          year={2018}
+        }
+    '''
+    x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap')  # 'nearest' | 'mirror'
+    x = bicubic_degradation(x, sf=sf)
+    return x
+
+
+def dpsr_degradation(x, k, sf=3):
+    ''' bicubic downsampling + blur
+    Args:
+        x: HxWxC image, [0, 1]
+        k: hxw, double
+        sf: down-scale factor
+    Return:
+        downsampled LR image
+    Reference:
+        @inproceedings{zhang2019deep,
+          title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
+          author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+          booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+          pages={1671--1681},
+          year={2019}
+        }
+    '''
+    x = bicubic_degradation(x, sf=sf)
+    x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+    return x
+
+
+def classical_degradation(x, k, sf=3):
+    ''' blur + downsampling
+    Args:
+        x: HxWxC image, [0, 1]/[0, 255]
+        k: hxw, double
+        sf: down-scale factor
+    Return:
+        downsampled LR image
+    '''
+    x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+    # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
+    st = 0
+    return x[st::sf, st::sf, ...]
+
+
+def add_sharpening(img, weight=0.5, radius=50, threshold=10):
+    """USM sharpening. borrowed from real-ESRGAN
+    Input image: I; Blurry image: B.
+    1. K = I + weight * (I - B)
+    2. Mask = 1 if abs(I - B) > threshold, else: 0
+    3. Blur mask:
+    4. Out = Mask * K + (1 - Mask) * I
+    Args:
+        img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
+        weight (float): Sharp weight. Default: 1.
+        radius (float): Kernel size of Gaussian blur. Default: 50.
+        threshold (int):
+    """
+    if radius % 2 == 0:
+        radius += 1
+    blur = cv2.GaussianBlur(img, (radius, radius), 0)
+    residual = img - blur
+    mask = np.abs(residual) * 255 > threshold
+    mask = mask.astype('float32')
+    soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
+
+    K = img + weight * residual
+    K = np.clip(K, 0, 1)
+    return soft_mask * K + (1 - soft_mask) * img
+
+
+def add_blur(img, sf=4):
+    wd2 = 4.0 + sf
+    wd = 2.0 + 0.2 * sf
+
+    wd2 = wd2/4
+    wd = wd/4
+
+    if random.random() < 0.5:
+        l1 = wd2 * random.random()
+        l2 = wd2 * random.random()
+        k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
+    else:
+        k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random())
+    img = ndimage.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
+
+    return img
+
+
+def add_resize(img, sf=4):
+    rnum = np.random.rand()
+    if rnum > 0.8:  # up
+        sf1 = random.uniform(1, 2)
+    elif rnum < 0.7:  # down
+        sf1 = random.uniform(0.5 / sf, 1)
+    else:
+        sf1 = 1.0
+    img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
+    img = np.clip(img, 0.0, 1.0)
+
+    return img
+
+
+# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+#     noise_level = random.randint(noise_level1, noise_level2)
+#     rnum = np.random.rand()
+#     if rnum > 0.6:  # add color Gaussian noise
+#         img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+#     elif rnum < 0.4:  # add grayscale Gaussian noise
+#         img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+#     else:  # add  noise
+#         L = noise_level2 / 255.
+#         D = np.diag(np.random.rand(3))
+#         U = orth(np.random.rand(3, 3))
+#         conv = np.dot(np.dot(np.transpose(U), D), U)
+#         img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+#     img = np.clip(img, 0.0, 1.0)
+#     return img
+
+def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+    noise_level = random.randint(noise_level1, noise_level2)
+    rnum = np.random.rand()
+    if rnum > 0.6:  # add color Gaussian noise
+        img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+    elif rnum < 0.4:  # add grayscale Gaussian noise
+        img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+    else:  # add  noise
+        L = noise_level2 / 255.
+        D = np.diag(np.random.rand(3))
+        U = orth(np.random.rand(3, 3))
+        conv = np.dot(np.dot(np.transpose(U), D), U)
+        img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+    img = np.clip(img, 0.0, 1.0)
+    return img
+
+
+def add_speckle_noise(img, noise_level1=2, noise_level2=25):
+    noise_level = random.randint(noise_level1, noise_level2)
+    img = np.clip(img, 0.0, 1.0)
+    rnum = random.random()
+    if rnum > 0.6:
+        img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+    elif rnum < 0.4:
+        img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+    else:
+        L = noise_level2 / 255.
+        D = np.diag(np.random.rand(3))
+        U = orth(np.random.rand(3, 3))
+        conv = np.dot(np.dot(np.transpose(U), D), U)
+        img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+    img = np.clip(img, 0.0, 1.0)
+    return img
+
+
+def add_Poisson_noise(img):
+    img = np.clip((img * 255.0).round(), 0, 255) / 255.
+    vals = 10 ** (2 * random.random() + 2.0)  # [2, 4]
+    if random.random() < 0.5:
+        img = np.random.poisson(img * vals).astype(np.float32) / vals
+    else:
+        img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
+        img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
+        noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
+        img += noise_gray[:, :, np.newaxis]
+    img = np.clip(img, 0.0, 1.0)
+    return img
+
+
+def add_JPEG_noise(img):
+    quality_factor = random.randint(80, 95)
+    img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
+    result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
+    img = cv2.imdecode(encimg, 1)
+    img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
+    return img
+
+
+def random_crop(lq, hq, sf=4, lq_patchsize=64):
+    h, w = lq.shape[:2]
+    rnd_h = random.randint(0, h - lq_patchsize)
+    rnd_w = random.randint(0, w - lq_patchsize)
+    lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
+
+    rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
+    hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
+    return lq, hq
+
+
+def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
+    """
+    This is the degradation model of BSRGAN from the paper
+    "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+    ----------
+    img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
+    sf: scale factor
+    isp_model: camera ISP model
+    Returns
+    -------
+    img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+    hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+    """
+    isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+    sf_ori = sf
+
+    h1, w1 = img.shape[:2]
+    img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...]  # mod crop
+    h, w = img.shape[:2]
+
+    if h < lq_patchsize * sf or w < lq_patchsize * sf:
+        raise ValueError(f'img size ({h1}X{w1}) is too small!')
+
+    hq = img.copy()
+
+    if sf == 4 and random.random() < scale2_prob:  # downsample1
+        if np.random.rand() < 0.5:
+            img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
+                             interpolation=random.choice([1, 2, 3]))
+        else:
+            img = util.imresize_np(img, 1 / 2, True)
+        img = np.clip(img, 0.0, 1.0)
+        sf = 2
+
+    shuffle_order = random.sample(range(7), 7)
+    idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+    if idx1 > idx2:  # keep downsample3 last
+        shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+
+    for i in shuffle_order:
+
+        if i == 0:
+            img = add_blur(img, sf=sf)
+
+        elif i == 1:
+            img = add_blur(img, sf=sf)
+
+        elif i == 2:
+            a, b = img.shape[1], img.shape[0]
+            # downsample2
+            if random.random() < 0.75:
+                sf1 = random.uniform(1, 2 * sf)
+                img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
+                                 interpolation=random.choice([1, 2, 3]))
+            else:
+                k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+                k_shifted = shift_pixel(k, sf)
+                k_shifted = k_shifted / k_shifted.sum()  # blur with shifted kernel
+                img = ndimage.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
+                img = img[0::sf, 0::sf, ...]  # nearest downsampling
+            img = np.clip(img, 0.0, 1.0)
+
+        elif i == 3:
+            # downsample3
+            img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+            img = np.clip(img, 0.0, 1.0)
+
+        elif i == 4:
+            # add Gaussian noise
+            img = add_Gaussian_noise(img, noise_level1=2, noise_level2=8)
+
+        elif i == 5:
+            # add JPEG noise
+            if random.random() < jpeg_prob:
+                img = add_JPEG_noise(img)
+
+        elif i == 6:
+            # add processed camera sensor noise
+            if random.random() < isp_prob and isp_model is not None:
+                with torch.no_grad():
+                    img, hq = isp_model.forward(img.copy(), hq)
+
+    # add final JPEG compression noise
+    img = add_JPEG_noise(img)
+
+    # random crop
+    img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
+
+    return img, hq
+
+
+# todo no isp_model?
+def degradation_bsrgan_variant(image, sf=4, isp_model=None, up=False):
+    """
+    This is the degradation model of BSRGAN from the paper
+    "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+    ----------
+    sf: scale factor
+    isp_model: camera ISP model
+    Returns
+    -------
+    img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+    hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+    """
+    image = util.uint2single(image)
+    isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+    sf_ori = sf
+
+    h1, w1 = image.shape[:2]
+    image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...]  # mod crop
+    h, w = image.shape[:2]
+
+    hq = image.copy()
+
+    if sf == 4 and random.random() < scale2_prob:  # downsample1
+        if np.random.rand() < 0.5:
+            image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
+                               interpolation=random.choice([1, 2, 3]))
+        else:
+            image = util.imresize_np(image, 1 / 2, True)
+        image = np.clip(image, 0.0, 1.0)
+        sf = 2
+
+    shuffle_order = random.sample(range(7), 7)
+    idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+    if idx1 > idx2:  # keep downsample3 last
+        shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+
+    for i in shuffle_order:
+
+        if i == 0:
+            image = add_blur(image, sf=sf)
+
+        # elif i == 1:
+        #     image = add_blur(image, sf=sf)
+
+        if i == 0:
+            pass
+
+        elif i == 2:
+            a, b = image.shape[1], image.shape[0]
+            # downsample2
+            if random.random() < 0.8:
+                sf1 = random.uniform(1, 2 * sf)
+                image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
+                                   interpolation=random.choice([1, 2, 3]))
+            else:
+                k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+                k_shifted = shift_pixel(k, sf)
+                k_shifted = k_shifted / k_shifted.sum()  # blur with shifted kernel
+                image = ndimage.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
+                image = image[0::sf, 0::sf, ...]  # nearest downsampling
+
+            image = np.clip(image, 0.0, 1.0)
+
+        elif i == 3:
+            # downsample3
+            image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+            image = np.clip(image, 0.0, 1.0)
+
+        elif i == 4:
+            # add Gaussian noise
+            image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2)
+
+        elif i == 5:
+            # add JPEG noise
+            if random.random() < jpeg_prob:
+                image = add_JPEG_noise(image)
+        #
+        # elif i == 6:
+        #     # add processed camera sensor noise
+        #     if random.random() < isp_prob and isp_model is not None:
+        #         with torch.no_grad():
+        #             img, hq = isp_model.forward(img.copy(), hq)
+
+    # add final JPEG compression noise
+    image = add_JPEG_noise(image)
+    image = util.single2uint(image)
+    if up:
+        image = cv2.resize(image, (w1, h1), interpolation=cv2.INTER_CUBIC)  # todo: random, as above? want to condition on it then
+    example = {"image": image}
+    return example
+
+
+
+
+if __name__ == '__main__':
+    print("hey")
+    img = util.imread_uint('utils/test.png', 3)
+    img = img[:448, :448]
+    h = img.shape[0] // 4
+    print("resizing to", h)
+    sf = 4
+    deg_fn = partial(degradation_bsrgan_variant, sf=sf)
+    for i in range(20):
+        print(i)
+        img_hq = img
+        img_lq = deg_fn(img)["image"]
+        img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq)
+        print(img_lq)
+        img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"]
+        print(img_lq.shape)
+        print("bicubic", img_lq_bicubic.shape)
+        print(img_hq.shape)
+        lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+                                interpolation=0)
+        lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic),
+                                        (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+                                        interpolation=0)
+        img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
+        util.imsave(img_concat, str(i) + '.png')
diff --git a/ldm/modules/image_degradation/utils/test.png b/ldm/modules/image_degradation/utils/test.png
new file mode 100644
index 0000000000000000000000000000000000000000..4249b43de0f22707758d13c240268a401642f6e6
Binary files /dev/null and b/ldm/modules/image_degradation/utils/test.png differ
diff --git a/ldm/modules/image_degradation/utils_image.py b/ldm/modules/image_degradation/utils_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..0175f155ad900ae33c3c46ed87f49b352e3faf98
--- /dev/null
+++ b/ldm/modules/image_degradation/utils_image.py
@@ -0,0 +1,916 @@
+import os
+import math
+import random
+import numpy as np
+import torch
+import cv2
+from torchvision.utils import make_grid
+from datetime import datetime
+#import matplotlib.pyplot as plt   # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
+
+
+os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
+
+
+'''
+# --------------------------------------------
+# Kai Zhang (github: https://github.com/cszn)
+# 03/Mar/2019
+# --------------------------------------------
+# https://github.com/twhui/SRGAN-pyTorch
+# https://github.com/xinntao/BasicSR
+# --------------------------------------------
+'''
+
+
+IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif']
+
+
+def is_image_file(filename):
+    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
+
+
+def get_timestamp():
+    return datetime.now().strftime('%y%m%d-%H%M%S')
+
+
+def imshow(x, title=None, cbar=False, figsize=None):
+    plt.figure(figsize=figsize)
+    plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray')
+    if title:
+        plt.title(title)
+    if cbar:
+        plt.colorbar()
+    plt.show()
+
+
+def surf(Z, cmap='rainbow', figsize=None):
+    plt.figure(figsize=figsize)
+    ax3 = plt.axes(projection='3d')
+
+    w, h = Z.shape[:2]
+    xx = np.arange(0,w,1)
+    yy = np.arange(0,h,1)
+    X, Y = np.meshgrid(xx, yy)
+    ax3.plot_surface(X,Y,Z,cmap=cmap)
+    #ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap)
+    plt.show()
+
+
+'''
+# --------------------------------------------
+# get image pathes
+# --------------------------------------------
+'''
+
+
+def get_image_paths(dataroot):
+    paths = None  # return None if dataroot is None
+    if dataroot is not None:
+        paths = sorted(_get_paths_from_images(dataroot))
+    return paths
+
+
+def _get_paths_from_images(path):
+    assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
+    images = []
+    for dirpath, _, fnames in sorted(os.walk(path)):
+        for fname in sorted(fnames):
+            if is_image_file(fname):
+                img_path = os.path.join(dirpath, fname)
+                images.append(img_path)
+    assert images, '{:s} has no valid image file'.format(path)
+    return images
+
+
+'''
+# --------------------------------------------
+# split large images into small images 
+# --------------------------------------------
+'''
+
+
+def patches_from_image(img, p_size=512, p_overlap=64, p_max=800):
+    w, h = img.shape[:2]
+    patches = []
+    if w > p_max and h > p_max:
+        w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int))
+        h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int))
+        w1.append(w-p_size)
+        h1.append(h-p_size)
+#        print(w1)
+#        print(h1)
+        for i in w1:
+            for j in h1:
+                patches.append(img[i:i+p_size, j:j+p_size,:])
+    else:
+        patches.append(img)
+
+    return patches
+
+
+def imssave(imgs, img_path):
+    """
+    imgs: list, N images of size WxHxC
+    """
+    img_name, ext = os.path.splitext(os.path.basename(img_path))
+
+    for i, img in enumerate(imgs):
+        if img.ndim == 3:
+            img = img[:, :, [2, 1, 0]]
+        new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png')
+        cv2.imwrite(new_path, img)
+
+
+def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, p_overlap=96, p_max=1000):
+    """
+    split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size),
+    and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max)
+    will be splitted.
+    Args:
+        original_dataroot:
+        taget_dataroot:
+        p_size: size of small images
+        p_overlap: patch size in training is a good choice
+        p_max: images with smaller size than (p_max)x(p_max) keep unchanged.
+    """
+    paths = get_image_paths(original_dataroot)
+    for img_path in paths:
+        # img_name, ext = os.path.splitext(os.path.basename(img_path))
+        img = imread_uint(img_path, n_channels=n_channels)
+        patches = patches_from_image(img, p_size, p_overlap, p_max)
+        imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path)))
+        #if original_dataroot == taget_dataroot:
+        #del img_path
+
+'''
+# --------------------------------------------
+# makedir
+# --------------------------------------------
+'''
+
+
+def mkdir(path):
+    if not os.path.exists(path):
+        os.makedirs(path)
+
+
+def mkdirs(paths):
+    if isinstance(paths, str):
+        mkdir(paths)
+    else:
+        for path in paths:
+            mkdir(path)
+
+
+def mkdir_and_rename(path):
+    if os.path.exists(path):
+        new_name = path + '_archived_' + get_timestamp()
+        print('Path already exists. Rename it to [{:s}]'.format(new_name))
+        os.rename(path, new_name)
+    os.makedirs(path)
+
+
+'''
+# --------------------------------------------
+# read image from path
+# opencv is fast, but read BGR numpy image
+# --------------------------------------------
+'''
+
+
+# --------------------------------------------
+# get uint8 image of size HxWxn_channles (RGB)
+# --------------------------------------------
+def imread_uint(path, n_channels=3):
+    #  input: path
+    # output: HxWx3(RGB or GGG), or HxWx1 (G)
+    if n_channels == 1:
+        img = cv2.imread(path, 0)  # cv2.IMREAD_GRAYSCALE
+        img = np.expand_dims(img, axis=2)  # HxWx1
+    elif n_channels == 3:
+        img = cv2.imread(path, cv2.IMREAD_UNCHANGED)  # BGR or G
+        if img.ndim == 2:
+            img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)  # GGG
+        else:
+            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # RGB
+    return img
+
+
+# --------------------------------------------
+# matlab's imwrite
+# --------------------------------------------
+def imsave(img, img_path):
+    img = np.squeeze(img)
+    if img.ndim == 3:
+        img = img[:, :, [2, 1, 0]]
+    cv2.imwrite(img_path, img)
+
+def imwrite(img, img_path):
+    img = np.squeeze(img)
+    if img.ndim == 3:
+        img = img[:, :, [2, 1, 0]]
+    cv2.imwrite(img_path, img)
+
+
+
+# --------------------------------------------
+# get single image of size HxWxn_channles (BGR)
+# --------------------------------------------
+def read_img(path):
+    # read image by cv2
+    # return: Numpy float32, HWC, BGR, [0,1]
+    img = cv2.imread(path, cv2.IMREAD_UNCHANGED)  # cv2.IMREAD_GRAYSCALE
+    img = img.astype(np.float32) / 255.
+    if img.ndim == 2:
+        img = np.expand_dims(img, axis=2)
+    # some images have 4 channels
+    if img.shape[2] > 3:
+        img = img[:, :, :3]
+    return img
+
+
+'''
+# --------------------------------------------
+# image format conversion
+# --------------------------------------------
+# numpy(single) <--->  numpy(unit)
+# numpy(single) <--->  tensor
+# numpy(unit)   <--->  tensor
+# --------------------------------------------
+'''
+
+
+# --------------------------------------------
+# numpy(single) [0, 1] <--->  numpy(unit)
+# --------------------------------------------
+
+
+def uint2single(img):
+
+    return np.float32(img/255.)
+
+
+def single2uint(img):
+
+    return np.uint8((img.clip(0, 1)*255.).round())
+
+
+def uint162single(img):
+
+    return np.float32(img/65535.)
+
+
+def single2uint16(img):
+
+    return np.uint16((img.clip(0, 1)*65535.).round())
+
+
+# --------------------------------------------
+# numpy(unit) (HxWxC or HxW) <--->  tensor
+# --------------------------------------------
+
+
+# convert uint to 4-dimensional torch tensor
+def uint2tensor4(img):
+    if img.ndim == 2:
+        img = np.expand_dims(img, axis=2)
+    return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0)
+
+
+# convert uint to 3-dimensional torch tensor
+def uint2tensor3(img):
+    if img.ndim == 2:
+        img = np.expand_dims(img, axis=2)
+    return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.)
+
+
+# convert 2/3/4-dimensional torch tensor to uint
+def tensor2uint(img):
+    img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy()
+    if img.ndim == 3:
+        img = np.transpose(img, (1, 2, 0))
+    return np.uint8((img*255.0).round())
+
+
+# --------------------------------------------
+# numpy(single) (HxWxC) <--->  tensor
+# --------------------------------------------
+
+
+# convert single (HxWxC) to 3-dimensional torch tensor
+def single2tensor3(img):
+    return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float()
+
+
+# convert single (HxWxC) to 4-dimensional torch tensor
+def single2tensor4(img):
+    return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0)
+
+
+# convert torch tensor to single
+def tensor2single(img):
+    img = img.data.squeeze().float().cpu().numpy()
+    if img.ndim == 3:
+        img = np.transpose(img, (1, 2, 0))
+
+    return img
+
+# convert torch tensor to single
+def tensor2single3(img):
+    img = img.data.squeeze().float().cpu().numpy()
+    if img.ndim == 3:
+        img = np.transpose(img, (1, 2, 0))
+    elif img.ndim == 2:
+        img = np.expand_dims(img, axis=2)
+    return img
+
+
+def single2tensor5(img):
+    return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0)
+
+
+def single32tensor5(img):
+    return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0)
+
+
+def single42tensor4(img):
+    return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float()
+
+
+# from skimage.io import imread, imsave
+def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
+    '''
+    Converts a torch Tensor into an image Numpy array of BGR channel order
+    Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
+    Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
+    '''
+    tensor = tensor.squeeze().float().cpu().clamp_(*min_max)  # squeeze first, then clamp
+    tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0])  # to range [0,1]
+    n_dim = tensor.dim()
+    if n_dim == 4:
+        n_img = len(tensor)
+        img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
+        img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0))  # HWC, BGR
+    elif n_dim == 3:
+        img_np = tensor.numpy()
+        img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0))  # HWC, BGR
+    elif n_dim == 2:
+        img_np = tensor.numpy()
+    else:
+        raise TypeError(
+            'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
+    if out_type == np.uint8:
+        img_np = (img_np * 255.0).round()
+        # Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
+    return img_np.astype(out_type)
+
+
+'''
+# --------------------------------------------
+# Augmentation, flipe and/or rotate
+# --------------------------------------------
+# The following two are enough.
+# (1) augmet_img: numpy image of WxHxC or WxH
+# (2) augment_img_tensor4: tensor image 1xCxWxH
+# --------------------------------------------
+'''
+
+
+def augment_img(img, mode=0):
+    '''Kai Zhang (github: https://github.com/cszn)
+    '''
+    if mode == 0:
+        return img
+    elif mode == 1:
+        return np.flipud(np.rot90(img))
+    elif mode == 2:
+        return np.flipud(img)
+    elif mode == 3:
+        return np.rot90(img, k=3)
+    elif mode == 4:
+        return np.flipud(np.rot90(img, k=2))
+    elif mode == 5:
+        return np.rot90(img)
+    elif mode == 6:
+        return np.rot90(img, k=2)
+    elif mode == 7:
+        return np.flipud(np.rot90(img, k=3))
+
+
+def augment_img_tensor4(img, mode=0):
+    '''Kai Zhang (github: https://github.com/cszn)
+    '''
+    if mode == 0:
+        return img
+    elif mode == 1:
+        return img.rot90(1, [2, 3]).flip([2])
+    elif mode == 2:
+        return img.flip([2])
+    elif mode == 3:
+        return img.rot90(3, [2, 3])
+    elif mode == 4:
+        return img.rot90(2, [2, 3]).flip([2])
+    elif mode == 5:
+        return img.rot90(1, [2, 3])
+    elif mode == 6:
+        return img.rot90(2, [2, 3])
+    elif mode == 7:
+        return img.rot90(3, [2, 3]).flip([2])
+
+
+def augment_img_tensor(img, mode=0):
+    '''Kai Zhang (github: https://github.com/cszn)
+    '''
+    img_size = img.size()
+    img_np = img.data.cpu().numpy()
+    if len(img_size) == 3:
+        img_np = np.transpose(img_np, (1, 2, 0))
+    elif len(img_size) == 4:
+        img_np = np.transpose(img_np, (2, 3, 1, 0))
+    img_np = augment_img(img_np, mode=mode)
+    img_tensor = torch.from_numpy(np.ascontiguousarray(img_np))
+    if len(img_size) == 3:
+        img_tensor = img_tensor.permute(2, 0, 1)
+    elif len(img_size) == 4:
+        img_tensor = img_tensor.permute(3, 2, 0, 1)
+
+    return img_tensor.type_as(img)
+
+
+def augment_img_np3(img, mode=0):
+    if mode == 0:
+        return img
+    elif mode == 1:
+        return img.transpose(1, 0, 2)
+    elif mode == 2:
+        return img[::-1, :, :]
+    elif mode == 3:
+        img = img[::-1, :, :]
+        img = img.transpose(1, 0, 2)
+        return img
+    elif mode == 4:
+        return img[:, ::-1, :]
+    elif mode == 5:
+        img = img[:, ::-1, :]
+        img = img.transpose(1, 0, 2)
+        return img
+    elif mode == 6:
+        img = img[:, ::-1, :]
+        img = img[::-1, :, :]
+        return img
+    elif mode == 7:
+        img = img[:, ::-1, :]
+        img = img[::-1, :, :]
+        img = img.transpose(1, 0, 2)
+        return img
+
+
+def augment_imgs(img_list, hflip=True, rot=True):
+    # horizontal flip OR rotate
+    hflip = hflip and random.random() < 0.5
+    vflip = rot and random.random() < 0.5
+    rot90 = rot and random.random() < 0.5
+
+    def _augment(img):
+        if hflip:
+            img = img[:, ::-1, :]
+        if vflip:
+            img = img[::-1, :, :]
+        if rot90:
+            img = img.transpose(1, 0, 2)
+        return img
+
+    return [_augment(img) for img in img_list]
+
+
+'''
+# --------------------------------------------
+# modcrop and shave
+# --------------------------------------------
+'''
+
+
+def modcrop(img_in, scale):
+    # img_in: Numpy, HWC or HW
+    img = np.copy(img_in)
+    if img.ndim == 2:
+        H, W = img.shape
+        H_r, W_r = H % scale, W % scale
+        img = img[:H - H_r, :W - W_r]
+    elif img.ndim == 3:
+        H, W, C = img.shape
+        H_r, W_r = H % scale, W % scale
+        img = img[:H - H_r, :W - W_r, :]
+    else:
+        raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))
+    return img
+
+
+def shave(img_in, border=0):
+    # img_in: Numpy, HWC or HW
+    img = np.copy(img_in)
+    h, w = img.shape[:2]
+    img = img[border:h-border, border:w-border]
+    return img
+
+
+'''
+# --------------------------------------------
+# image processing process on numpy image
+# channel_convert(in_c, tar_type, img_list):
+# rgb2ycbcr(img, only_y=True):
+# bgr2ycbcr(img, only_y=True):
+# ycbcr2rgb(img):
+# --------------------------------------------
+'''
+
+
+def rgb2ycbcr(img, only_y=True):
+    '''same as matlab rgb2ycbcr
+    only_y: only return Y channel
+    Input:
+        uint8, [0, 255]
+        float, [0, 1]
+    '''
+    in_img_type = img.dtype
+    img.astype(np.float32)
+    if in_img_type != np.uint8:
+        img *= 255.
+    # convert
+    if only_y:
+        rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
+    else:
+        rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
+                              [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
+    if in_img_type == np.uint8:
+        rlt = rlt.round()
+    else:
+        rlt /= 255.
+    return rlt.astype(in_img_type)
+
+
+def ycbcr2rgb(img):
+    '''same as matlab ycbcr2rgb
+    Input:
+        uint8, [0, 255]
+        float, [0, 1]
+    '''
+    in_img_type = img.dtype
+    img.astype(np.float32)
+    if in_img_type != np.uint8:
+        img *= 255.
+    # convert
+    rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
+                          [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
+    if in_img_type == np.uint8:
+        rlt = rlt.round()
+    else:
+        rlt /= 255.
+    return rlt.astype(in_img_type)
+
+
+def bgr2ycbcr(img, only_y=True):
+    '''bgr version of rgb2ycbcr
+    only_y: only return Y channel
+    Input:
+        uint8, [0, 255]
+        float, [0, 1]
+    '''
+    in_img_type = img.dtype
+    img.astype(np.float32)
+    if in_img_type != np.uint8:
+        img *= 255.
+    # convert
+    if only_y:
+        rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
+    else:
+        rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
+                              [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
+    if in_img_type == np.uint8:
+        rlt = rlt.round()
+    else:
+        rlt /= 255.
+    return rlt.astype(in_img_type)
+
+
+def channel_convert(in_c, tar_type, img_list):
+    # conversion among BGR, gray and y
+    if in_c == 3 and tar_type == 'gray':  # BGR to gray
+        gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
+        return [np.expand_dims(img, axis=2) for img in gray_list]
+    elif in_c == 3 and tar_type == 'y':  # BGR to y
+        y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
+        return [np.expand_dims(img, axis=2) for img in y_list]
+    elif in_c == 1 and tar_type == 'RGB':  # gray/y to BGR
+        return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
+    else:
+        return img_list
+
+
+'''
+# --------------------------------------------
+# metric, PSNR and SSIM
+# --------------------------------------------
+'''
+
+
+# --------------------------------------------
+# PSNR
+# --------------------------------------------
+def calculate_psnr(img1, img2, border=0):
+    # img1 and img2 have range [0, 255]
+    #img1 = img1.squeeze()
+    #img2 = img2.squeeze()
+    if not img1.shape == img2.shape:
+        raise ValueError('Input images must have the same dimensions.')
+    h, w = img1.shape[:2]
+    img1 = img1[border:h-border, border:w-border]
+    img2 = img2[border:h-border, border:w-border]
+
+    img1 = img1.astype(np.float64)
+    img2 = img2.astype(np.float64)
+    mse = np.mean((img1 - img2)**2)
+    if mse == 0:
+        return float('inf')
+    return 20 * math.log10(255.0 / math.sqrt(mse))
+
+
+# --------------------------------------------
+# SSIM
+# --------------------------------------------
+def calculate_ssim(img1, img2, border=0):
+    '''calculate SSIM
+    the same outputs as MATLAB's
+    img1, img2: [0, 255]
+    '''
+    #img1 = img1.squeeze()
+    #img2 = img2.squeeze()
+    if not img1.shape == img2.shape:
+        raise ValueError('Input images must have the same dimensions.')
+    h, w = img1.shape[:2]
+    img1 = img1[border:h-border, border:w-border]
+    img2 = img2[border:h-border, border:w-border]
+
+    if img1.ndim == 2:
+        return ssim(img1, img2)
+    elif img1.ndim == 3:
+        if img1.shape[2] == 3:
+            ssims = []
+            for i in range(3):
+                ssims.append(ssim(img1[:,:,i], img2[:,:,i]))
+            return np.array(ssims).mean()
+        elif img1.shape[2] == 1:
+            return ssim(np.squeeze(img1), np.squeeze(img2))
+    else:
+        raise ValueError('Wrong input image dimensions.')
+
+
+def ssim(img1, img2):
+    C1 = (0.01 * 255)**2
+    C2 = (0.03 * 255)**2
+
+    img1 = img1.astype(np.float64)
+    img2 = img2.astype(np.float64)
+    kernel = cv2.getGaussianKernel(11, 1.5)
+    window = np.outer(kernel, kernel.transpose())
+
+    mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]  # valid
+    mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
+    mu1_sq = mu1**2
+    mu2_sq = mu2**2
+    mu1_mu2 = mu1 * mu2
+    sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
+    sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
+    sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
+
+    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
+                                                            (sigma1_sq + sigma2_sq + C2))
+    return ssim_map.mean()
+
+
+'''
+# --------------------------------------------
+# matlab's bicubic imresize (numpy and torch) [0, 1]
+# --------------------------------------------
+'''
+
+
+# matlab 'imresize' function, now only support 'bicubic'
+def cubic(x):
+    absx = torch.abs(x)
+    absx2 = absx**2
+    absx3 = absx**3
+    return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \
+        (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx))
+
+
+def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
+    if (scale < 1) and (antialiasing):
+        # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
+        kernel_width = kernel_width / scale
+
+    # Output-space coordinates
+    x = torch.linspace(1, out_length, out_length)
+
+    # Input-space coordinates. Calculate the inverse mapping such that 0.5
+    # in output space maps to 0.5 in input space, and 0.5+scale in output
+    # space maps to 1.5 in input space.
+    u = x / scale + 0.5 * (1 - 1 / scale)
+
+    # What is the left-most pixel that can be involved in the computation?
+    left = torch.floor(u - kernel_width / 2)
+
+    # What is the maximum number of pixels that can be involved in the
+    # computation?  Note: it's OK to use an extra pixel here; if the
+    # corresponding weights are all zero, it will be eliminated at the end
+    # of this function.
+    P = math.ceil(kernel_width) + 2
+
+    # The indices of the input pixels involved in computing the k-th output
+    # pixel are in row k of the indices matrix.
+    indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(
+        1, P).expand(out_length, P)
+
+    # The weights used to compute the k-th output pixel are in row k of the
+    # weights matrix.
+    distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
+    # apply cubic kernel
+    if (scale < 1) and (antialiasing):
+        weights = scale * cubic(distance_to_center * scale)
+    else:
+        weights = cubic(distance_to_center)
+    # Normalize the weights matrix so that each row sums to 1.
+    weights_sum = torch.sum(weights, 1).view(out_length, 1)
+    weights = weights / weights_sum.expand(out_length, P)
+
+    # If a column in weights is all zero, get rid of it. only consider the first and last column.
+    weights_zero_tmp = torch.sum((weights == 0), 0)
+    if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
+        indices = indices.narrow(1, 1, P - 2)
+        weights = weights.narrow(1, 1, P - 2)
+    if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
+        indices = indices.narrow(1, 0, P - 2)
+        weights = weights.narrow(1, 0, P - 2)
+    weights = weights.contiguous()
+    indices = indices.contiguous()
+    sym_len_s = -indices.min() + 1
+    sym_len_e = indices.max() - in_length
+    indices = indices + sym_len_s - 1
+    return weights, indices, int(sym_len_s), int(sym_len_e)
+
+
+# --------------------------------------------
+# imresize for tensor image [0, 1]
+# --------------------------------------------
+def imresize(img, scale, antialiasing=True):
+    # Now the scale should be the same for H and W
+    # input: img: pytorch tensor, CHW or HW [0,1]
+    # output: CHW or HW [0,1] w/o round
+    need_squeeze = True if img.dim() == 2 else False
+    if need_squeeze:
+        img.unsqueeze_(0)
+    in_C, in_H, in_W = img.size()
+    out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
+    kernel_width = 4
+    kernel = 'cubic'
+
+    # Return the desired dimension order for performing the resize.  The
+    # strategy is to perform the resize first along the dimension with the
+    # smallest scale factor.
+    # Now we do not support this.
+
+    # get weights and indices
+    weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
+        in_H, out_H, scale, kernel, kernel_width, antialiasing)
+    weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
+        in_W, out_W, scale, kernel, kernel_width, antialiasing)
+    # process H dimension
+    # symmetric copying
+    img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
+    img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)
+
+    sym_patch = img[:, :sym_len_Hs, :]
+    inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+    sym_patch_inv = sym_patch.index_select(1, inv_idx)
+    img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)
+
+    sym_patch = img[:, -sym_len_He:, :]
+    inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+    sym_patch_inv = sym_patch.index_select(1, inv_idx)
+    img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
+
+    out_1 = torch.FloatTensor(in_C, out_H, in_W)
+    kernel_width = weights_H.size(1)
+    for i in range(out_H):
+        idx = int(indices_H[i][0])
+        for j in range(out_C):
+            out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
+
+    # process W dimension
+    # symmetric copying
+    out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)
+    out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)
+
+    sym_patch = out_1[:, :, :sym_len_Ws]
+    inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
+    sym_patch_inv = sym_patch.index_select(2, inv_idx)
+    out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)
+
+    sym_patch = out_1[:, :, -sym_len_We:]
+    inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
+    sym_patch_inv = sym_patch.index_select(2, inv_idx)
+    out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
+
+    out_2 = torch.FloatTensor(in_C, out_H, out_W)
+    kernel_width = weights_W.size(1)
+    for i in range(out_W):
+        idx = int(indices_W[i][0])
+        for j in range(out_C):
+            out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i])
+    if need_squeeze:
+        out_2.squeeze_()
+    return out_2
+
+
+# --------------------------------------------
+# imresize for numpy image [0, 1]
+# --------------------------------------------
+def imresize_np(img, scale, antialiasing=True):
+    # Now the scale should be the same for H and W
+    # input: img: Numpy, HWC or HW [0,1]
+    # output: HWC or HW [0,1] w/o round
+    img = torch.from_numpy(img)
+    need_squeeze = True if img.dim() == 2 else False
+    if need_squeeze:
+        img.unsqueeze_(2)
+
+    in_H, in_W, in_C = img.size()
+    out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
+    kernel_width = 4
+    kernel = 'cubic'
+
+    # Return the desired dimension order for performing the resize.  The
+    # strategy is to perform the resize first along the dimension with the
+    # smallest scale factor.
+    # Now we do not support this.
+
+    # get weights and indices
+    weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
+        in_H, out_H, scale, kernel, kernel_width, antialiasing)
+    weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
+        in_W, out_W, scale, kernel, kernel_width, antialiasing)
+    # process H dimension
+    # symmetric copying
+    img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
+    img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
+
+    sym_patch = img[:sym_len_Hs, :, :]
+    inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
+    sym_patch_inv = sym_patch.index_select(0, inv_idx)
+    img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
+
+    sym_patch = img[-sym_len_He:, :, :]
+    inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
+    sym_patch_inv = sym_patch.index_select(0, inv_idx)
+    img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
+
+    out_1 = torch.FloatTensor(out_H, in_W, in_C)
+    kernel_width = weights_H.size(1)
+    for i in range(out_H):
+        idx = int(indices_H[i][0])
+        for j in range(out_C):
+            out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i])
+
+    # process W dimension
+    # symmetric copying
+    out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
+    out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
+
+    sym_patch = out_1[:, :sym_len_Ws, :]
+    inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+    sym_patch_inv = sym_patch.index_select(1, inv_idx)
+    out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
+
+    sym_patch = out_1[:, -sym_len_We:, :]
+    inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+    sym_patch_inv = sym_patch.index_select(1, inv_idx)
+    out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
+
+    out_2 = torch.FloatTensor(out_H, out_W, in_C)
+    kernel_width = weights_W.size(1)
+    for i in range(out_W):
+        idx = int(indices_W[i][0])
+        for j in range(out_C):
+            out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i])
+    if need_squeeze:
+        out_2.squeeze_()
+
+    return out_2.numpy()
+
+
+if __name__ == '__main__':
+    print('---')
+#    img = imread_uint('test.bmp', 3)
+#    img = uint2single(img)
+#    img_bicubic = imresize_np(img, 1/4)
\ No newline at end of file
diff --git a/ldm/util.py b/ldm/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..839beb87df984382121831f935e1162fecb3839f
--- /dev/null
+++ b/ldm/util.py
@@ -0,0 +1,197 @@
+import importlib
+
+import torch
+from torch import optim
+import numpy as np
+
+from inspect import isfunction
+from PIL import Image, ImageDraw, ImageFont
+
+
+def log_txt_as_img(wh, xc, size=10):
+    # wh a tuple of (width, height)
+    # xc a list of captions to plot
+    b = len(xc)
+    txts = list()
+    for bi in range(b):
+        txt = Image.new("RGB", wh, color="white")
+        draw = ImageDraw.Draw(txt)
+        font = ImageFont.truetype('font/DejaVuSans.ttf', size=size)
+        nc = int(25 * (wh[0] / 256))
+        lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
+
+        try:
+            draw.text((0, 0), lines, fill="black", font=font)
+        except UnicodeEncodeError:
+            print("Cant encode string for logging. Skipping.")
+
+        txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
+        txts.append(txt)
+    txts = np.stack(txts)
+    txts = torch.tensor(txts)
+    return txts
+
+
+def ismap(x):
+    if not isinstance(x, torch.Tensor):
+        return False
+    return (len(x.shape) == 4) and (x.shape[1] > 3)
+
+
+def isimage(x):
+    if not isinstance(x,torch.Tensor):
+        return False
+    return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
+
+
+def exists(x):
+    return x is not None
+
+
+def default(val, d):
+    if exists(val):
+        return val
+    return d() if isfunction(d) else d
+
+
+def mean_flat(tensor):
+    """
+    https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
+    Take the mean over all non-batch dimensions.
+    """
+    return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def count_params(model, verbose=False):
+    total_params = sum(p.numel() for p in model.parameters())
+    if verbose:
+        print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
+    return total_params
+
+
+def instantiate_from_config(config):
+    if not "target" in config:
+        if config == '__is_first_stage__':
+            return None
+        elif config == "__is_unconditional__":
+            return None
+        raise KeyError("Expected key `target` to instantiate.")
+    return get_obj_from_str(config["target"])(**config.get("params", dict()))
+
+
+def get_obj_from_str(string, reload=False):
+    module, cls = string.rsplit(".", 1)
+    if reload:
+        module_imp = importlib.import_module(module)
+        importlib.reload(module_imp)
+    return getattr(importlib.import_module(module, package=None), cls)
+
+
+class AdamWwithEMAandWings(optim.Optimizer):
+    # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
+    def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8,  # TODO: check hyperparameters before using
+                 weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999,   # ema decay to match previous code
+                 ema_power=1., param_names=()):
+        """AdamW that saves EMA versions of the parameters."""
+        if not 0.0 <= lr:
+            raise ValueError("Invalid learning rate: {}".format(lr))
+        if not 0.0 <= eps:
+            raise ValueError("Invalid epsilon value: {}".format(eps))
+        if not 0.0 <= betas[0] < 1.0:
+            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
+        if not 0.0 <= betas[1] < 1.0:
+            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
+        if not 0.0 <= weight_decay:
+            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+        if not 0.0 <= ema_decay <= 1.0:
+            raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
+        defaults = dict(lr=lr, betas=betas, eps=eps,
+                        weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay,
+                        ema_power=ema_power, param_names=param_names)
+        super().__init__(params, defaults)
+
+    def __setstate__(self, state):
+        super().__setstate__(state)
+        for group in self.param_groups:
+            group.setdefault('amsgrad', False)
+
+    @torch.no_grad()
+    def step(self, closure=None):
+        """Performs a single optimization step.
+        Args:
+            closure (callable, optional): A closure that reevaluates the model
+                and returns the loss.
+        """
+        loss = None
+        if closure is not None:
+            with torch.enable_grad():
+                loss = closure()
+
+        for group in self.param_groups:
+            params_with_grad = []
+            grads = []
+            exp_avgs = []
+            exp_avg_sqs = []
+            ema_params_with_grad = []
+            state_sums = []
+            max_exp_avg_sqs = []
+            state_steps = []
+            amsgrad = group['amsgrad']
+            beta1, beta2 = group['betas']
+            ema_decay = group['ema_decay']
+            ema_power = group['ema_power']
+
+            for p in group['params']:
+                if p.grad is None:
+                    continue
+                params_with_grad.append(p)
+                if p.grad.is_sparse:
+                    raise RuntimeError('AdamW does not support sparse gradients')
+                grads.append(p.grad)
+
+                state = self.state[p]
+
+                # State initialization
+                if len(state) == 0:
+                    state['step'] = 0
+                    # Exponential moving average of gradient values
+                    state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
+                    # Exponential moving average of squared gradient values
+                    state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
+                    if amsgrad:
+                        # Maintains max of all exp. moving avg. of sq. grad. values
+                        state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
+                    # Exponential moving average of parameter values
+                    state['param_exp_avg'] = p.detach().float().clone()
+
+                exp_avgs.append(state['exp_avg'])
+                exp_avg_sqs.append(state['exp_avg_sq'])
+                ema_params_with_grad.append(state['param_exp_avg'])
+
+                if amsgrad:
+                    max_exp_avg_sqs.append(state['max_exp_avg_sq'])
+
+                # update the steps for each param group update
+                state['step'] += 1
+                # record the step after step update
+                state_steps.append(state['step'])
+
+            optim._functional.adamw(params_with_grad,
+                    grads,
+                    exp_avgs,
+                    exp_avg_sqs,
+                    max_exp_avg_sqs,
+                    state_steps,
+                    amsgrad=amsgrad,
+                    beta1=beta1,
+                    beta2=beta2,
+                    lr=group['lr'],
+                    weight_decay=group['weight_decay'],
+                    eps=group['eps'],
+                    maximize=False)
+
+            cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power)
+            for param, ema_param in zip(params_with_grad, ema_params_with_grad):
+                ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
+
+        return loss
diff --git a/models/anycontrol.py b/models/anycontrol.py
new file mode 100644
index 0000000000000000000000000000000000000000..76764c6a0504bd9aff7d249584b83f2022f76490
--- /dev/null
+++ b/models/anycontrol.py
@@ -0,0 +1,273 @@
+import einops
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+import cv2
+
+from PIL import Image
+from einops import rearrange, repeat
+from torchvision.utils import make_grid
+
+from ldm.models.diffusion.ddpm import LatentDiffusion
+from ldm.util import log_txt_as_img, instantiate_from_config
+from ldm.models.diffusion.ddim import DDIMSampler
+
+from models.q_formers import load_qformer_model
+
+
+class AnyControlNet(LatentDiffusion):
+
+    def __init__(self, mode, qformer_config=None, local_control_config=None, global_control_config=None, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        assert mode in ['local', 'uni']
+        self.mode = mode
+        self.qformer_config = qformer_config
+        self.local_control_config = local_control_config
+        self.global_control_config = global_control_config
+
+        self.model.diffusion_model.requires_grad_(False)
+        self.model.diffusion_model.requires_grad_(False)
+        self.model.diffusion_model.requires_grad_(False)
+
+        q_former, (vis_processor, txt_processor) = load_qformer_model(qformer_config)
+        self.q_former = q_former
+        self.qformer_vis_processor = vis_processor
+        self.qformer_txt_processor = txt_processor
+
+        self.local_adapter = instantiate_from_config(local_control_config)
+        self.local_control_scales = [1.0] * 13
+        self.global_adapter = instantiate_from_config(global_control_config) if self.mode == 'uni' else None
+        self.clip_embeddings_dim = global_control_config.params.clip_embeddings_dim
+        self.color_in_dim = global_control_config.params.color_in_dim
+
+    @torch.no_grad()
+    def get_input(self, batch, k, bs=None, *args, **kwargs):
+        # latent and text 
+        x, c = super().get_input(batch, self.first_stage_key, *args, **kwargs)
+        bs = bs or x.size(0)
+        shape = self.get_shape(batch, bs)
+        
+        local_control = self.get_local_conditions_for_vision_encoder(batch, bs)
+        local_control = local_control.to(memory_format=torch.contiguous_format).float()
+
+        global_control = {}
+        global_conditions = batch['global_conditions'][:bs]
+        for key in batch['global_conditions'][0].data.keys():
+            global_cond = torch.stack([torch.Tensor(dc.data[key]) for dc in global_conditions])
+            global_cond = global_cond.to(self.device).to(memory_format=torch.contiguous_format).float()
+            global_control[key] = global_cond
+
+        conditions = dict(
+            text=[batch['txt']], 
+            c_crossattn=[c], 
+            local_control=[local_control], 
+            global_control=[global_control],
+        ) 
+        return x, conditions 
+
+    def apply_model(self, x_noisy, t, cond, local_strength=1.0, content_strength=1.0, color_strength=1.0, *args, **kwargs):
+        assert isinstance(cond, dict)
+        diffusion_model = self.model.diffusion_model
+        cond_txt = torch.cat(cond['c_crossattn'], 1)
+        text = cond['text'][0]
+        bs = x_noisy.shape[0]
+
+        # extract global control
+        if self.mode in ['uni']:
+            content_control, color_control = self.global_adapter(
+                cond['global_control'][0]['clipembedding'], cond['global_control'][0]['color'])
+        else:
+            content_control = torch.zeros(bs, self.clip_embeddings_dim).to(self.device).to(memory_format=torch.contiguous_format).float()
+            color_control = torch.zeros(bs, self.color_in_dim).to(self.device).to(memory_format=torch.contiguous_format).float()
+
+        # extract local control
+        if self.mode in ['local', 'uni']:
+            local_features = self.local_adapter.extract_local_features(self.q_former, text, cond['local_control'][0])
+            local_control = self.local_adapter(x=x_noisy, timesteps=t, context=cond_txt, local_features=local_features)
+            local_control = [c * scale for c, scale in zip(local_control, self.local_control_scales)]
+
+        eps = diffusion_model(
+            x=x_noisy, timesteps=t, context=cond_txt, 
+            local_control=local_control, local_w=local_strength,
+            content_control=content_control, extra_w=content_strength, 
+            color_control=color_control, color_w=color_strength)
+        return eps
+
+    @torch.no_grad()
+    def get_unconditional_conditioning(self, N):
+        return self.get_learned_conditioning([""] * N)
+
+    @torch.no_grad()
+    def get_unconditional_global_conditioning(self, c):
+        if isinstance(c, dict):
+            return {k:torch.zeros_like(v) for k,v in c.items()}
+        elif isinstance(c, list):
+            return [torch.zeros_like(v) for v in c]
+        else:
+            return torch.zeros_like(c) 
+
+    @torch.no_grad()
+    def get_shape(self, batch, N):
+        return [dc.data[0].shape[:2] for dc in batch['local_conditions'][:N]]
+
+    @torch.no_grad()
+    def get_local_conditions_for_vision_encoder(self, batch, N):
+        # return: local_conditions, (bs, num_conds * 3, h, w)
+        local_conditions = []
+        max_len = max([len(dc.data) for dc in batch['local_conditions'][:N]])
+        for dc in batch['local_conditions'][:N]:
+            conds = torch.cat([self.qformer_vis_processor['eval'](Image.fromarray(img)).unsqueeze(0) for img in dc.data], dim=1)
+            local_conditions.append(conds)
+        local_conditions = [F.pad(cond, (0,0,0,0,0,max_len*3-cond.shape[1],0,0)) for cond in local_conditions]
+        local_conditions = torch.cat(local_conditions, dim=0).to(self.device) 
+        return local_conditions
+
+    @torch.no_grad()
+    def get_local_conditions_for_logging(self, batch, N):
+        local_conditions = []
+        max_len = max([len(dc.data) for dc in batch['local_conditions'][:N]])
+        for dc in batch['local_conditions'][:N]:
+            conds = torch.stack([torch.Tensor(img).permute(2,0,1) for img in dc.data], dim=0) # (n, c, h, w)
+            conds = conds.float() / 255.
+            conds = conds * 2.0 - 1.0
+            local_conditions.append(conds)
+        local_conditions = [F.pad(cond, (0,0,0,0,0,0,0,max_len-cond.shape[0])) for cond in local_conditions]
+        local_conditions = torch.stack(local_conditions, dim=0).to(self.device) # (bs, n, c, h, w)
+        local_conditions = local_conditions.flatten(1,2)
+        return local_conditions
+
+    def clip_batch(self, batch, key, N, flag=True):
+        if isinstance(batch, torch.Tensor):
+            return batch[:N] 
+        elif isinstance(batch, list):
+            return batch[:N] 
+        batch = batch[key][0] if flag else batch[key]
+        if isinstance(batch, torch.Tensor):
+            return batch[:N] 
+        elif isinstance(batch, list):
+            return batch[:N] 
+        elif isinstance(batch, dict):
+            return {k:self.clip_batch(v,'',N,flag=False) for k,v in batch.items()}
+        else:
+            raise ValueError(f'Unsupported type {type(batch)}')
+
+    @torch.no_grad()
+    def log_images(self, batch, N=4, n_row=2, sample=False, ddim_steps=50, ddim_eta=0.0, 
+                   plot_denoise_rows=False, plot_diffusion_rows=False, unconditional_guidance_scale=9.0, **kwargs):
+        use_ddim = ddim_steps is not None
+
+        log = dict()
+        z, c = self.get_input(batch, self.first_stage_key, bs=N)
+
+        shape = self.get_shape(batch, N)
+        c_local = self.clip_batch(c, "local_control", N)
+        c_global = self.clip_batch(c, "global_control", N)
+        c_context = self.clip_batch(c, "c_crossattn", N)
+        c_text = self.clip_batch(batch, self.cond_stage_key, N, False)
+        
+        N = min(z.shape[0], N)
+        n_row = min(z.shape[0], n_row)
+        log["reconstruction"] = self.decode_first_stage(z)
+        log["conditioning"] = log_txt_as_img((512, 512), c_text, size=16)
+        log["local_control"] = self.get_local_conditions_for_logging(batch, N)
+
+        if plot_diffusion_rows:
+            diffusion_row = list()
+            z_start = z[:n_row]
+            for t in range(self.num_timesteps):
+                if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+                    t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+                    t = t.to(self.device).long()
+                    noise = torch.randn_like(z_start)
+                    z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
+                    diffusion_row.append(self.decode_first_stage(z_noisy))
+
+            diffusion_row = torch.stack(diffusion_row)
+            diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
+            diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
+            diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
+            log["diffusion_row"] = diffusion_grid
+
+        cond_dict = dict(
+            local_control=[c_local],
+            global_control=[c_global],
+            c_crossattn=[c_context],
+            text=[c_text],
+            shape=[shape],
+        )
+
+        if sample:
+            samples, z_denoise_row = self.sample_log(cond=cond_dict,
+                                                     batch_size=N, ddim=use_ddim,
+                                                     ddim_steps=ddim_steps, eta=ddim_eta,
+                                                     log_every_t=self.log_every_t * 0.05)
+            x_samples = self.decode_first_stage(samples)
+            log["samples"] = x_samples
+            if plot_denoise_rows:
+                if isinstance(z_denoise_row, dict):
+                    for key in ['pred_x0', 'x_inter']:
+                       z_denoise_row_key = z_denoise_row[key]
+                       denoise_grid = self._get_denoise_row_from_list(z_denoise_row_key)
+                       log[f"denoise_row_{key}"] = denoise_grid
+                else:
+                    denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
+                    log["denoise_row"] = denoise_grid
+
+        if unconditional_guidance_scale > 1.0:
+            uc_context = self.get_unconditional_conditioning(N)
+            uc_global = self.get_unconditional_global_conditioning(c_global)
+            uc_local = c_local 
+            uc_text = c_text 
+
+            uncond_dict = dict(
+                local_control=[uc_local],
+                global_control=[uc_global],
+                c_crossattn=[uc_context],
+                text=[uc_text],
+                shape=[shape]
+            )
+
+            samples_cfg, _ = self.sample_log(cond=cond_dict,
+                                             batch_size=N, ddim=use_ddim,
+                                             ddim_steps=ddim_steps, eta=ddim_eta,
+                                             unconditional_guidance_scale=unconditional_guidance_scale,
+                                             unconditional_conditioning=uncond_dict,
+                                             )
+            x_samples_cfg = self.decode_first_stage(samples_cfg)
+            log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
+
+        return log
+
+    @torch.no_grad()
+    def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
+        ddim_sampler = DDIMSampler(self)
+        if cond['shape'] is None:
+            h, w = 512, 512
+        else:
+            h, w = cond["shape"][0][0]
+        shape = (self.channels, h // 8, w // 8)
+        samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, **kwargs)
+        return samples, intermediates
+
+    def configure_optimizers(self):
+        lr = self.learning_rate
+        params = list(self.q_former.parameters()) + list(self.local_adapter.parameters())
+        if not self.sd_locked:
+            params += list(self.model.diffusion_model.output_blocks.parameters())
+            params += list(self.model.diffusion_model.out.parameters())
+
+        opt = torch.optim.AdamW(params, lr=lr)
+        return opt
+
+    def low_vram_shift(self, is_diffusing):
+        if is_diffusing:
+            self.model = self.model.cuda()
+            self.local_adapter = self.local_adapter.cuda()
+            self.first_stage_model = self.first_stage_model.cpu()
+            self.cond_stage_model = self.cond_stage_model.cpu()
+        else:
+            self.model = self.model.cpu()
+            self.local_adapter = self.local_adapter.cpu()
+            self.first_stage_model = self.first_stage_model.cuda()
+            self.cond_stage_model = self.cond_stage_model.cuda()
diff --git a/models/ddim_hacked.py b/models/ddim_hacked.py
new file mode 100644
index 0000000000000000000000000000000000000000..2503b96b93f884231749edd41940d25983572083
--- /dev/null
+++ b/models/ddim_hacked.py
@@ -0,0 +1,322 @@
+"""SAMPLING ONLY."""
+
+import torch
+import numpy as np
+from tqdm import tqdm
+
+from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
+
+
+class DDIMSampler(object):
+    def __init__(self, model, schedule="linear", **kwargs):
+        super().__init__()
+        self.model = model
+        self.ddpm_num_timesteps = model.num_timesteps
+        self.schedule = schedule
+
+    def register_buffer(self, name, attr):
+        if type(attr) == torch.Tensor:
+            if attr.device != torch.device("cuda"):
+                attr = attr.to(torch.device("cuda"))
+        setattr(self, name, attr)
+
+    def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
+        self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
+                                                  num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
+        alphas_cumprod = self.model.alphas_cumprod
+        assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
+        to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
+
+        self.register_buffer('betas', to_torch(self.model.betas))
+        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+        self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
+
+        # calculations for diffusion q(x_t | x_{t-1}) and others
+        self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
+        self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
+        self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
+        self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
+        self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
+
+        # ddim sampling parameters
+        ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
+                                                                                   ddim_timesteps=self.ddim_timesteps,
+                                                                                   eta=ddim_eta,verbose=verbose)
+        self.register_buffer('ddim_sigmas', ddim_sigmas)
+        self.register_buffer('ddim_alphas', ddim_alphas)
+        self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
+        self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
+        sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
+            (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
+                        1 - self.alphas_cumprod / self.alphas_cumprod_prev))
+        self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
+
+    @torch.no_grad()
+    def sample(self,
+               S,
+               batch_size,
+               shape,
+               conditioning=None,
+               callback=None,
+               normals_sequence=None,
+               img_callback=None,
+               quantize_x0=False,
+               eta=0.,
+               mask=None,
+               x0=None,
+               temperature=1.,
+               noise_dropout=0.,
+               score_corrector=None,
+               corrector_kwargs=None,
+               verbose=True,
+               x_T=None,
+               log_every_t=100,
+               unconditional_guidance_scale=1.,
+               unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+               dynamic_threshold=None,
+               ucg_schedule=None,
+               global_strength=1,
+               color_strength=1,
+               local_strength=1,
+               **kwargs
+               ):
+        if conditioning is not None:
+            if isinstance(conditioning, dict):
+                ctmp = conditioning[list(conditioning.keys())[0]]
+                while isinstance(ctmp, list): ctmp = ctmp[0]
+                cbs = ctmp.shape[0]
+                if cbs != batch_size:
+                    print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+
+            elif isinstance(conditioning, list):
+                for ctmp in conditioning:
+                    if ctmp.shape[0] != batch_size:
+                        print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+
+            else:
+                if conditioning.shape[0] != batch_size:
+                    print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+
+        self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
+        # sampling
+        C, H, W = shape
+        size = (batch_size, C, H, W)
+        print(f'Data shape for DDIM sampling is {size}, eta {eta}')
+
+        samples, intermediates = self.ddim_sampling(conditioning, size,
+                                                    callback=callback,
+                                                    img_callback=img_callback,
+                                                    quantize_denoised=quantize_x0,
+                                                    mask=mask, x0=x0,
+                                                    ddim_use_original_steps=False,
+                                                    noise_dropout=noise_dropout,
+                                                    temperature=temperature,
+                                                    score_corrector=score_corrector,
+                                                    corrector_kwargs=corrector_kwargs,
+                                                    x_T=x_T,
+                                                    log_every_t=log_every_t,
+                                                    unconditional_guidance_scale=unconditional_guidance_scale,
+                                                    unconditional_conditioning=unconditional_conditioning,
+                                                    dynamic_threshold=dynamic_threshold,
+                                                    ucg_schedule=ucg_schedule,
+                                                    global_strength=global_strength,
+                                                    color_strength=color_strength,
+                                                    local_strength=local_strength
+                                                    )
+        return samples, intermediates
+
+    @torch.no_grad()
+    def ddim_sampling(self, cond, shape,
+                      x_T=None, ddim_use_original_steps=False,
+                      callback=None, timesteps=None, quantize_denoised=False,
+                      mask=None, x0=None, img_callback=None, log_every_t=100,
+                      temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+                      unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
+                      ucg_schedule=None,global_strength=1,color_strength=1,local_strength=1):
+        device = self.model.betas.device
+        b = shape[0]
+        if x_T is None:
+            img = torch.randn(shape, device=device)
+        else:
+            img = x_T
+
+        if timesteps is None:
+            timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
+        elif timesteps is not None and not ddim_use_original_steps:
+            subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
+            timesteps = self.ddim_timesteps[:subset_end]
+
+        intermediates = {'x_inter': [img], 'pred_x0': [img]}
+        time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
+        total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
+        print(f"Running DDIM Sampling with {total_steps} timesteps")
+
+        iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
+
+        for i, step in enumerate(iterator):
+            index = total_steps - i - 1
+            ts = torch.full((b,), step, device=device, dtype=torch.long)
+
+            if mask is not None:
+                assert x0 is not None
+                img_orig = self.model.q_sample(x0, ts)  # TODO: deterministic forward pass?
+                img = img_orig * mask + (1. - mask) * img
+
+            if ucg_schedule is not None:
+                assert len(ucg_schedule) == len(time_range)
+                unconditional_guidance_scale = ucg_schedule[i]
+
+            outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
+                                      quantize_denoised=quantize_denoised, temperature=temperature,
+                                      noise_dropout=noise_dropout, score_corrector=score_corrector,
+                                      corrector_kwargs=corrector_kwargs,
+                                      unconditional_guidance_scale=unconditional_guidance_scale,
+                                      unconditional_conditioning=unconditional_conditioning,
+                                      dynamic_threshold=dynamic_threshold,global_strength=global_strength,color_strength=color_strength,local_strength=local_strength)
+            img, pred_x0 = outs
+            if callback: callback(i)
+            if img_callback: img_callback(pred_x0, i)
+
+            if index % log_every_t == 0 or index == total_steps - 1:
+                intermediates['x_inter'].append(img)
+                intermediates['pred_x0'].append(pred_x0)
+
+        return img, intermediates
+
+    @torch.no_grad()
+    def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
+                      temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+                      unconditional_guidance_scale=1., unconditional_conditioning=None,
+                      dynamic_threshold=None,global_strength=1,color_strength=1,local_strength=1):
+        b, *_, device = *x.shape, x.device
+
+        if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
+            model_output = self.model.apply_model(x, t, c)
+        else:
+            model_t = self.model.apply_model(x, t, c, global_strength, color_strength, local_strength)
+            model_uncond = self.model.apply_model(x, t, unconditional_conditioning, global_strength, color_strength, local_strength)
+            model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
+
+        if self.model.parameterization == "v":
+            e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
+        else:
+            e_t = model_output
+
+        if score_corrector is not None:
+            assert self.model.parameterization == "eps", 'not implemented'
+            e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
+
+        alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+        alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
+        sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
+        sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
+        # select parameters corresponding to the currently considered timestep
+        a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+        a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+        sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+        sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
+
+        # current prediction for x_0
+        if self.model.parameterization != "v":
+            pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+        else:
+            pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
+
+        if quantize_denoised:
+            pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+
+        if dynamic_threshold is not None:
+            raise NotImplementedError()
+
+        # direction pointing to x_t
+        dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+        noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+        if noise_dropout > 0.:
+            noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+        x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+        return x_prev, pred_x0
+
+    @torch.no_grad()
+    def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
+               unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
+        num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]
+
+        assert t_enc <= num_reference_steps
+        num_steps = t_enc
+
+        if use_original_steps:
+            alphas_next = self.alphas_cumprod[:num_steps]
+            alphas = self.alphas_cumprod_prev[:num_steps]
+        else:
+            alphas_next = self.ddim_alphas[:num_steps]
+            alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
+
+        x_next = x0
+        intermediates = []
+        inter_steps = []
+        for i in tqdm(range(num_steps), desc='Encoding Image'):
+            t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
+            if unconditional_guidance_scale == 1.:
+                noise_pred = self.model.apply_model(x_next, t, c)
+            else:
+                assert unconditional_conditioning is not None
+                e_t_uncond, noise_pred = torch.chunk(
+                    self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
+                                           torch.cat((unconditional_conditioning, c))), 2)
+                noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
+
+            xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
+            weighted_noise_pred = alphas_next[i].sqrt() * (
+                    (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
+            x_next = xt_weighted + weighted_noise_pred
+            if return_intermediates and i % (
+                    num_steps // return_intermediates) == 0 and i < num_steps - 1:
+                intermediates.append(x_next)
+                inter_steps.append(i)
+            elif return_intermediates and i >= num_steps - 2:
+                intermediates.append(x_next)
+                inter_steps.append(i)
+            if callback: callback(i)
+
+        out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
+        if return_intermediates:
+            out.update({'intermediates': intermediates})
+        return x_next, out
+
+    @torch.no_grad()
+    def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
+        # fast, but does not allow for exact reconstruction
+        # t serves as an index to gather the correct alphas
+        if use_original_steps:
+            sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
+            sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
+        else:
+            sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
+            sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
+
+        if noise is None:
+            noise = torch.randn_like(x0)
+        return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
+                extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
+
+    @torch.no_grad()
+    def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
+               use_original_steps=False, callback=None):
+
+        timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
+        timesteps = timesteps[:t_start]
+
+        time_range = np.flip(timesteps)
+        total_steps = timesteps.shape[0]
+        print(f"Running DDIM Sampling with {total_steps} timesteps")
+
+        iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
+        x_dec = x_latent
+        for i, step in enumerate(iterator):
+            index = total_steps - i - 1
+            ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
+            x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
+                                          unconditional_guidance_scale=unconditional_guidance_scale,
+                                          unconditional_conditioning=unconditional_conditioning)
+            if callback: callback(i)
+        return x_dec
diff --git a/models/global_adapter.py b/models/global_adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..a999a76c19aeb4ac2bd76f8caada37b0980940f8
--- /dev/null
+++ b/models/global_adapter.py
@@ -0,0 +1,75 @@
+import torch
+from torch import nn
+from einops import rearrange
+
+
+class FourierEmbedder(object):
+    def __init__(self, num_freqs=64, temperature=100):
+        self.num_freqs = num_freqs
+        self.temperature = temperature
+        self.freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs)  
+
+    @ torch.no_grad()
+    def __call__(self, x, cat_dim=-1):
+        "x: arbitrary shape of tensor. dim: cat dim"
+        out = []
+        for freq in self.freq_bands:
+            out.append(torch.sin(freq*x))
+            out.append(torch.cos(freq*x))
+        return torch.cat(out, cat_dim)
+
+
+class FourierColorEmbedder(nn.Module):
+    def __init__(self, in_dim=180, out_dim=768, num_tokens=4, fourier_freqs=4, temperature=100, scale=100):
+        super().__init__()
+        self.in_dim = in_dim  
+        self.out_dim = out_dim
+        self.fourier_freqs = fourier_freqs
+        self.num_tokens = num_tokens
+
+        self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs, temperature=temperature)
+        self.in_dim *= (fourier_freqs * 2)
+        self.mlp = nn.Sequential(
+            nn.Linear(self.in_dim, 512),
+            nn.LayerNorm(512),
+            nn.SiLU(),
+            nn.Linear(512, 512),
+            nn.LayerNorm(512),
+            nn.SiLU(),
+            nn.Linear(512, out_dim*self.num_tokens),
+        )
+
+        self.null_features = torch.nn.Parameter(torch.zeros([self.in_dim]))
+        self.scale = scale
+
+    def forward(self, x, mask=None):
+        if x.ndim == 3:
+            assert x.size(1) == 1
+            x = x.squeeze(1)
+        bs = x.shape[0]
+        if mask is None:
+            mask = torch.ones(bs, 1, device=x.device)
+        x = self.fourier_embedder(x * self.scale) 
+        x = mask * x + (1-mask) * self.null_features.view(1,-1)
+        x = self.mlp(x).view(bs, self.num_tokens, self.out_dim)  # B*1*C
+        return x
+
+
+class GlobalAdapter(nn.Module):
+
+    def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, context_tokens=4,
+            color_in_dim=180, color_num_tokens=4, color_fourier_freqs=4, color_temperature=100, color_scale=100):
+
+        super().__init__()
+
+        self.cross_attention_dim = cross_attention_dim
+        self.context_tokens = context_tokens
+        self.proj = torch.nn.Linear(clip_embeddings_dim, self.context_tokens * cross_attention_dim)
+        self.norm = torch.nn.LayerNorm(cross_attention_dim)
+        self.color_embed = FourierColorEmbedder(color_in_dim, cross_attention_dim, color_num_tokens, color_fourier_freqs, color_temperature, color_scale)
+
+    def forward(self, x, x_color, *args, **kwargs):
+        context_tokens = self.proj(x).reshape(-1, self.context_tokens, self.cross_attention_dim)
+        context_tokens = self.norm(context_tokens)
+        color_tokens = self.color_embed(x_color)
+        return context_tokens, color_tokens
diff --git a/models/hack.py b/models/hack.py
new file mode 100644
index 0000000000000000000000000000000000000000..454361e9d036cd1a6a79122c2fd16b489e4767b1
--- /dev/null
+++ b/models/hack.py
@@ -0,0 +1,111 @@
+import torch
+import einops
+
+import ldm.modules.encoders.modules
+import ldm.modules.attention
+
+from transformers import logging
+from ldm.modules.attention import default
+
+
+def disable_verbosity():
+    logging.set_verbosity_error()
+    print('logging improved.')
+    return
+
+
+def enable_sliced_attention():
+    ldm.modules.attention.CrossAttention.forward = _hacked_sliced_attentin_forward
+    print('Enabled sliced_attention.')
+    return
+
+
+def hack_everything(clip_skip=0):
+    disable_verbosity()
+    ldm.modules.encoders.modules.FrozenCLIPEmbedder.forward = _hacked_clip_forward
+    ldm.modules.encoders.modules.FrozenCLIPEmbedder.clip_skip = clip_skip
+    print('Enabled clip hacks.')
+    return
+
+
+# Written by Lvmin
+def _hacked_clip_forward(self, text):
+    PAD = self.tokenizer.pad_token_id
+    EOS = self.tokenizer.eos_token_id
+    BOS = self.tokenizer.bos_token_id
+
+    def tokenize(t):
+        return self.tokenizer(t, truncation=False, add_special_tokens=False)["input_ids"]
+
+    def transformer_encode(t):
+        if self.clip_skip > 1:
+            rt = self.transformer(input_ids=t, output_hidden_states=True)
+            return self.transformer.text_model.final_layer_norm(rt.hidden_states[-self.clip_skip])
+        else:
+            return self.transformer(input_ids=t, output_hidden_states=False).last_hidden_state
+
+    def split(x):
+        return x[75 * 0: 75 * 1], x[75 * 1: 75 * 2], x[75 * 2: 75 * 3]
+
+    def pad(x, p, i):
+        return x[:i] if len(x) >= i else x + [p] * (i - len(x))
+
+    raw_tokens_list = tokenize(text)
+    tokens_list = []
+
+    for raw_tokens in raw_tokens_list:
+        raw_tokens_123 = split(raw_tokens)
+        raw_tokens_123 = [[BOS] + raw_tokens_i + [EOS] for raw_tokens_i in raw_tokens_123]
+        raw_tokens_123 = [pad(raw_tokens_i, PAD, 77) for raw_tokens_i in raw_tokens_123]
+        tokens_list.append(raw_tokens_123)
+
+    tokens_list = torch.IntTensor(tokens_list).to(self.device)
+
+    feed = einops.rearrange(tokens_list, 'b f i -> (b f) i')
+    y = transformer_encode(feed)
+    z = einops.rearrange(y, '(b f) i c -> b (f i) c', f=3)
+
+    return z
+
+
+# Stolen from https://github.com/basujindal/stable-diffusion/blob/main/optimizedSD/splitAttention.py
+def _hacked_sliced_attentin_forward(self, x, context=None, mask=None):
+    h = self.heads
+
+    q = self.to_q(x)
+    context = default(context, x)
+    k = self.to_k(context)
+    v = self.to_v(context)
+    del context, x
+
+    q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+
+    limit = k.shape[0]
+    att_step = 1
+    q_chunks = list(torch.tensor_split(q, limit // att_step, dim=0))
+    k_chunks = list(torch.tensor_split(k, limit // att_step, dim=0))
+    v_chunks = list(torch.tensor_split(v, limit // att_step, dim=0))
+
+    q_chunks.reverse()
+    k_chunks.reverse()
+    v_chunks.reverse()
+    sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
+    del k, q, v
+    for i in range(0, limit, att_step):
+        q_buffer = q_chunks.pop()
+        k_buffer = k_chunks.pop()
+        v_buffer = v_chunks.pop()
+        sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale
+
+        del k_buffer, q_buffer
+        # attention, what we cannot get enough of, by chunks
+
+        sim_buffer = sim_buffer.softmax(dim=-1)
+
+        sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer)
+        del v_buffer
+        sim[i:i + att_step, :, :] = sim_buffer
+
+        del sim_buffer
+    sim = einops.rearrange(sim, '(b h) n d -> b n (h d)', h=h)
+    return self.to_out(sim)
diff --git a/models/local_adapter.py b/models/local_adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..854701ba833343b543d18b595794de0b1b52e485
--- /dev/null
+++ b/models/local_adapter.py
@@ -0,0 +1,443 @@
+import torch
+import torch as th
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ldm.modules.diffusionmodules.util import (
+    checkpoint,
+    conv_nd,
+    linear,
+    zero_module,
+    timestep_embedding,
+)
+from ldm.modules.diffusionmodules.openaimodel import (
+    UNetModel, 
+    TimestepBlock, 
+    TimestepEmbedSequential, 
+    ResBlock, 
+    Downsample, 
+    AttentionBlock
+)
+from ldm.modules.attention import SpatialTransformer
+from ldm.util import exists
+
+
+def layer_norm(tensor, drop=0.5, eps=1e-6):
+    mean = tensor.mean(dim=(1,2)).squeeze()
+    std = tensor.std(dim=(1,2)).squeeze()
+    var = tensor.var(dim=(1,2))
+    tensor = (tensor-mean) / (var+eps) ** 0.5
+    neg = (tensor * (tensor < 0).float()).abs().sum() / (tensor<0).float().sum()
+    pos = (tensor * (tensor > 0).float()).abs().sum() / (tensor>0).float().sum()
+
+
+class LocalTimestepEmbedSequential(nn.Sequential, TimestepBlock):
+    def forward(self, x, emb, context=None, local_control=None, content_control=None, color_control=None, content_w=1.0, color_w=1.0):
+        for layer in self:
+            if isinstance(layer, TimestepBlock):
+                x = layer(x, emb)
+            elif isinstance(layer, SpatialTransformer):
+                x = layer(x, context, content_control, color_control, content_w, color_w)
+            elif isinstance(layer, LocalResBlock):
+                x = layer(x, emb, local_control)
+            else:
+                x = layer(x)
+        return x
+
+
+class FDN(nn.Module):
+    def __init__(self, norm_nc, label_nc):
+        super().__init__()
+        ks = 3
+        pw = ks // 2
+        self.param_free_norm = nn.GroupNorm(32, norm_nc, affine=False)
+        self.conv_gamma = nn.Conv2d(label_nc, norm_nc, kernel_size=ks, padding=pw)
+        self.conv_beta = nn.Conv2d(label_nc, norm_nc, kernel_size=ks, padding=pw)
+
+    def forward(self, x, local_features):
+        normalized = self.param_free_norm(x)
+        assert local_features.size()[2:] == x.size()[2:]
+        gamma = self.conv_gamma(local_features)
+        beta = self.conv_beta(local_features)
+        out = normalized * (1 + gamma) + beta
+        return out
+
+
+class LocalResBlock(nn.Module):
+    def __init__(
+        self,
+        channels,
+        emb_channels,
+        dropout,
+        out_channels=None,
+        dims=2,
+        use_checkpoint=False,
+        inject_channels=None,
+    ):
+        super().__init__()
+        self.channels = channels
+        self.emb_channels = emb_channels
+        self.dropout = dropout
+        self.out_channels = out_channels or channels
+        self.use_checkpoint = use_checkpoint
+        self.norm_in = FDN(channels, inject_channels)
+        self.norm_out = FDN(self.out_channels, inject_channels)
+
+        self.in_layers = nn.Sequential(
+            nn.Identity(),
+            nn.SiLU(),
+            conv_nd(dims, channels, self.out_channels, 3, padding=1),
+        )
+
+        self.emb_layers = nn.Sequential(
+            nn.SiLU(),
+            linear(
+                emb_channels,
+                self.out_channels,
+            ),
+        )
+        self.out_layers = nn.Sequential(
+            nn.Identity(),
+            nn.SiLU(),
+            nn.Dropout(p=dropout),
+            zero_module(
+                conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
+            ),
+        )
+
+        if self.out_channels == channels:
+            self.skip_connection = nn.Identity()
+        else:
+            self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
+
+    def forward(self, x, emb, local_conditions):
+        return checkpoint(
+            self._forward, (x, emb, local_conditions), self.parameters(), self.use_checkpoint
+        )
+
+    def _forward(self, x, emb, local_conditions):
+        local_conditions = F.interpolate(local_conditions, x.shape[-2:], mode="bilinear")
+        h = self.norm_in(x, local_conditions)
+        h = self.in_layers(h)
+        
+        emb_out = self.emb_layers(emb).type(h.dtype)
+        while len(emb_out.shape) < len(h.shape):
+            emb_out = emb_out[..., None]
+        
+        h = h + emb_out
+        h = self.norm_out(h, local_conditions)
+        h = self.out_layers(h)
+        
+        return self.skip_connection(x) + h
+
+
+class LocalAdapter(nn.Module):
+    def __init__(
+            self,
+            in_channels,
+            model_channels,
+            local_channels,
+            inject_channels,
+            inject_layers,
+            query_channels,
+            query_layers,
+            query_scales,
+            num_res_blocks,
+            attention_resolutions,
+            dropout=0,
+            channel_mult=(1, 2, 4, 8),
+            conv_resample=True,
+            dims=2,
+            use_checkpoint=False,
+            use_fp16=False,
+            num_heads=-1,
+            num_head_channels=-1,
+            num_heads_upsample=-1,
+            use_scale_shift_norm=False,
+            resblock_updown=False,
+            use_new_attention_order=False,
+            use_spatial_transformer=False,  # custom transformer support
+            transformer_depth=1,  # custom transformer support
+            context_dim=None,  # custom transformer support
+            n_embed=None,  # custom support for prediction of discrete ids into codebook of first stage vq model
+            legacy=True,
+            disable_self_attentions=None,
+            num_attention_blocks=None,
+            disable_middle_self_attn=False,
+            use_linear_in_transformer=False,
+    ):
+        super().__init__()
+        if use_spatial_transformer:
+            assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
+
+        if context_dim is not None:
+            assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
+            from omegaconf.listconfig import ListConfig
+            if type(context_dim) == ListConfig:
+                context_dim = list(context_dim)
+
+        if num_heads_upsample == -1:
+            num_heads_upsample = num_heads
+
+        if num_heads == -1:
+            assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
+
+        if num_head_channels == -1:
+            assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
+
+        self.dims = dims
+        self.in_channels = in_channels
+        self.model_channels = model_channels
+        self.inject_layers = inject_layers
+        if isinstance(num_res_blocks, int):
+            self.num_res_blocks = len(channel_mult) * [num_res_blocks]
+        else:
+            if len(num_res_blocks) != len(channel_mult):
+                raise ValueError("provide num_res_blocks either as an int (globally constant) or "
+                                 "as a list/tuple (per-level) with the same length as channel_mult")
+            self.num_res_blocks = num_res_blocks
+        if disable_self_attentions is not None:
+            # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
+            assert len(disable_self_attentions) == len(channel_mult)
+        if num_attention_blocks is not None:
+            assert len(num_attention_blocks) == len(self.num_res_blocks)
+            assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
+            print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
+                  f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
+                  f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
+                  f"attention will still not be set.")
+
+        self.attention_resolutions = attention_resolutions
+        self.dropout = dropout
+        self.channel_mult = channel_mult
+        self.conv_resample = conv_resample
+        self.use_checkpoint = use_checkpoint
+        self.dtype = th.float16 if use_fp16 else th.float32
+        self.num_heads = num_heads
+        self.num_head_channels = num_head_channels
+        self.num_heads_upsample = num_heads_upsample
+        self.predict_codebook_ids = n_embed is not None
+
+        self.query_channels = query_channels
+        self.query_layers = query_layers
+        self.query_scales = query_scales
+        visual_projs = []
+        for query_channel, inject_channel in zip(query_channels, inject_channels):
+            layer_proj = zero_module(linear(query_channel, inject_channel))
+            visual_projs.append(layer_proj)
+        self.visual_projs = nn.ModuleList(visual_projs)
+
+        time_embed_dim = model_channels * 4
+        self.time_embed = nn.Sequential(
+            linear(model_channels, time_embed_dim),
+            nn.SiLU(),
+            linear(time_embed_dim, time_embed_dim),
+        )
+
+        self.input_blocks = nn.ModuleList(
+            [
+                LocalTimestepEmbedSequential(
+                    conv_nd(dims, in_channels, model_channels, 3, padding=1)
+                )
+            ]
+        )
+        self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
+
+        self._feature_size = model_channels
+        input_block_chans = [model_channels]
+        ch = model_channels
+        ds = 1
+        for level, mult in enumerate(channel_mult):
+            for nr in range(self.num_res_blocks[level]):
+                if (1 + 3*level + nr) in self.inject_layers:
+                    layers = [
+                        LocalResBlock(
+                            ch,
+                            time_embed_dim,
+                            dropout,
+                            out_channels=mult * model_channels,
+                            dims=dims,
+                            use_checkpoint=use_checkpoint,
+                            inject_channels=inject_channels[level],
+                        )
+                    ]
+                else:
+                    layers = [
+                        ResBlock(
+                            ch,
+                            time_embed_dim,
+                            dropout,
+                            out_channels=mult * model_channels,
+                            dims=dims,
+                            use_checkpoint=use_checkpoint,
+                            use_scale_shift_norm=use_scale_shift_norm,
+                        )
+                    ]
+                ch = mult * model_channels
+                if ds in attention_resolutions:
+                    if num_head_channels == -1:
+                        dim_head = ch // num_heads
+                    else:
+                        num_heads = ch // num_head_channels
+                        dim_head = num_head_channels
+                    if legacy:
+                        dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+                    if exists(disable_self_attentions):
+                        disabled_sa = disable_self_attentions[level]
+                    else:
+                        disabled_sa = False
+
+                    if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
+                        layers.append(
+                            AttentionBlock(
+                                ch,
+                                use_checkpoint=use_checkpoint,
+                                num_heads=num_heads,
+                                num_head_channels=dim_head,
+                                use_new_attention_order=use_new_attention_order,
+                            ) if not use_spatial_transformer else SpatialTransformer(
+                                ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
+                                disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
+                                use_checkpoint=use_checkpoint
+                            )
+                        )
+                self.input_blocks.append(LocalTimestepEmbedSequential(*layers))
+                self.zero_convs.append(self.make_zero_conv(ch))
+                self._feature_size += ch
+                input_block_chans.append(ch)
+            if level != len(channel_mult) - 1:
+                out_ch = ch
+                self.input_blocks.append(
+                    LocalTimestepEmbedSequential(
+                        ResBlock(
+                            ch,
+                            time_embed_dim,
+                            dropout,
+                            out_channels=out_ch,
+                            dims=dims,
+                            use_checkpoint=use_checkpoint,
+                            use_scale_shift_norm=use_scale_shift_norm,
+                            down=True,
+                        )
+                        if resblock_updown
+                        else Downsample(
+                            ch, conv_resample, dims=dims, out_channels=out_ch
+                        )
+                    )
+                )
+                ch = out_ch
+                input_block_chans.append(ch)
+                self.zero_convs.append(self.make_zero_conv(ch))
+                ds *= 2
+                self._feature_size += ch
+
+        if num_head_channels == -1:
+            dim_head = ch // num_heads
+        else:
+            num_heads = ch // num_head_channels
+            dim_head = num_head_channels
+        if legacy:
+            dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+        self.middle_block = LocalTimestepEmbedSequential(
+            ResBlock(
+                ch,
+                time_embed_dim,
+                dropout,
+                dims=dims,
+                use_checkpoint=use_checkpoint,
+                use_scale_shift_norm=use_scale_shift_norm,
+            ),
+            AttentionBlock(
+                ch,
+                use_checkpoint=use_checkpoint,
+                num_heads=num_heads,
+                num_head_channels=dim_head,
+                use_new_attention_order=use_new_attention_order,
+            ) if not use_spatial_transformer else SpatialTransformer(
+                ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
+                disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
+                use_checkpoint=use_checkpoint
+            ),
+            ResBlock(
+                ch,
+                time_embed_dim,
+                dropout,
+                dims=dims,
+                use_checkpoint=use_checkpoint,
+                use_scale_shift_norm=use_scale_shift_norm,
+            ),
+        )
+        self.middle_block_out = self.make_zero_conv(ch)
+        self._feature_size += ch
+
+    def make_zero_conv(self, channels):
+        return LocalTimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
+
+    def extract_local_features(self, q_former, text, local_conditions):
+        # extract local features
+        bs, chn, h, w = local_conditions.shape
+        n = chn // 3
+        image_features_frozen, image_atts = q_former.forward_visual_encoder(local_conditions.view(bs * n, 3, h, w))
+        bs_n, seq_len, v_chn = image_features_frozen[0].shape 
+
+        # with pos embed
+        image_features_frozen = [q_former.crossattn_embeddings(image_feat) for image_feat in image_features_frozen]
+
+        # image_features_frozen: [bs * n, seq_len, c]
+        image_features_frozen = [image_feat.view(bs, n*seq_len, v_chn) for image_feat in image_features_frozen]
+        image_atts = [image_att.view(bs, -1) for image_att in image_atts]
+
+        local_embeddings = q_former.forward_qformer(text, image_features_frozen, image_atts)
+
+        # process qformer features
+        local_features = []
+        for lvl, scale_factor, visual_proj in zip(self.query_layers, self.query_scales, self.visual_projs):
+            local_emb = local_embeddings[lvl]
+            _, seq_len, ndim = local_emb.shape
+            l = int(seq_len ** 0.5)
+            local_emb = F.interpolate(local_emb.transpose(1,2).view(bs, -1, l, l), None, scale_factor=scale_factor, mode="bilinear")
+            local_emb = visual_proj(local_emb.transpose(1,2).transpose(2,3).flatten(1,2))
+            local_emb = local_emb.view(bs, int(l*scale_factor), int(l*scale_factor), -1).transpose(2,3).transpose(1,2)
+            local_features.append(local_emb)
+        return local_features
+
+    def forward(self, x, timesteps, context, local_features, **kwargs):
+        t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+        emb = self.time_embed(t_emb)
+
+        outs = []
+        h = x.type(self.dtype)
+        for layer_idx, (module, zero_conv) in enumerate(zip(self.input_blocks, self.zero_convs)):
+            if layer_idx in self.inject_layers:
+                h = module(h, emb, context, local_control=local_features[self.inject_layers.index(layer_idx)])
+            else:
+                h = module(h, emb, context)
+            outs.append(zero_conv(h, emb, context))
+
+        h = self.middle_block(h, emb, context)
+        outs.append(self.middle_block_out(h, emb, context))
+
+        return outs
+
+
+class LocalControlUNetModel(UNetModel):
+    def forward(self, x, timesteps=None, context=None, local_control=None, content_control=None, color_control=None, local_w=1.0, content_w=1.0, color_w=1.0, **kwargs):
+        hs = []
+
+        with torch.no_grad():
+            t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+            emb = self.time_embed(t_emb)
+            h = x.type(self.dtype)
+            for module in self.input_blocks:
+                h = module(h, emb, context, content_control=content_control, color_control=color_control, content_w=content_w, color_w=color_w)
+                hs.append(h)
+            h = self.middle_block(h, emb, context, content_control=content_control, color_control=color_control, content_w=content_w, color_w=color_w)
+
+        h += local_w * local_control.pop()
+
+        for module in self.output_blocks:
+            h = torch.cat([h, hs.pop() + local_w * local_control.pop()], dim=1)
+            h = module(h, emb, context, content_control=content_control, color_control=color_control, content_w=content_w, color_w=color_w)
+
+        h = h.type(x.dtype)
+        return self.out(h)
diff --git a/models/logger.py b/models/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a89ce02c9680929bac2a84d0b1148e7259152fb
--- /dev/null
+++ b/models/logger.py
@@ -0,0 +1,97 @@
+import os
+
+import numpy as np
+from PIL import Image
+import torch
+import torchvision
+from PIL import Image
+from pytorch_lightning.callbacks import Callback
+try:
+    from pytorch_lightning.utilities.distributed import rank_zero_only
+except:
+    from pytorch_lightning.utilities.rank_zero import rank_zero_only
+
+
+class ImageLogger(Callback):
+    def __init__(self, batch_frequency=2000, max_images=4, clamp=True, increase_log_steps=True,
+                 rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False,
+                 log_images_kwargs=None, num_local_conditions=7):
+        super().__init__()
+        self.rescale = rescale
+        self.batch_freq = batch_frequency
+        self.max_images = max_images
+        if not increase_log_steps:
+            self.log_steps = [self.batch_freq]
+        self.clamp = clamp
+        self.disabled = disabled
+        self.log_on_batch_idx = log_on_batch_idx
+        self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
+        self.log_first_step = log_first_step
+        self.num_local_conditions = num_local_conditions
+
+    @rank_zero_only
+    def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx):
+        root = os.path.join(save_dir, "image_log", split)
+        for k in images:
+            if k == 'local_control':
+                _, chn, h, w = images[k].shape
+                if h == w == 1:
+                    continue
+                for local_idx in range(chn//3):
+                    grid = torchvision.utils.make_grid(images[k][:, 3*local_idx: 3*(local_idx+1), :, : ], nrow=4)
+                    if self.rescale:
+                        grid = (grid + 1.0) / 2.0  # -1,1 -> 0,1; c,h,w
+                    grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
+                    grid = grid.numpy()
+                    grid = (grid * 255).astype(np.uint8)
+                    filename = "gs-{:06}_e-{:06}_b-{:06}_{}_{}.png".format(global_step, current_epoch, batch_idx, k, local_idx)
+                    path = os.path.join(root, filename)
+                    os.makedirs(os.path.split(path)[0], exist_ok=True)
+                    Image.fromarray(grid).save(path)
+            elif k != 'global_control':
+                grid = torchvision.utils.make_grid(images[k], nrow=4)
+                if self.rescale:
+                    grid = (grid + 1.0) / 2.0  # -1,1 -> 0,1; c,h,w
+                grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
+                grid = grid.numpy()
+                grid = (grid * 255).astype(np.uint8)
+                filename = "gs-{:06}_e-{:06}_b-{:06}_{}.png".format(global_step, current_epoch, batch_idx, k)
+                path = os.path.join(root, filename)
+                os.makedirs(os.path.split(path)[0], exist_ok=True)
+                Image.fromarray(grid).save(path)
+
+    def log_img(self, pl_module, batch, batch_idx, split="train"):
+        check_idx = batch_idx  # if self.log_on_batch_idx else pl_module.global_step
+        if (self.check_frequency(check_idx) and  # batch_idx % self.batch_freq == 0
+                hasattr(pl_module, "log_images") and
+                callable(pl_module.log_images) and
+                self.max_images > 0):
+            logger = type(pl_module.logger)
+
+            is_train = pl_module.training
+            if is_train:
+                pl_module.eval()
+
+            with torch.no_grad():
+                images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)
+
+            for k in images:
+                N = min(images[k].shape[0], self.max_images)
+                images[k] = images[k][:N]
+                if isinstance(images[k], torch.Tensor):
+                    images[k] = images[k].detach().cpu()
+                    if self.clamp:
+                        images[k] = torch.clamp(images[k], -1., 1.)
+
+            self.log_local(pl_module.logger.save_dir, split, images,
+                           pl_module.global_step, pl_module.current_epoch, batch_idx)
+
+            if is_train:
+                pl_module.train()
+
+    def check_frequency(self, check_idx):
+        return check_idx % self.batch_freq == 0
+
+    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
+        if not self.disabled:
+            self.log_img(pl_module, batch, batch_idx, split="train")
diff --git a/models/q_formers/Qformer.py b/models/q_formers/Qformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e1a30b61cd558330118dae6ea4a594837dac6db
--- /dev/null
+++ b/models/q_formers/Qformer.py
@@ -0,0 +1,1291 @@
+"""
+ * Copyright (c) 2023, salesforce.com, inc.
+ * All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+ * By Junnan Li
+ * Based on huggingface code base
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
+"""
+
+import math
+import os
+import warnings
+from dataclasses import dataclass
+from typing import Optional, Tuple, Dict, Any
+
+import torch
+from torch import Tensor, device, dtype, nn
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import CrossEntropyLoss
+import torch.nn.functional as F
+
+from transformers.activations import ACT2FN
+from transformers.file_utils import (
+    ModelOutput,
+)
+from transformers.modeling_outputs import (
+    BaseModelOutputWithPastAndCrossAttentions,
+    BaseModelOutputWithPoolingAndCrossAttentions,
+    CausalLMOutputWithCrossAttentions,
+    MaskedLMOutput,
+    MultipleChoiceModelOutput,
+    NextSentencePredictorOutput,
+    QuestionAnsweringModelOutput,
+    SequenceClassifierOutput,
+    TokenClassifierOutput,
+)
+from transformers.modeling_utils import (
+    PreTrainedModel,
+    apply_chunking_to_forward,
+    find_pruneable_heads_and_indices,
+    prune_linear_layer,
+)
+from transformers.utils import logging
+from transformers.models.bert.configuration_bert import BertConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class HeadDropout(nn.Module):
+    def __init__(self, p=0.0):
+        super().__init__()
+        self.p = p
+
+    def forward(self, x):
+
+        if not self.training:
+            return x
+
+        assert x.ndim == 4
+        bs, nH, l, c = x.shape
+        drop_mask = (torch.rand(nH, device=x.device).view(1, nH, 1, 1) < self.p).to(x.dtype)
+        return x * (1-drop_mask)
+
+
+class BertEmbeddings(nn.Module):
+    """Construct the embeddings from word and position embeddings."""
+
+    def __init__(self, config):
+        super().__init__()
+        self.word_embeddings = nn.Embedding(
+            config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
+        )
+        self.position_embeddings = nn.Embedding(
+            config.max_position_embeddings, config.hidden_size
+        )
+
+        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
+        # any TensorFlow checkpoint file
+        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+        # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+        self.register_buffer(
+            "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))
+        )
+        self.position_embedding_type = getattr(
+            config, "position_embedding_type", "absolute"
+        )
+
+        self.config = config
+
+    def forward(
+        self,
+        input_ids=None,
+        position_ids=None,
+        query_embeds=None,
+        extra_embeds=None,
+        past_key_values_length=0,
+    ):
+        if input_ids is not None:
+            seq_length = input_ids.size()[1]
+        else:
+            seq_length = 0
+
+        if position_ids is None:
+            position_ids = self.position_ids[
+                :, past_key_values_length : seq_length + past_key_values_length
+            ].clone()
+
+        if input_ids is not None:
+            embeddings = self.word_embeddings(input_ids)
+            if self.position_embedding_type == "absolute":
+                position_embeddings = self.position_embeddings(position_ids)
+                embeddings = embeddings + position_embeddings
+
+            if query_embeds is not None:
+                embeddings = torch.cat((query_embeds, embeddings), dim=1)
+        else:
+            embeddings = query_embeds
+
+        if extra_embeds is not None:
+            embeddings = torch.cat((embeddings, extra_embeds), dim=1)
+
+        embeddings = self.LayerNorm(embeddings)
+        embeddings = self.dropout(embeddings)
+        return embeddings
+
+
+class BertSelfAttention(nn.Module):
+    def __init__(self, config, is_cross_attention):
+        super().__init__()
+        self.config = config
+        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
+            config, "embedding_size"
+        ):
+            raise ValueError(
+                "The hidden size (%d) is not a multiple of the number of attention "
+                "heads (%d)" % (config.hidden_size, config.num_attention_heads)
+            )
+
+        self.num_attention_heads = config.num_attention_heads
+        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+        self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+        self.query = nn.Linear(config.hidden_size, self.all_head_size)
+        if is_cross_attention:
+            self.key = nn.Linear(config.encoder_width, self.all_head_size)
+            self.value = nn.Linear(config.encoder_width, self.all_head_size)
+        else:
+            self.key = nn.Linear(config.hidden_size, self.all_head_size)
+            self.value = nn.Linear(config.hidden_size, self.all_head_size)
+
+        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+        self.head_dropout = HeadDropout(config.attention_probs_head_dropout_prob)
+        self.position_embedding_type = getattr(
+            config, "position_embedding_type", "absolute"
+        )
+        if (
+            self.position_embedding_type == "relative_key"
+            or self.position_embedding_type == "relative_key_query"
+        ):
+            self.max_position_embeddings = config.max_position_embeddings
+            self.distance_embedding = nn.Embedding(
+                2 * config.max_position_embeddings - 1, self.attention_head_size
+            )
+        self.save_attention = False
+
+    def save_attn_gradients(self, attn_gradients):
+        self.attn_gradients = attn_gradients
+
+    def get_attn_gradients(self):
+        return self.attn_gradients
+
+    def save_attention_map(self, attention_map):
+        self.attention_map = attention_map
+
+    def get_attention_map(self):
+        return self.attention_map
+
+    def transpose_for_scores(self, x):
+        new_x_shape = x.size()[:-1] + (
+            self.num_attention_heads,
+            self.attention_head_size,
+        )
+        x = x.view(*new_x_shape)
+        return x.permute(0, 2, 1, 3)
+
+    def forward(
+        self,
+        hidden_states,
+        attention_mask=None,
+        head_mask=None,
+        encoder_hidden_states=None,
+        encoder_attention_mask=None,
+        past_key_value=None,
+        output_attentions=False,
+    ):
+
+        # If this is instantiated as a cross-attention module, the keys
+        # and values come from an encoder; the attention mask needs to be
+        # such that the encoder's padding tokens are not attended to.
+        is_cross_attention = encoder_hidden_states is not None
+
+        if is_cross_attention:
+            key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
+            value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
+            attention_mask = encoder_attention_mask
+        elif past_key_value is not None:
+            key_layer = self.transpose_for_scores(self.key(hidden_states))
+            value_layer = self.transpose_for_scores(self.value(hidden_states))
+            key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
+            value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
+        else:
+            key_layer = self.transpose_for_scores(self.key(hidden_states))
+            value_layer = self.transpose_for_scores(self.value(hidden_states))
+
+        mixed_query_layer = self.query(hidden_states)
+
+        query_layer = self.transpose_for_scores(mixed_query_layer)
+
+        past_key_value = (key_layer, value_layer)
+
+        if is_cross_attention:
+            key_layer = self.head_dropout(key_layer)
+            value_layer = self.head_dropout(value_layer)
+
+        # Take the dot product between "query" and "key" to get the raw attention scores.
+        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+        if (
+            self.position_embedding_type == "relative_key"
+            or self.position_embedding_type == "relative_key_query"
+        ):
+            seq_length = hidden_states.size()[1]
+            position_ids_l = torch.arange(
+                seq_length, dtype=torch.long, device=hidden_states.device
+            ).view(-1, 1)
+            position_ids_r = torch.arange(
+                seq_length, dtype=torch.long, device=hidden_states.device
+            ).view(1, -1)
+            distance = position_ids_l - position_ids_r
+            positional_embedding = self.distance_embedding(
+                distance + self.max_position_embeddings - 1
+            )
+            positional_embedding = positional_embedding.to(
+                dtype=query_layer.dtype
+            )  # fp16 compatibility
+
+            if self.position_embedding_type == "relative_key":
+                relative_position_scores = torch.einsum(
+                    "bhld,lrd->bhlr", query_layer, positional_embedding
+                )
+                attention_scores = attention_scores + relative_position_scores
+            elif self.position_embedding_type == "relative_key_query":
+                relative_position_scores_query = torch.einsum(
+                    "bhld,lrd->bhlr", query_layer, positional_embedding
+                )
+                relative_position_scores_key = torch.einsum(
+                    "bhrd,lrd->bhlr", key_layer, positional_embedding
+                )
+                attention_scores = (
+                    attention_scores
+                    + relative_position_scores_query
+                    + relative_position_scores_key
+                )
+
+        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+        if attention_mask is not None:
+            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
+            attention_scores = attention_scores + attention_mask
+
+        # Normalize the attention scores to probabilities.
+        attention_probs = nn.Softmax(dim=-1)(attention_scores)
+
+        if is_cross_attention and self.save_attention:
+            self.save_attention_map(attention_probs)
+            attention_probs.register_hook(self.save_attn_gradients)
+
+        # This is actually dropping out entire tokens to attend to, which might
+        # seem a bit unusual, but is taken from the original Transformer paper.
+        attention_probs_dropped = self.dropout(attention_probs)
+
+        # Mask heads if we want to
+        if head_mask is not None:
+            attention_probs_dropped = attention_probs_dropped * head_mask
+
+        context_layer = torch.matmul(attention_probs_dropped, value_layer)
+
+        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+        context_layer = context_layer.view(*new_context_layer_shape)
+
+        outputs = (
+            (context_layer, attention_probs) if output_attentions else (context_layer,)
+        )
+
+        outputs = outputs + (past_key_value,)
+        return outputs
+
+
+class BertSelfOutput(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+    def forward(self, hidden_states, input_tensor):
+        hidden_states = self.dense(hidden_states)
+        hidden_states = self.dropout(hidden_states)
+        hidden_states = self.LayerNorm(hidden_states + input_tensor)
+        return hidden_states
+
+
+class BertAttention(nn.Module):
+    def __init__(self, config, is_cross_attention=False):
+        super().__init__()
+        self.self = BertSelfAttention(config, is_cross_attention)
+        self.output = BertSelfOutput(config)
+        self.pruned_heads = set()
+
+    def prune_heads(self, heads):
+        if len(heads) == 0:
+            return
+        heads, index = find_pruneable_heads_and_indices(
+            heads,
+            self.self.num_attention_heads,
+            self.self.attention_head_size,
+            self.pruned_heads,
+        )
+
+        # Prune linear layers
+        self.self.query = prune_linear_layer(self.self.query, index)
+        self.self.key = prune_linear_layer(self.self.key, index)
+        self.self.value = prune_linear_layer(self.self.value, index)
+        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+        # Update hyper params and store pruned heads
+        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+        self.self.all_head_size = (
+            self.self.attention_head_size * self.self.num_attention_heads
+        )
+        self.pruned_heads = self.pruned_heads.union(heads)
+
+    def forward(
+        self,
+        hidden_states,
+        attention_mask=None,
+        head_mask=None,
+        encoder_hidden_states=None,
+        encoder_attention_mask=None,
+        past_key_value=None,
+        output_attentions=False,
+    ):
+        self_outputs = self.self(
+            hidden_states,
+            attention_mask,
+            head_mask,
+            encoder_hidden_states,
+            encoder_attention_mask,
+            past_key_value,
+            output_attentions,
+        )
+        attention_output = self.output(self_outputs[0], hidden_states)
+
+        outputs = (attention_output,) + self_outputs[
+            1:
+        ]  # add attentions if we output them
+        return outputs
+
+
+class BertIntermediate(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+        if isinstance(config.hidden_act, str):
+            self.intermediate_act_fn = ACT2FN[config.hidden_act]
+        else:
+            self.intermediate_act_fn = config.hidden_act
+
+    def forward(self, hidden_states):
+        hidden_states = self.dense(hidden_states)
+        hidden_states = self.intermediate_act_fn(hidden_states)
+        return hidden_states
+
+
+class BertOutput(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+    def forward(self, hidden_states, input_tensor):
+        hidden_states = self.dense(hidden_states)
+        hidden_states = self.dropout(hidden_states)
+        hidden_states = self.LayerNorm(hidden_states + input_tensor)
+        return hidden_states
+
+
+class BertLayer(nn.Module):
+    def __init__(self, config, layer_num):
+        super().__init__()
+        self.config = config
+        self.chunk_size_feed_forward = config.chunk_size_feed_forward
+        self.seq_len_dim = 1
+        self.attention = BertAttention(config)
+        self.layer_num = layer_num
+        if (
+            self.config.add_cross_attention
+            and layer_num % self.config.cross_attention_freq == 0
+        ):
+            self.crossattention = BertAttention(
+                config, is_cross_attention=self.config.add_cross_attention
+            )
+            self.has_cross_attention = True
+        else:
+            self.has_cross_attention = False
+        self.intermediate = BertIntermediate(config)
+        self.output = BertOutput(config)
+
+        self.intermediate_query = BertIntermediate(config)
+        self.output_query = BertOutput(config)
+
+    def with_pos_embed(self, tensor, pos: Optional[Tensor], query_length):
+        if query_length > 0 and pos is not None:
+            return torch.cat((tensor[:, :query_length, :] + pos, tensor[:, query_length:, :]), dim=1)
+        else:
+            return tensor
+
+    def forward(
+        self,
+        hidden_states,
+        attention_mask=None,
+        head_mask=None,
+        encoder_hidden_states=None,
+        encoder_attention_mask=None,
+        query_pos_embeds=None,
+        past_key_value=None,
+        output_attentions=False,
+        query_length=0,
+    ):
+        # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+        self_attn_past_key_value = (
+            past_key_value[:2] if past_key_value is not None else None
+        )
+        
+        self_attention_outputs = self.attention(
+            self.with_pos_embed(hidden_states, query_pos_embeds, query_length),
+            attention_mask,
+            head_mask,
+            output_attentions=output_attentions,
+            past_key_value=self_attn_past_key_value,
+        )
+        attention_output = self_attention_outputs[0]
+        outputs = self_attention_outputs[1:-1]
+
+        present_key_value = self_attention_outputs[-1]
+
+        if query_length > 0:
+            query_attention_output = attention_output[:, :query_length, :]
+
+            if self.has_cross_attention:
+                assert (
+                    encoder_hidden_states is not None
+                ), "encoder_hidden_states must be given for cross-attention layers"
+
+                cross_attention_outputs = self.crossattention(
+                    self.with_pos_embed(query_attention_output, query_pos_embeds, query_length),
+                    attention_mask,
+                    head_mask,
+                    encoder_hidden_states,
+                    encoder_attention_mask,
+                    output_attentions=output_attentions,
+                )
+                query_attention_output = cross_attention_outputs[0]
+                outputs = (
+                    outputs + cross_attention_outputs[1:-1]
+                )  # add cross attentions if we output attention weights
+
+            layer_output = apply_chunking_to_forward(
+                self.feed_forward_chunk_query,
+                self.chunk_size_feed_forward,
+                self.seq_len_dim,
+                query_attention_output,
+            )
+            if attention_output.shape[1] > query_length:
+                layer_output_text = apply_chunking_to_forward(
+                    self.feed_forward_chunk,
+                    self.chunk_size_feed_forward,
+                    self.seq_len_dim,
+                    attention_output[:, query_length:, :],
+                )
+                layer_output = torch.cat([layer_output, layer_output_text], dim=1)
+        else:
+            layer_output = apply_chunking_to_forward(
+                self.feed_forward_chunk,
+                self.chunk_size_feed_forward,
+                self.seq_len_dim,
+                attention_output,
+            )
+        outputs = (layer_output,) + outputs
+
+        outputs = outputs + (present_key_value,)
+
+        return outputs
+
+    def feed_forward_chunk(self, attention_output):
+        intermediate_output = self.intermediate(attention_output)
+        layer_output = self.output(intermediate_output, attention_output)
+        return layer_output
+
+    def feed_forward_chunk_query(self, attention_output):
+        intermediate_output = self.intermediate_query(attention_output)
+        layer_output = self.output_query(intermediate_output, attention_output)
+        return layer_output
+
+
+class BertEncoder(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.config = config
+        self.layer = nn.ModuleList(
+            [BertLayer(config, i) for i in range(config.num_hidden_layers)]
+        )
+
+    def forward(
+        self,
+        hidden_states,
+        attention_mask=None,
+        head_mask=None,
+        encoder_hidden_states=None,
+        encoder_attention_mask=None,
+        query_pos_embeds=None,
+        past_key_values=None,
+        use_cache=None,
+        output_attentions=False,
+        output_hidden_states=False,
+        return_dict=True,
+        query_length=0,
+    ):
+        all_hidden_states = () if output_hidden_states else None
+        all_self_attentions = () if output_attentions else None
+        all_cross_attentions = (
+            () if output_attentions and self.config.add_cross_attention else None
+        )
+
+        next_decoder_cache = () if use_cache else None
+
+        for i in range(self.config.num_hidden_layers):
+            layer_module = self.layer[i]
+            if output_hidden_states:
+                all_hidden_states = all_hidden_states + (hidden_states,)
+            if layer_module.has_cross_attention and isinstance(encoder_hidden_states, list):
+                encoder_hidden_states_curr = encoder_hidden_states.pop(0)
+                encoder_attention_mask_curr = encoder_attention_mask.pop(0)
+            else:
+                encoder_hidden_states_curr = encoder_hidden_states 
+                encoder_attention_mask_curr = encoder_attention_mask
+
+            layer_head_mask = head_mask[i] if head_mask is not None else None
+            past_key_value = past_key_values[i] if past_key_values is not None else None
+
+            if getattr(self.config, "gradient_checkpointing", False) and self.training:
+
+                if use_cache:
+                    logger.warn(
+                        "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+                    )
+                    use_cache = False
+
+                def create_custom_forward(module):
+                    def custom_forward(*inputs):
+                        return module(
+                            *inputs, past_key_value, output_attentions, query_length
+                        )
+
+                    return custom_forward
+
+                layer_outputs = torch.utils.checkpoint.checkpoint(
+                    create_custom_forward(layer_module),
+                    hidden_states,
+                    attention_mask,
+                    layer_head_mask,
+                    encoder_hidden_states_curr,
+                    encoder_attention_mask_curr,
+                    query_pos_embeds,
+                )
+            else:
+                layer_outputs = layer_module(
+                    hidden_states,
+                    attention_mask,
+                    layer_head_mask,
+                    encoder_hidden_states_curr,
+                    encoder_attention_mask_curr,
+                    query_pos_embeds,
+                    past_key_value,
+                    output_attentions,
+                    query_length,
+                )
+
+            hidden_states = layer_outputs[0]
+            if use_cache:
+                next_decoder_cache += (layer_outputs[-1],)
+            if output_attentions:
+                all_self_attentions = all_self_attentions + (layer_outputs[1],)
+                all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
+
+        if output_hidden_states:
+            all_hidden_states = all_hidden_states + (hidden_states,)
+
+        if not return_dict:
+            return tuple(
+                v
+                for v in [
+                    hidden_states,
+                    next_decoder_cache,
+                    all_hidden_states,
+                    all_self_attentions,
+                    all_cross_attentions,
+                ]
+                if v is not None
+            )
+        return BaseModelOutputWithPastAndCrossAttentions(
+            last_hidden_state=hidden_states,
+            past_key_values=next_decoder_cache,
+            hidden_states=all_hidden_states,
+            attentions=all_self_attentions,
+            cross_attentions=all_cross_attentions,
+        )
+
+
+class BertPooler(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+        self.activation = nn.Tanh()
+
+    def forward(self, hidden_states):
+        # We "pool" the model by simply taking the hidden state corresponding
+        # to the first token.
+        first_token_tensor = hidden_states[:, 0]
+        pooled_output = self.dense(first_token_tensor)
+        pooled_output = self.activation(pooled_output)
+        return pooled_output
+
+
+class BertPredictionHeadTransform(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+        if isinstance(config.hidden_act, str):
+            self.transform_act_fn = ACT2FN[config.hidden_act]
+        else:
+            self.transform_act_fn = config.hidden_act
+        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+    def forward(self, hidden_states):
+        hidden_states = self.dense(hidden_states)
+        hidden_states = self.transform_act_fn(hidden_states)
+        hidden_states = self.LayerNorm(hidden_states)
+        return hidden_states
+
+
+class BertLMPredictionHead(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.transform = BertPredictionHeadTransform(config)
+
+        # The output weights are the same as the input embeddings, but there is
+        # an output-only bias for each token.
+        self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+        self.bias = nn.Parameter(torch.zeros(config.vocab_size))
+
+        # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
+        self.decoder.bias = self.bias
+
+    def forward(self, hidden_states):
+        hidden_states = self.transform(hidden_states)
+        hidden_states = self.decoder(hidden_states)
+        return hidden_states
+
+
+class BertOnlyMLMHead(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.predictions = BertLMPredictionHead(config)
+
+    def forward(self, sequence_output):
+        prediction_scores = self.predictions(sequence_output)
+        return prediction_scores
+
+
+class BertPreTrainedModel(PreTrainedModel):
+    """
+    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+    models.
+    """
+
+    config_class = BertConfig
+    base_model_prefix = "bert"
+    _keys_to_ignore_on_load_missing = [r"position_ids"]
+
+    def _init_weights(self, module):
+        """Initialize the weights"""
+        if isinstance(module, (nn.Linear, nn.Embedding)):
+            # Slightly different from the TF version which uses truncated_normal for initialization
+            # cf https://github.com/pytorch/pytorch/pull/5617
+            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+        elif isinstance(module, nn.LayerNorm):
+            module.bias.data.zero_()
+            module.weight.data.fill_(1.0)
+        if isinstance(module, nn.Linear) and module.bias is not None:
+            module.bias.data.zero_()
+
+
+class BertModel(BertPreTrainedModel):
+    """
+    The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
+    cross-attention is added between the self-attention layers, following the architecture described in `Attention is
+    all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
+    Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
+    argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
+    input to the forward pass.
+    """
+
+    def __init__(self, config, add_pooling_layer=False):
+        super().__init__(config)
+        self.config = config
+
+        self.embeddings = BertEmbeddings(config)
+
+        self.encoder = BertEncoder(config)
+
+        self.pooler = BertPooler(config) if add_pooling_layer else None
+
+        self.init_weights()
+
+    def get_input_embeddings(self):
+        return self.embeddings.word_embeddings
+
+    def set_input_embeddings(self, value):
+        self.embeddings.word_embeddings = value
+
+    def _prune_heads(self, heads_to_prune):
+        """
+        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+        class PreTrainedModel
+        """
+        for layer, heads in heads_to_prune.items():
+            self.encoder.layer[layer].attention.prune_heads(heads)
+
+    def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor:
+        """
+        Invert an attention mask (e.g., switches 0. and 1.).
+
+        Args:
+            encoder_attention_mask (`torch.Tensor`): An attention mask.
+
+        Returns:
+            `torch.Tensor`: The inverted attention mask.
+        """
+        if encoder_attention_mask.dim() == 3:
+            encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
+        if encoder_attention_mask.dim() == 2:
+            encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
+        # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
+        # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow
+        # /transformer/transformer_layers.py#L270
+        # encoder_extended_attention_mask = (encoder_extended_attention_mask ==
+        # encoder_extended_attention_mask.transpose(-1, -2))
+        encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=self.dtype)  # fp16 compatibility
+        encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * torch.finfo(self.dtype).min
+
+        return encoder_extended_attention_mask
+
+    def get_extended_attention_mask(
+        self,
+        attention_mask: Tensor,
+        input_shape: Tuple[int],
+        device: device,
+        is_decoder: bool,
+        has_query: bool = False,
+    ) -> Tensor:
+        """
+        Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
+
+        Arguments:
+            attention_mask (:obj:`torch.Tensor`):
+                Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
+            input_shape (:obj:`Tuple[int]`):
+                The shape of the input to the model.
+            device: (:obj:`torch.device`):
+                The device of the input to the model.
+
+        Returns:
+            :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
+        """
+        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+        # ourselves in which case we just need to make it broadcastable to all heads.
+        if attention_mask.dim() == 3:
+            extended_attention_mask = attention_mask[:, None, :, :]
+        elif attention_mask.dim() == 2:
+            # Provided a padding mask of dimensions [batch_size, seq_length]
+            # - if the model is a decoder, apply a causal mask in addition to the padding mask
+            # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
+            if is_decoder:
+                batch_size, seq_length = input_shape
+
+                seq_ids = torch.arange(seq_length, device=device)
+                causal_mask = (
+                    seq_ids[None, None, :].repeat(batch_size, seq_length, 1)
+                    <= seq_ids[None, :, None]
+                )
+
+                # add a prefix ones mask to the causal mask
+                # causal and attention masks must have same type with pytorch version < 1.3
+                causal_mask = causal_mask.to(attention_mask.dtype)
+
+                if causal_mask.shape[1] < attention_mask.shape[1]:
+                    prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
+                    if has_query:  # UniLM style attention mask
+                        causal_mask = torch.cat(
+                            [
+                                torch.zeros(
+                                    (batch_size, prefix_seq_len, seq_length),
+                                    device=device,
+                                    dtype=causal_mask.dtype,
+                                ),
+                                causal_mask,
+                            ],
+                            axis=1,
+                        )
+                    causal_mask = torch.cat(
+                        [
+                            torch.ones(
+                                (batch_size, causal_mask.shape[1], prefix_seq_len),
+                                device=device,
+                                dtype=causal_mask.dtype,
+                            ),
+                            causal_mask,
+                        ],
+                        axis=-1,
+                    )
+                extended_attention_mask = (
+                    causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
+                )
+            else:
+                extended_attention_mask = attention_mask[:, None, None, :]
+        else:
+            raise ValueError(
+                "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
+                    input_shape, attention_mask.shape
+                )
+            )
+
+        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+        # masked positions, this operation will create a tensor which is 0.0 for
+        # positions we want to attend and -10000.0 for masked positions.
+        # Since we are adding it to the raw scores before the softmax, this is
+        # effectively the same as removing these entirely.
+        extended_attention_mask = extended_attention_mask.to(
+            dtype=self.dtype
+        )  # fp16 compatibility
+        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
+        return extended_attention_mask
+
+    def forward(
+        self,
+        input_ids=None,
+        attention_mask=None,
+        position_ids=None,
+        head_mask=None,
+        query_embeds=None,
+        extra_embeds=None,
+        encoder_hidden_states=None,
+        encoder_attention_mask=None,
+        query_pos_embeds=None,
+        past_key_values=None,
+        use_cache=None,
+        output_attentions=None,
+        output_hidden_states=None,
+        return_dict=None,
+        is_decoder=False,
+    ):
+        r"""
+        encoder_hidden_states  (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
+            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+            the model is configured as a decoder.
+        encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+            Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+            the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
+            - 1 for tokens that are **not masked**,
+            - 0 for tokens that are **masked**.
+        past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+            Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+            If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
+            (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
+            instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
+        use_cache (:obj:`bool`, `optional`):
+            If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
+            decoding (see :obj:`past_key_values`).
+        """
+        output_attentions = (
+            output_attentions
+            if output_attentions is not None
+            else self.config.output_attentions
+        )
+        output_hidden_states = (
+            output_hidden_states
+            if output_hidden_states is not None
+            else self.config.output_hidden_states
+        )
+        return_dict = (
+            return_dict if return_dict is not None else self.config.use_return_dict
+        )
+
+        # use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+        if input_ids is None:
+            assert (
+                query_embeds is not None
+            ), "You have to specify query_embeds when input_ids is None"
+
+        # past_key_values_length
+        past_key_values_length = (
+            past_key_values[0][0].shape[2] - self.config.query_length
+            if past_key_values is not None
+            else 0
+        )
+
+        query_length = query_embeds.shape[1] if query_embeds is not None else 0
+
+        embedding_output = self.embeddings(
+            input_ids=input_ids,
+            position_ids=position_ids,
+            query_embeds=query_embeds,
+            extra_embeds=extra_embeds,
+            past_key_values_length=past_key_values_length,
+        )
+
+        input_shape = embedding_output.size()[:-1]
+        batch_size, seq_length = input_shape
+        device = embedding_output.device
+
+        if attention_mask is None:
+            attention_mask = torch.ones(
+                ((batch_size, seq_length + past_key_values_length)), device=device
+            )
+
+        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+        # ourselves in which case we just need to make it broadcastable to all heads.
+        if is_decoder:
+            extended_attention_mask = self.get_extended_attention_mask(
+                attention_mask,
+                input_ids.shape,
+                device,
+                is_decoder,
+                has_query=(query_embeds is not None),
+            )
+        else:
+            extended_attention_mask = self.get_extended_attention_mask(
+                attention_mask, input_shape, device, is_decoder
+            )
+
+        # If a 2D or 3D attention mask is provided for the cross-attention
+        # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+        if encoder_hidden_states is not None:
+            if type(encoder_hidden_states) == list:
+                # encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
+                encoder_batch_size = encoder_hidden_states[0].size(0)
+                encoder_sequence_length = encoder_hidden_states[0].size(1)
+                encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+            else:
+                # encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
+                encoder_batch_size = encoder_hidden_states.size(0)
+                encoder_sequence_length = encoder_hidden_states.size(1)
+                encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+
+            if type(encoder_attention_mask) == list:
+                encoder_extended_attention_mask = [
+                    self.invert_attention_mask(mask) for mask in encoder_attention_mask
+                ]
+            elif encoder_attention_mask is None:
+                encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
+                encoder_extended_attention_mask = self.invert_attention_mask(
+                    encoder_attention_mask
+                )
+            else:
+                encoder_extended_attention_mask = self.invert_attention_mask(
+                    encoder_attention_mask
+                )
+        else:
+            encoder_extended_attention_mask = None
+
+        # Prepare head mask if needed
+        # 1.0 in head_mask indicate we keep the head
+        # attention_probs has shape bsz x n_heads x N x N
+        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+        encoder_outputs = self.encoder(
+            embedding_output,
+            attention_mask=extended_attention_mask,
+            head_mask=head_mask,
+            encoder_hidden_states=encoder_hidden_states,
+            encoder_attention_mask=encoder_extended_attention_mask,
+            query_pos_embeds=query_pos_embeds,
+            past_key_values=past_key_values,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            query_length=query_length,
+        )
+        sequence_output = encoder_outputs[0]
+        pooled_output = (
+            self.pooler(sequence_output) if self.pooler is not None else None
+        )
+
+        if not return_dict:
+            return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+        return BaseModelOutputWithPoolingAndCrossAttentions(
+            last_hidden_state=sequence_output,
+            pooler_output=pooled_output,
+            past_key_values=encoder_outputs.past_key_values,
+            hidden_states=encoder_outputs.hidden_states,
+            attentions=encoder_outputs.attentions,
+            cross_attentions=encoder_outputs.cross_attentions,
+        )
+
+
+class BertLMHeadModel(BertPreTrainedModel):
+
+    _keys_to_ignore_on_load_unexpected = [r"pooler"]
+    _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
+
+    def __init__(self, config):
+        super().__init__(config)
+
+        self.bert = BertModel(config, add_pooling_layer=False)
+        self.cls = BertOnlyMLMHead(config)
+
+        self.init_weights()
+
+    def get_output_embeddings(self):
+        return self.cls.predictions.decoder
+
+    def set_output_embeddings(self, new_embeddings):
+        self.cls.predictions.decoder = new_embeddings
+
+    def forward(
+        self,
+        input_ids=None,
+        attention_mask=None,
+        position_ids=None,
+        head_mask=None,
+        query_embeds=None,
+        extra_embeds=None,
+        encoder_hidden_states=None,
+        encoder_attention_mask=None,
+        query_pos_embeds=None,
+        labels=None,
+        past_key_values=None,
+        use_cache=True,
+        output_attentions=None,
+        output_hidden_states=None,
+        return_dict=None,
+        return_logits=False,
+        is_decoder=True,
+        reduction="mean",
+    ):
+        r"""
+        encoder_hidden_states  (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
+            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+            the model is configured as a decoder.
+        encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+            Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+            the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
+            - 1 for tokens that are **not masked**,
+            - 0 for tokens that are **masked**.
+        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+            Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
+            ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
+            ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
+        past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+            Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+            If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
+            (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
+            instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
+        use_cache (:obj:`bool`, `optional`):
+            If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
+            decoding (see :obj:`past_key_values`).
+        Returns:
+        Example::
+            >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
+            >>> import torch
+            >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
+            >>> config = BertConfig.from_pretrained("bert-base-cased")
+            >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
+            >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+            >>> outputs = model(**inputs)
+            >>> prediction_logits = outputs.logits
+        """
+        return_dict = (
+            return_dict if return_dict is not None else self.config.use_return_dict
+        )
+        if labels is not None:
+            use_cache = False
+        if past_key_values is not None:
+            query_embeds = None
+
+        outputs = self.bert(
+            input_ids,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            head_mask=head_mask,
+            query_embeds=query_embeds,
+            extra_embeds=extra_embeds,
+            encoder_hidden_states=encoder_hidden_states,
+            encoder_attention_mask=encoder_attention_mask,
+            query_pos_embeds=query_pos_embeds,
+            past_key_values=past_key_values,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            is_decoder=is_decoder,
+        )
+
+        sequence_output = outputs[0]
+        if query_embeds is not None:
+            sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
+
+        prediction_scores = self.cls(sequence_output)
+
+        if return_logits:
+            return prediction_scores[:, :-1, :].contiguous()
+
+        lm_loss = None
+        if labels is not None:
+            # we are doing next-token prediction; shift prediction scores and input ids by one
+            shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
+            labels = labels[:, 1:].contiguous()
+            loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
+            lm_loss = loss_fct(
+                shifted_prediction_scores.view(-1, self.config.vocab_size),
+                labels.view(-1),
+            )
+            if reduction == "none":
+                lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
+
+        if not return_dict:
+            output = (prediction_scores,) + outputs[2:]
+            return ((lm_loss,) + output) if lm_loss is not None else output
+
+        return CausalLMOutputWithCrossAttentions(
+            loss=lm_loss,
+            logits=prediction_scores,
+            past_key_values=outputs.past_key_values,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+            cross_attentions=outputs.cross_attentions,
+        )
+
+    def prepare_inputs_for_generation(
+        self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs
+    ):
+        # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
+        if attention_mask is None:
+            attention_mask = input_ids.new_ones(input_ids.shape)
+        query_mask = input_ids.new_ones(query_embeds.shape[:-1])
+        attention_mask = torch.cat([query_mask, attention_mask], dim=-1)
+
+        # cut decoder_input_ids if past is used
+        if past is not None:
+            input_ids = input_ids[:, -1:]
+
+        return {
+            "input_ids": input_ids,
+            "query_embeds": query_embeds,
+            "attention_mask": attention_mask,
+            "past_key_values": past,
+            "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
+            "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
+            "is_decoder": True,
+        }
+
+    def _reorder_cache(self, past, beam_idx):
+        reordered_past = ()
+        for layer_past in past:
+            reordered_past += (
+                tuple(
+                    past_state.index_select(0, beam_idx) for past_state in layer_past
+                ),
+            )
+        return reordered_past
+
+
+class BertForMaskedLM(BertPreTrainedModel):
+
+    _keys_to_ignore_on_load_unexpected = [r"pooler"]
+    _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
+
+    def __init__(self, config):
+        super().__init__(config)
+
+        self.bert = BertModel(config, add_pooling_layer=False)
+        self.cls = BertOnlyMLMHead(config)
+
+        self.init_weights()
+
+    def get_output_embeddings(self):
+        return self.cls.predictions.decoder
+
+    def set_output_embeddings(self, new_embeddings):
+        self.cls.predictions.decoder = new_embeddings
+
+    def forward(
+        self,
+        input_ids=None,
+        attention_mask=None,
+        position_ids=None,
+        head_mask=None,
+        query_embeds=None,
+        encoder_hidden_states=None,
+        encoder_attention_mask=None,
+        labels=None,
+        output_attentions=None,
+        output_hidden_states=None,
+        return_dict=None,
+        return_logits=False,
+        is_decoder=False,
+    ):
+        r"""
+        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+            Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
+            config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
+            (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
+        """
+
+        return_dict = (
+            return_dict if return_dict is not None else self.config.use_return_dict
+        )
+
+        outputs = self.bert(
+            input_ids,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            head_mask=head_mask,
+            query_embeds=query_embeds,
+            encoder_hidden_states=encoder_hidden_states,
+            encoder_attention_mask=encoder_attention_mask,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            is_decoder=is_decoder,
+        )
+
+        if query_embeds is not None:
+            sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
+        prediction_scores = self.cls(sequence_output)
+
+        if return_logits:
+            return prediction_scores
+
+        masked_lm_loss = None
+        if labels is not None:
+            loss_fct = CrossEntropyLoss()  # -100 index = padding token
+            masked_lm_loss = loss_fct(
+                prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
+            )
+
+        if not return_dict:
+            output = (prediction_scores,) + outputs[2:]
+            return (
+                ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+            )
+
+        return MaskedLMOutput(
+            loss=masked_lm_loss,
+            logits=prediction_scores,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+        )
diff --git a/models/q_formers/__init__.py b/models/q_formers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a53af87e5fdc4320e60dec3a597e45099d95a633
--- /dev/null
+++ b/models/q_formers/__init__.py
@@ -0,0 +1,51 @@
+import logging
+
+from omegaconf import OmegaConf
+from lavis.models import registry
+from lavis.models import load_preprocess
+
+from ldm.util import instantiate_from_config
+
+
+def load_blip2_model(cfg, is_eval=False, device="cpu"):
+    model_cls = registry.get_model_class(cfg.model_name)
+
+    # load preprocess
+    default_cfg = OmegaConf.load(model_cls.default_config_path(cfg.model_type))
+    default_cfg.model.pretrained = cfg.pretrained
+
+    if default_cfg.model.image_size != cfg.params.img_size:
+        default_cfg.model.image_size = cfg.params.img_size
+    model = model_cls.from_config(default_cfg.model)
+    model.cfg = default_cfg.model
+
+    if is_eval:
+        model.eval()
+
+    if default_cfg is not None:
+        preprocess_cfg = default_cfg.preprocess
+        vis_processors, txt_processors = load_preprocess(preprocess_cfg)
+    else:
+        vis_processors, txt_processors = None, None
+        logging.info(
+            f"""No default preprocess for model {name} ({model_type}).
+                This can happen if the model is not finetuned on downstream datasets,
+                or it is not intended for direct use without finetuning.
+            """
+        )
+
+    if device == "cpu" or device == torch.device("cpu"):
+        model = model.float()
+
+    return model.to(device), vis_processors, txt_processors
+
+
+def load_qformer_model(cfg):
+    blip2_model, vis_processor, txt_processor = load_blip2_model(cfg) 
+    q_former = instantiate_from_config(cfg)
+    if blip2_model.query_tokens.shape != q_former.query_tokens.shape:
+        blip2_model.query_tokens = q_former.query_tokens
+    model_name = cfg.params.get('model_name', 'bert-base-uncased')
+    if model_name == 'bert-base-uncased':
+        q_former.load_state_dict(blip2_model.state_dict(), strict=False)
+    return q_former, (vis_processor, txt_processor)
diff --git a/models/q_formers/blip2.py b/models/q_formers/blip2.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6e22acd0f1792efbb952f289abab558402e85f8
--- /dev/null
+++ b/models/q_formers/blip2.py
@@ -0,0 +1,329 @@
+"""
+ Copyright (c) 2023, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+import contextlib
+import logging
+import os
+import time
+import datetime
+
+import torch
+import torch.nn as nn
+import torch.distributed as dist
+import torch.nn.functional as F
+
+import lavis.common.dist_utils as dist_utils
+from lavis.common.dist_utils import download_cached_file
+from lavis.common.utils import is_url
+from lavis.common.logger import MetricLogger
+from lavis.models.base_model import BaseModel
+from lavis.models.blip2_models.Qformer import BertConfig
+# from lavis.models.eva_vit import create_eva_vit_g
+# from lavis.models.clip_vit import create_clip_vit_L
+from transformers import BertTokenizer
+
+from models.q_formers.Qformer import BertLMHeadModel
+from models.q_formers.eva_vit import create_eva_vit_g
+from models.q_formers.clip_vit import create_clip_vit_L
+
+
+class Blip2Base(BaseModel):
+    @classmethod
+    def init_tokenizer(cls, model_name='bert/bert-base-uncased', truncation_side="right"):
+        tokenizer = BertTokenizer.from_pretrained(model_name, truncation_side=truncation_side)
+        tokenizer.add_special_tokens({"bos_token": "[DEC]"})
+        return tokenizer
+
+    def maybe_autocast(self, dtype=torch.float16):
+        # if on cpu, don't use autocast
+        # if on gpu, use autocast with dtype if provided, otherwise use torch.float16
+        enable_autocast = self.device != torch.device("cpu")
+
+        if enable_autocast:
+            return torch.cuda.amp.autocast(dtype=dtype)
+        else:
+            return contextlib.nullcontext()
+
+    @classmethod
+    def init_Qformer(cls, num_query_token, vision_width, model_name="bert/bert-base-uncased", head_dropout=0, cross_attention_freq=2, query_token_init_type="normal"):
+        encoder_config = BertConfig.from_pretrained(model_name)
+        encoder_config.encoder_width = vision_width
+        # insert cross-attention layer every other block
+        encoder_config.add_cross_attention = True
+        encoder_config.cross_attention_freq = cross_attention_freq
+        encoder_config.query_length = num_query_token
+        encoder_config.attention_probs_head_dropout_prob = head_dropout
+        Qformer = BertLMHeadModel.from_pretrained(
+            model_name, config=encoder_config
+        )
+
+        if query_token_init_type == "uniform":
+            scale = encoder_config.hidden_size ** -0.5
+            query_tokens = nn.Parameter(
+                scale * torch.randn(1, num_query_token, encoder_config.hidden_size)
+            ) 
+            print("Initialize query tokens with uniform.")
+        else:
+            query_tokens = nn.Parameter(
+                torch.zeros(1, num_query_token, encoder_config.hidden_size)
+            )
+            query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
+            print("Initialize query tokens with normal.")
+        return Qformer, query_tokens
+
+    def init_vision_encoder(self, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision, n_levels=0):
+        assert model_name in [
+            "eva_clip_g",
+            "eva2_clip_L",
+            "clip_L",
+        ], "vit model must be eva_clip_g, eva2_clip_L or clip_L"
+        if model_name == "eva_clip_g":
+            visual_encoder = create_eva_vit_g(
+                img_size, drop_path_rate, use_grad_checkpoint, precision
+            )
+        elif model_name == "eva2_clip_L":
+            visual_encoder = create_eva2_vit_L(
+                img_size, drop_path_rate, use_grad_checkpoint, precision
+            )
+        elif model_name == "clip_L":
+            visual_encoder = create_clip_vit_L(img_size, use_grad_checkpoint, precision)
+
+        ln_vision = nn.ModuleList([LayerNorm(visual_encoder.num_features) for i in range(n_levels)])
+
+        self.vit_name = model_name
+        return visual_encoder, ln_vision
+
+    def load_from_pretrained(self, url_or_filename):
+        if is_url(url_or_filename):
+            cached_file = download_cached_file(
+                url_or_filename, check_hash=False, progress=True
+            )
+            checkpoint = torch.load(cached_file, map_location="cpu")
+        elif os.path.isfile(url_or_filename):
+            checkpoint = torch.load(url_or_filename, map_location="cpu")
+        else:
+            raise RuntimeError("checkpoint url or path is invalid")
+
+        state_dict = checkpoint["model"]
+
+        msg = self.load_state_dict(state_dict, strict=False)
+
+        # logging.info("Missing keys {}".format(msg.missing_keys))
+        logging.info("load checkpoint from %s" % url_or_filename)
+
+        return msg
+
+    def get_optimizer_params(self, weight_decay, lr_scale=1):
+
+        vit_num_layers = self.visual_encoder.get_num_layer()
+        lr_scales = list(lr_scale ** (vit_num_layers + 1 - i) for i in range(vit_num_layers + 2))
+
+        parameter_group_names = {}
+        parameter_group_vars = {}
+
+        for name, param in self.named_parameters():
+            if not param.requires_grad:
+                continue  # frozen weights
+            if len(param.shape) == 1 or name.endswith(".bias"):
+                group_name = "no_decay"
+                this_weight_decay = 0.
+            else:
+                group_name = "decay"
+                this_weight_decay = weight_decay
+            if 'visual_encoder' in name:
+                layer_id = self.visual_encoder.get_num_layer(name.replace('visual_encoder.',''))
+                group_name = "vit_layer_%d_%s" % (layer_id, group_name)
+            else:
+                layer_id = None
+
+            if group_name not in parameter_group_names:
+                if layer_id is not None:
+                    scale = lr_scales[layer_id]
+                else:
+                    scale = 1
+                parameter_group_names[group_name] = {
+                    "weight_decay": this_weight_decay,
+                    "params": [],
+                    "lr_scale": scale
+                }
+                parameter_group_vars[group_name] = {
+                    "weight_decay": this_weight_decay,
+                    "params": [],
+                    "lr_scale": scale
+                }
+            parameter_group_vars[group_name]["params"].append(param)
+            parameter_group_names[group_name]["params"].append(name)
+        # import json
+        # print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
+        optim_params = list(parameter_group_vars.values())
+        return optim_params
+
+    def _lemmatize(self, answers):
+        def apply(answer):
+            doc = self.lemmatizer(answer)
+
+            words = []
+            for token in doc:
+                if token.pos_ in ["NOUN", "VERB"]:
+                    words.append(token.lemma_)
+                else:
+                    words.append(token.text)
+            answer = " ".join(words)
+
+            return answer
+
+        return [apply(answer) for answer in answers]
+
+    @property
+    def lemmatizer(self):
+        if self._lemmatizer is None:
+            try:
+                import spacy
+
+                self._lemmatizer = spacy.load("en_core_web_sm")
+            except ImportError:
+                logging.error(
+                    """
+                    Please install spacy and en_core_web_sm model to apply lemmatization.
+                    python -m spacy download en_core_web_sm
+                    OR
+                    import spacy.cli
+                    spacy.cli.download("en_core_web_sm")
+                    """
+                )
+                exit(1)
+
+        return self._lemmatizer
+
+def disabled_train(self, mode=True):
+    """Overwrite model.train with this function to make sure train/eval mode
+    does not change anymore."""
+    return self
+
+
+class LayerNorm(nn.LayerNorm):
+    """Subclass torch's LayerNorm to handle fp16."""
+
+    def forward(self, x: torch.Tensor):
+        orig_type = x.dtype
+        ret = super().forward(x.type(torch.float32))
+        return ret.type(orig_type)
+
+
+def compute_sim_matrix(model, data_loader, **kwargs):
+    k_test = kwargs.pop("k_test")
+
+    metric_logger = MetricLogger(delimiter="  ")
+    header = "Evaluation:"
+
+    logging.info("Computing features for evaluation...")
+    start_time = time.time()
+
+    texts = data_loader.dataset.text
+    num_text = len(texts)
+    text_bs = 256
+    text_ids = []
+    text_embeds = []
+    text_atts = []
+    for i in range(0, num_text, text_bs):
+        text = texts[i : min(num_text, i + text_bs)]
+        text_input = model.tokenizer(
+            text,
+            padding="max_length",
+            truncation=True,
+            max_length=35,
+            return_tensors="pt",
+        ).to(model.device)
+        text_feat = model.forward_text(text_input)
+        text_embed = F.normalize(model.text_proj(text_feat))
+        text_embeds.append(text_embed)
+        text_ids.append(text_input.input_ids)
+        text_atts.append(text_input.attention_mask)
+
+    text_embeds = torch.cat(text_embeds, dim=0)
+    text_ids = torch.cat(text_ids, dim=0)
+    text_atts = torch.cat(text_atts, dim=0)
+
+    vit_feats = []
+    image_embeds = []
+    for samples in data_loader:
+        image = samples["image"]
+
+        image = image.to(model.device)
+        image_feat, vit_feat = model.forward_image(image)
+        image_embed = model.vision_proj(image_feat)
+        image_embed = F.normalize(image_embed, dim=-1)
+
+        vit_feats.append(vit_feat.cpu())
+        image_embeds.append(image_embed)
+
+    vit_feats = torch.cat(vit_feats, dim=0)
+    image_embeds = torch.cat(image_embeds, dim=0)
+
+    sims_matrix = []
+    for image_embed in image_embeds:
+        sim_q2t = image_embed @ text_embeds.t()
+        sim_i2t, _ = sim_q2t.max(0)
+        sims_matrix.append(sim_i2t)
+    sims_matrix = torch.stack(sims_matrix, dim=0)
+
+    score_matrix_i2t = torch.full(
+        (len(data_loader.dataset.image), len(texts)), -100.0
+    ).to(model.device)
+
+    num_tasks = dist_utils.get_world_size()
+    rank = dist_utils.get_rank()
+    step = sims_matrix.size(0) // num_tasks + 1
+    start = rank * step
+    end = min(sims_matrix.size(0), start + step)
+
+    for i, sims in enumerate(
+        metric_logger.log_every(sims_matrix[start:end], 50, header)
+    ):
+        topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
+        image_inputs = vit_feats[start + i].repeat(k_test, 1, 1).to(model.device)
+        score = model.compute_itm(
+            image_inputs=image_inputs,
+            text_ids=text_ids[topk_idx],
+            text_atts=text_atts[topk_idx],
+        ).float()
+        score_matrix_i2t[start + i, topk_idx] = score + topk_sim
+
+    sims_matrix = sims_matrix.t()
+    score_matrix_t2i = torch.full(
+        (len(texts), len(data_loader.dataset.image)), -100.0
+    ).to(model.device)
+
+    step = sims_matrix.size(0) // num_tasks + 1
+    start = rank * step
+    end = min(sims_matrix.size(0), start + step)
+
+    for i, sims in enumerate(
+        metric_logger.log_every(sims_matrix[start:end], 50, header)
+    ):
+        topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
+        image_inputs = vit_feats[topk_idx.cpu()].to(model.device)
+        score = model.compute_itm(
+            image_inputs=image_inputs,
+            text_ids=text_ids[start + i].repeat(k_test, 1),
+            text_atts=text_atts[start + i].repeat(k_test, 1),
+        ).float()
+        score_matrix_t2i[start + i, topk_idx] = score + topk_sim
+
+    if dist_utils.is_dist_avail_and_initialized():
+        dist.barrier()
+        torch.distributed.all_reduce(
+            score_matrix_i2t, op=torch.distributed.ReduceOp.SUM
+        )
+        torch.distributed.all_reduce(
+            score_matrix_t2i, op=torch.distributed.ReduceOp.SUM
+        )
+
+    total_time = time.time() - start_time
+    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+    logging.info("Evaluation time {}".format(total_time_str))
+
+    return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy()
diff --git a/models/q_formers/blip2_qformer.py b/models/q_formers/blip2_qformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..0127de2ef60bf8707e7e00810aa8115b5a9f8c5f
--- /dev/null
+++ b/models/q_formers/blip2_qformer.py
@@ -0,0 +1,651 @@
+"""
+ Copyright (c) 2023, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+import logging
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from typing import Optional, Tuple, List
+from torch.cuda.amp import autocast as autocast
+from torch.nn import functional as F
+
+from lavis.common.registry import registry
+from lavis.models.base_model import all_gather_with_grad, concat_all_gather
+from lavis.models.blip2_models.blip2 import (
+    compute_sim_matrix,
+    disabled_train,
+)
+from lavis.models.blip_models.blip_outputs import BlipOutput
+from transformers.modeling_outputs import ModelOutput
+
+from models.q_formers.blip2 import Blip2Base
+from models.q_formers.position_encoding import PositionEmbeddings
+from ldm.modules.diffusionmodules.util import conv_nd
+
+import time
+
+
+class BlipOutputFeatures(ModelOutput):
+    """
+    Data class of features from BlipFeatureExtractor.
+
+    Args:
+        image_embeds: (torch.FloatTensor) of shape (batch_size, num_patches+1, embed_dim), optional
+        image_features: (torch.FloatTensor) of shape (batch_size, num_patches+1, feature_dim), optional
+        text_embeds: (torch.FloatTensor) of shape (batch_size, sequence_length+1, embed_dim), optional
+        text_features: (torch.FloatTensor) of shape (batch_size, sequence_length+1, feature_dim), optional
+
+        The first embedding or feature is for the [CLS] token.
+
+        Features are obtained by projecting the corresponding embedding into a normalized low-dimensional space.
+    """
+
+    image_embeds: Optional[torch.FloatTensor] = None
+    image_embeds_proj: Optional[torch.FloatTensor] = None
+
+    text_embeds: Optional[torch.FloatTensor] = None
+    text_embeds_proj: Optional[torch.FloatTensor] = None
+
+    multimodal_embeds: Optional[torch.FloatTensor] = None
+
+    hidden_states: List[torch.FloatTensor] = None
+
+    attentions: List[torch.FloatTensor] = None
+    cross_attentions: List[torch.FloatTensor] = None
+
+
+class Blip2Qformer(Blip2Base):
+    """
+    BLIP2 first-stage model with Q-former and ViT.
+    Supported model types:
+        - pretrained: pretrained model with vit-g
+        - pretrain_vitL: pretrained model with vit-large
+        - coco: fintuned model on coco
+    Usage:
+        >>> from lavis.models import load_model
+        >>> model = load_model("blip2", "pretrain")
+    """
+
+    PRETRAINED_MODEL_CONFIG_DICT = {
+        "pretrain": "configs/models/blip2/blip2_pretrain.yaml",
+        "pretrain_vitL": "configs/models/blip2/blip2_pretrain_vitL.yaml",
+        "coco": "configs/models/blip2/blip2_coco.yaml",
+    }
+
+    def __init__(
+        self,
+        model_name="bert-base-uncased",
+        vit_model="eva_clip_g",
+        img_size=224,
+        drop_path_rate=0,
+        head_dropout=0,
+        use_grad_checkpoint=False,
+        vit_precision="fp16",
+        freeze_vit=True,
+        num_query_token=32,
+        cross_attention_freq=2,
+        embed_dim=256,
+        max_txt_len=32,
+        query_token_init_type='normal',
+        max_position_embeddings=512,
+        multilevels=[],
+    ):
+        super().__init__()
+
+        self.num_query_token = num_query_token
+
+        self.tokenizer = self.init_tokenizer(model_name)
+
+        self.visual_encoder, self.ln_vision = self.init_vision_encoder(
+            vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision, len(multilevels),
+        )
+        self.multilevels = multilevels
+
+        self.crossattn_embeddings = PositionEmbeddings(max_position_embeddings, self.visual_encoder.num_features) 
+
+        self.Qformer, self.query_tokens = self.init_Qformer(
+            num_query_token, self.visual_encoder.num_features, model_name, head_dropout, cross_attention_freq, query_token_init_type,
+        )
+        self.Qformer.resize_token_embeddings(len(self.tokenizer))
+        state_dict = self.Qformer.state_dict()
+        for name, param in self.Qformer.named_parameters():
+            if "_query" in name:
+                key_orig = name.replace("_query", "")
+                param.data.copy_(state_dict[key_orig])
+
+        self.vision_proj = nn.Linear(self.Qformer.config.hidden_size, embed_dim)
+        self.text_proj = nn.Linear(self.Qformer.config.hidden_size, embed_dim)
+        self.itm_head = nn.Linear(self.Qformer.config.hidden_size, 2)
+        self.temp = nn.Parameter(0.07 * torch.ones([]))
+        self.max_txt_len = max_txt_len
+        self.visual_encoder.requires_grad_(False)
+
+        for name, param in self.Qformer.named_parameters():
+            if 'crossattention' in name:
+                param.requires_grad = True
+            else:
+                param.requires_grad = False
+
+        del self.Qformer.cls
+        del self.vision_proj
+        del self.text_proj
+        del self.itm_head
+        del self.temp
+        
+    def forward(self, samples):
+        image = samples["image"]
+        text = samples["text_input"]
+
+        image_embeds = self.ln_vision(self.visual_encoder(image))
+        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
+            image.device
+        )
+
+        query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
+
+        query_output = self.Qformer.bert(
+            query_embeds=query_tokens,
+            encoder_hidden_states=image_embeds,
+            encoder_attention_mask=image_atts,
+            use_cache=True,
+            return_dict=True,
+        )
+
+        image_feats = F.normalize(
+            self.vision_proj(query_output.last_hidden_state), dim=-1
+        )
+
+        text_tokens = self.tokenizer(
+            text,
+            padding="max_length",
+            truncation=True,
+            max_length=self.max_txt_len,
+            return_tensors="pt",
+        ).to(image.device)
+        text_output = self.Qformer.bert(
+            text_tokens.input_ids,
+            attention_mask=text_tokens.attention_mask,
+            return_dict=True,
+        )
+        text_feat = F.normalize(
+            self.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1
+        )
+
+        ###============== Image-text Contrastive ===================###
+        image_feats_all = concat_all_gather(
+            image_feats
+        )  # [batch_size*num_gpu, num_query_tokens, embed_dim]
+        text_feat_all = concat_all_gather(text_feat)  # [batch_size*num_gpu, embed_dim]
+
+        sim_q2t = torch.matmul(
+            image_feats.unsqueeze(1), text_feat_all.unsqueeze(-1)
+        ).squeeze()
+        # [batch_size, batch_size*num_gpu, num_query_tokens]
+
+        # image-text similarity: aggregate across all query tokens
+        sim_i2t, _ = sim_q2t.max(-1)
+        sim_i2t = sim_i2t / self.temp
+
+        # text-query similarity: [batch_size, batch_size*num_gpu, num_query_tokens]
+        sim_t2q = torch.matmul(
+            text_feat.unsqueeze(1).unsqueeze(1), image_feats_all.permute(0, 2, 1)
+        ).squeeze()
+
+        # text-image similarity: aggregate across all query tokens
+        sim_t2i, _ = sim_t2q.max(-1)
+        sim_t2i = sim_t2i / self.temp  # [batch_size, batch_size*num_gpu]
+
+        rank = dist.get_rank()
+        bs = image.size(0)
+        targets = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=int).to(
+            image.device
+        )
+
+        if "image_id" in samples.keys(): #coco retrieval finetuning
+            image_ids = samples["image_id"].view(-1,1)
+            image_ids_all = concat_all_gather(image_ids)
+            pos_idx = torch.eq(image_ids, image_ids_all.t()).float()       
+            sim_targets = pos_idx / pos_idx.sum(1,keepdim=True)   
+            sim_targets = 0.9 * sim_targets + 0.1 * torch.ones_like(sim_targets) / sim_targets.size(1)
+
+            loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_targets,dim=1).mean()
+            loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_targets,dim=1).mean()     
+            loss_itc = (loss_t2i+loss_i2t)/2  
+        else:                     
+            loss_itc = (
+                F.cross_entropy(sim_i2t, targets, label_smoothing=0.1)
+                + F.cross_entropy(sim_t2i, targets, label_smoothing=0.1)
+            ) / 2
+
+        ###============== Image-text Matching ===================###
+        text_input_ids_world = concat_all_gather(text_tokens.input_ids)
+        text_attention_mask_world = concat_all_gather(text_tokens.attention_mask)
+        image_embeds_world = all_gather_with_grad(image_embeds)
+        with torch.no_grad():
+            if "image_id" in samples.keys():
+                mask = torch.eq(image_ids, image_ids_all.t())
+                sim_t2i.masked_fill_(mask, -10000)
+                sim_i2t.masked_fill_(mask, -10000)
+            else:    
+                sim_t2i[:, rank * bs : rank * bs + bs].fill_diagonal_(-10000)
+                sim_i2t[:, rank * bs : rank * bs + bs].fill_diagonal_(-10000)            
+                
+            weights_t2i = F.softmax(sim_t2i, dim=1)
+            weights_i2t = F.softmax(sim_i2t, dim=1)
+
+        # select a negative image for each text
+        image_embeds_neg = []
+        for b in range(bs):
+            neg_idx = torch.multinomial(weights_t2i[b], 1).item()
+            image_embeds_neg.append(image_embeds_world[neg_idx])
+        image_embeds_neg = torch.stack(image_embeds_neg, dim=0)
+
+        # select a negative text for each image
+        text_ids_neg = []
+        text_atts_neg = []
+        for b in range(bs):
+            neg_idx = torch.multinomial(weights_i2t[b], 1).item()
+            text_ids_neg.append(text_input_ids_world[neg_idx])
+            text_atts_neg.append(text_attention_mask_world[neg_idx])
+
+        text_ids_neg = torch.stack(text_ids_neg, dim=0)
+        text_atts_neg = torch.stack(text_atts_neg, dim=0)
+
+        text_ids_all = torch.cat(
+            [text_tokens.input_ids, text_tokens.input_ids, text_ids_neg], dim=0
+        )  # pos, pos, neg
+        text_atts_all = torch.cat(
+            [text_tokens.attention_mask, text_tokens.attention_mask, text_atts_neg],
+            dim=0,
+        )
+
+        query_tokens_itm = self.query_tokens.expand(text_ids_all.shape[0], -1, -1)
+        query_atts_itm = torch.ones(query_tokens_itm.size()[:-1], dtype=torch.long).to(
+            image.device
+        )
+        attention_mask_all = torch.cat([query_atts_itm, text_atts_all], dim=1)
+
+        image_embeds_all = torch.cat(
+            [image_embeds, image_embeds_neg, image_embeds], dim=0
+        )  # pos, neg, pos
+        image_atts_all = torch.ones(image_embeds_all.size()[:-1], dtype=torch.long).to(
+            image.device
+        )
+
+        output_itm = self.Qformer.bert(
+            text_ids_all,
+            query_embeds=query_tokens_itm,
+            attention_mask=attention_mask_all,
+            encoder_hidden_states=image_embeds_all,
+            encoder_attention_mask=image_atts_all,
+            return_dict=True,
+        )
+
+        vl_embeddings = output_itm.last_hidden_state[:, : query_tokens_itm.size(1), :]
+        vl_output = self.itm_head(vl_embeddings)
+        logits = vl_output.mean(dim=1)
+
+        itm_labels = torch.cat(
+            [torch.ones(bs, dtype=torch.long), torch.zeros(2 * bs, dtype=torch.long)],
+            dim=0,
+        ).to(image.device)
+        loss_itm = F.cross_entropy(logits, itm_labels)
+
+        ##================= Image Captioning ========================##
+        decoder_input_ids = text_tokens.input_ids.clone()
+        decoder_input_ids[:, 0] = self.tokenizer.bos_token_id
+        labels = decoder_input_ids.masked_fill(
+            decoder_input_ids == self.tokenizer.pad_token_id, -100
+        )
+
+        query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(
+            image.device
+        )
+        attention_mask = torch.cat([query_atts, text_tokens.attention_mask], dim=1)
+        lm_output = self.Qformer(
+            decoder_input_ids,
+            attention_mask=attention_mask,
+            past_key_values=query_output.past_key_values,
+            return_dict=True,
+            labels=labels,
+        )
+
+        loss_lm = lm_output.loss
+
+        return BlipOutput(
+            loss=loss_itc + loss_itm + loss_lm,
+            loss_itc=loss_itc,
+            loss_itm=loss_itm,
+            loss_lm=loss_lm,
+        )
+
+    @torch.no_grad()
+    def generate(
+        self,
+        samples,
+        use_nucleus_sampling=False,
+        num_beams=3,
+        max_length=30,
+        min_length=10,
+        top_p=0.9,
+        repetition_penalty=1.0,
+    ):
+        """
+        Args:
+            samples (dict): A dictionary containing the following keys:
+                - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W)
+            use_nucleus_sampling (bool): Whether to use nucleus sampling. If False, use top-k sampling.
+            num_beams (int): Number of beams for beam search. 1 means no beam search.
+            max_length (int): The maximum length of the sequence to be generated.
+            min_length (int): The minimum length of the sequence to be generated.
+            top_p (float): The cumulative probability for nucleus sampling.
+            repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty.
+            num_captions (int): Number of captions to be generated for each image.
+        Returns:
+            captions (list): A list of strings of length batch_size * num_captions.
+        """
+        image = samples["image"]
+        image_embeds = self.ln_vision(self.visual_encoder(image))
+
+        if not use_nucleus_sampling:
+            image_embeds = image_embeds.repeat_interleave(num_beams, dim=0)
+        else:
+            num_beams = 1
+        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
+            image.device
+        )
+
+        model_kwargs = {
+            "encoder_hidden_states": image_embeds,
+            "encoder_attention_mask": image_atts,
+        }
+
+        input_ids = (
+            torch.LongTensor(image.size(0), 1)
+            .fill_(self.tokenizer.bos_token_id)
+            .to(image.device)
+        )
+        query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
+
+        outputs = self.Qformer.generate(
+            input_ids=input_ids,
+            query_embeds=query_tokens,
+            max_length=max_length,
+            min_length=min_length,
+            num_beams=num_beams,
+            do_sample=use_nucleus_sampling,
+            top_p=top_p,
+            eos_token_id=self.tokenizer.sep_token_id,
+            pad_token_id=self.tokenizer.pad_token_id,
+            **model_kwargs
+        )
+        captions = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
+        return captions
+
+    def forward_visual_encoder(self, image):
+        with torch.no_grad():
+            with self.maybe_autocast():
+                image_embeds_frozen = self.visual_encoder(image, output_hidden_states=True)
+        image_embeds_frozen = [ln(image_embeds_frozen[lvl]) for lvl, ln in zip(self.multilevels, self.ln_vision)]
+        image_embeds_frozen = [image_embed.float() for image_embed in image_embeds_frozen]
+        image_atts = [torch.ones(
+            image_embed.size()[:-1], dtype=torch.long
+        ).to(self.device) for image_embed in image_embeds_frozen]
+        return image_embeds_frozen, image_atts
+
+    def forward_qformer(self, caption, image_embeds_frozen, image_atts, output_hidden_states=False):
+        query_tokens = self.query_tokens.expand(
+            image_embeds_frozen.shape[0], -1, -1
+        )
+        query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(
+            self.device
+        )
+        text = self.tokenizer(caption, return_tensors="pt", padding=True, truncation=True).to(
+            self.device
+        )
+        attention_mask = torch.cat([query_atts, text.attention_mask], dim=1)
+        query_pos_embeds = self.query_tokens.repeat(image_embeds_frozen.shape[0], 1, 1)
+
+        output = self.Qformer.bert(
+            text.input_ids,
+            query_embeds=query_tokens,
+            attention_mask=attention_mask,
+            encoder_hidden_states=image_embeds_frozen,
+            encoder_attention_mask=image_atts,
+            query_pos_embeds=query_pos_embeds,
+            output_hidden_states=output_hidden_states,
+            return_dict=True,
+        )
+
+        hidden_states = [feat[:, : query_tokens.size(1), :] for feat in output.hidden_states]
+
+        return hidden_states
+
+    def forward_qformer(self, caption, image_embeds_frozen, image_atts):
+        bs = image_embeds_frozen[0].shape[0]
+
+        query_tokens = self.query_tokens.expand(bs, -1, -1)
+        query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(self.device)
+        text = self.tokenizer(['']*len(caption), return_tensors="pt", padding=True, truncation=True, max_length=512).to(
+            self.device
+        )
+
+        attention_mask = torch.cat([query_atts, text.attention_mask], dim=1)
+        query_pos_embeds = self.query_tokens.repeat(bs, 1, 1)
+
+        output = self.Qformer.bert(
+            text.input_ids,
+            query_embeds=query_tokens,
+            attention_mask=attention_mask,
+            encoder_hidden_states=image_embeds_frozen,
+            encoder_attention_mask=image_atts,
+            query_pos_embeds=query_pos_embeds,
+            output_hidden_states=True,
+            return_dict=True,
+        )
+
+        hidden_states = [feat[:, : query_tokens.size(1), :] for feat in output.hidden_states]
+        return hidden_states
+
+    def forward_image(self, image):
+        image_embeds = self.ln_vision(self.visual_encoder(image))
+        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
+            image.device
+        )
+
+        query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
+
+        query_output = self.Qformer.bert(
+            query_embeds=query_tokens,
+            encoder_hidden_states=image_embeds,
+            encoder_attention_mask=image_atts,
+            return_dict=True,
+        )
+        return query_output.last_hidden_state, image_embeds
+
+    def forward_text(self, text_tokens):
+        text_output = self.Qformer.bert(
+            text_tokens.input_ids,
+            attention_mask=text_tokens.attention_mask,
+            return_dict=True,
+        )
+        return text_output.last_hidden_state[:, 0, :]
+
+    def compute_itm(self, image_inputs, text_ids, text_atts):
+        image_atts = torch.ones(image_inputs.size()[:-1], dtype=torch.long).to(
+            image_inputs.device
+        )
+        query_tokens = self.query_tokens.expand(image_inputs.shape[0], -1, -1)
+        query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(
+            image_inputs.device
+        )
+        attention_mask = torch.cat([query_atts, text_atts], dim=1)
+        output_itm = self.Qformer.bert(
+            text_ids,
+            query_embeds=query_tokens,
+            attention_mask=attention_mask,
+            encoder_hidden_states=image_inputs,
+            encoder_attention_mask=image_atts,
+            return_dict=True,
+        )
+        vl_embeddings = output_itm.last_hidden_state[:, : query_tokens.size(1), :]
+        itm_logit = self.itm_head(vl_embeddings)
+        itm_logit = itm_logit[:, :, 1].mean(dim=1)
+        return itm_logit
+
+    @torch.no_grad()
+    def extract_features(self, samples, mode="multimodal"):
+        """
+        Extract features for multimodal or unimodal samples.
+        Args:
+            samples (dict): A dictionary of samples, containing the following keys:
+                - image (torch.Tensor): A tensor of shape (B, C, H, W) containing the image.
+                    Raw images should be preprocessed before being passed to feature extractor.
+                - text_input (list): A list of strings containing the text, length B.
+            mode (str): The mode of feature extraction. Can be either "multimodal", "text" or "image".
+                If "multimodal", return image features and multimodal features;
+                if "text", return text features;
+                if "image", return image features.
+                Default: "multimodal".
+        Returns:
+            BlipOutputFeatures: A BlipOutputFeatures object containing the features.
+                See lavis/models/blip_models/blip_outputs.py for more details.
+        """
+        image = samples.get("image")
+        caption = samples.get("text_input")
+
+        # assert mode is one of "image", "text", "multimodal"
+        assert mode in [
+            "image",
+            "text",
+            "multimodal",
+        ], "mode must be one of 'image', 'text', 'multimodal'"
+
+        # initalize output
+        image_embeds, text_embeds, multimodal_embeds = None, None, None
+        image_features, text_features = None, None
+
+        if mode == "image":
+            assert (
+                image is not None
+            ), "Image is not provided for mode 'image' or 'multimodal'"
+            # return query features
+            with self.maybe_autocast():
+                image_embeds_frozen = self.ln_vision(self.visual_encoder(image))
+            image_embeds_frozen = image_embeds_frozen.float()
+            image_atts = torch.ones(
+                image_embeds_frozen.size()[:-1], dtype=torch.long
+            ).to(self.device)
+            query_tokens = self.query_tokens.expand(
+                image_embeds_frozen.shape[0], -1, -1
+            )
+
+            query_output = self.Qformer.bert(
+                query_embeds=query_tokens,
+                encoder_hidden_states=image_embeds_frozen,
+                encoder_attention_mask=image_atts,
+                return_dict=True,
+            )
+
+            image_embeds = query_output.last_hidden_state
+            image_features = F.normalize(self.vision_proj(image_embeds), dim=-1)
+
+        elif mode == "text":
+            assert (
+                caption is not None
+            ), "text input is None for mode 'text' or 'multimodal'"
+
+            # return text features
+            text = self.tokenizer(caption, return_tensors="pt", padding=True).to(
+                self.device
+            )
+
+            text_output = self.Qformer.bert(
+                text.input_ids,
+                attention_mask=text.attention_mask,
+                return_dict=True,
+            )
+
+            text_embeds = text_output.last_hidden_state
+            text_features = self.text_proj(text_embeds)
+            text_features = F.normalize(text_features, dim=-1)
+
+        elif mode == "multimodal":
+            # return multimodel query features
+            with self.maybe_autocast():
+                image_embeds_frozen = self.ln_vision(self.visual_encoder(image))
+            image_embeds_frozen = image_embeds_frozen.float()
+            image_atts = torch.ones(
+                image_embeds_frozen.size()[:-1], dtype=torch.long
+            ).to(self.device)
+            query_tokens = self.query_tokens.expand(
+                image_embeds_frozen.shape[0], -1, -1
+            )
+            query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(
+                self.device
+            )
+
+            text = self.tokenizer(caption, return_tensors="pt", padding=True).to(
+                self.device
+            )
+            attention_mask = torch.cat([query_atts, text.attention_mask], dim=1)
+
+            output = self.Qformer.bert(
+                text.input_ids,
+                query_embeds=query_tokens,
+                attention_mask=attention_mask,
+                encoder_hidden_states=image_embeds_frozen,
+                encoder_attention_mask=image_atts,
+                return_dict=True,
+            )
+
+            multimodal_embeds = output.last_hidden_state[:, : query_tokens.size(1), :]
+
+        return BlipOutputFeatures(
+            image_embeds=image_embeds,
+            image_embeds_proj=image_features,
+            text_embeds=text_embeds,
+            text_embeds_proj=text_features,
+            multimodal_embeds=multimodal_embeds,
+        )
+
+    @classmethod
+    def from_config(cls, cfg):
+        vit_model = cfg.get("vit_model", "eva_clip_g")
+        img_size = cfg.get("image_size")
+        num_query_token = cfg.get("num_query_token")
+        cross_attention_freq = cfg.get("cross_attention_freq", 2)
+
+        drop_path_rate = cfg.get("drop_path_rate", 0)
+        use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
+        vit_precision = cfg.get("vit_precision", "fp16")
+        freeze_vit = cfg.get("freeze_vit", True)
+
+        max_txt_len = cfg.get("max_txt_len", 32)
+
+        model = cls(
+            vit_model=vit_model,
+            img_size=img_size,
+            drop_path_rate=drop_path_rate,
+            use_grad_checkpoint=use_grad_checkpoint,
+            vit_precision=vit_precision,
+            freeze_vit=freeze_vit,
+            num_query_token=num_query_token,
+            cross_attention_freq=cross_attention_freq,
+            max_txt_len=max_txt_len,
+        )
+        model.load_checkpoint_from_config(cfg)
+
+        return model
+
+    def compute_sim_matrix(self, data_loader, task_cfg):
+        """
+        Compute similarity i2t, t2i matrix for the given data loader.
+        """
+        k_test = task_cfg.k_test
+
+        return compute_sim_matrix(model=self, data_loader=data_loader, k_test=k_test)
diff --git a/models/q_formers/clip_vit.py b/models/q_formers/clip_vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..90113a2d6f934963a173dea7dfa7ed3281329700
--- /dev/null
+++ b/models/q_formers/clip_vit.py
@@ -0,0 +1,271 @@
+from collections import OrderedDict
+from itertools import repeat
+import collections.abc
+import math
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
+
+from lavis.models.eva_vit import convert_weights_to_fp16
+from lavis.common.dist_utils import download_cached_file
+
+class Bottleneck(nn.Module):
+    expansion = 4
+
+    def __init__(self, inplanes, planes, stride=1):
+        super().__init__()
+
+        # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
+        self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
+        self.bn1 = nn.BatchNorm2d(planes)
+        self.relu1 = nn.ReLU(inplace=True)
+
+        self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
+        self.bn2 = nn.BatchNorm2d(planes)
+        self.relu2 = nn.ReLU(inplace=True)
+
+        self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
+
+        self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
+        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+        self.relu3 = nn.ReLU(inplace=True)
+
+        self.downsample = None
+        self.stride = stride
+
+        if stride > 1 or inplanes != planes * Bottleneck.expansion:
+            # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
+            self.downsample = nn.Sequential(OrderedDict([
+                ("-1", nn.AvgPool2d(stride)),
+                ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
+                ("1", nn.BatchNorm2d(planes * self.expansion))
+            ]))
+
+    def forward(self, x: torch.Tensor):
+        identity = x
+
+        out = self.relu1(self.bn1(self.conv1(x)))
+        out = self.relu2(self.bn2(self.conv2(out)))
+        out = self.avgpool(out)
+        out = self.bn3(self.conv3(out))
+
+        if self.downsample is not None:
+            identity = self.downsample(x)
+
+        out += identity
+        out = self.relu3(out)
+        return out
+
+
+class AttentionPool2d(nn.Module):
+    def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
+        super().__init__()
+        self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
+        self.k_proj = nn.Linear(embed_dim, embed_dim)
+        self.q_proj = nn.Linear(embed_dim, embed_dim)
+        self.v_proj = nn.Linear(embed_dim, embed_dim)
+        self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
+        self.num_heads = num_heads
+
+    def forward(self, x):
+        x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1)  # NCHW -> (HW)NC
+        x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0)  # (HW+1)NC
+        x = x + self.positional_embedding[:, None, :].to(x.dtype)  # (HW+1)NC
+        x, _ = F.multi_head_attention_forward(
+            query=x, key=x, value=x,
+            embed_dim_to_check=x.shape[-1],
+            num_heads=self.num_heads,
+            q_proj_weight=self.q_proj.weight,
+            k_proj_weight=self.k_proj.weight,
+            v_proj_weight=self.v_proj.weight,
+            in_proj_weight=None,
+            in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
+            bias_k=None,
+            bias_v=None,
+            add_zero_attn=False,
+            dropout_p=0,
+            out_proj_weight=self.c_proj.weight,
+            out_proj_bias=self.c_proj.bias,
+            use_separate_proj_weight=True,
+            training=self.training,
+            need_weights=False
+        )
+
+        return x[0]
+    
+
+class LayerNorm(nn.LayerNorm):
+    """Subclass torch's LayerNorm to handle fp16."""
+
+    def forward(self, x: torch.Tensor):
+        orig_type = x.dtype
+        ret = super().forward(x.type(torch.float32))
+        return ret.type(orig_type)
+
+
+class QuickGELU(nn.Module):
+    def forward(self, x: torch.Tensor):
+        return x * torch.sigmoid(1.702 * x)
+
+
+class ResidualAttentionBlock(nn.Module):
+    def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, use_grad_checkpointing=False):
+        super().__init__()
+
+        self.attn = nn.MultiheadAttention(d_model, n_head)
+        self.ln_1 = LayerNorm(d_model)
+        self.mlp = nn.Sequential(OrderedDict([
+            ("c_fc", nn.Linear(d_model, d_model * 4)),
+            ("gelu", QuickGELU()),
+            ("c_proj", nn.Linear(d_model * 4, d_model))
+        ]))
+        self.ln_2 = LayerNorm(d_model)
+        self.attn_mask = attn_mask
+
+        if use_grad_checkpointing:
+            self.attn = checkpoint_wrapper(self.attn)
+            self.mlp = checkpoint_wrapper(self.mlp)
+            
+    def attention(self, x: torch.Tensor):
+        self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
+        return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
+
+    def forward(self, x: torch.Tensor):
+        x = x + self.attention(self.ln_1(x))
+        x = x + self.mlp(self.ln_2(x))
+        return x
+
+
+class Transformer(nn.Module):
+    def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, use_grad_checkpointing=False):
+        super().__init__()
+        self.width = width
+        self.layers = layers
+        self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask, use_grad_checkpointing and i>12) for i in range(layers)])
+
+    def forward(self, x: torch.Tensor):
+        return self.resblocks(x)
+
+
+class VisionTransformer(nn.Module):
+    def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, use_grad_checkpointing: bool):
+        super().__init__()
+        self.input_resolution = input_resolution
+        self.num_features = width
+        self.num_heads = heads
+        self.num_patches = (input_resolution // patch_size) ** 2
+        self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
+
+        scale = width ** -0.5
+        self.class_embedding = nn.Parameter(scale * torch.randn(width))
+        self.positional_embedding = nn.Parameter(scale * torch.randn(self.num_patches + 1, width))
+        self.ln_pre = LayerNorm(width)
+        
+        self.transformer = Transformer(width, layers, heads, use_grad_checkpointing=use_grad_checkpointing)
+           
+#         self.ln_final = LayerNorm(width)
+
+    def forward(self, x: torch.Tensor, output_hidden_states=False):
+
+        x = self.conv1(x)  # shape = [*, width, grid, grid]
+        x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
+        x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
+        x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]
+        x = x + self.positional_embedding.to(x.dtype)
+        x = self.ln_pre(x)
+
+        hiddens = []
+        x = x.permute(1, 0, 2)  # NLD -> LND
+        # x = self.transformer(x)
+        for block in self.transformer.resblocks:
+            x = block(x)
+            hiddens.append(x.permute(1, 0, 2))
+        x = x.permute(1, 0, 2)  # LND -> NLD
+
+#         x = self.ln_final(x)
+
+        if output_hidden_states:
+            return hiddens
+        else:
+            return x
+    
+    def get_num_layer(self, var_name=""):
+        if var_name in ("class_embedding", "positional_embedding", "conv1", "ln_pre"):
+            return 0
+        elif var_name.startswith("transformer.resblocks"):
+            layer_id = int(var_name.split('.')[2])
+            return layer_id + 1
+        else:
+            return len(self.transformer.resblocks)    
+            
+            
+# From PyTorch internals
+def _ntuple(n):
+    def parse(x):
+        if isinstance(x, collections.abc.Iterable):
+            return x
+        return tuple(repeat(x, n))
+    return parse
+to_2tuple = _ntuple(2)    
+    
+def interpolate_pos_embed(model, state_dict, interpolation: str = 'bicubic', seq_dim=1):
+    # Rescale the grid of position embeddings when loading from state_dict
+    old_pos_embed = state_dict.get('positional_embedding', None)
+    
+    grid_size = round((model.positional_embedding.shape[0] - 1) ** 0.5)
+    if old_pos_embed is None:
+        return
+    grid_size = to_2tuple(grid_size)
+    extra_tokens = 1  # FIXME detect different token configs (ie no class token, or more)
+    new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
+    if new_seq_len == old_pos_embed.shape[0]:
+        return
+
+    if extra_tokens:
+        pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
+    else:
+        pos_emb_tok, pos_emb_img = None, old_pos_embed
+        
+    old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
+
+    print('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
+    pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
+    pos_emb_img = F.interpolate(
+        pos_emb_img,
+        size=grid_size,
+        mode=interpolation,
+        align_corners=True,
+    )
+    pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
+    if pos_emb_tok is not None:
+        new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
+    else:
+        new_pos_embed = pos_emb_img
+    state_dict['positional_embedding'] = new_pos_embed
+    
+    
+def create_clip_vit_L(img_size=224,use_checkpoint=False,precision="fp16"):
+    model = VisionTransformer(
+            input_resolution=img_size,
+            patch_size=14,
+            width=1024,
+            layers=23,
+            heads=16,
+            use_grad_checkpointing=use_checkpoint,
+        )         
+    url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/clip_vit_L.pth"
+    cached_file = download_cached_file(
+        url, check_hash=False, progress=True
+    )
+    state_dict = torch.load(cached_file, map_location="cpu")    
+    interpolate_pos_embed(model,state_dict)
+    
+    incompatible_keys = model.load_state_dict(state_dict, strict=False)
+    # print(incompatible_keys)
+    
+    if precision == "fp16":
+        convert_weights_to_fp16(model)
+    return model
diff --git a/models/q_formers/eva_vit.py b/models/q_formers/eva_vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..b25092a8a7ede4d2877929201c7382fb84d1691a
--- /dev/null
+++ b/models/q_formers/eva_vit.py
@@ -0,0 +1,461 @@
+# Based on EVA, BEIT, timm and DeiT code bases
+# https://github.com/baaivision/EVA
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm
+# https://github.com/microsoft/unilm/tree/master/beit
+# https://github.com/facebookresearch/deit/
+# https://github.com/facebookresearch/dino
+# --------------------------------------------------------'
+import math
+from functools import partial
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+from timm.models.layers import drop_path, to_2tuple, trunc_normal_
+from timm.models.registry import register_model
+
+from lavis.common.dist_utils import download_cached_file
+
+def _cfg(url='', **kwargs):
+    return {
+        'url': url,
+        'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
+        'crop_pct': .9, 'interpolation': 'bicubic',
+        'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
+        **kwargs
+    }
+
+
+class DropPath(nn.Module):
+    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
+    """
+    def __init__(self, drop_prob=None):
+        super(DropPath, self).__init__()
+        self.drop_prob = drop_prob
+
+    def forward(self, x):
+        return drop_path(x, self.drop_prob, self.training)
+    
+    def extra_repr(self) -> str:
+        return 'p={}'.format(self.drop_prob)
+
+
+class Mlp(nn.Module):
+    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+        super().__init__()
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+        self.fc1 = nn.Linear(in_features, hidden_features)
+        self.act = act_layer()
+        self.fc2 = nn.Linear(hidden_features, out_features)
+        self.drop = nn.Dropout(drop)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.act(x)
+        # x = self.drop(x)
+        # commit this for the orignal BERT implement 
+        x = self.fc2(x)
+        x = self.drop(x)
+        return x
+
+
+class Attention(nn.Module):
+    def __init__(
+            self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
+            proj_drop=0., window_size=None, attn_head_dim=None):
+        super().__init__()
+        self.num_heads = num_heads
+        head_dim = dim // num_heads
+        if attn_head_dim is not None:
+            head_dim = attn_head_dim
+        all_head_dim = head_dim * self.num_heads
+        self.scale = qk_scale or head_dim ** -0.5
+
+        self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
+        if qkv_bias:
+            self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
+            self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
+        else:
+            self.q_bias = None
+            self.v_bias = None
+
+        if window_size:
+            self.window_size = window_size
+            self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
+            self.relative_position_bias_table = nn.Parameter(
+                torch.zeros(self.num_relative_distance, num_heads))  # 2*Wh-1 * 2*Ww-1, nH
+            # cls to token & token 2 cls & cls to cls
+
+            # get pair-wise relative position index for each token inside the window
+            coords_h = torch.arange(window_size[0])
+            coords_w = torch.arange(window_size[1])
+            coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
+            coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
+            relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
+            relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
+            relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0
+            relative_coords[:, :, 1] += window_size[1] - 1
+            relative_coords[:, :, 0] *= 2 * window_size[1] - 1
+            relative_position_index = \
+                torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
+            relative_position_index[1:, 1:] = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
+            relative_position_index[0, 0:] = self.num_relative_distance - 3
+            relative_position_index[0:, 0] = self.num_relative_distance - 2
+            relative_position_index[0, 0] = self.num_relative_distance - 1
+
+            self.register_buffer("relative_position_index", relative_position_index)
+        else:
+            self.window_size = None
+            self.relative_position_bias_table = None
+            self.relative_position_index = None
+
+        self.attn_drop = nn.Dropout(attn_drop)
+        self.proj = nn.Linear(all_head_dim, dim)
+        self.proj_drop = nn.Dropout(proj_drop)
+
+    def forward(self, x, rel_pos_bias=None):
+        B, N, C = x.shape
+        qkv_bias = None
+        if self.q_bias is not None:
+            qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
+        # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+        qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
+        qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)
+
+        q = q * self.scale
+        attn = (q @ k.transpose(-2, -1))
+
+        if self.relative_position_bias_table is not None:
+            relative_position_bias = \
+                self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
+                    self.window_size[0] * self.window_size[1] + 1,
+                    self.window_size[0] * self.window_size[1] + 1, -1)  # Wh*Ww,Wh*Ww,nH
+            relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
+            attn = attn + relative_position_bias.unsqueeze(0)
+
+        if rel_pos_bias is not None:
+            attn = attn + rel_pos_bias
+        
+        attn = attn.softmax(dim=-1)
+        attn = self.attn_drop(attn)
+
+        x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
+        x = self.proj(x)
+        x = self.proj_drop(x)
+        return x
+
+
+class Block(nn.Module):
+
+    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+                 drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
+                 window_size=None, attn_head_dim=None):
+        super().__init__()
+        self.norm1 = norm_layer(dim)
+        self.attn = Attention(
+            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
+            attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)
+        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+        self.norm2 = norm_layer(dim)
+        mlp_hidden_dim = int(dim * mlp_ratio)
+        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+        if init_values is not None and init_values > 0:
+            self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
+            self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
+        else:
+            self.gamma_1, self.gamma_2 = None, None
+
+    def forward(self, x, rel_pos_bias=None):
+        if self.gamma_1 is None:
+            x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
+            x = x + self.drop_path(self.mlp(self.norm2(x)))
+        else:
+            x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
+            x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
+        return x
+
+
+class PatchEmbed(nn.Module):
+    """ Image to Patch Embedding
+    """
+    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
+        super().__init__()
+        img_size = to_2tuple(img_size)
+        patch_size = to_2tuple(patch_size)
+        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
+        self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
+        self.img_size = img_size
+        self.patch_size = patch_size
+        self.num_patches = num_patches
+
+        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+    def forward(self, x, **kwargs):
+        B, C, H, W = x.shape
+        # FIXME look at relaxing size constraints
+        assert H == self.img_size[0] and W == self.img_size[1], \
+            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+        x = self.proj(x).flatten(2).transpose(1, 2)
+        return x
+
+
+class RelativePositionBias(nn.Module):
+
+    def __init__(self, window_size, num_heads):
+        super().__init__()
+        self.window_size = window_size
+        self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
+        self.relative_position_bias_table = nn.Parameter(
+            torch.zeros(self.num_relative_distance, num_heads))  # 2*Wh-1 * 2*Ww-1, nH
+        # cls to token & token 2 cls & cls to cls
+
+        # get pair-wise relative position index for each token inside the window
+        coords_h = torch.arange(window_size[0])
+        coords_w = torch.arange(window_size[1])
+        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
+        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
+        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
+        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
+        relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0
+        relative_coords[:, :, 1] += window_size[1] - 1
+        relative_coords[:, :, 0] *= 2 * window_size[1] - 1
+        relative_position_index = \
+            torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
+        relative_position_index[1:, 1:] = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
+        relative_position_index[0, 0:] = self.num_relative_distance - 3
+        relative_position_index[0:, 0] = self.num_relative_distance - 2
+        relative_position_index[0, 0] = self.num_relative_distance - 1
+
+        self.register_buffer("relative_position_index", relative_position_index)
+
+        # trunc_normal_(self.relative_position_bias_table, std=.02)
+
+    def forward(self):
+        relative_position_bias = \
+            self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
+                self.window_size[0] * self.window_size[1] + 1,
+                self.window_size[0] * self.window_size[1] + 1, -1)  # Wh*Ww,Wh*Ww,nH
+        return relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
+
+
+class VisionTransformer(nn.Module):
+    """ Vision Transformer with support for patch or hybrid CNN input stage
+    """
+    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
+                 num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
+                 drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None,
+                 use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False,
+                 use_mean_pooling=True, init_scale=0.001, use_checkpoint=False):
+        super().__init__()
+        self.image_size = img_size
+        self.num_classes = num_classes
+        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
+
+        self.patch_embed = PatchEmbed(
+            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+        num_patches = self.patch_embed.num_patches
+
+        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+        if use_abs_pos_emb:
+            self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
+        else:
+            self.pos_embed = None
+        self.pos_drop = nn.Dropout(p=drop_rate)
+
+        if use_shared_rel_pos_bias:
+            self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
+        else:
+            self.rel_pos_bias = None
+        self.use_checkpoint = use_checkpoint
+        
+        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
+        self.use_rel_pos_bias = use_rel_pos_bias
+        self.blocks = nn.ModuleList([
+            Block(
+                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
+                init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)
+            for i in range(depth)])
+#         self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
+#         self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
+#         self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+        if self.pos_embed is not None:
+            trunc_normal_(self.pos_embed, std=.02)
+        trunc_normal_(self.cls_token, std=.02)
+        # trunc_normal_(self.mask_token, std=.02)
+#         if isinstance(self.head, nn.Linear):
+#             trunc_normal_(self.head.weight, std=.02)
+        self.apply(self._init_weights)
+        self.fix_init_weight()
+#         if isinstance(self.head, nn.Linear):
+#             self.head.weight.data.mul_(init_scale)
+#             self.head.bias.data.mul_(init_scale)
+
+    def fix_init_weight(self):
+        def rescale(param, layer_id):
+            param.div_(math.sqrt(2.0 * layer_id))
+
+        for layer_id, layer in enumerate(self.blocks):
+            rescale(layer.attn.proj.weight.data, layer_id + 1)
+            rescale(layer.mlp.fc2.weight.data, layer_id + 1)
+
+    def _init_weights(self, m):
+        if isinstance(m, nn.Linear):
+            trunc_normal_(m.weight, std=.02)
+            if isinstance(m, nn.Linear) and m.bias is not None:
+                nn.init.constant_(m.bias, 0)
+        elif isinstance(m, nn.LayerNorm):
+            nn.init.constant_(m.bias, 0)
+            nn.init.constant_(m.weight, 1.0)
+
+    def get_classifier(self):
+        return self.head
+
+    def reset_classifier(self, num_classes, global_pool=''):
+        self.num_classes = num_classes
+        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+    def forward_features(self, x, output_hidden_states=False):
+        x = self.patch_embed(x)
+        batch_size, seq_len, _ = x.size()
+
+        cls_tokens = self.cls_token.expand(batch_size, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
+        x = torch.cat((cls_tokens, x), dim=1)
+        if self.pos_embed is not None:
+            x = x + self.pos_embed
+        x = self.pos_drop(x)
+
+        hiddens = []
+        rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
+        for blk in self.blocks:
+            if self.use_checkpoint:
+                x = checkpoint.checkpoint(blk, x, rel_pos_bias)
+            else:
+                x = blk(x, rel_pos_bias)
+            hiddens.append(x)
+
+        if output_hidden_states:
+            return hiddens
+        else:
+            return x
+#         x = self.norm(x)
+
+#         if self.fc_norm is not None:
+#             t = x[:, 1:, :]
+#             return self.fc_norm(t.mean(1))
+#         else:
+#             return x[:, 0]
+
+    def forward(self, x, output_hidden_states=False):
+        x = self.forward_features(x, output_hidden_states)
+#         x = self.head(x)
+        return x
+
+    def get_intermediate_layers(self, x):
+        x = self.patch_embed(x)
+        batch_size, seq_len, _ = x.size()
+
+        cls_tokens = self.cls_token.expand(batch_size, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
+        x = torch.cat((cls_tokens, x), dim=1)
+        if self.pos_embed is not None:
+            x = x + self.pos_embed
+        x = self.pos_drop(x)
+
+        features = []
+        rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
+        for blk in self.blocks:
+            x = blk(x, rel_pos_bias)
+            features.append(x)
+
+        return features
+    
+    def get_num_layer(self, var_name=""):
+        if var_name in ("cls_token", "mask_token", "pos_embed"):
+            return 0
+        elif var_name.startswith("patch_embed"):
+            return 0
+        elif var_name.startswith("rel_pos_bias"):
+            return len(self.blocks) - 1
+        elif var_name.startswith("blocks"):
+            layer_id = int(var_name.split('.')[1])
+            return layer_id + 1
+        else:
+            return len(self.blocks)
+        
+            
+def interpolate_pos_embed(model, checkpoint_model):
+    if 'pos_embed' in checkpoint_model:
+        pos_embed_checkpoint = checkpoint_model['pos_embed'].float()
+        embedding_size = pos_embed_checkpoint.shape[-1]
+        num_patches = model.patch_embed.num_patches
+        num_extra_tokens = model.pos_embed.shape[-2] - num_patches
+        # height (== width) for the checkpoint position embedding
+        orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
+        # height (== width) for the new position embedding
+        new_size = int(num_patches ** 0.5)
+        # class_token and dist_token are kept unchanged
+        if orig_size != new_size:
+            print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
+            extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
+            # only the position tokens are interpolated
+            pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
+            pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
+            pos_tokens = torch.nn.functional.interpolate(
+                pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
+            pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
+            new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
+            checkpoint_model['pos_embed'] = new_pos_embed
+            
+            
+def convert_weights_to_fp16(model: nn.Module):
+    """Convert applicable model parameters to fp16"""
+
+    def _convert_weights_to_fp16(l):
+        if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
+            l.weight.data = l.weight.data.half()
+            if l.bias is not None:
+                l.bias.data = l.bias.data.half()
+
+#         if isinstance(l, (nn.MultiheadAttention, Attention)):
+#             for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
+#                 tensor = getattr(l, attr)
+#                 if tensor is not None:
+#                     tensor.data = tensor.data.half()
+
+    model.apply(_convert_weights_to_fp16)
+    
+    
+def create_eva_vit_g(img_size=224,drop_path_rate=0.4,use_checkpoint=False,precision="fp16"):
+    model = VisionTransformer(
+        img_size=img_size,
+        patch_size=14,
+        use_mean_pooling=False,
+        embed_dim=1408,
+        depth=39,
+        num_heads=1408//88,
+        mlp_ratio=4.3637,
+        qkv_bias=True,
+        drop_path_rate=drop_path_rate,
+        norm_layer=partial(nn.LayerNorm, eps=1e-6),
+        use_checkpoint=use_checkpoint,
+    )  
+    url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth"
+    cached_file = download_cached_file(
+        url, check_hash=False, progress=True
+    )
+    state_dict = torch.load(cached_file, map_location="cpu")    
+    interpolate_pos_embed(model,state_dict)
+    
+    incompatible_keys = model.load_state_dict(state_dict, strict=False)
+#     print(incompatible_keys)
+    
+    if precision == "fp16":
+#         model.to("cuda") 
+        convert_weights_to_fp16(model)
+    return model
diff --git a/models/q_formers/position_encoding.py b/models/q_formers/position_encoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..88035899947c158dcd995e986d2a23d7c4989913
--- /dev/null
+++ b/models/q_formers/position_encoding.py
@@ -0,0 +1,47 @@
+import torch
+import torch.nn as nn
+
+
+class PositionEmbeddings(nn.Module):
+    def __init__(self, max_position_embeddings, hidden_size, eps=1e-12, dropout=0.1, inplace=True):
+        super().__init__()
+        self.position_embeddings = nn.Embedding(
+            max_position_embeddings, hidden_size
+        )
+
+        self.LayerNorm = nn.LayerNorm(hidden_size, eps=eps)
+        self.dropout = nn.Dropout(dropout, inplace=inplace)
+
+        self.register_buffer(
+            "position_ids", torch.arange(max_position_embeddings).expand((1, -1))
+        )
+
+    def forward(self, embeddings, position_ids=None, offset=0):
+        seq_length = embeddings.size()[1]
+
+        if position_ids is None:
+            position_ids = self.position_ids[:, offset:offset+seq_length].clone()
+
+        position_embeddings = self.position_embeddings(position_ids)
+        embeddings = embeddings + position_embeddings
+
+        embeddings = self.LayerNorm(embeddings)
+        embeddings = self.dropout(embeddings)
+        return embeddings
+
+
+class PositionScore(nn.Module):
+    def __init__(self, seq_len, shape=None, score_type="gaussian"):
+        assert seq_len is not None or shape is not None, "seq_len or shape must be provided"
+        self.cls_token = False
+        if seq_len is not None:
+            h = w = int(seq_len ** 0.5)
+        elif isinstance(shape, int):
+            h = w = shape
+        else:
+            h, w = shape
+        self.h = h
+        self.w = w
+
+    def forward(self, tensor):
+        bs, chn, m, n = tensor.shape
diff --git a/models/util.py b/models/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6d05db8c7d8d9c2db50b0872c1a33650f208397
--- /dev/null
+++ b/models/util.py
@@ -0,0 +1,45 @@
+import os
+import torch
+
+from omegaconf import OmegaConf
+from ldm.util import instantiate_from_config
+
+
+def get_state_dict(d):
+    return d.get('state_dict', d)
+
+
+def load_state_dict(ckpt_path, location='cpu'):
+    _, extension = os.path.splitext(ckpt_path)
+    if extension.lower() == ".safetensors":
+        import safetensors.torch
+        state_dict = safetensors.torch.load_file(ckpt_path, device=location)
+    else:
+        state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
+    state_dict = get_state_dict(state_dict)
+    print(f'Loaded state_dict from [{ckpt_path}]')
+    return state_dict
+
+
+def load_ckpt(model, state_dict, strict=True, input_channel_copy_indices=None):
+    input_channel_key = "local_adapter.feature_extractor.pre_extractor.0.weight"
+    model_state_dict = model.state_dict()
+    if input_channel_key in model_state_dict and input_channel_key in state_dict:
+        model_shape = model_state_dict[input_channel_key].shape
+        shape = state_dict[input_channel_key].shape
+        if model_shape != shape:
+            if input_channel_copy_indices is None:
+                state_dict[input_channel_key] = state_dict[input_channel_key][:, :model_shape[1], :, :]
+            else:
+                cout, cin, h, w = model_shape
+                weight = state_dict[input_channel_key].view(cout, -1, cin//3, h, w)
+                weight = weight[:, input_channel_copy_indices].view(cout, cin, h, w)
+                state_dict[input_channel_key] = weight
+    model.load_state_dict(state_dict, strict=strict)
+
+
+def create_model(config_path):
+    config = OmegaConf.load(config_path)
+    model = instantiate_from_config(config.model).cpu()
+    print(f'Loaded model config from [{config_path}]')
+    return model
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..401d1f605d594558e947b6b9b50bc75f5b6e5514
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,23 @@
+numpy==1.23.5
+pillow==9.5.0
+scipy==1.13.0
+scikit-image==0.21.0
+scikit-learn==1.3.1
+pycocotools==2.0.7
+nltk==3.8.1
+torch==2.0.1
+torchvision==0.15.2
+mmagic
+salesforce-lavis
+einops==0.4.1
+pytorch-lightning==1.9.4
+accelerate==0.21.0
+diffusers==0.22.3
+mmcv==2.0.0
+gradio==4.37.1
+spacy
+opencv-python==4.9.0.80
+transformers==4.30.2
+basicsr
+clip
+open_clip_torch
diff --git a/utils/config.py b/utils/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e5bfdca6d861a87d5992afff04990f8dfd49598
--- /dev/null
+++ b/utils/config.py
@@ -0,0 +1 @@
+save_memory = True
diff --git a/utils/share.py b/utils/share.py
new file mode 100644
index 0000000000000000000000000000000000000000..c47264fdf34dad0db101380e9c8259b6491c70a6
--- /dev/null
+++ b/utils/share.py
@@ -0,0 +1,8 @@
+import utils.config as config
+from models.hack import disable_verbosity, enable_sliced_attention
+
+
+disable_verbosity()
+
+if config.save_memory:
+    enable_sliced_attention()