diff --git a/.gitignore b/.gitignore
new file mode 100755
index 0000000000000000000000000000000000000000..0110818d6b1e70612d770dd55726953ec8002c6f
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,10 @@
+img/
+logfile/
+__pycache__/
+*/__pycache__/
+models/
+plt/
+docs/
+exp/
+examples/mask/
+examples/mask_box/
\ No newline at end of file
diff --git a/.vscode/launch.json b/.vscode/launch.json
new file mode 100644
index 0000000000000000000000000000000000000000..224b59528850f7106228202b8d5901771b41c614
--- /dev/null
+++ b/.vscode/launch.json
@@ -0,0 +1,17 @@
+{
+    // Use IntelliSense to learn about possible attributes.
+    // Hover to view descriptions of existing attributes.
+    // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
+    "version": "0.2.0",
+    "configurations": [
+        {
+            "name": "Demo",
+            "type": "debugpy",
+            "request": "launch",
+            "program": "/home/jyr/demo/DesignEdit/app.py",
+            "console": "integratedTerminal",
+            "python": "/home/jyr/.conda/envs/new_design/bin/python",
+
+        }
+    ]
+}
\ No newline at end of file
diff --git a/README.md b/README.md
index 9504551be6c0d250f32f617faec2350c5c855aa1..ddb40e73375e68172b0c30cf147112ae756624ec 100644
--- a/README.md
+++ b/README.md
@@ -1,10 +1,10 @@
 ---
 title: DesignEdit
-emoji: 🏆
-colorFrom: pink
-colorTo: yellow
+emoji: 🌿
+colorFrom: yellow
+colorTo: green
 sdk: gradio
-sdk_version: 4.25.0
+sdk_version: 4.24.0
 app_file: app.py
 pinned: false
 ---
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..891e638d4689ebee66ef358de780a849ea5121fa
--- /dev/null
+++ b/app.py
@@ -0,0 +1,61 @@
+import gradio as gr
+import spaces
+import torch
+
+import os
+import subprocess
+import shlex
+from src.demo.model import DesignEdit
+
+os.makedirs('models', exist_ok=True)
+subprocess.run(shlex.split('wget https://huggingface.co/Adapter/DragonDiffusion/resolve/main/model/efficient_sam_vits.pt -O models/efficient_sam_vits.pt'))
+
+from src.demo.demo import *
+import shlex
+import cv2
+
+pretrained_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
+model =  DesignEdit(pretrained_model_path=pretrained_model_path)
+DESCRIPTION_1 = """<div style="text-align: center; font-size: 80px;">
+        <strong class="title is-1">
+            <span style="color: green;">🌿D</span>
+            <span style="color: orange;">e</span>
+            <span style="color: rgb(63, 185, 63);">s</span>
+            <span style="color: green;">i</span>
+            <span style="color: rgb(200, 85, 23);">g</span>
+            <span style="color: green;">n</span>
+            <span style="color: orange;">E</span>
+            <span style="color: crimson;">d</span>
+            <span style="color: darkorange;">i</span>
+            <span style="color: green;">t🌿</span>
+          </strong> 
+    </div>
+    """
+DESCRIPTION_2 = """ <div style="text-align: center;font-size: 24px;"> <h1> Multi-Layered Latent Decomposition and Fusion for Unified & Accurate Image Editing</h1></div>"""
+DESCRIPTION_3 = """
+<div style="text-align: center; font-size: 24px;">
+    <p> Gradio demo for <a href="https://design-edit.github.io/">DesignEdit</a></p>
+</div>
+"""
+
+
+with gr.Blocks(css='style.css') as demo:
+    gr.HTML(DESCRIPTION_1)
+    gr.HTML(DESCRIPTION_2)
+    gr.HTML(DESCRIPTION_3)
+    with gr.Tabs():
+        with gr.TabItem('1️⃣ Object Removal'):
+            create_demo_remove(model.run_remove)
+        with gr.TabItem('2️⃣ Zooming Out'):
+            create_demo_zooming(model.run_zooming)
+        with gr.TabItem('3️⃣ Camera Panning'):
+            create_demo_panning(model.run_panning)
+        with gr.TabItem('4️⃣ Object Moving, Resizing and Flipping'):
+            create_demo_moving(model.run_moving)
+        with gr.TabItem('5️⃣ 🚩 Multi-Layered Editing 🚩'):
+            create_demo_layer(model.run_layer)
+        with gr.TabItem('🔧 Mask Preparation: Draw or Sketch'):
+            create_demo_mask_box(model.run_mask)
+demo.queue(max_size=20)
+demo.launch(max_threads=3, server_name="0.0.0.0")
+
diff --git a/examples/layer/01_horse/00.jpg b/examples/layer/01_horse/00.jpg
new file mode 100755
index 0000000000000000000000000000000000000000..502be3f0bbada559c91718606b6b03c86021572b
Binary files /dev/null and b/examples/layer/01_horse/00.jpg differ
diff --git a/examples/layer/01_horse/mask0.jpg b/examples/layer/01_horse/mask0.jpg
new file mode 100755
index 0000000000000000000000000000000000000000..39377e481c21d9980df45c94c233f562809efbfa
Binary files /dev/null and b/examples/layer/01_horse/mask0.jpg differ
diff --git a/examples/layer/02_baby/00.jpg b/examples/layer/02_baby/00.jpg
new file mode 100755
index 0000000000000000000000000000000000000000..f4fbb272ec041a9fd0793f10cc6830c42f740047
Binary files /dev/null and b/examples/layer/02_baby/00.jpg differ
diff --git a/examples/layer/02_baby/mask0.jpg b/examples/layer/02_baby/mask0.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..f94b32a63688d6b2852e7f743dc4dffb742bdd53
Binary files /dev/null and b/examples/layer/02_baby/mask0.jpg differ
diff --git a/examples/layer/02_baby/mask1.jpg b/examples/layer/02_baby/mask1.jpg
new file mode 100755
index 0000000000000000000000000000000000000000..bbed5f61ca38f384ab3f2f87c206efd74174e1da
Binary files /dev/null and b/examples/layer/02_baby/mask1.jpg differ
diff --git a/examples/layer/02_baby/mask2.jpg b/examples/layer/02_baby/mask2.jpg
new file mode 100755
index 0000000000000000000000000000000000000000..c77e82d7eeea470bd3a787d018815eda5ee7f588
Binary files /dev/null and b/examples/layer/02_baby/mask2.jpg differ
diff --git a/examples/layer/03_text/00.jpg b/examples/layer/03_text/00.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..2ae11e62b15bbabcba72e46fd373ad0c718a8fb3
Binary files /dev/null and b/examples/layer/03_text/00.jpg differ
diff --git a/examples/layer/03_text/01.jpg b/examples/layer/03_text/01.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..654004437236e476431cb81d38434dec1d65e1d1
Binary files /dev/null and b/examples/layer/03_text/01.jpg differ
diff --git a/examples/layer/03_text/mask0.jpg b/examples/layer/03_text/mask0.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..bf0ce7ffc1799bd21967250f02ecd5f5159ae384
Binary files /dev/null and b/examples/layer/03_text/mask0.jpg differ
diff --git a/examples/layer/03_text/mask1.jpg b/examples/layer/03_text/mask1.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..a1d805e541613f2c53f767b0e775ab053707cb34
Binary files /dev/null and b/examples/layer/03_text/mask1.jpg differ
diff --git a/examples/layer/04_cross/0.jpg b/examples/layer/04_cross/0.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..4edde7252b283d38f8e123da726c00e8ecf418fb
Binary files /dev/null and b/examples/layer/04_cross/0.jpg differ
diff --git a/examples/layer/04_cross/1.jpg b/examples/layer/04_cross/1.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..a71e2c9477b3ae1911be802d6020df06c55554da
Binary files /dev/null and b/examples/layer/04_cross/1.jpg differ
diff --git a/examples/layer/04_cross/2.jpg b/examples/layer/04_cross/2.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..f40b0440819237e7e18eb8e8479cd370da837e2c
Binary files /dev/null and b/examples/layer/04_cross/2.jpg differ
diff --git a/examples/layer/04_cross/3.jpg b/examples/layer/04_cross/3.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..a2ded809ac249e22dd2808dd99ca386c56632adc
Binary files /dev/null and b/examples/layer/04_cross/3.jpg differ
diff --git a/examples/layer/04_cross/mask0.jpg b/examples/layer/04_cross/mask0.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..e07c313622a380fe141c9b799ab362ab54703345
Binary files /dev/null and b/examples/layer/04_cross/mask0.jpg differ
diff --git a/examples/layer/04_cross/mask1.jpg b/examples/layer/04_cross/mask1.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..3c90e633a455b762be1fc39a0ece2c62ea55bb5f
Binary files /dev/null and b/examples/layer/04_cross/mask1.jpg differ
diff --git a/examples/layer/04_cross/mask2.jpg b/examples/layer/04_cross/mask2.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..1d09eecc9bf6bf60321def094bc3ecb726963fee
Binary files /dev/null and b/examples/layer/04_cross/mask2.jpg differ
diff --git a/examples/layer/04_cross/mask3.jpg b/examples/layer/04_cross/mask3.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..19b58983cf1463edfd7263fa40296ba723f4d03a
Binary files /dev/null and b/examples/layer/04_cross/mask3.jpg differ
diff --git a/examples/moving/01_ball/0.jpg b/examples/moving/01_ball/0.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..f11738232d02b9ffcb60fe4878764a1081d7fb7a
Binary files /dev/null and b/examples/moving/01_ball/0.jpg differ
diff --git a/examples/moving/01_ball/mask0.jpg b/examples/moving/01_ball/mask0.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..49e57f3cdb98cd1ee80e70318bcb081fb0fbd930
Binary files /dev/null and b/examples/moving/01_ball/mask0.jpg differ
diff --git a/examples/moving/02_bell/0.jpg b/examples/moving/02_bell/0.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..1a3e1fba25f66f21fe42265f867fa8daaa3f9032
Binary files /dev/null and b/examples/moving/02_bell/0.jpg differ
diff --git a/examples/moving/02_bell/mask0.jpg b/examples/moving/02_bell/mask0.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..e4f89e4293acea880cd8a04bb0f7e2d69f7cbdb2
Binary files /dev/null and b/examples/moving/02_bell/mask0.jpg differ
diff --git a/examples/pan/01.jpg b/examples/pan/01.jpg
new file mode 100755
index 0000000000000000000000000000000000000000..502be3f0bbada559c91718606b6b03c86021572b
Binary files /dev/null and b/examples/pan/01.jpg differ
diff --git a/examples/pan/02.jpg b/examples/pan/02.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..dd1d5d4450d43f0e2812a4740f92b2eabd5d7d11
Binary files /dev/null and b/examples/pan/02.jpg differ
diff --git a/examples/pan/03.jpg b/examples/pan/03.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..d3a84ffebd40b0748deb1b2f1f8734f88b2d1ebc
Binary files /dev/null and b/examples/pan/03.jpg differ
diff --git a/examples/pan/04.jpg b/examples/pan/04.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..63c305e6fd77128f5b1664d89bf70bc0fad21a40
Binary files /dev/null and b/examples/pan/04.jpg differ
diff --git a/examples/pan/05.jpg b/examples/pan/05.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..5617fd79f1a039673e5de8e3ee865351f96ff1cf
Binary files /dev/null and b/examples/pan/05.jpg differ
diff --git a/examples/pan/06.jpg b/examples/pan/06.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..d5e6670f5491df5388f4d5181c60afb36cdcdfed
Binary files /dev/null and b/examples/pan/06.jpg differ
diff --git a/examples/remove/01_moto/0.jpg b/examples/remove/01_moto/0.jpg
new file mode 100755
index 0000000000000000000000000000000000000000..2164e166b5fb3e420a0c892b891ba3e3e68c6712
Binary files /dev/null and b/examples/remove/01_moto/0.jpg differ
diff --git a/examples/remove/01_moto/mask0.jpg b/examples/remove/01_moto/mask0.jpg
new file mode 100755
index 0000000000000000000000000000000000000000..3c602058c95f0165eb3e4b4b84dc6205de06cf66
Binary files /dev/null and b/examples/remove/01_moto/mask0.jpg differ
diff --git a/examples/remove/01_moto/mask1.jpg b/examples/remove/01_moto/mask1.jpg
new file mode 100755
index 0000000000000000000000000000000000000000..d773289fee31961900990eb745f7cbd8ac85735e
Binary files /dev/null and b/examples/remove/01_moto/mask1.jpg differ
diff --git a/examples/remove/02_ring/0.jpg b/examples/remove/02_ring/0.jpg
new file mode 100755
index 0000000000000000000000000000000000000000..897e98afa4ebe74083b154fee58f77332c072973
Binary files /dev/null and b/examples/remove/02_ring/0.jpg differ
diff --git a/examples/remove/02_ring/mask0.jpg b/examples/remove/02_ring/mask0.jpg
new file mode 100755
index 0000000000000000000000000000000000000000..437703be8b840ac9f7d84587f7871872d60b9b1c
Binary files /dev/null and b/examples/remove/02_ring/mask0.jpg differ
diff --git a/examples/remove/02_ring/mask1.jpg b/examples/remove/02_ring/mask1.jpg
new file mode 100755
index 0000000000000000000000000000000000000000..61001eff2dd2cb528a110d63897e7d48945f2cf5
Binary files /dev/null and b/examples/remove/02_ring/mask1.jpg differ
diff --git a/examples/remove/02_ring/mask2.jpg b/examples/remove/02_ring/mask2.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..faabcd04ce296a29eecbf23a9fbb5fcd93a82dd9
Binary files /dev/null and b/examples/remove/02_ring/mask2.jpg differ
diff --git a/examples/remove/03_ball/0.jpg b/examples/remove/03_ball/0.jpg
new file mode 100755
index 0000000000000000000000000000000000000000..a32933213dd03dedcd1483fa5bee7582c0c51236
Binary files /dev/null and b/examples/remove/03_ball/0.jpg differ
diff --git a/examples/remove/03_ball/mask0.jpg b/examples/remove/03_ball/mask0.jpg
new file mode 100755
index 0000000000000000000000000000000000000000..41d2fb7ffc5b8605d8f6622bb646ca3afb8093a1
Binary files /dev/null and b/examples/remove/03_ball/mask0.jpg differ
diff --git a/examples/remove/03_ball/mask1.jpg b/examples/remove/03_ball/mask1.jpg
new file mode 100755
index 0000000000000000000000000000000000000000..dbe79573f0b773655b22ea70003eb38e747d5394
Binary files /dev/null and b/examples/remove/03_ball/mask1.jpg differ
diff --git a/examples/remove/04_pikachu/0.jpg b/examples/remove/04_pikachu/0.jpg
new file mode 100755
index 0000000000000000000000000000000000000000..1204812f145e29311d15fd20433e7649954f8af2
Binary files /dev/null and b/examples/remove/04_pikachu/0.jpg differ
diff --git a/examples/remove/04_pikachu/mask0.jpg b/examples/remove/04_pikachu/mask0.jpg
new file mode 100755
index 0000000000000000000000000000000000000000..a197882e1b3f9c728096885ee6265d6875521905
Binary files /dev/null and b/examples/remove/04_pikachu/mask0.jpg differ
diff --git a/examples/remove/04_pikachu/mask1.jpg b/examples/remove/04_pikachu/mask1.jpg
new file mode 100755
index 0000000000000000000000000000000000000000..80777884fa0fb4ee59deedc1b99477e2f95594fc
Binary files /dev/null and b/examples/remove/04_pikachu/mask1.jpg differ
diff --git a/examples/remove/04_pikachu/mask2.jpg b/examples/remove/04_pikachu/mask2.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..2eb6c21723b01a5c74b50563a096cda0e21b87c0
Binary files /dev/null and b/examples/remove/04_pikachu/mask2.jpg differ
diff --git a/examples/remove/05_betty/0.jpg b/examples/remove/05_betty/0.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..39eceadc2330703743090990bdbff8aca97afb13
Binary files /dev/null and b/examples/remove/05_betty/0.jpg differ
diff --git a/examples/remove/05_betty/mask0.jpg b/examples/remove/05_betty/mask0.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..87c0784b1c402471436c89166f8114c36535c672
Binary files /dev/null and b/examples/remove/05_betty/mask0.jpg differ
diff --git a/examples/zoom/01.jpg b/examples/zoom/01.jpg
new file mode 100755
index 0000000000000000000000000000000000000000..502be3f0bbada559c91718606b6b03c86021572b
Binary files /dev/null and b/examples/zoom/01.jpg differ
diff --git a/examples/zoom/02.jpg b/examples/zoom/02.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..f5730a436e7a08e47a7f2e29e28617271d9d1c44
Binary files /dev/null and b/examples/zoom/02.jpg differ
diff --git a/examples/zoom/03.jpg b/examples/zoom/03.jpg
new file mode 100755
index 0000000000000000000000000000000000000000..f7a0cc7d115a9e8a59959c910d66774ab0d44af8
Binary files /dev/null and b/examples/zoom/03.jpg differ
diff --git a/examples/zoom/04.jpg b/examples/zoom/04.jpg
new file mode 100755
index 0000000000000000000000000000000000000000..e4341114794f8be5fe80c9d442693630db1c8ff6
Binary files /dev/null and b/examples/zoom/04.jpg differ
diff --git a/examples/zoom/05.jpg b/examples/zoom/05.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..f4a2f82d540c7f6731b8e4dff5dcdb74e18c4ab6
Binary files /dev/null and b/examples/zoom/05.jpg differ
diff --git a/examples/zoom/06.jpg b/examples/zoom/06.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..63c305e6fd77128f5b1664d89bf70bc0fad21a40
Binary files /dev/null and b/examples/zoom/06.jpg differ
diff --git a/examples/zoom/07.jpg b/examples/zoom/07.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..aaffb0d35dff979c342e5bc383e0c8d999607134
Binary files /dev/null and b/examples/zoom/07.jpg differ
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d434bd9eaabf87af87056786d9479729abc8a266
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,14 @@
+diffusers==0.18.2
+torch==2.0.1
+torchvision==0.15.2
+matplotlib==3.7.2
+numpy==1.25.1
+opencv_python==4.8.0.74
+opencv_python_headless==4.8.0.74
+Pillow==10.1.0
+Pillow==10.1.0
+transformers==4.35.0
+gradio==4.0.0
+basicsr==1.4.2
+accelerate==0.21.0
+invisible-watermark
\ No newline at end of file
diff --git a/sam/efficient_sam/__init__.py b/sam/efficient_sam/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..22a2d29cee5d2a2df01944c90b6e01f879301f3f
--- /dev/null
+++ b/sam/efficient_sam/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+from .build_efficient_sam import (
+    build_efficient_sam_vitt,
+    build_efficient_sam_vits,
+)
diff --git a/sam/efficient_sam/build_efficient_sam.py b/sam/efficient_sam/build_efficient_sam.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d138e7335d10c8cbf43aa9ceafef12eda92a66e
--- /dev/null
+++ b/sam/efficient_sam/build_efficient_sam.py
@@ -0,0 +1,22 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .efficient_sam import build_efficient_sam
+
+def build_efficient_sam_vitt():
+    return build_efficient_sam(
+        encoder_patch_embed_dim=192,
+        encoder_num_heads=3,
+        checkpoint="models/efficient_sam_vitt.pt",
+    ).eval()
+
+
+def build_efficient_sam_vits():
+    return build_efficient_sam(
+        encoder_patch_embed_dim=384,
+        encoder_num_heads=6,
+        checkpoint="models/efficient_sam_vits.pt",
+    ).eval()
diff --git a/sam/efficient_sam/efficient_sam.py b/sam/efficient_sam/efficient_sam.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a3ba4c328e0716a8e3166b7d267353757aa76d7
--- /dev/null
+++ b/sam/efficient_sam/efficient_sam.py
@@ -0,0 +1,310 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+from typing import Any, List, Tuple, Type
+
+import torch
+import torch.nn.functional as F
+
+from torch import nn, Tensor
+
+from .efficient_sam_decoder import MaskDecoder, PromptEncoder
+from .efficient_sam_encoder import ImageEncoderViT
+from .two_way_transformer import TwoWayAttentionBlock, TwoWayTransformer
+
+class EfficientSam(nn.Module):
+    mask_threshold: float = 0.0
+    image_format: str = "RGB"
+
+    def __init__(
+        self,
+        image_encoder: ImageEncoderViT,
+        prompt_encoder: PromptEncoder,
+        decoder_max_num_input_points: int,
+        mask_decoder: MaskDecoder,
+        pixel_mean: List[float] = [0.485, 0.456, 0.406],
+        pixel_std: List[float] = [0.229, 0.224, 0.225],
+    ) -> None:
+        """
+        SAM predicts object masks from an image and input prompts.
+
+        Arguments:
+          image_encoder (ImageEncoderViT): The backbone used to encode the
+            image into image embeddings that allow for efficient mask prediction.
+          prompt_encoder (PromptEncoder): Encodes various types of input prompts.
+          mask_decoder (MaskDecoder): Predicts masks from the image embeddings
+            and encoded prompts.
+          pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
+          pixel_std (list(float)): Std values for normalizing pixels in the input image.
+        """
+        super().__init__()
+        self.image_encoder = image_encoder
+        self.prompt_encoder = prompt_encoder
+        self.decoder_max_num_input_points = decoder_max_num_input_points
+        self.mask_decoder = mask_decoder
+        self.register_buffer(
+            "pixel_mean", torch.Tensor(pixel_mean).view(1, 3, 1, 1), False
+        )
+        self.register_buffer(
+            "pixel_std", torch.Tensor(pixel_std).view(1, 3, 1, 1), False
+        )
+
+    @torch.jit.export
+    def predict_masks(
+        self,
+        image_embeddings: torch.Tensor,
+        batched_points: torch.Tensor,
+        batched_point_labels: torch.Tensor,
+        multimask_output: bool,
+        input_h: int,
+        input_w: int,
+        output_h: int = -1,
+        output_w: int = -1,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """
+        Predicts masks given image embeddings and prompts. This only runs the decoder.
+
+        Arguments:
+          image_embeddings: A tensor of shape [B, C, H, W] or [B*max_num_queries, C, H, W]
+          batched_points: A tensor of shape [B, max_num_queries, num_pts, 2]
+          batched_point_labels: A tensor of shape [B, max_num_queries, num_pts]
+        Returns:
+          A tuple of two tensors:
+            low_res_mask: A tensor of shape [B, max_num_queries, 256, 256] of predicted masks
+            iou_predictions: A tensor of shape [B, max_num_queries] of estimated IOU scores
+        """
+
+        batch_size, max_num_queries, num_pts, _ = batched_points.shape
+        num_pts = batched_points.shape[2]
+        rescaled_batched_points = self.get_rescaled_pts(batched_points, input_h, input_w)
+
+        if num_pts > self.decoder_max_num_input_points:
+            rescaled_batched_points = rescaled_batched_points[
+                :, :, : self.decoder_max_num_input_points, :
+            ]
+            batched_point_labels = batched_point_labels[
+                :, :, : self.decoder_max_num_input_points
+            ]
+        elif num_pts < self.decoder_max_num_input_points:
+            rescaled_batched_points = F.pad(
+                rescaled_batched_points,
+                (0, 0, 0, self.decoder_max_num_input_points - num_pts),
+                value=-1.0,
+            )
+            batched_point_labels = F.pad(
+                batched_point_labels,
+                (0, self.decoder_max_num_input_points - num_pts),
+                value=-1.0,
+            )
+
+        sparse_embeddings = self.prompt_encoder(
+            rescaled_batched_points.reshape(
+                batch_size * max_num_queries, self.decoder_max_num_input_points, 2
+            ),
+            batched_point_labels.reshape(
+                batch_size * max_num_queries, self.decoder_max_num_input_points
+            ),
+        )
+        sparse_embeddings = sparse_embeddings.view(
+            batch_size,
+            max_num_queries,
+            sparse_embeddings.shape[1],
+            sparse_embeddings.shape[2],
+        )
+        low_res_masks, iou_predictions = self.mask_decoder(
+            image_embeddings,
+            self.prompt_encoder.get_dense_pe(),
+            sparse_prompt_embeddings=sparse_embeddings,
+            multimask_output=multimask_output,
+        )
+        _, num_predictions, low_res_size, _ = low_res_masks.shape
+
+        if output_w > 0 and output_h > 0:
+            output_masks = F.interpolate(
+                low_res_masks, (output_h, output_w), mode="bicubic"
+            )
+            output_masks = torch.reshape(
+                output_masks,
+                (batch_size, max_num_queries, num_predictions, output_h, output_w),
+            )
+        else:
+            output_masks = torch.reshape(
+                low_res_masks,
+                (
+                    batch_size,
+                    max_num_queries,
+                    num_predictions,
+                    low_res_size,
+                    low_res_size,
+                ),
+            )
+        iou_predictions = torch.reshape(
+            iou_predictions, (batch_size, max_num_queries, num_predictions)
+        )
+        sorted_ids = torch.argsort(iou_predictions, dim=-1, descending=True)
+        iou_predictions = torch.take_along_dim(iou_predictions, sorted_ids, dim=2)
+        output_masks = torch.take_along_dim(
+            output_masks, sorted_ids[..., None, None], dim=2
+        )
+        return output_masks, iou_predictions
+
+    def get_rescaled_pts(self, batched_points: torch.Tensor, input_h: int, input_w: int):
+        return torch.stack(
+            [
+                torch.where(
+                    batched_points[..., 0] >= 0,
+                    batched_points[..., 0] * self.image_encoder.img_size / input_w,
+                    -1.0,
+                ),
+                torch.where(
+                    batched_points[..., 1] >= 0,
+                    batched_points[..., 1] * self.image_encoder.img_size / input_h,
+                    -1.0,
+                ),
+            ],
+            dim=-1,
+        )
+
+    @torch.jit.export
+    def get_image_embeddings(self, batched_images) -> torch.Tensor:
+        """
+        Predicts masks end-to-end from provided images and prompts.
+        If prompts are not known in advance, using SamPredictor is
+        recommended over calling the model directly.
+
+        Arguments:
+          batched_images: A tensor of shape [B, 3, H, W]
+        Returns:
+          List of image embeddings each of of shape [B, C(i), H(i), W(i)].
+          The last embedding corresponds to the final layer.
+        """
+        batched_images = self.preprocess(batched_images)
+        return self.image_encoder(batched_images)
+
+    def forward(
+        self,
+        batched_images: torch.Tensor,
+        batched_points: torch.Tensor,
+        batched_point_labels: torch.Tensor,
+        scale_to_original_image_size: bool = True,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """
+        Predicts masks end-to-end from provided images and prompts.
+        If prompts are not known in advance, using SamPredictor is
+        recommended over calling the model directly.
+
+        Arguments:
+          batched_images: A tensor of shape [B, 3, H, W]
+          batched_points: A tensor of shape [B, num_queries, max_num_pts, 2]
+          batched_point_labels: A tensor of shape [B, num_queries, max_num_pts]
+
+        Returns:
+          A list tuples of two tensors where the ith element is by considering the first i+1 points.
+            low_res_mask: A tensor of shape [B, 256, 256] of predicted masks
+            iou_predictions: A tensor of shape [B, max_num_queries] of estimated IOU scores
+        """
+        batch_size, _, input_h, input_w = batched_images.shape
+        image_embeddings = self.get_image_embeddings(batched_images)
+        return self.predict_masks(
+            image_embeddings,
+            batched_points,
+            batched_point_labels,
+            multimask_output=True,
+            input_h=input_h,
+            input_w=input_w,
+            output_h=input_h if scale_to_original_image_size else -1,
+            output_w=input_w if scale_to_original_image_size else -1,
+        )
+
+    def preprocess(self, x: torch.Tensor) -> torch.Tensor:
+        """Normalize pixel values and pad to a square input."""
+        if (
+            x.shape[2] != self.image_encoder.img_size
+            or x.shape[3] != self.image_encoder.img_size
+        ):
+            x = F.interpolate(
+                x,
+                (self.image_encoder.img_size, self.image_encoder.img_size),
+                mode="bilinear",
+            )
+        return (x - self.pixel_mean) / self.pixel_std
+
+
+def build_efficient_sam(encoder_patch_embed_dim, encoder_num_heads, checkpoint=None):
+    img_size = 1024
+    encoder_patch_size = 16
+    encoder_depth = 12
+    encoder_mlp_ratio = 4.0
+    encoder_neck_dims = [256, 256]
+    decoder_max_num_input_points = 6
+    decoder_transformer_depth = 2
+    decoder_transformer_mlp_dim = 2048
+    decoder_num_heads = 8
+    decoder_upscaling_layer_dims = [64, 32]
+    num_multimask_outputs = 3
+    iou_head_depth = 3
+    iou_head_hidden_dim = 256
+    activation = "gelu"
+    normalization_type = "layer_norm"
+    normalize_before_activation = False
+
+    assert activation == "relu" or activation == "gelu"
+    if activation == "relu":
+        activation_fn = nn.ReLU
+    else:
+        activation_fn = nn.GELU
+
+    image_encoder = ImageEncoderViT(
+        img_size=img_size,
+        patch_size=encoder_patch_size,
+        in_chans=3,
+        patch_embed_dim=encoder_patch_embed_dim,
+        normalization_type=normalization_type,
+        depth=encoder_depth,
+        num_heads=encoder_num_heads,
+        mlp_ratio=encoder_mlp_ratio,
+        neck_dims=encoder_neck_dims,
+        act_layer=activation_fn,
+    )
+
+    image_embedding_size = image_encoder.image_embedding_size
+    encoder_transformer_output_dim = image_encoder.transformer_output_dim
+
+    sam = EfficientSam(
+        image_encoder=image_encoder,
+        prompt_encoder=PromptEncoder(
+            embed_dim=encoder_transformer_output_dim,
+            image_embedding_size=(image_embedding_size, image_embedding_size),
+            input_image_size=(img_size, img_size),
+        ),
+        decoder_max_num_input_points=decoder_max_num_input_points,
+        mask_decoder=MaskDecoder(
+            transformer_dim=encoder_transformer_output_dim,
+            transformer=TwoWayTransformer(
+                depth=decoder_transformer_depth,
+                embedding_dim=encoder_transformer_output_dim,
+                num_heads=decoder_num_heads,
+                mlp_dim=decoder_transformer_mlp_dim,
+                activation=activation_fn,
+                normalize_before_activation=normalize_before_activation,
+            ),
+            num_multimask_outputs=num_multimask_outputs,
+            activation=activation_fn,
+            normalization_type=normalization_type,
+            normalize_before_activation=normalize_before_activation,
+            iou_head_depth=iou_head_depth - 1,
+            iou_head_hidden_dim=iou_head_hidden_dim,
+            upscaling_layer_dims=decoder_upscaling_layer_dims,
+        ),
+        pixel_mean=[0.485, 0.456, 0.406],
+        pixel_std=[0.229, 0.224, 0.225],
+    )
+    if checkpoint is not None:
+        with open(checkpoint, "rb") as f:
+            state_dict = torch.load(f, map_location="cpu")
+        sam.load_state_dict(state_dict["model"])
+    return sam
diff --git a/sam/efficient_sam/efficient_sam_decoder.py b/sam/efficient_sam/efficient_sam_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..380f41c1650f1ffb824d8d911f810eabedc66ddd
--- /dev/null
+++ b/sam/efficient_sam/efficient_sam_decoder.py
@@ -0,0 +1,315 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import List, Tuple, Type
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .mlp import MLPBlock
+
+
+class PromptEncoder(nn.Module):
+    def __init__(
+        self,
+        embed_dim: int,
+        image_embedding_size: Tuple[int, int],
+        input_image_size: Tuple[int, int],
+    ) -> None:
+        """
+        Encodes prompts for input to SAM's mask decoder.
+
+        Arguments:
+          embed_dim (int): The prompts' embedding dimension
+          image_embedding_size (tuple(int, int)): The spatial size of the
+            image embedding, as (H, W).
+          input_image_size (int): The padded size of the image as input
+            to the image encoder, as (H, W).
+        """
+        super().__init__()
+        self.embed_dim = embed_dim
+        self.input_image_size = input_image_size
+        self.image_embedding_size = image_embedding_size
+        self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
+        self.invalid_points = nn.Embedding(1, embed_dim)
+        self.point_embeddings = nn.Embedding(1, embed_dim)
+        self.bbox_top_left_embeddings = nn.Embedding(1, embed_dim)
+        self.bbox_bottom_right_embeddings = nn.Embedding(1, embed_dim)
+
+    def get_dense_pe(self) -> torch.Tensor:
+        """
+        Returns the positional encoding used to encode point prompts,
+        applied to a dense set of points the shape of the image encoding.
+
+        Returns:
+          torch.Tensor: Positional encoding with shape
+            1x(embed_dim)x(embedding_h)x(embedding_w)
+        """
+        return self.pe_layer(self.image_embedding_size).unsqueeze(0)
+
+    def _embed_points(
+        self,
+        points: torch.Tensor,
+        labels: torch.Tensor,
+    ) -> torch.Tensor:
+        """Embeds point prompts."""
+        points = points + 0.5  # Shift to center of pixel
+        point_embedding = self.pe_layer.forward_with_coords(
+            points, self.input_image_size
+        )
+        invalid_label_ids = torch.eq(labels, -1)[:,:,None]
+        point_label_ids = torch.eq(labels, 1)[:,:,None]
+        topleft_label_ids = torch.eq(labels, 2)[:,:,None]
+        bottomright_label_ids = torch.eq(labels, 3)[:,:,None]
+        point_embedding = point_embedding + self.invalid_points.weight[:,None,:] * invalid_label_ids
+        point_embedding = point_embedding + self.point_embeddings.weight[:,None,:] * point_label_ids
+        point_embedding = point_embedding + self.bbox_top_left_embeddings.weight[:,None,:] * topleft_label_ids
+        point_embedding = point_embedding + self.bbox_bottom_right_embeddings.weight[:,None,:] * bottomright_label_ids
+        return point_embedding
+
+    def forward(
+        self,
+        coords,
+        labels,
+    ) -> torch.Tensor:
+        """
+        Embeds different types of prompts, returning both sparse and dense
+        embeddings.
+
+        Arguments:
+          points: A tensor of shape [B, 2]
+          labels: An integer tensor of shape [B] where each element is 1,2 or 3.
+
+        Returns:
+          torch.Tensor: sparse embeddings for the points and boxes, with shape
+            BxNx(embed_dim), where N is determined by the number of input points
+            and boxes.
+        """
+        return self._embed_points(coords, labels)
+
+
+class PositionEmbeddingRandom(nn.Module):
+    """
+    Positional encoding using random spatial frequencies.
+    """
+
+    def __init__(self, num_pos_feats: int) -> None:
+        super().__init__()
+        self.register_buffer(
+            "positional_encoding_gaussian_matrix", torch.randn((2, num_pos_feats))
+        )
+
+    def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
+        """Positionally encode points that are normalized to [0,1]."""
+        # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
+        coords = 2 * coords - 1
+        coords = coords @ self.positional_encoding_gaussian_matrix
+        coords = 2 * np.pi * coords
+        # outputs d_1 x ... x d_n x C shape
+        return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
+
+    def forward(self, size: Tuple[int, int]) -> torch.Tensor:
+        """Generate positional encoding for a grid of the specified size."""
+        h, w = size
+        device = self.positional_encoding_gaussian_matrix.device
+        grid = torch.ones([h, w], device=device, dtype=torch.float32)
+        y_embed = grid.cumsum(dim=0) - 0.5
+        x_embed = grid.cumsum(dim=1) - 0.5
+        y_embed = y_embed / h
+        x_embed = x_embed / w
+
+        pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
+        return pe.permute(2, 0, 1)  # C x H x W
+
+    def forward_with_coords(
+        self, coords_input: torch.Tensor, image_size: Tuple[int, int]
+    ) -> torch.Tensor:
+        """Positionally encode points that are not normalized to [0,1]."""
+        coords = coords_input.clone()
+        coords[:, :, 0] = coords[:, :, 0] / image_size[1]
+        coords[:, :, 1] = coords[:, :, 1] / image_size[0]
+        return self._pe_encoding(coords.to(torch.float))  # B x N x C
+
+
+class MaskDecoder(nn.Module):
+    def __init__(
+        self,
+        *,
+        transformer_dim: int,
+        transformer: nn.Module,
+        num_multimask_outputs: int,
+        activation: Type[nn.Module],
+        normalization_type: str,
+        normalize_before_activation: bool,
+        iou_head_depth: int,
+        iou_head_hidden_dim: int,
+        upscaling_layer_dims: List[int],
+    ) -> None:
+        """
+        Predicts masks given an image and prompt embeddings, using a
+        transformer architecture.
+
+        Arguments:
+          transformer_dim (int): the channel dimension of the transformer
+          transformer (nn.Module): the transformer used to predict masks
+          num_multimask_outputs (int): the number of masks to predict
+            when disambiguating masks
+          activation (nn.Module): the type of activation to use when
+            upscaling masks
+          iou_head_depth (int): the depth of the MLP used to predict
+            mask quality
+          iou_head_hidden_dim (int): the hidden dimension of the MLP
+            used to predict mask quality
+        """
+        super().__init__()
+        self.transformer_dim = transformer_dim
+        self.transformer = transformer
+
+        self.num_multimask_outputs = num_multimask_outputs
+
+        self.iou_token = nn.Embedding(1, transformer_dim)
+        if num_multimask_outputs > 1:
+            self.num_mask_tokens = num_multimask_outputs + 1
+        else:
+            self.num_mask_tokens = 1
+        self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
+        output_dim_after_upscaling = transformer_dim
+
+        self.final_output_upscaling_layers = nn.ModuleList([])
+        for idx, layer_dims in enumerate(upscaling_layer_dims):
+            self.final_output_upscaling_layers.append(
+                nn.Sequential(
+                    nn.ConvTranspose2d(
+                        output_dim_after_upscaling,
+                        layer_dims,
+                        kernel_size=2,
+                        stride=2,
+                    ),
+                    nn.GroupNorm(1, layer_dims)
+                    if idx < len(upscaling_layer_dims) - 1
+                    else nn.Identity(),
+                    activation(),
+                )
+            )
+            output_dim_after_upscaling = layer_dims
+
+        self.output_hypernetworks_mlps = nn.ModuleList(
+            [
+                MLPBlock(
+                    input_dim=transformer_dim,
+                    hidden_dim=transformer_dim,
+                    output_dim=output_dim_after_upscaling,
+                    num_layers=2,
+                    act=activation,
+                )
+                for i in range(self.num_mask_tokens)
+            ]
+        )
+
+        self.iou_prediction_head = MLPBlock(
+            input_dim=transformer_dim,
+            hidden_dim=iou_head_hidden_dim,
+            output_dim=self.num_mask_tokens,
+            num_layers=iou_head_depth,
+            act=activation,
+        )
+
+    def forward(
+        self,
+        image_embeddings: torch.Tensor,
+        image_pe: torch.Tensor,
+        sparse_prompt_embeddings: torch.Tensor,
+        multimask_output: bool,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """
+        Predict masks given image and prompt embeddings.
+
+        Arguments:
+          image_embeddings: A tensor of shape [B, C, H, W] or [B*max_num_queries, C, H, W]
+          image_pe (torch.Tensor): positional encoding with the shape of image_embeddings (the batch dimension is broadcastable).
+          sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
+          multimask_output (bool): Whether to return multiple masks or a single
+            mask.
+
+        Returns:
+          torch.Tensor: batched predicted masks
+          torch.Tensor: batched predictions of mask quality
+        """
+
+        (
+            batch_size,
+            max_num_queries,
+            sparse_embed_dim_1,
+            sparse_embed_dim_2,
+        ) = sparse_prompt_embeddings.shape
+
+        (
+            _,
+            image_embed_dim_c,
+            image_embed_dim_h,
+            image_embed_dim_w,
+        ) = image_embeddings.shape
+
+        # Tile the image embedding for all queries.
+        image_embeddings_tiled = torch.tile(
+            image_embeddings[:, None, :, :, :], [1, max_num_queries, 1, 1, 1]
+        ).view(
+            batch_size * max_num_queries,
+            image_embed_dim_c,
+            image_embed_dim_h,
+            image_embed_dim_w,
+        )
+        sparse_prompt_embeddings = sparse_prompt_embeddings.reshape(
+            batch_size * max_num_queries, sparse_embed_dim_1, sparse_embed_dim_2
+        )
+        masks, iou_pred = self.predict_masks(
+            image_embeddings=image_embeddings_tiled,
+            image_pe=image_pe,
+            sparse_prompt_embeddings=sparse_prompt_embeddings,
+        )
+        if multimask_output and self.num_multimask_outputs > 1:
+            return masks[:, 1:, :], iou_pred[:, 1:]
+        else:
+            return masks[:, :1, :], iou_pred[:, :1]
+
+    def predict_masks(
+        self,
+        image_embeddings: torch.Tensor,
+        image_pe: torch.Tensor,
+        sparse_prompt_embeddings: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Predicts masks. See 'forward' for more details."""
+        # Concatenate output tokens
+        output_tokens = torch.cat(
+            [self.iou_token.weight, self.mask_tokens.weight], dim=0
+        )
+        output_tokens = output_tokens.unsqueeze(0).expand(
+            sparse_prompt_embeddings.size(0), -1, -1
+        )
+        tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
+        # Expand per-image data in batch direction to be per-mask
+        pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
+        b, c, h, w = image_embeddings.shape
+        hs, src = self.transformer(image_embeddings, pos_src, tokens)
+        iou_token_out = hs[:, 0, :]
+        mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
+
+        # Upscale mask embeddings and predict masks using the mask tokens
+        upscaled_embedding = src.transpose(1, 2).view(b, c, h, w)
+
+        for upscaling_layer in self.final_output_upscaling_layers:
+            upscaled_embedding = upscaling_layer(upscaled_embedding)
+        hyper_in_list: List[torch.Tensor] = []
+        for i, output_hypernetworks_mlp in enumerate(self.output_hypernetworks_mlps):
+            hyper_in_list.append(output_hypernetworks_mlp(mask_tokens_out[:, i, :]))
+        hyper_in = torch.stack(hyper_in_list, dim=1)
+        b, c, h, w = upscaled_embedding.shape
+        masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
+        # Generate mask quality predictions
+        iou_pred = self.iou_prediction_head(iou_token_out)
+        return masks, iou_pred
diff --git a/sam/efficient_sam/efficient_sam_encoder.py b/sam/efficient_sam/efficient_sam_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..73fd7ac470b42df738e5e6bcbbcb60b4f30fb46e
--- /dev/null
+++ b/sam/efficient_sam/efficient_sam_encoder.py
@@ -0,0 +1,257 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+from typing import List, Optional, Tuple, Type
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class LayerNorm2d(nn.Module):
+    def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
+        super().__init__()
+        self.weight = nn.Parameter(torch.ones(num_channels))
+        self.bias = nn.Parameter(torch.zeros(num_channels))
+        self.eps = eps
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        u = x.mean(1, keepdim=True)
+        s = (x - u).pow(2).mean(1, keepdim=True)
+        x = (x - u) / torch.sqrt(s + self.eps)
+        x = self.weight[:, None, None] * x + self.bias[:, None, None]
+        return x
+
+
+class PatchEmbed(nn.Module):
+    """2D Image to Patch Embedding"""
+
+    def __init__(
+        self,
+        img_size,
+        patch_size,
+        in_chans,
+        embed_dim,
+    ):
+        super().__init__()
+        self.proj = nn.Conv2d(
+            in_chans,
+            embed_dim,
+            kernel_size=(patch_size, patch_size),
+            stride=(patch_size, patch_size),
+            bias=True,
+        )
+
+    def forward(self, x):
+        B, C, H, W = x.shape
+        x = self.proj(x)
+        return x
+
+
+class Attention(nn.Module):
+    def __init__(
+        self,
+        dim,
+        num_heads,
+        qkv_bias,
+        qk_scale=None,
+    ):
+        super().__init__()
+        self.num_heads = num_heads
+        head_dim = dim // num_heads
+        self.scale = qk_scale or head_dim**-0.5
+        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+        self.proj = nn.Linear(dim, dim)
+
+    def forward(self, x):
+        B, N, C = x.shape
+        qkv = (
+            self.qkv(x)
+            .reshape(B, N, 3, self.num_heads, C // self.num_heads)
+            .permute(2, 0, 3, 1, 4)
+        )
+        q, k, v = (
+            qkv[0],
+            qkv[1],
+            qkv[2],
+        )
+        attn = (q @ k.transpose(-2, -1)) * self.scale
+        attn = attn.softmax(dim=-1)
+        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+        x = self.proj(x)
+        return x
+
+
+class Mlp(nn.Module):
+    def __init__(
+        self,
+        in_features,
+        hidden_features=None,
+        out_features=None,
+        act_layer=nn.GELU,
+    ):
+        super().__init__()
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+        self.fc1 = nn.Linear(in_features, hidden_features)
+        self.act = act_layer()
+        self.fc2 = nn.Linear(hidden_features, out_features)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.act(x)
+        x = self.fc2(x)
+        return x
+
+
+class Block(nn.Module):
+    def __init__(
+        self,
+        dim,
+        num_heads,
+        mlp_ratio=4.0,
+        qkv_bias=False,
+        qk_scale=None,
+        act_layer=nn.GELU,
+    ):
+        super().__init__()
+        self.norm1 = nn.LayerNorm(dim, eps=1e-6)
+        self.attn = Attention(
+            dim,
+            num_heads=num_heads,
+            qkv_bias=qkv_bias,
+            qk_scale=qk_scale,
+        )
+        self.norm2 = nn.LayerNorm(dim, eps=1e-6)
+        mlp_hidden_dim = int(dim * mlp_ratio)
+        self.mlp = Mlp(
+            in_features=dim,
+            hidden_features=mlp_hidden_dim,
+            act_layer=act_layer,
+        )
+
+    def forward(self, x):
+        x = x + self.attn(self.norm1(x))
+        x = x + self.mlp(self.norm2(x))
+        return x
+
+
+@torch.jit.export
+def get_abs_pos(
+    abs_pos: torch.Tensor, has_cls_token: bool, hw: List[int]
+) -> torch.Tensor:
+    """
+    Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
+        dimension for the original embeddings.
+    Args:
+        abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
+        has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
+        hw (Tuple): size of input image tokens.
+
+    Returns:
+        Absolute positional embeddings after processing with shape (1, H, W, C)
+    """
+    h = hw[0]
+    w = hw[1]
+    if has_cls_token:
+        abs_pos = abs_pos[:, 1:]
+    xy_num = abs_pos.shape[1]
+    size = int(math.sqrt(xy_num))
+    assert size * size == xy_num
+
+    if size != h or size != w:
+        new_abs_pos = F.interpolate(
+            abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2),
+            size=(h, w),
+            mode="bicubic",
+            align_corners=False,
+        )
+        return new_abs_pos.permute(0, 2, 3, 1)
+    else:
+        return abs_pos.reshape(1, h, w, -1)
+
+
+# Image encoder for efficient SAM.
+class ImageEncoderViT(nn.Module):
+    def __init__(
+        self,
+        img_size: int,
+        patch_size: int,
+        in_chans: int,
+        patch_embed_dim: int,
+        normalization_type: str,
+        depth: int,
+        num_heads: int,
+        mlp_ratio: float,
+        neck_dims: List[int],
+        act_layer: Type[nn.Module],
+    ) -> None:
+        """
+        Args:
+            img_size (int): Input image size.
+            patch_size (int): Patch size.
+            in_chans (int): Number of input image channels.
+            patch_embed_dim (int): Patch embedding dimension.
+            depth (int): Depth of ViT.
+            num_heads (int): Number of attention heads in each ViT block.
+            mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+            act_layer (nn.Module): Activation layer.
+        """
+        super().__init__()
+
+        self.img_size = img_size
+        self.image_embedding_size = img_size // ((patch_size if patch_size > 0 else 1))
+        self.transformer_output_dim = ([patch_embed_dim] + neck_dims)[-1]
+        self.pretrain_use_cls_token = True
+        pretrain_img_size = 224
+        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, patch_embed_dim)
+        # Initialize absolute positional embedding with pretrain image size.
+        num_patches = (pretrain_img_size // patch_size) * (
+            pretrain_img_size // patch_size
+        )
+        num_positions = num_patches + 1
+        self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, patch_embed_dim))
+        self.blocks = nn.ModuleList()
+        for i in range(depth):
+            vit_block = Block(patch_embed_dim, num_heads, mlp_ratio, True)
+            self.blocks.append(vit_block)
+        self.neck = nn.Sequential(
+            nn.Conv2d(
+                patch_embed_dim,
+                neck_dims[0],
+                kernel_size=1,
+                bias=False,
+            ),
+            LayerNorm2d(neck_dims[0]),
+            nn.Conv2d(
+                neck_dims[0],
+                neck_dims[0],
+                kernel_size=3,
+                padding=1,
+                bias=False,
+            ),
+            LayerNorm2d(neck_dims[0]),
+        )
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        assert (
+            x.shape[2] == self.img_size and x.shape[3] == self.img_size
+        ), "input image size must match self.img_size"
+        x = self.patch_embed(x)
+        # B C H W -> B H W C
+        x = x.permute(0, 2, 3, 1)
+        x = x + get_abs_pos(
+            self.pos_embed, self.pretrain_use_cls_token, [x.shape[1], x.shape[2]]
+        )
+        num_patches = x.shape[1]
+        assert x.shape[2] == num_patches
+        x = x.reshape(x.shape[0], num_patches * num_patches, x.shape[3])
+        for blk in self.blocks:
+            x = blk(x)
+        x = x.reshape(x.shape[0], num_patches, num_patches, x.shape[2])
+        x = self.neck(x.permute(0, 3, 1, 2))
+        return x
diff --git a/sam/efficient_sam/mlp.py b/sam/efficient_sam/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3be8db49cbf6990ce55a467e7e62f60daf62c9d
--- /dev/null
+++ b/sam/efficient_sam/mlp.py
@@ -0,0 +1,29 @@
+from typing import Type
+
+from torch import nn
+
+
+# Lightly adapted from
+# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
+class MLPBlock(nn.Module):
+    def __init__(
+        self,
+        input_dim: int,
+        hidden_dim: int,
+        output_dim: int,
+        num_layers: int,
+        act: Type[nn.Module],
+    ) -> None:
+        super().__init__()
+        self.num_layers = num_layers
+        h = [hidden_dim] * (num_layers - 1)
+        self.layers = nn.ModuleList(
+            nn.Sequential(nn.Linear(n, k), act())
+            for n, k in zip([input_dim] + h, [hidden_dim] * num_layers)
+        )
+        self.fc = nn.Linear(hidden_dim, output_dim)
+
+    def forward(self, x):
+        for layer in self.layers:
+            x = layer(x)
+        return self.fc(x)
diff --git a/sam/efficient_sam/two_way_transformer.py b/sam/efficient_sam/two_way_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..881e76fd7efd07eeeef12999931fc2b74db406a9
--- /dev/null
+++ b/sam/efficient_sam/two_way_transformer.py
@@ -0,0 +1,264 @@
+import math
+from typing import Tuple, Type
+import torch
+from torch import nn, Tensor
+from .mlp import MLPBlock
+
+
+
+
+class TwoWayTransformer(nn.Module):
+    def __init__(
+        self,
+        depth: int,
+        embedding_dim: int,
+        num_heads: int,
+        mlp_dim: int,
+        activation: Type[nn.Module],
+        normalize_before_activation: bool,
+        attention_downsample_rate: int = 2,
+    ) -> None:
+        """
+        A transformer decoder that attends to an input image using
+        queries whose positional embedding is supplied.
+
+        Args:
+          depth (int): number of layers in the transformer
+          embedding_dim (int): the channel dimension for the input embeddings
+          num_heads (int): the number of heads for multihead attention. Must
+            divide embedding_dim
+          mlp_dim (int): the channel dimension internal to the MLP block
+          activation (nn.Module): the activation to use in the MLP block
+        """
+        super().__init__()
+        self.depth = depth
+        self.embedding_dim = embedding_dim
+        self.num_heads = num_heads
+        self.mlp_dim = mlp_dim
+        self.layers = nn.ModuleList()
+
+        for i in range(depth):
+            curr_layer = TwoWayAttentionBlock(
+                embedding_dim=embedding_dim,
+                num_heads=num_heads,
+                mlp_dim=mlp_dim,
+                activation=activation,
+                normalize_before_activation=normalize_before_activation,
+                attention_downsample_rate=attention_downsample_rate,
+                skip_first_layer_pe=(i == 0),
+            )
+            self.layers.append(curr_layer)
+
+        self.final_attn_token_to_image = AttentionForTwoWayAttentionBlock(
+            embedding_dim,
+            num_heads,
+            downsample_rate=attention_downsample_rate,
+        )
+        self.norm_final_attn = nn.LayerNorm(embedding_dim)
+
+    def forward(
+        self,
+        image_embedding: Tensor,
+        image_pe: Tensor,
+        point_embedding: Tensor,
+    ) -> Tuple[Tensor, Tensor]:
+        """
+        Args:
+          image_embedding (torch.Tensor): image to attend to. Should be shape
+            B x embedding_dim x h x w for any h and w.
+          image_pe (torch.Tensor): the positional encoding to add to the image. Must
+            have the same shape as image_embedding.
+          point_embedding (torch.Tensor): the embedding to add to the query points.
+            Must have shape B x N_points x embedding_dim for any N_points.
+
+        Returns:
+          torch.Tensor: the processed point_embedding
+          torch.Tensor: the processed image_embedding
+        """
+
+        # BxCxHxW -> BxHWxC == B x N_image_tokens x C
+        bs, c, h, w = image_embedding.shape
+        image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
+        image_pe = image_pe.flatten(2).permute(0, 2, 1)
+
+        # Prepare queries
+        queries = point_embedding
+        keys = image_embedding
+
+        # Apply transformer blocks and final layernorm
+        for idx, layer in enumerate(self.layers):
+            queries, keys = layer(
+                queries=queries,
+                keys=keys,
+                query_pe=point_embedding,
+                key_pe=image_pe,
+            )
+
+        # Apply the final attention layer from the points to the image
+        q = queries + point_embedding
+        k = keys + image_pe
+        attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
+        queries = queries + attn_out
+        queries = self.norm_final_attn(queries)
+        return queries, keys
+
+
+class TwoWayAttentionBlock(nn.Module):
+    def __init__(
+        self,
+        embedding_dim: int,
+        num_heads: int,
+        mlp_dim: int,
+        activation: Type[nn.Module],
+        normalize_before_activation: bool,
+        attention_downsample_rate: int = 2,
+        skip_first_layer_pe: bool = False,
+    ) -> None:
+        """
+        A transformer block with four layers: (1) self-attention of sparse
+        inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
+        block on sparse inputs, and (4) cross attention of dense inputs to sparse
+        inputs.
+
+        Arguments:
+          embedding_dim (int): the channel dimension of the embeddings
+          num_heads (int): the number of heads in the attention layers
+          mlp_dim (int): the hidden dimension of the mlp block
+          activation (nn.Module): the activation of the mlp block
+          skip_first_layer_pe (bool): skip the PE on the first layer
+        """
+        super().__init__()
+        self.self_attn = AttentionForTwoWayAttentionBlock(embedding_dim, num_heads)
+        self.norm1 = nn.LayerNorm(embedding_dim)
+
+        self.cross_attn_token_to_image = AttentionForTwoWayAttentionBlock(
+            embedding_dim,
+            num_heads,
+            downsample_rate=attention_downsample_rate,
+        )
+        self.norm2 = nn.LayerNorm(embedding_dim)
+
+        self.mlp = MLPBlock(
+            embedding_dim,
+            mlp_dim,
+            embedding_dim,
+            1,
+            activation,
+        )
+
+        self.norm3 = nn.LayerNorm(embedding_dim)
+
+        self.norm4 = nn.LayerNorm(embedding_dim)
+        self.cross_attn_image_to_token = AttentionForTwoWayAttentionBlock(
+            embedding_dim,
+            num_heads,
+            downsample_rate=attention_downsample_rate,
+        )
+
+        self.skip_first_layer_pe = skip_first_layer_pe
+
+    def forward(
+        self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
+    ) -> Tuple[Tensor, Tensor]:
+        # Self attention block
+        if not self.skip_first_layer_pe:
+            queries = queries + query_pe
+        attn_out = self.self_attn(q=queries, k=queries, v=queries)
+        queries = queries + attn_out
+        queries = self.norm1(queries)
+
+        # Cross attention block, tokens attending to image embedding
+        q = queries + query_pe
+        k = keys + key_pe
+        attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
+        queries = queries + attn_out
+        queries = self.norm2(queries)
+
+        # MLP block
+        mlp_out = self.mlp(queries)
+        queries = queries + mlp_out
+        queries = self.norm3(queries)
+
+        # Cross attention block, image embedding attending to tokens
+        q = queries + query_pe
+        k = keys + key_pe
+        attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
+        keys = keys + attn_out
+        keys = self.norm4(keys)
+
+        return queries, keys
+
+
+class AttentionForTwoWayAttentionBlock(nn.Module):
+    """
+    An attention layer that allows for downscaling the size of the embedding
+    after projection to queries, keys, and values.
+    """
+
+    def __init__(
+        self,
+        embedding_dim: int,
+        num_heads: int,
+        downsample_rate: int = 1,
+    ) -> None:
+        super().__init__()
+        self.embedding_dim = embedding_dim
+        self.internal_dim = embedding_dim // downsample_rate
+        self.num_heads = num_heads
+        assert (
+            self.internal_dim % num_heads == 0
+        ), "num_heads must divide embedding_dim."
+
+        self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
+        self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
+        self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
+        self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
+        self._reset_parameters()
+
+    def _reset_parameters(self) -> None:
+        # The fan_out is incorrect, but matches pytorch's initialization
+        # for which qkv is a single 3*embedding_dim x embedding_dim matrix
+        fan_in = self.embedding_dim
+        fan_out = 3 * self.internal_dim
+        # Xavier uniform with our custom fan_out
+        bnd = math.sqrt(6 / (fan_in + fan_out))
+        nn.init.uniform_(self.q_proj.weight, -bnd, bnd)
+        nn.init.uniform_(self.k_proj.weight, -bnd, bnd)
+        nn.init.uniform_(self.v_proj.weight, -bnd, bnd)
+        # out_proj.weight is left with default initialization, like pytorch attention
+        nn.init.zeros_(self.q_proj.bias)
+        nn.init.zeros_(self.k_proj.bias)
+        nn.init.zeros_(self.v_proj.bias)
+        nn.init.zeros_(self.out_proj.bias)
+
+    def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
+        b, n, c = x.shape
+        x = x.reshape(b, n, num_heads, c // num_heads)
+        return x.transpose(1, 2)  # B x N_heads x N_tokens x C_per_head
+
+    def _recombine_heads(self, x: Tensor) -> Tensor:
+        b, n_heads, n_tokens, c_per_head = x.shape
+        x = x.transpose(1, 2)
+        return x.reshape(b, n_tokens, n_heads * c_per_head)  # B x N_tokens x C
+
+    def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
+        # Input projections
+        q = self.q_proj(q)
+        k = self.k_proj(k)
+        v = self.v_proj(v)
+
+        # Separate into heads
+        q = self._separate_heads(q, self.num_heads)
+        k = self._separate_heads(k, self.num_heads)
+        v = self._separate_heads(v, self.num_heads)
+
+        # Attention
+        _, _, _, c_per_head = q.shape
+        attn = q @ k.permute(0, 1, 3, 2)  # B x N_heads x N_tokens x N_tokens
+        attn = attn / math.sqrt(c_per_head)
+        attn = torch.softmax(attn, dim=-1)
+        # Get output
+        out = attn @ v
+        out = self._recombine_heads(out)
+        out = self.out_proj(out)
+        return out
diff --git a/src/demo/demo.py b/src/demo/demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..166b2b03f28b71b0b3ea8e81e808ffaa9bc6f5da
--- /dev/null
+++ b/src/demo/demo.py
@@ -0,0 +1,738 @@
+import gradio as gr
+import numpy as np
+from src.demo.utils import get_point, store_img, get_point_move, store_img_move, clear_points, upload_image_move, segment_with_points, segment_with_points_paste, fun_clear, paste_with_mask_and_offset
+import spaces
+
+examples_remove = [
+    [
+        "examples/remove/02_ring/0.jpg", # original image 1
+        "examples/remove/02_ring/mask0.jpg", # mask 1
+        "examples/remove/02_ring/0.jpg", # original image 2
+        "examples/remove/02_ring/mask1.jpg", #mask 2
+        "examples/remove/02_ring/0.jpg", #Original image 3
+        "examples/remove/02_ring/mask2.jpg", #mask 3
+        None, #Original image 4
+        None, # refine mask
+    ], # 02
+    [
+        "examples/remove/01_moto/0.jpg", # original image 1
+        "examples/remove/01_moto/mask0.jpg", # mask 1
+        "examples/remove/01_moto/0.jpg", # original image 2
+        None, #mask 2
+        "examples/remove/01_moto/0.jpg", #Original image 3
+        None, #mask 3
+        "examples/remove/01_moto/0.jpg", #Original image 4
+        "examples/remove/01_moto/mask1.jpg", # refine mask
+    ], # 01
+    [
+        "examples/remove/03_ball/0.jpg", # original image 1
+        "examples/remove/03_ball/mask0.jpg", # mask 1
+        "examples/remove/03_ball/0.jpg", # original image 2
+        "examples/remove/03_ball/mask1.jpg", #mask 2
+        "examples/remove/03_ball/0.jpg", #Original image 3
+        None, #mask 3
+        None, #Original image 4
+        None, # refine mask
+    ], # 03
+    [
+        "examples/remove/04_pikachu/0.jpg", # original image 1
+        "examples/remove/04_pikachu/mask0.jpg", # mask 1
+        "examples/remove/04_pikachu/0.jpg", # original image 2
+        "examples/remove/04_pikachu/mask1.jpg", #mask 2
+        "examples/remove/04_pikachu/0.jpg", #Original image 3
+        "examples/remove/04_pikachu/mask2.jpg", #mask 3
+        None, #Original image 4
+        None, # refine mask
+    ], # 04
+    [
+        "examples/remove/05_betty/0.jpg", # original image 1
+        "examples/remove/05_betty/mask0.jpg", # mask 1
+        None, # original image 2
+        None, #mask 2
+        None, #Original image 3
+        None, #mask 3
+        None, #Original image 4
+        None, # refine mask
+    ], # 05
+]
+examples_zoom = [
+    ["examples/zoom/01.jpg"],
+    ["examples/zoom/02.jpg"],
+    ["examples/zoom/03.jpg"],
+    ["examples/zoom/04.jpg"],
+    ["examples/zoom/05.jpg"],
+    ["examples/zoom/06.jpg"],
+    ["examples/zoom/07.jpg"],
+]
+examples_pan = [
+    ["examples/pan/01.jpg"],
+    ["examples/pan/02.jpg"],
+    ["examples/pan/03.jpg"],
+    ["examples/pan/04.jpg"],
+    ["examples/pan/05.jpg"],
+    ["examples/pan/06.jpg"],
+]
+
+examples_moving = [
+    [
+    "examples/layer/01_horse/00.jpg", #bg
+    "examples/layer/01_horse/mask0.jpg", #bg_mask
+    0, 0, 1.2, "None", "left/right",  #l1_dx, l1_dy, l1_resize
+    ],
+    [
+    "examples/moving/01_ball/0.jpg", #bg
+    "examples/moving/01_ball/mask0.jpg", #bg_mask
+    -0.2, -0.1, 0.8, "None", "None",  #l1_dx, l1_dy, l1_resize
+    ],
+    [
+    "examples/moving/02_bell/0.jpg", #bg
+    "examples/moving/02_bell/mask0.jpg", #bg_mask
+    0, 0, 0.75, "None", "None",  #l1_dx, l1_dy, l1_resize
+    ],
+]
+examples_layer = [
+    [
+    "examples/layer/01_horse/00.jpg", #bg
+    "examples/layer/01_horse/mask0.jpg", #bg_mask
+
+    "examples/layer/01_horse/00.jpg", #l1
+    "examples/layer/01_horse/mask0.jpg", #l1_mask
+    -0.2, 0, 1, "None", "None", #l1_dx, l1_dy, l1_resize
+
+    "examples/layer/01_horse/00.jpg", #l2
+    "examples/layer/01_horse/mask0.jpg", #l2_mask
+    0.2, 0, 1, "None", "None", #l2_dx, l2_dy, l2_resize
+
+    None, #l3
+    None, #l3_mask
+    0, 0, 1, "None", "None", #l3_dx, l3_dy, l3_resize
+
+    "examples/layer/01_horse/00.jpg", #bg_ori
+    "examples/layer/01_horse/00.jpg", #l1_ori
+    "examples/layer/01_horse/00.jpg", #l2_ori
+    None, "None", "None", #l3_ori
+    ],
+
+    [
+    "examples/layer/02_baby/00.jpg", #bg
+    "examples/layer/02_baby/mask0.jpg", #bg_mask
+
+    "examples/layer/02_baby/00.jpg", #l1
+    "examples/layer/02_baby/mask1.jpg", #l1_mask
+    -0.35, 0, 1,"left/right", "None", #l1_dx, l1_dy, l1_resize
+
+    "examples/layer/02_baby/00.jpg", #l2
+    "examples/layer/02_baby/mask2.jpg", #l2_mask
+    0.35, 0, 1, "left/right", "None", #l2_dx, l2_dy, l2_resize
+
+    None, #l3
+    None, #l3_mask
+    0, 0, 1,"None", "None", #l3_dx, l3_dy, l3_resize
+    ],
+
+    [
+    "examples/layer/03_text/00.jpg", #bg
+    "examples/layer/03_text/mask0.jpg", #bg_mask
+
+    "examples/layer/03_text/01.jpg", #l1
+    "examples/layer/03_text/mask1.jpg", #l1_mask
+    0.1, -0.1, 0.5, "None", "None",#l1_dx, l1_dy, l1_resize
+
+    None, #l2
+    None, #l2_mask
+    0, 0, 1, "None", "None",#l2_dx, l2_dy, l2_resize
+
+    None, #l3
+    None, #l3_mask
+    0, 0, 1,"None", "None", #l3_dx, l3_dy, l3_resize
+    ],
+    [
+    "examples/layer/04_cross/0.jpg", #bg
+    "examples/layer/04_cross/mask0.jpg", #bg_mask
+
+    "examples/layer/04_cross/2.jpg", #l1
+    "examples/layer/04_cross/mask2.jpg", #l1_mask
+    -0.1, -0.25, 0.5, "None", "None",#l1_dx, l1_dy, l1_resize
+
+    "examples/layer/04_cross/1.jpg", #l2
+    "examples/layer/04_cross/mask1.jpg", #l2_mask
+    -0.1, -0.15, 0.7, "None", "None",#l2_dx, l2_dy, l2_resize
+
+    "examples/layer/04_cross/3.jpg", #l3
+    "examples/layer/04_cross/mask3.jpg", #l3_mask
+    -0.1, -0.55, 0.5, "None", "None",#l3_dx, l3_dy, l3_resize
+    ],
+]
+examples_mask_box = [
+    [
+        "examples/mask_box/image1.jpg", # original image 1
+        "examples/mask_box/image2.jpg", # original image 1
+        "examples/mask_box/mask01.jpg", # original image 1
+        "examples/mask_box/mask02.jpg", # original image 1
+        "examples/mask_box/mask00.jpg", # original image 1
+    ]
+]
+
+# 01
+def create_demo_remove(runner=None):
+    DESCRIPTION = """
+    # Object Removal
+
+    ## Usage:
+
+    - Upload a sources image, and then draw a box to generate the mask corresponding to the selecting object.
+    - You can choose to mask more than one object by using Mask2 and Mask3.
+    - If you encounter artifacts, try to sketch the regions that caused the artifacts.
+    - You can refer to the first motorcycle example to understand the usage of the <span style="color:red;">Refined Mask</span>.
+    - Please <span style="color:blue;">clear<span> the output before running a new example!
+    - For more irregular composition masks, refer to the last page: Mask Preparation.
+"""
+    
+    with gr.Blocks() as demo:
+        original_image = gr.State(value=None) 
+        img_with_mask = gr.State(value=None) 
+
+        selected_points = gr.State([])
+        global_points = gr.State([])
+        global_point_label = gr.State([])
+
+        gr.Markdown(DESCRIPTION)
+
+        with gr.Row():
+            with gr.Column():
+                with gr.Group():
+                    gr.Markdown("# INPUT")
+                    # mask 0 
+                    gr.Markdown("## Select two points for Mask 1:")
+                    gr.Markdown("the top left and the bottom right")
+                    original_image_1 = gr.Image(sources='upload', label="Original image (Mask 1)", interactive=True, type="numpy")
+                    # mask 1
+                    gr.Markdown("## Option: Select two points for Mask 2")
+                    gr.Markdown("the top left and the bottom right")
+                    original_image_2 = gr.Image(sources='upload', label="Original (Mask 2)", interactive=True, type="numpy")
+                    # mask 2
+                    gr.Markdown("## Option: Select two points for Mask 3")
+                    gr.Markdown("the top left and the bottom right")
+                    original_image_3 = gr.Image(label="Original image (Mask 3)", interactive=True, type="numpy")
+
+                    gr.Markdown("## Option: Mask regions caused artifacts")
+                    gr.Markdown("the top left and the bottom right")
+                    original_image_4 = gr.Image(label="Original image (Refine Mask)", interactive=True, type="numpy") 
+                    with gr.Row():
+                        run_button = gr.Button("Edit")
+                        clear_button = gr.Button("Clear")
+
+       
+            with gr.Column():
+                with gr.Group():
+                    gr.Markdown("# Mask")
+
+                    gr.Markdown("## Removal Mask 1")
+                    mask_1 = gr.Image(sources='upload', label="Removal Mask 1", interactive=True, type="numpy")
+                    gr.Markdown("## Option: Removal Mask 2")
+                    mask_2 = gr.Image(sources='upload', label="Removal Mask 2", interactive=True, type="numpy")
+                    gr.Markdown("## Option: Removal Mask 3")
+                    mask_3 = gr.Image(sources='upload', label="Removal Mask 3", interactive=True, type="numpy")
+
+                    gr.Markdown("## Option: Refine Mask to avoid artifacts")
+                    refine_mask = gr.Image(sources='upload', label="Refine Mask", interactive=True, type="numpy")                    
+            
+            with gr.Column():
+                with gr.Group():
+                    gr.Markdown("# OUTPUT")
+                    gr.Markdown("## Results")
+                    output = gr.Gallery(columns=1, height='auto')
+
+
+            original_image_1.select(
+                segment_with_points, 
+                inputs=[original_image_1, original_image, global_points, global_point_label], 
+                outputs=[original_image_1, original_image, mask_1, global_points, global_point_label]
+            )
+            original_image_2.select(
+                segment_with_points, 
+                inputs=[original_image_2, original_image, global_points, global_point_label], 
+                outputs=[original_image_2, original_image, mask_2, global_points, global_point_label]
+            )
+            original_image_3.select(
+                segment_with_points, 
+                inputs=[original_image_3, original_image, global_points, global_point_label], 
+                outputs=[original_image_3, original_image, mask_3, global_points, global_point_label]
+            )
+            original_image_4.select(
+                segment_with_points, 
+                inputs=[original_image_4, original_image, global_points, global_point_label], 
+                outputs=[original_image_4, original_image, refine_mask, global_points, global_point_label]
+            )
+
+        with gr.Column():
+            gr.Markdown("Try some of the examples below ⬇️")
+            gr.Examples(
+                examples=examples_remove,
+                inputs=[
+                original_image_1, mask_1, 
+                original_image_2, mask_2,
+                original_image_3, mask_3, 
+                original_image_4, refine_mask]
+            )
+        run_button.click(fn=runner, inputs=[original_image, mask_1, mask_2, mask_3, refine_mask,
+        original_image_1, original_image_2, original_image_3], outputs=[output])
+        clear_button.click(
+            fn=fun_clear, 
+            inputs=[original_image, img_with_mask, selected_points, global_points, global_point_label, original_image_1, original_image_2, original_image_3, original_image_4, mask_1, mask_2, mask_3, refine_mask], 
+            outputs=[original_image, img_with_mask, selected_points, global_points, global_point_label, original_image_1, original_image_2, original_image_3, original_image_4, mask_1, mask_2, mask_3, refine_mask]
+        )
+    return demo
+
+
+# 02:
+def create_demo_zooming(runner=None):
+    DESCRIPTION = """
+    # Zooming Out
+
+    ## Usage:
+
+    - Upload a sources image and choose the width and height zooming scale to zoom out.
+    - The illustration of image adjustment and mask preparation is shown in the second column.
+    - We recommend setting the zooming scale between <span style="color:red;"> 0.75 <span> and <span style="color:red;"> 1 <span> for optimal results.
+    - Please <span style="color:blue;">clear<span> the output before running a new example!
+    """
+    
+    with gr.Blocks() as demo:
+        
+        gr.Markdown(DESCRIPTION)
+
+        with gr.Row():
+            with gr.Column():
+                with gr.Group():
+                    gr.Markdown("# INPUT")
+                    # mask 0
+                    gr.Markdown("## Original Image")
+                    original_image = gr.Image(sources='upload', interactive=True, type="numpy")
+
+
+                    gr.Markdown("## Scale:") 
+                    width_scale= gr.Slider(
+                                label="Width scale",
+                                minimum=0,
+                                maximum=1,
+                                step=0.05,
+                                value=0.9,
+                                interactive=True)
+                    height_scale= gr.Slider( 
+                                label="Height scale",
+                                minimum=0,
+                                maximum=1,
+                                step=0.05,
+                                value=0.9,
+                                interactive=True)              
+                    with gr.Row():
+                        run_button = gr.Button("Edit")
+                        clear_button = gr.Button("Clear")
+
+            with gr.Column():
+                with gr.Group():
+                    gr.Markdown("# Preprocess")
+                    gr.Markdown("## Image Adjustment:")
+                    new_image = gr.Gallery(columns=1, height='auto')
+                    gr.Markdown("## Mask Adjustment:")
+                    new_mask = gr.Gallery(columns=1, height='auto')
+
+            with gr.Column():
+                with gr.Group():
+                    gr.Markdown("# OUTPUT")
+                    gr.Markdown("## Results")
+                    output = gr.Gallery(columns=1, height='auto')    
+
+        with gr.Column():
+            gr.Markdown("Try some of the examples below ⬇️")
+            gr.Examples(
+                examples=examples_zoom,
+                inputs=[original_image]
+            )
+        run_button.click(fn=runner, inputs=[original_image, width_scale, height_scale], outputs=[output, new_image, new_mask])
+        clear_button.click(fn=fun_clear, inputs=[original_image, width_scale, height_scale, output, new_image, new_mask], 
+        outputs=[original_image, width_scale, height_scale, output, new_image, new_mask])
+    return demo
+# 03
+
+def create_demo_panning(runner=None):
+    DESCRIPTION = """
+    # Camera Panning
+
+    ## Usage:
+
+    - Upload a sources image and choose the width and height panning scale.
+    - The illustration of image adjustment and mask preparation is shown in the second column.
+    - We recommend setting the panning scale between<span style="color:red;"> 0 <span> and <span style="color:red;">0.25 <span> for optimal results.
+    - Please <span style="color:blue;">clear<span> the output before running a new example!
+    """
+
+    with gr.Blocks() as demo:
+        gr.Markdown(DESCRIPTION)
+
+        with gr.Row():
+            with gr.Column():
+                with gr.Group():
+                    gr.Markdown("# INPUT")
+                    # mask 0
+                    gr.Markdown("## Original Image")
+                    original_image = gr.Image(sources='upload', interactive=True, type="numpy")
+                    w_direction = gr.Radio(["left", "right"], value="left", label="Width Direction")
+                    w_scale = gr.Slider(
+                                label="Width scale",
+                                minimum=0,
+                                maximum=1,
+                                step=0.05,
+                                value=0,
+                                interactive=True)
+                    
+                    h_direction = gr.Radio(["up", "down"], value="up", label="Height Direction")
+                    h_scale = gr.Slider(
+                                label="Height scale",
+                                minimum=0,
+                                maximum=1,
+                                step=0.05,
+                                value=0,
+                                interactive=True)
+                    with gr.Row():
+                        run_button = gr.Button("Edit")
+                        clear_button = gr.Button("Clear")
+
+            with gr.Column():
+                with gr.Group():
+                    gr.Markdown("# Preprocess")
+                    gr.Markdown("## Image Adjustment:")
+                    new_image = gr.Gallery(columns=1, height='auto')
+                    gr.Markdown("## Mask Adjustment:")
+                    new_mask = gr.Gallery(columns=1, height='auto')
+
+            with gr.Column():
+                with gr.Group():
+                    gr.Markdown("# OUTPUT")
+                    gr.Markdown("## Results")
+                    output = gr.Gallery(columns=1, height='auto')     
+
+        with gr.Column():
+            gr.Markdown("Try some of the examples below ⬇️")
+            gr.Examples(
+                examples=examples_pan,
+                inputs=[original_image]
+            )
+        run_button.click(fn=runner, inputs=[original_image, w_direction, w_scale, h_direction, h_scale], outputs=[output, new_image, new_mask])
+        clear_button.click(fn=fun_clear, inputs=[original_image, w_direction, w_scale, h_direction, h_scale, new_image, new_mask, output], 
+        outputs=[original_image, w_direction, w_scale, h_direction, h_scale, new_image, new_mask, output])
+    return demo
+# 04:
+def create_position_size(label=None):
+    image = gr.Image(sources='upload', label=label, interactive=True, type="numpy")
+    with gr.Row():
+        dx = gr.Slider(
+                            label="Left-Right",
+                            minimum=-1,
+                            maximum=1,
+                            step=0.05,
+                            value=0,
+                            interactive=True
+                        )
+        dy = gr.Slider(
+                            label="Down-Up",
+                            minimum=-1,
+                            maximum=1,
+                            step=0.05,
+                            value=0,
+                            interactive=True
+                        )
+    resize_scale = gr.Slider(
+                        label="Resize",
+                        minimum=0,
+                        maximum=2,
+                        step=0.05,
+                        value=1,
+                        interactive=True
+                    )
+    with gr.Row():
+        w_flip = gr.Radio(["left/right","None"], value="None", label="Horizontal Flip")
+        h_flip = gr.Radio(["down/up", "None"], value="None", label="Vertical Flip")
+    return image, dx, dy, resize_scale, w_flip, h_flip
+# 05:
+def create_demo_layer(runner=None):
+    DESCRIPTION = """
+    # 🚩 Multi-Layered selecting 🚩
+
+    ## Usage:
+
+    - Notice that all operations can be achieved using the multi-layered selecting mode.
+    - In particular, you can accomplish multi-object selecting such as adding objects and cross-image composition on this page.
+    - Try some interesting examples given below to understand the usage.
+    - Please <span style="color:blue;">clear<span> the output before running a new example!
+    - We strongly recommend you to read the [original paper](https://arxiv.org/abs/2403.14487) to further explore more uses of multi-layered selecting.
+    """
+    global_points = gr.State([])
+    global_point_label = gr.State([])
+    bg_ori = gr.State(value=None)
+    l1_ori = gr.State(value=None)
+    l2_ori = gr.State(value=None)
+    l3_ori = gr.State(value=None)
+    with gr.Blocks() as demo:
+        gr.Markdown(DESCRIPTION)
+        with gr.Row():
+            with gr.Column():
+                with gr.Group():
+                    gr.Markdown("# INPUT")
+                    gr.Markdown("## Background Image")
+                    bg_img = gr.Image(sources='upload', label="Background", interactive=True, type="numpy")
+                    gr.Markdown("## Layer-1")
+                    l1_img, l1_dx, l1_dy, l1_resize, l1_w_flip, l1_h_flip = create_position_size(label="Layer-1")
+                    gr.Markdown("## Layer-2")
+                    l2_img, l2_dx, l2_dy, l2_resize, l2_w_flip, l2_h_flip = create_position_size(label="Layer-2")
+                    gr.Markdown("## Layer-3")
+                    l3_img, l3_dx, l3_dy, l3_resize, l3_w_flip, l3_h_flip = create_position_size(label="Layer-3")
+                    with gr.Row():
+                        run_button = gr.Button("Edit")
+                        clear_button = gr.Button("Clear")
+
+            with gr.Column():
+                with gr.Group():
+                    gr.Markdown("# Mask")
+                    gr.Markdown("## Background Mask for Removal:")
+                    bg_mask =  gr.Image(sources='upload', label="BG Mask", interactive=True, type="numpy")
+                    gr.Markdown("## Layer-1 Mask:")
+                    l1_mask = gr.Image(sources='upload', label="L1 Mask", interactive=True, type="numpy")
+                    gr.Markdown("## Layer-2 Mask:")
+                    l2_mask = gr.Image(sources='upload', label="L2 Mask", interactive=True, type="numpy")
+                    gr.Markdown("## Layer-3 Mask:")
+                    l3_mask = gr.Image(sources='upload', label="L3 Mask", interactive=True, type="numpy")
+
+            with gr.Column():
+                with gr.Group():
+                    gr.Markdown("# OUTPUT")
+                    gr.Markdown("## Results")
+                    output = gr.Gallery(columns=1, height='auto')    
+
+        with gr.Column():
+            gr.Markdown("Try some of the examples below ⬇️")            
+            gr.Examples(
+                examples=examples_layer,
+                inputs=[
+                bg_img, bg_mask,
+                l1_img, l1_mask, l1_dx, l1_dy, l1_resize, l1_w_flip, l1_h_flip,
+                l2_img, l2_mask, l2_dx, l2_dy, l2_resize, l2_w_flip, l2_h_flip,
+                l3_img, l3_mask, l3_dx, l3_dy, l3_resize, l3_w_flip, l3_h_flip,
+                ]
+            )
+        bg_img.select(
+                segment_with_points, 
+                inputs=[bg_img, bg_ori, global_points, global_point_label], 
+                outputs=[bg_img, bg_ori, bg_mask, global_points, global_point_label]
+        )
+        l1_img.select(
+                segment_with_points, 
+                inputs=[l1_img, l1_ori, global_points, global_point_label], 
+                outputs=[l1_img, l1_ori, l1_mask, global_points, global_point_label]
+        )
+        l2_img.select(
+                segment_with_points, 
+                inputs=[l2_img, l2_ori, global_points, global_point_label], 
+                outputs=[l2_img, l2_ori, l2_mask, global_points, global_point_label]
+        )
+        l3_img.select(
+                segment_with_points, 
+                inputs=[l3_img, l3_ori, global_points, global_point_label], 
+                outputs=[l3_img, l3_ori, l3_mask, global_points, global_point_label]
+        )
+
+        run_button.click(fn=runner, inputs=[
+        bg_img, 
+        l1_img, l1_dx, l1_dy, l1_resize, l1_w_flip, l1_h_flip, 
+        l2_img, l2_dx, l2_dy, l2_resize, l2_w_flip, l2_h_flip, 
+        l3_img, l3_dx, l3_dy, l3_resize, l3_w_flip, l3_h_flip,
+        bg_mask, l1_mask, l2_mask, l3_mask,
+        bg_ori, l1_ori, l2_ori, l3_ori
+        ], outputs=[output])
+
+        clear_button.click(fn=fun_clear, 
+        inputs=[bg_img, bg_ori, 
+        l1_img, l1_ori, l1_dx, l1_dy, l1_resize, l1_w_flip, l1_h_flip,
+        l2_img, l2_ori, l2_dx, l2_dy, l2_resize, l2_w_flip, l2_h_flip,
+        l3_img, l3_ori, l3_dx, l3_dy, l3_resize, l3_w_flip, l3_h_flip,
+        bg_mask, l1_mask, l2_mask, l3_mask,
+        global_points, global_point_label, output],
+        outputs=[bg_img, bg_ori, 
+        l1_img, l1_ori, l1_dx, l1_dy, l1_resize, l1_w_flip, l1_h_flip,
+        l2_img, l2_ori, l2_dx, l2_dy, l2_resize, l2_w_flip, l2_h_flip,
+        l3_img, l3_ori, l3_dx, l3_dy, l3_resize, l3_w_flip, l3_h_flip,
+        bg_mask, l1_mask, l2_mask, l3_mask,
+        global_points, global_point_label, output],            
+        )
+    return demo
+
+# 06:
+def create_demo_mask_box(runner=None):
+    DESCRIPTION = """
+    # 🔧 Mask Preparation 
+    ## Usage:
+    - This page is a tool for you to combine more than one mask.
+    - You can draw a box to mask an object to obtain Masks 1-4.
+    - The merged mask is the union of Masks 1-4.
+    - Please <span style="color:blue;">clear<span> the output before running a new example!
+    """
+    
+    with gr.Blocks() as demo:
+        original_image = gr.State(value=None) 
+        img_with_mask = gr.State(value=None)
+        selected_points = gr.State([])
+        global_points = gr.State([])
+        global_point_label = gr.State([])
+        gr.Markdown(DESCRIPTION)
+        with gr.Row():
+            with gr.Column():
+                with gr.Group():
+                    gr.Markdown("# INPUT")
+                    gr.Markdown("## 1. Select two points for Mask 1")
+                    gr.Markdown("the top left and the bottom right")
+                    img_draw_box_1 = gr.Image(sources='upload', label="Original Image", interactive=True, type="numpy")
+
+                    gr.Markdown("## 2. Select two points for Mask 2")
+                    gr.Markdown("the top left and the bottom right")
+                    img_draw_box_2 = gr.Image(sources='upload', label="Original Image", interactive=True, type="numpy")
+
+                    gr.Markdown("## 3. Select two points for Mask 3")
+                    gr.Markdown("the top left and the bottom right")
+                    img_draw_box_3 = gr.Image(sources='upload', label="Original Image", interactive=True, type="numpy")
+
+                    gr.Markdown("## 4. Select two points for Mask 4")
+                    gr.Markdown("the top left and the bottom right")
+                    img_draw_box_4 = gr.Image(label="Original Image", interactive=True, type="numpy")
+
+                    with gr.Row():
+                        run_button = gr.Button("Edit")
+                        clear_button = gr.Button("Clear")
+
+            with gr.Column():
+                with gr.Group():
+                    gr.Markdown("# Mask")
+                    gr.Markdown("## Mask 1")
+                    mask_1 = gr.Image(sources='upload', label="Mask", interactive=True, type="numpy")
+                    gr.Markdown("## Mask 2")
+                    mask_2 = gr.Image(sources='upload', label="Mask", interactive=True, type="numpy")
+                    gr.Markdown("## Mask 3")
+                    mask_3 = gr.Image(sources='upload', label="Mask", interactive=True, type="numpy")
+                    gr.Markdown("## Mask 4")
+                    mask_4 = gr.Image(sources='upload', label="Mask", interactive=True, type="numpy")
+
+            with gr.Column():
+                with gr.Group():
+                    gr.Markdown("# Merged Mask")
+                    merged_mask = gr.Image(sources='upload', label="Mask of object", interactive=True, type="numpy")  
+
+        with gr.Column():
+            gr.Markdown("Please see the example below. ⬇️")
+            gr.Examples(
+                examples=examples_mask_box,
+                inputs=[
+                    img_draw_box_1, img_draw_box_2, mask_1, mask_2, merged_mask
+                ]
+            )
+        img_draw_box_1.select(
+            segment_with_points, 
+            inputs=[img_draw_box_1, original_image, global_points, global_point_label], 
+            outputs=[img_draw_box_1, original_image, mask_1, global_points, global_point_label]
+        )
+        img_draw_box_2.select(
+            segment_with_points, 
+            inputs=[img_draw_box_2, original_image, global_points, global_point_label], 
+            outputs=[img_draw_box_2, original_image, mask_2, global_points, global_point_label]
+        )
+        img_draw_box_3.select(
+            segment_with_points, 
+            inputs=[img_draw_box_3, original_image, global_points, global_point_label], 
+            outputs=[img_draw_box_3, original_image, mask_3, global_points, global_point_label]
+        )
+        img_draw_box_4.select(
+            segment_with_points, 
+            inputs=[img_draw_box_4, original_image, global_points, global_point_label], 
+            outputs=[img_draw_box_4, original_image, mask_4, global_points, global_point_label]
+        )
+
+        run_button.click(fn=runner, inputs=[mask_1, mask_2, mask_3, mask_4], outputs=[merged_mask])
+        clear_button.click(
+        fn=fun_clear, 
+        inputs=[original_image, img_with_mask, selected_points, global_points, global_point_label, img_draw_box_1, img_draw_box_2, img_draw_box_3, img_draw_box_4, mask_1, mask_2, mask_3, mask_4], 
+        outputs=[original_image, img_with_mask, selected_points, global_points, global_point_label, img_draw_box_1, img_draw_box_2, img_draw_box_3, img_draw_box_4, mask_1, mask_2, mask_3, mask_4, merged_mask]
+    )
+    return demo
+
+def create_demo_moving(runner=None):
+    DESCRIPTION = """
+    # Object Moving, Resizing, and Flipping
+
+    ## Usage:
+    - Upload an image and draw a box around the object to manipulate.
+    - Move the object vertically or horizontally using sliders or by drawing an arrow.
+    - You can select options for moving and flipping the object from a menu.
+    - Please <span style="color:blue;">clear<span> the output before running a new example!
+    """
+
+    selected_points = gr.State([])
+    global_points = gr.State([])
+    global_point_label = gr.State([])
+    bg_ori = gr.State(value=None)
+    l1_ori = gr.State(value=None)
+    with gr.Blocks() as demo:
+        gr.Markdown(DESCRIPTION)
+        with gr.Row():
+            with gr.Column():
+                with gr.Group():
+                    gr.Markdown("# INPUT")
+                    gr.Markdown("## Draw box to mask target object")
+                    bg_img = gr.Image(sources='upload', label="Background", interactive=True, type="numpy")
+                    gr.Markdown("## Draw arrow to describe the movement")
+                    l1_img, l1_dx, l1_dy, l1_resize, l1_w_flip, l1_h_flip = create_position_size(label="Layer-1")
+                    with gr.Row():
+                        run_button = gr.Button("Edit")
+                        clear_button = gr.Button("Clear")
+
+            with gr.Column():
+                with gr.Group():
+                    gr.Markdown("# Mask")
+                    gr.Markdown("## Background Mask for Removal:")
+                    bg_mask =  gr.Image(sources='upload', label="Mask", interactive=True, type="numpy")
+
+            with gr.Column():
+                with gr.Group():
+                    gr.Markdown("# OUTPUT")
+                    gr.Markdown("## Results")
+                    output = gr.Gallery(columns=1, height='auto')  
+
+        with gr.Column():
+            gr.Markdown("Try some of the examples below ⬇️")
+            gr.Examples(
+                examples=examples_moving,
+                inputs=[
+                bg_img, bg_mask, l1_dx, l1_dy, l1_resize, l1_h_flip, l1_w_flip
+                ]
+            )
+        bg_img.select(
+                segment_with_points, 
+                inputs=[bg_img, bg_ori, global_points, global_point_label], 
+                outputs=[bg_img, bg_ori, bg_mask, global_points, global_point_label]
+        )
+        l1_img.select(
+                get_point_move,
+                [bg_ori, l1_img, selected_points],
+                [l1_img, bg_ori, selected_points, l1_dx, l1_dy],
+        )
+
+        run_button.click(fn=runner, inputs=[
+        bg_img, bg_ori,bg_mask, 
+        l1_dx, l1_dy, l1_resize, l1_w_flip, l1_h_flip, selected_points
+        ], outputs=[output])
+
+        clear_button.click(fn=fun_clear, 
+        inputs=[bg_img, bg_ori, bg_mask, l1_img, l1_ori, l1_dx, l1_dy, l1_resize, l1_w_flip, l1_h_flip,
+        global_points, global_point_label, selected_points, output],
+        outputs=[bg_img, bg_ori, bg_mask, l1_img, l1_ori, l1_dx, l1_dy, l1_resize, l1_w_flip, l1_h_flip,
+        global_points, global_point_label, selected_points, output],         
+        )
+    return demo
diff --git a/src/demo/model.py b/src/demo/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..a92e3c92296a283680eb055c030b0f241efe4a2d
--- /dev/null
+++ b/src/demo/model.py
@@ -0,0 +1,517 @@
+import numpy as np
+import torch
+from diffusers import  DDIMScheduler
+import cv2
+from utils.sdxl import sdxl
+from utils.inversion import Inversion
+import math
+import torch.nn.functional as F
+import utils.utils as utils
+import os 
+import matplotlib.pyplot as plt
+from PIL import Image, ImageDraw, ImageFont
+import spaces
+
+MAX_NUM_WORDS = 77
+
+
+class LayerFusion:   
+    def get_mask(self, maps, alpha, use_pool,x_t):
+        k = 1
+        maps = (maps * alpha).sum(-1).mean(1)
+        if use_pool:
+            maps = F.max_pool2d(maps, (k * 2 + 1, k * 2 + 1), (1, 1), padding=(k, k))
+        mask = F.interpolate(maps, size=(x_t.shape[2:])) #[2, 1, 128, 128]
+        mask = mask / mask.max(2, keepdims=True)[0].max(3, keepdims=True)[0]
+        mask=(mask - mask.min ()) / (mask.max () - mask.min ())
+        mask = mask.gt(self.mask_threshold)
+        self.mask=mask
+        mask = mask[:1] + mask
+        return mask 
+
+    def get_one_mask(self, maps, use_pool, x_t, idx_lst, i=None, sav_img=False):
+        k=1
+        if sav_img is False:
+            mask_tot = 0
+            for obj in idx_lst:
+                mask = maps[0, :, :, :, obj].mean(0).reshape(1, 1, 32, 32)
+                if use_pool:
+                    mask = F.max_pool2d(mask, (k * 2 + 1, k * 2 + 1), (1, 1), padding=(k, k))
+                mask = F.interpolate(mask, size=(x_t.shape[2:]))
+                mask = mask / mask.max(2, keepdims=True)[0].max(3, keepdims=True)[0]
+                mask=(mask - mask.min ()) / (mask.max () - mask.min ())
+                mask = mask.gt(self.mask_threshold[int(self.counter/10)])
+                mask_tot |= mask
+            mask = mask_tot  
+            return mask
+        else: 
+            for obj in idx_lst:
+                mask = maps[0, :, :, :, obj].mean(0).reshape(1, 1, 32, 32)
+                if use_pool:
+                    mask = F.max_pool2d(mask, (k * 2 + 1, k * 2 + 1), (1, 1), padding=(k, k))
+                mask = F.interpolate(mask, size=(1024, 1024))#[1, 1, 1024, 1024]
+                mask = mask / mask.max(2, keepdims=True)[0].max(3, keepdims=True)[0]
+                mask=(mask - mask.min ()) / (mask.max () - mask.min ())
+                mask = mask.gt(0.6)  
+                mask = np.array(mask[0][0].clone().cpu()).astype(np.uint8)*255
+                cv2.imwrite(f'./img/sam_mask/{self.blend_list[i][0]}_{self.counter}.jpg', mask)
+        return mask
+
+    def mv_op(self, mp, op, scale=0.2, ones=False, flip=None):
+        _, b, H, W = mp.shape
+        if ones == False:
+            new_mp = torch.zeros_like(mp)
+        else:
+            new_mp = torch.ones_like(mp)
+        K = int(scale*W)
+        if op == 'right':
+            new_mp[:, :, :, K:] = mp[:, :, :, 0:W-K]
+        elif op == 'left':
+            new_mp[:, :, :, 0:W-K] = mp[:, :, :, K:]
+        elif op == 'down':
+            new_mp[:, :, K:, :] = mp[:, :, 0:W-K, :]
+        elif op == 'up':
+            new_mp[:, :, 0:W-K, :] = mp[:, :, K:, :]
+        if flip is not None:
+            new_mp = torch.flip(new_mp, dims=flip)
+               
+        return new_mp
+
+    def mv_layer(self, x_t, bg_id, fg_id, op_id):
+        bg_img = x_t[bg_id:(bg_id+1)].clone()
+        fg_img = x_t[fg_id:(fg_id+1)].clone()
+        fg_mask = self.fg_mask_list[fg_id-3]
+        op_list = self.op_list[fg_id-3]
+
+        for item in op_list:
+            op, scale = item[0], item[1]
+            if scale != 0:
+                fg_img = self.mv_op(fg_img, op=op, scale=scale)
+                fg_mask = self.mv_op(fg_mask, op=op, scale=scale)
+        x_t[op_id:(op_id+1)] = bg_img*(1-fg_mask) + fg_img*fg_mask
+
+    def __call__(self, x_t):
+        self.counter += 1
+        # inpainting
+        if self.blend_time[0] <= self.counter <= self.blend_time[1]:
+            x_t[1:2] = x_t[1:2]*self.remove_mask + x_t[0:1]*(1-self.remove_mask) 
+
+        if self.counter == self.blend_time[1] + 1 and self.mode != "removal":
+            b = x_t.shape[0]
+            bg_id = 1 #bg_layer
+            op_id = 2 #canvas
+            for fg_id in range(3, b): #fg_layer
+                self.mv_layer(x_t, bg_id=bg_id, fg_id=fg_id, op_id=op_id)
+                bg_id = op_id
+    
+        return x_t
+
+    def __init__(self, remove_mask, fg_mask_list, refine_mask=None, 
+                blend_time=[0, 40],
+                 mode="removal", op_list=None):
+        self.counter = 0
+        self.mode = mode
+        self.op_list = op_list
+        self.blend_time = blend_time
+
+        self.remove_mask = remove_mask
+        self.refine_mask = refine_mask
+        if self.refine_mask is not None:
+            self.new_mask = self.remove_mask + self.refine_mask
+            self.new_mask[self.new_mask>0] = 1
+        else:
+            self.new_mask = None
+        self.fg_mask_list = fg_mask_list
+
+
+class Control():
+    def step_callback(self, x_t):
+        if self.layer_fusion is not None:
+             x_t = self.layer_fusion(x_t)
+        return x_t
+    def __init__(self, layer_fusion):
+        self.layer_fusion = layer_fusion
+
+def register_attention_control(model, controller, mask_time=[0, 40], refine_time=[0, 25]):
+    def ca_forward(self, place_in_unet):
+        to_out = self.to_out
+        if type(to_out) is torch.nn.modules.container.ModuleList:
+            to_out = self.to_out[0]
+        else:
+            to_out = self.to_out
+        self.counter = 0 #time
+        def forward(hidden_states, encoder_hidden_states=None, attention_mask=None): #self_attention
+            x = hidden_states.clone() 
+            context = encoder_hidden_states
+            is_cross = context is not None
+            if is_cross is False:
+                if controller.layer_fusion is not None and (mask_time[0] < self.counter < mask_time[1]):
+                    b, i, j = x.shape
+                    H = W = int(math.sqrt(i))
+                    x_old = x.clone()
+                    x = x.reshape(b, H, W, j)
+                    new_mask = controller.layer_fusion.remove_mask
+                    if new_mask is not None:
+                        new_mask[new_mask>0] = 1
+                        new_mask = F.interpolate(new_mask.to(dtype=torch.float32).clone(), size=(H, W), mode='bilinear').cuda()
+                        new_mask =  (1 - new_mask).reshape(1, H, W).unsqueeze(-1)
+                        if (refine_time[0] < self.counter <= refine_time[1]) and controller.layer_fusion.refine_mask is not None:
+                            new_mask = controller.layer_fusion.new_mask
+                            new_mask = F.interpolate(new_mask.to(dtype=torch.float32).clone(), size=(H, W), mode='bilinear').cuda()
+                            new_mask =  (1 - new_mask).reshape(1, H, W).unsqueeze(-1)                
+                        idx = 1 #inpaiint_idx:bg
+                        x[int(b/2)+idx, :, :] = (x[int(b/2)+idx, :, :]*new_mask[0])
+                    x = x.reshape(b, i, j)
+            if is_cross: 
+                q = self.to_q(x) 
+                k = self.to_k(context)
+                v = self.to_v(context)
+            else:
+                context = x
+                q = self.to_q(hidden_states) 
+                k = self.to_k(x) 
+                v = self.to_v(hidden_states)
+            q = self.head_to_batch_dim(q)
+            k = self.head_to_batch_dim(k)
+            v = self.head_to_batch_dim(v)
+
+            if hasattr(controller, 'count_layers'):
+                controller.count_layers(place_in_unet,is_cross)
+            sim = torch.einsum("b i d, b j d -> b i j", q.clone(), k.clone()) * self.scale 
+
+            attn = sim.softmax(dim=-1)
+            out = torch.einsum("b i j, b j d -> b i d", attn, v)
+            out = self.batch_to_head_dim(out)
+            global global_cnt
+            self.counter += 1
+            return to_out(out)
+        
+        return forward
+
+    def register_recr(net_, count, place_in_unet):
+        if net_.__class__.__name__ == 'Attention':
+            net_.forward = ca_forward(net_, place_in_unet)
+            return count + 1
+        elif hasattr(net_, 'children'):
+            for net__ in net_.children():
+                count = register_recr(net__, count, place_in_unet)
+        return count
+
+    cross_att_count = 0
+    sub_nets = model.unet.named_children()
+    for net in sub_nets:
+        if "down" in net[0]:
+            cross_att_count += register_recr(net[1], 0, "down")
+        elif "up" in net[0]:
+            cross_att_count += register_recr(net[1], 0, "up")
+        elif "mid" in net[0]:
+            cross_att_count += register_recr(net[1], 0, "mid")
+
+    controller.num_att_layers = cross_att_count
+
+class DesignEdit():
+    def __init__(self, pretrained_model_path="/home/jyr/model/stable-diffusion-xl-base-1.0"):
+        self.model_dtype = "fp16"
+        self.pretrained_model_path=pretrained_model_path
+        self.num_ddim_steps = 50
+        self.mask_time = [0, 40]
+        self.op_list = {}
+        self.attend_scale = {}
+        scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
+        if self.model_dtype == "fp16":
+            torch_dtype = torch.float16
+        elif self.model_dtype == "fp32":
+            torch_dtype = torch.float32
+        self.pipe = sdxl.from_pretrained(self.pretrained_model_path, torch_dtype=torch_dtype, use_safetensors=True, variant=self.model_dtype,scheduler=scheduler)
+   
+    @spaces.GPU
+    def init_model(self, num_ddim_steps=50):
+        device = torch.device('cuda:0')
+        self.pipe.to(device)
+        inversion = Inversion(self.pipe,num_ddim_steps)
+        return self.pipe, inversion
+    
+    @spaces.GPU(duration=120, enable_queue=True)
+    def run_remove(self, original_image=None, mask_1=None, mask_2=None, mask_3=None, refine_mask=None, 
+        ori_1=None, ori_2=None, ori_3=None,
+        prompt="", save_dir="./tmp", mode='removal',):
+        # 01-1: 
+        self.ldm_model, self.inversion= self.init_model(num_ddim_steps=self.num_ddim_steps)
+        if original_image is None:
+            original_image = ori_1 if ori_1 is not None else ori_2 if ori_2 is not None else ori_3
+        op_list = None
+        attend_scale = 20
+        sample_ref_match={0 : 0, 1 : 0}
+        ori_shape = original_image.shape
+
+        # 01-2: prepare: image_gt, remove_mask, fg_mask_list, refine_mask
+        image_gt = Image.fromarray(original_image).resize((1024, 1024))
+        image_gt = np.stack([np.array(image_gt)])
+        mask_list = [mask_1, mask_2, mask_3]
+        remove_mask = utils.attend_mask(utils.add_masks_resized(mask_list), attend_scale=attend_scale) # numpy to tensor
+        fg_mask_list = None
+        refine_mask = utils.attend_mask(utils.convert_and_resize_mask(refine_mask)) if refine_mask is not None else None
+
+        # 01-3: prepare: prompts, blend_time, refine_time
+        prompts = len(sample_ref_match)*[prompt] # 2
+        blend_time = [0, 41]
+        refine_time = [0, 25]
+        
+        # 02: invert
+        _, x_t, x_stars, prompt_embeds, pooled_prompt_embeds = self.inversion.invert(image_gt, prompts, inv_batch_size=1)
+        
+        # 03: init layer_fusion and controller
+        lb = LayerFusion(remove_mask=remove_mask, fg_mask_list=fg_mask_list, refine_mask=refine_mask,
+                    blend_time=blend_time, mode=mode, op_list=op_list)
+        controller = Control(layer_fusion=lb)
+        register_attention_control(model=self.ldm_model, controller=controller, mask_time=self.mask_time, refine_time=refine_time)
+        
+        # 04: generate images
+        images = self.ldm_model(controller=controller, prompt=prompts,
+                        latents=x_t, x_stars=x_stars,  
+                        negative_prompt_embeds=prompt_embeds, 
+                        negative_pooled_prompt_embeds=pooled_prompt_embeds,
+                        sample_ref_match=sample_ref_match)
+        folder = None
+        utils.view_images(images, folder=folder)
+        return [cv2.resize(images[1], (ori_shape[1], ori_shape[0]))]
+    
+    @spaces.GPU(duration=120, enable_queue=True)
+    def run_zooming(self, original_image, width_scale=1, height_scale=1, prompt="", save_dir="./tmp", mode='removal'):
+        self.ldm_model, self.inversion= self.init_model(num_ddim_steps=self.num_ddim_steps)
+        # 01-1: 
+        op_list = {0: ['zooming', [height_scale, width_scale]]}
+        ori_shape = original_image.shape
+        attend_scale = 30
+        sample_ref_match = {0 : 0, 1 : 0}
+
+        # 01-2: prepare: image_gt, remove_mask, fg_mask_list, refine_mask
+        img_new, mask = utils.zooming(original_image, [height_scale, width_scale])
+        img_new_copy = img_new.copy()
+        mask_copy = mask.copy()
+        
+        image_gt = Image.fromarray(img_new).resize((1024, 1024))
+        image_gt = np.stack([np.array(image_gt)])
+
+        remove_mask = utils.attend_mask(utils.convert_and_resize_mask(mask), attend_scale=attend_scale) # numpy to tensor
+        fg_mask_list = None
+        refine_mask = None
+
+        # 01-3: prepare: prompts, blend_time, refine_time
+        prompts = len(sample_ref_match)*[prompt] # 2
+        blend_time = [0, 41]
+        refine_time = [0, 25]
+
+        # 02: invert
+        _, x_t, x_stars, prompt_embeds, pooled_prompt_embeds = self.inversion.invert(image_gt, prompts, inv_batch_size=1)
+        
+        # 03: init layer_fusion and controller
+        lb = LayerFusion(remove_mask=remove_mask, fg_mask_list=fg_mask_list, blend_time=blend_time,
+                    mode=mode, op_list=op_list)
+        controller = Control(layer_fusion=lb)
+        register_attention_control(model=self.ldm_model, controller=controller, mask_time=self.mask_time, refine_time=refine_time)
+        
+        # 04: generate images
+        images = self.ldm_model(controller=controller, prompt=prompts,
+                        latents=x_t, x_stars=x_stars,  
+                        negative_prompt_embeds=prompt_embeds, 
+                        negative_pooled_prompt_embeds=pooled_prompt_embeds,
+                        sample_ref_match=sample_ref_match)
+        folder = None
+        utils.view_images(images, folder=folder)
+        resized_img = cv2.resize(images[1], (ori_shape[1], ori_shape[0]))
+        return [resized_img], [img_new_copy], [mask_copy]
+    
+    @spaces.GPU(duration=120, enable_queue=True)
+    def run_panning(self, original_image, w_direction, w_scale, h_direction, h_scale, prompt="", save_dir="./tmp", mode='removal'):
+        # 01-1: prepare: op_list, attend_scale, sample_ref_match
+        self.ldm_model, self.inversion= self.init_model(num_ddim_steps=self.num_ddim_steps)
+        ori_shape = original_image.shape
+        attend_scale = 30
+        sample_ref_match = {0 : 0, 1 : 0}
+
+        # 01-2: prepare: image_gt, remove_mask, fg_mask_list, refine_mask
+        op_list = [[w_direction, w_scale], [h_direction, h_scale]]
+        img_new, mask = utils.panning(original_image, op_list=op_list)
+        img_new_copy = img_new.copy()
+        mask_copy = mask.copy()
+        
+        image_gt = Image.fromarray(img_new).resize((1024, 1024))
+        image_gt = np.stack([np.array(image_gt)])
+        remove_mask = utils.attend_mask(utils.convert_and_resize_mask(mask), attend_scale=attend_scale) # numpy to tensor
+
+        fg_mask_list = None
+        refine_mask = None
+
+        # 01-3: prepare: prompts, blend_time, refine_time
+        prompts = len(sample_ref_match)*[prompt] # 2
+        blend_time = [0, 41]
+        refine_time = [0, 25]
+
+        # 02: invert
+        _, x_t, x_stars, prompt_embeds, pooled_prompt_embeds = self.inversion.invert(image_gt, prompts, inv_batch_size=1)
+        # 03: init layer_fusion and controller
+        lb = LayerFusion(remove_mask=remove_mask, fg_mask_list=fg_mask_list, blend_time=blend_time,
+                    mode=mode, op_list=op_list)
+        controller = Control(layer_fusion=lb)
+        register_attention_control(model=self.ldm_model, controller=controller, mask_time=self.mask_time, refine_time=refine_time)
+        
+        # 04: generate images
+
+        images = self.ldm_model(controller=controller, prompt=prompts,
+                        latents=x_t, x_stars=x_stars,  
+                        negative_prompt_embeds=prompt_embeds, 
+                        negative_pooled_prompt_embeds=pooled_prompt_embeds,
+                        sample_ref_match=sample_ref_match)
+        folder = None
+        utils.view_images(images, folder=folder)
+        resized_img = cv2.resize(images[1], (ori_shape[1], ori_shape[0]))
+        return [resized_img], [img_new_copy], [mask_copy]
+
+    # layer-wise multi-object editing
+    def process_layer_states(self, layer_states):
+        self.ldm_model, self.inversion= self.init_model(num_ddim_steps=self.num_ddim_steps)
+        image_paths = []
+        mask_paths = []
+        op_list = []
+        
+        for state in layer_states:
+            img, mask, dx, dy, resize, w_flip, h_flip = state
+            if img is not None:  
+                img = cv2.resize(img, (1024, 1024))
+                mask = utils.convert_and_resize_mask(mask)
+                dx_command = ['right', dx] if dx > 0 else ['left', -dx]
+                dy_command = ['up', dy] if dy > 0 else ['down', -dy]
+                flip_code = None
+                if w_flip == "left/right" and h_flip == "down/up":
+                    flip_code = -1
+                elif w_flip == "left/right":
+                    flip_code = 1  # 或者其他默认值,根据您的需要设置
+                elif h_flip == "down/up":
+                    flip_code = 0
+                op_list.append([dx_command, dy_command])
+                img, mask, _ = utils.resize_image_with_mask(img, mask, resize)
+                img, mask, _ = utils.flip_image_with_mask(img, mask, flip_code=flip_code)
+                image_paths.append(img)
+                mask_paths.append(utils.attend_mask(mask))
+        sample_ref_match = {0: 0, 1: 0, 2: 0, 3: 1, 4: 2, 5: 3}
+        required_length = len(image_paths) + 3
+        truncated_sample_ref_match = {k: sample_ref_match[k] for k in sorted(sample_ref_match.keys())[:required_length]}
+        return image_paths, mask_paths, op_list, truncated_sample_ref_match
+
+    @spaces.GPU(duration=200)
+    def run_layer(self, bg_img, l1_img, l1_dx, l1_dy, l1_resize, l1_w_flip, l1_h_flip, 
+        l2_img, l2_dx, l2_dy, l2_resize, l2_w_flip, l2_h_flip,
+        l3_img, l3_dx, l3_dy, l3_resize, l3_w_flip, l3_h_flip,
+        bg_mask, l1_mask, l2_mask, l3_mask,
+        bg_ori=None, l1_ori=None, l2_ori=None, l3_ori=None,
+        prompt="", save_dir="./tmp", mode='layerwise'):
+        self.ldm_model, self.inversion= self.init_model(num_ddim_steps=self.num_ddim_steps)
+        # 00: prepare: layer-wise states
+        bg_img = bg_ori if bg_ori is not None else bg_img
+        l1_img = l1_ori if l1_ori is not None else l1_img
+        l2_img = l2_ori if l2_ori is not None else l2_img
+        l3_img = l3_ori if l3_ori is not None else l3_img
+        for mask in [bg_mask, l1_mask, l2_mask, l3_mask]:
+            if mask is None:
+                mask = np.zeros((1024, 1024), dtype=np.uint8)
+            else:
+                mask = utils.convert_and_resize_mask(mask)
+        l1_state = [l1_img, l1_mask, l1_dx, l1_dy, l1_resize, l1_w_flip, l1_h_flip]
+        l2_state = [l2_img, l2_mask, l2_dx, l2_dy, l2_resize, l2_w_flip, l2_h_flip]
+        l3_state = [l3_img, l3_mask, l3_dx, l3_dy, l3_resize, l3_w_flip, l3_h_flip]
+        ori_shape = bg_img.shape
+
+        image_paths, fg_mask_list, op_list, sample_ref_match = self.process_layer_states([l1_state, l2_state, l3_state])
+        if image_paths == []:
+            mode = "removal"
+        # 01-1: prepare: image_gt, remove_mask, fg_mask_list, refine_mask
+        attend_scale = 20
+        image_gt = [bg_img] + image_paths
+        image_gt = [Image.fromarray(img).resize((1024, 1024)) for img in image_gt]
+        image_gt = np.stack(image_gt)      
+        remove_mask = utils.attend_mask(bg_mask, attend_scale=attend_scale)
+        refine_mask = None
+
+        # 01-2: prepare: promptrun_masks, blend_time, refine_time
+        prompts = len(sample_ref_match)*[prompt] # 2
+        blend_time = [0, 41]
+        refine_time = [0, 25]
+        attend_scale = []
+
+        # 02: invert
+        _, x_t, x_stars, prompt_embeds, pooled_prompt_embeds = self.inversion.invert(image_gt, prompts, inv_batch_size=len(image_gt))
+        # 03: init layer_fusion and controller
+        lb = LayerFusion(remove_mask=remove_mask, fg_mask_list=fg_mask_list, blend_time=blend_time, refine_mask=refine_mask,
+                    mode=mode, op_list=op_list)
+        controller = Control(layer_fusion=lb)
+        register_attention_control(model=self.ldm_model, controller=controller, mask_time=self.mask_time, refine_time=refine_time)
+        # 04: generate images
+        images = self.ldm_model(controller=controller, prompt=prompts,
+                        latents=x_t, x_stars=x_stars,  
+                        negative_prompt_embeds=prompt_embeds, 
+                        negative_pooled_prompt_embeds=pooled_prompt_embeds,
+                        sample_ref_match=sample_ref_match)
+        folder = None
+        utils.view_images(images, folder=folder) 
+        if mode == 'removal':
+            resized_img = cv2.resize(images[1], (ori_shape[1], ori_shape[0]))       
+        else:
+            resized_img = cv2.resize(images[2], (ori_shape[1], ori_shape[0]))       
+        return [resized_img]
+
+    @spaces.GPU(duration=120, enable_queue=True)
+    def run_moving(self, bg_img, bg_ori, bg_mask, l1_dx, l1_dy, l1_resize, 
+        l1_w_flip=None, l1_h_flip=None, selected_points=None,
+        prompt="", save_dir="./tmp", mode='layerwise'):
+        self.ldm_model, self.inversion= self.init_model(num_ddim_steps=self.num_ddim_steps)
+        # 00: prepare: layer-wise states
+        bg_img = bg_ori if bg_ori is not None else bg_img
+        l1_img = bg_img
+        if bg_mask is None:
+            bg_mask = np.zeros((1024, 1024), dtype=np.uint8)
+        else:
+            bg_mask = utils.convert_and_resize_mask(bg_mask)
+        l1_mask = bg_mask
+        l1_state = [l1_img, l1_mask, l1_dx, l1_dy, l1_resize, l1_w_flip, l1_h_flip]
+        ori_shape = bg_img.shape
+
+        image_paths, fg_mask_list, op_list, sample_ref_match = self.process_layer_states([l1_state])
+
+        # 01-1: prepare: image_gt, remove_mask, fg_mask_list, refine_mask
+        attend_scale = 20
+        image_gt = [bg_img] + image_paths
+        image_gt = [Image.fromarray(img).resize((1024, 1024)) for img in image_gt]
+        image_gt = np.stack(image_gt)      
+        remove_mask = utils.attend_mask(bg_mask, attend_scale=attend_scale)
+        refine_mask = None
+
+        # 01-2: prepare: promptrun_masks, blend_time, refine_time
+        prompts = len(sample_ref_match)*[prompt] # 2
+        blend_time = [0, 41]
+        refine_time = [0, 25]
+        attend_scale = []
+
+        # 02: invert
+        _, x_t, x_stars, prompt_embeds, pooled_prompt_embeds = self.inversion.invert(image_gt, prompts, inv_batch_size=len(image_gt))
+        # 03: init layer_fusion and controller
+        lb = LayerFusion(remove_mask=remove_mask, fg_mask_list=fg_mask_list, blend_time=blend_time, refine_mask=refine_mask,
+                    mode=mode, op_list=op_list)
+        controller = Control(layer_fusion=lb)
+        register_attention_control(model=self.ldm_model, controller=controller, mask_time=self.mask_time, refine_time=refine_time)
+        # 04: generate images
+        images = self.ldm_model(controller=controller, prompt=prompts,
+                        latents=x_t, x_stars=x_stars,  
+                        negative_prompt_embeds=prompt_embeds, 
+                        negative_pooled_prompt_embeds=pooled_prompt_embeds,
+                        sample_ref_match=sample_ref_match)
+        folder = None
+        utils.view_images(images, folder=folder) 
+        resized_img = cv2.resize(images[2], (ori_shape[1], ori_shape[0]))       
+        return [resized_img]
+
+    # turn mask to 1024x1024 unit-8
+    def run_mask(self, mask_1, mask_2, mask_3, mask_4):
+        mask_list = [mask_1, mask_2, mask_3, mask_4]
+        final_mask = utils.add_masks_resized(mask_list)
+        return final_mask
\ No newline at end of file
diff --git a/src/demo/utils.py b/src/demo/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d0c789365ca94a89a9e71874999b73398f682f9
--- /dev/null
+++ b/src/demo/utils.py
@@ -0,0 +1,319 @@
+import numpy as np
+import gradio as gr
+import cv2
+from copy import deepcopy
+import torch
+from torchvision import transforms
+from PIL import Image, ImageDraw, ImageFont
+
+from sam.efficient_sam.build_efficient_sam import build_efficient_sam_vits
+from src.utils.utils import resize_numpy_image
+
+sam = build_efficient_sam_vits()
+
+def show_point_or_box(image, global_points):
+    # for point
+    if len(global_points) == 1:
+        image = cv2.circle(image, global_points[0], 10, (0, 0, 255), -1)
+    # for box
+    if len(global_points) == 2:
+        p1 = global_points[0]
+        p2 = global_points[1]
+        image = cv2.rectangle(image,(int(p1[0]),int(p1[1])),(int(p2[0]),int(p2[1])),(0,0,255),2)
+
+    return image
+    
+def segment_with_points(
+    image,
+    original_image,
+    global_points,
+    global_point_label,
+    evt: gr.SelectData,
+    img_direction,
+    save_dir = "./tmp"
+):
+    if original_image is None:
+        original_image = image
+    else:
+        image = original_image
+    if img_direction is None:
+        img_direction = original_image
+    x, y = evt.index[0], evt.index[1]
+    image_path = None
+    mask_path = None
+    if len(global_points) == 0:
+        global_points.append([x, y])
+        global_point_label.append(2)
+        image_with_point= show_point_or_box(image.copy(), global_points)
+        return image_with_point, original_image, None, global_points, global_point_label
+    elif len(global_points) == 1:
+        global_points.append([x, y])
+        global_point_label.append(3)
+        x1, y1 = global_points[0]
+        x2, y2 = global_points[1]
+        if x1 < x2 and y1 >= y2:
+            global_points[0][0] = x1
+            global_points[0][1] = y2
+            global_points[1][0] = x2
+            global_points[1][1] = y1
+        elif x1 >= x2 and y1 < y2:
+            global_points[0][0] = x2
+            global_points[0][1] = y1
+            global_points[1][0] = x1
+            global_points[1][1] = y2
+        elif x1 >= x2 and y1 >= y2:
+            global_points[0][0] = x2
+            global_points[0][1] = y2
+            global_points[1][0] = x1
+            global_points[1][1] = y1
+        image_with_point = show_point_or_box(image.copy(), global_points)
+        # data process
+        input_point = np.array(global_points)
+        input_label = np.array(global_point_label)
+        pts_sampled = torch.reshape(torch.tensor(input_point), [1, 1, -1, 2])
+        pts_labels = torch.reshape(torch.tensor(input_label), [1, 1, -1])
+        img_tensor = transforms.ToTensor()(image)
+        # sam
+        predicted_logits, predicted_iou = sam(
+            img_tensor[None, ...],
+            pts_sampled,
+            pts_labels,
+        )
+        mask = torch.ge(predicted_logits[0, 0, 0, :, :], 0).float().cpu().detach().numpy()
+        mask_image = (mask*255.).astype(np.uint8)
+        return image_with_point, original_image, mask_image, global_points, global_point_label
+    else:
+        global_points=[[x, y]]
+        global_point_label=[2]
+        image_with_point= show_point_or_box(image.copy(), global_points)
+        return image_with_point, original_image, None, global_points, global_point_label
+
+
+def segment_with_points_paste(
+    image,
+    original_image,
+    global_points,
+    global_point_label,
+    image_b,
+    evt: gr.SelectData,
+    dx, 
+    dy, 
+    resize_scale
+
+):
+    if original_image is None:
+        original_image = image
+    else:
+        image = original_image
+    x, y = evt.index[0], evt.index[1]
+    if len(global_points) == 0:
+        global_points.append([x, y])
+        global_point_label.append(2)
+        image_with_point= show_point_or_box(image.copy(), global_points)
+        return image_with_point, original_image, None, global_points, global_point_label, None
+    elif len(global_points) == 1:
+        global_points.append([x, y])
+        global_point_label.append(3)
+        x1, y1 = global_points[0]
+        x2, y2 = global_points[1]
+        if x1 < x2 and y1 >= y2:
+            global_points[0][0] = x1
+            global_points[0][1] = y2
+            global_points[1][0] = x2
+            global_points[1][1] = y1
+        elif x1 >= x2 and y1 < y2:
+            global_points[0][0] = x2
+            global_points[0][1] = y1
+            global_points[1][0] = x1
+            global_points[1][1] = y2
+        elif x1 >= x2 and y1 >= y2:
+            global_points[0][0] = x2
+            global_points[0][1] = y2
+            global_points[1][0] = x1
+            global_points[1][1] = y1
+        image_with_point = show_point_or_box(image.copy(), global_points)
+        # data process
+        input_point = np.array(global_points)
+        input_label = np.array(global_point_label)
+        pts_sampled = torch.reshape(torch.tensor(input_point), [1, 1, -1, 2])
+        pts_labels = torch.reshape(torch.tensor(input_label), [1, 1, -1])
+        img_tensor = transforms.ToTensor()(image)
+        # sam
+        predicted_logits, predicted_iou = sam(
+            img_tensor[None, ...],
+            pts_sampled,
+            pts_labels,
+        )
+        mask = torch.ge(predicted_logits[0, 0, 0, :, :], 0).float().cpu().detach().numpy()
+        mask_uint8 = (mask*255.).astype(np.uint8)
+
+        return image_with_point, original_image, paste_with_mask_and_offset(image, image_b, mask_uint8, dx, dy, resize_scale), global_points, global_point_label, mask_uint8
+    else:
+        global_points=[[x, y]]
+        global_point_label=[2]
+        image_with_point= show_point_or_box(image.copy(), global_points)
+        return image_with_point, original_image, None, global_points, global_point_label, None
+
+def paste_with_mask_and_offset(image_a, image_b, mask, x_offset=0, y_offset=0, delta=1):
+    try:
+        numpy_mask = np.array(mask)
+        y_coords, x_coords = np.nonzero(numpy_mask)  
+        x_min = x_coords.min()  
+        x_max = x_coords.max()  
+        y_min = y_coords.min()  
+        y_max = y_coords.max()
+        target_center_x = int((x_min + x_max) / 2)
+        target_center_y = int((y_min + y_max) / 2)
+
+        image_a = Image.fromarray(image_a)
+        image_b = Image.fromarray(image_b)
+        mask = Image.fromarray(mask)
+
+        if image_a.size != mask.size:
+            mask = mask.resize(image_a.size)
+
+        cropped_image = Image.composite(image_a, Image.new('RGBA', image_a.size, (0, 0, 0, 0)), mask)
+        x_b = int(target_center_x * (image_b.width / cropped_image.width))
+        y_b = int(target_center_y * (image_b.height / cropped_image.height))
+        x_offset = x_offset - int((delta - 1) * x_b)
+        y_offset = y_offset - int((delta - 1) * y_b)
+        cropped_image = cropped_image.resize(image_b.size)
+        new_size = (int(cropped_image.width * delta), int(cropped_image.height * delta))
+        cropped_image = cropped_image.resize(new_size)
+        image_b.putalpha(128) 
+        result_image = Image.new('RGBA', image_b.size, (0, 0, 0, 0))
+        result_image.paste(image_b, (0, 0))
+        result_image.paste(cropped_image, (x_offset, y_offset), mask=cropped_image)
+
+        return result_image
+    except:
+        return None
+
+def upload_image_move(img, original_image):
+    if original_image is not None:
+        return original_image
+    else:
+        return img
+
+def fun_clear(*args):
+    result = []
+    for arg in args:
+        if isinstance(arg, list):
+            result.append([])
+        else:
+            result.append(None)
+    return tuple(result)
+
+def clear_points(img):
+    image, mask = img["image"], np.float32(img["mask"][:, :, 0]) / 255.
+    if mask.sum() > 0:
+        mask = np.uint8(mask > 0)
+        masked_img = mask_image(image, 1 - mask, color=[0, 0, 0], alpha=0.3)
+    else:
+        masked_img = image.copy()
+
+    return [], masked_img
+
+def get_point(img, sel_pix, evt: gr.SelectData):
+    sel_pix.append(evt.index)
+    points = []
+    for idx, point in enumerate(sel_pix):
+        if idx % 2 == 0:
+            cv2.circle(img, tuple(point), 10, (0, 0, 255), -1)
+        else:
+            cv2.circle(img, tuple(point), 10, (255, 0, 0), -1)
+        points.append(tuple(point))
+        if len(points) == 2:
+            cv2.arrowedLine(img, points[0], points[1], (255, 255, 255), 4, tipLength=0.5)
+            points = []
+    return img if isinstance(img, np.ndarray) else np.array(img)
+
+def calculate_translation_percentage(ori_shape, selected_points):
+    dx = selected_points[1][0] - selected_points[0][0]
+    dy = selected_points[1][1] - selected_points[0][1]
+    dx_percentage = dx / ori_shape[1]
+    dy_percentage = dy / ori_shape[0]
+    
+    return dx_percentage, dy_percentage
+
+def get_point_move(original_image, img, sel_pix, evt: gr.SelectData):
+    if original_image is not None:
+        img = original_image.copy()
+    else:
+        original_image = img.copy()
+    if len(sel_pix)<2:
+        sel_pix.append(evt.index)
+    else:
+        sel_pix = [evt.index]
+    points = []
+    dx, dy = 0, 0
+    for idx, point in enumerate(sel_pix):
+        if idx % 2 == 0:
+            cv2.circle(img, tuple(point), 10, (0, 0, 255), -1)
+        else:
+            cv2.circle(img, tuple(point), 10, (255, 0, 0), -1)
+        points.append(tuple(point))
+        if len(points) == 2:
+            cv2.arrowedLine(img, points[0], points[1], (255, 255, 255), 4, tipLength=0.5)
+            ori_shape = original_image.shape
+            dx, dy = calculate_translation_percentage(original_image.shape, sel_pix)
+            points = []
+    img = np.array(img)
+
+    return img, original_image, sel_pix, dx, dy
+
+def store_img(img):
+    image, mask = img["image"], np.float32(img["mask"][:, :, 0]) / 255.
+    if mask.sum() > 0:
+        mask = np.uint8(mask > 0)
+        masked_img = mask_image(image, 1 - mask, color=[0, 0, 0], alpha=0.3)
+    else:
+        masked_img = image.copy()
+
+    return image, masked_img, mask
+# im["background"], im["layers"][0]
+def store_img_move(img, mask=None):
+    if mask is not None:
+        image = img["background"]
+        return image, None, mask
+    image, mask = img["background"], np.float32(["layers"][0][:, :, 0]) / 255.
+    if mask.sum() > 0:
+        mask = np.uint8(mask > 0)
+        masked_img = mask_image(image, 1 - mask, color=[0, 0, 0], alpha=0.3)
+    else:
+        masked_img = image.copy()
+
+    return image, masked_img, (mask*255.).astype(np.uint8)
+
+def store_img_move_old(img, mask=None):
+    if mask is not None:
+        image = img["image"]
+        return image, None, mask
+    image, mask = img["image"], np.float32(img["mask"][:, :, 0]) / 255.
+    if mask.sum() > 0:
+        mask = np.uint8(mask > 0)
+        masked_img = mask_image(image, 1 - mask, color=[0, 0, 0], alpha=0.3)
+    else:
+        masked_img = image.copy()
+
+    return image, masked_img, (mask*255.).astype(np.uint8)
+
+def mask_image(image, mask, color=[255,0,0], alpha=0.5, max_resolution=None):
+    """ Overlay mask on image for visualization purpose. 
+    Args:
+        image (H, W, 3) or (H, W): input image
+        mask (H, W): mask to be overlaid
+        color: the color of overlaid mask
+        alpha: the transparency of the mask
+    """
+    if max_resolution is not None:
+        image, _ = resize_numpy_image(image, max_resolution*max_resolution)
+        mask = cv2.resize(mask, (image.shape[1], image.shape[0]),interpolation=cv2.INTER_NEAREST)
+
+    out = deepcopy(image)
+    img = deepcopy(image)
+    img[mask == 1] = color
+    out = cv2.addWeighted(img, alpha, out, 1-alpha, 0, out)
+    contours = cv2.findContours(np.uint8(deepcopy(mask)), cv2.RETR_TREE, 
+                        cv2.CHAIN_APPROX_SIMPLE)[-2:]
+    return out
\ No newline at end of file
diff --git a/src/utils/utils.py b/src/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b127a81c0f3f87a58bbab638cd032da25cdb5d5c
--- /dev/null
+++ b/src/utils/utils.py
@@ -0,0 +1,240 @@
+import numpy as np
+import cv2
+from basicsr.utils import img2tensor
+import torch
+import torch.nn.functional as F
+
+def resize_numpy_image(image, max_resolution=768 * 768, resize_short_edge=None):
+    h, w = image.shape[:2]
+    w_org = image.shape[1]
+    if resize_short_edge is not None:
+        k = resize_short_edge / min(h, w)
+    else:
+        k = max_resolution / (h * w)
+        k = k**0.5
+    h = int(np.round(h * k / 64)) * 64
+    w = int(np.round(w * k / 64)) * 64
+    image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4)
+    scale = w/w_org
+    return image, scale
+
+def split_ldm(ldm):
+    x = []
+    y = []
+    for p in ldm:
+        x.append(p[0])
+        y.append(p[1])
+    return x,y
+
+def process_move(path_mask, h, w, dx, dy, scale, input_scale, resize_scale, up_scale, up_ft_index, w_edit, w_content, w_contrast, w_inpaint,  precision, path_mask_ref=None):
+    dx, dy = dx*input_scale, dy*input_scale
+    if isinstance(path_mask, str):
+        mask_x0 = cv2.imread(path_mask)
+    else:
+        mask_x0 = path_mask
+    mask_x0 = cv2.resize(mask_x0, (h, w))
+    if path_mask_ref is not None:
+        if isinstance(path_mask_ref, str):
+            mask_x0_ref = cv2.imread(path_mask_ref)
+        else:
+            mask_x0_ref = path_mask_ref
+        mask_x0_ref = cv2.resize(mask_x0_ref, (h, w))
+    else:
+        mask_x0_ref=None
+
+    mask_x0 = img2tensor(mask_x0)[0]
+    mask_x0 = (mask_x0>0.5).float().to('cuda', dtype=precision)
+    if mask_x0_ref is not None:
+        mask_x0_ref = img2tensor(mask_x0_ref)[0]
+        mask_x0_ref = (mask_x0_ref>0.5).float().to('cuda', dtype=precision)
+    mask_org = F.interpolate(mask_x0[None,None], (int(mask_x0.shape[-2]//scale), int(mask_x0.shape[-1]//scale)))>0.5
+
+    mask_tar = F.interpolate(mask_x0[None,None], (int(mask_x0.shape[-2]//scale*resize_scale), int(mask_x0.shape[-1]//scale*resize_scale)))>0.5
+    mask_cur = torch.roll(mask_tar, (int(dy//scale*resize_scale), int(dx//scale*resize_scale)), (-2,-1))
+    
+    pad_size_x = abs(mask_tar.shape[-1]-mask_org.shape[-1])//2
+    pad_size_y = abs(mask_tar.shape[-2]-mask_org.shape[-2])//2
+    if resize_scale>1:
+        sum_before = torch.sum(mask_cur)
+        mask_cur = mask_cur[:,:,pad_size_y:pad_size_y+mask_org.shape[-2],pad_size_x:pad_size_x+mask_org.shape[-1]]
+        sum_after = torch.sum(mask_cur)
+        if sum_after != sum_before:
+            raise ValueError('Resize out of bounds, exiting.')
+    else:
+        temp = torch.zeros(1,1,mask_org.shape[-2], mask_org.shape[-1]).to(mask_org.device)
+        temp[:,:,pad_size_y:pad_size_y+mask_cur.shape[-2],pad_size_x:pad_size_x+mask_cur.shape[-1]]=mask_cur
+        mask_cur =temp>0.5
+
+    mask_other = (1-((mask_cur+mask_org)>0.5).float())>0.5
+    mask_overlap = ((mask_cur.float()+mask_org.float())>1.5).float()
+    mask_non_overlap = (mask_org.float()-mask_overlap)>0.5
+
+    return {
+        "mask_x0":mask_x0, 
+        "mask_x0_ref":mask_x0_ref, 
+        "mask_tar":mask_tar, 
+        "mask_cur":mask_cur, 
+        "mask_other":mask_other, 
+        "mask_overlap":mask_overlap, 
+        "mask_non_overlap":mask_non_overlap, 
+        "up_scale":up_scale,
+        "up_ft_index":up_ft_index,
+        "resize_scale":resize_scale,
+        "w_edit":w_edit,
+        "w_content":w_content,
+        "w_contrast":w_contrast,
+        "w_inpaint":w_inpaint, 
+    }
+
+def process_drag_face(h, w, x, y, x_cur, y_cur, scale, input_scale, up_scale, up_ft_index, w_edit, w_inpaint, precision):
+    for i in range(len(x)):
+        x[i] = int(x[i]*input_scale)
+        y[i] = int(y[i]*input_scale)
+        x_cur[i] = int(x_cur[i]*input_scale)
+        y_cur[i] = int(y_cur[i]*input_scale)
+
+    mask_tar = []
+    for p_idx in range(len(x)):
+        mask_i = torch.zeros(int(h//scale), int(w//scale)).cuda()
+        y_clip = int(np.clip(y[p_idx]//scale, 1, mask_i.shape[0]-2))
+        x_clip = int(np.clip(x[p_idx]//scale, 1, mask_i.shape[1]-2))
+        mask_i[y_clip-1:y_clip+2,x_clip-1:x_clip+2]=1
+        mask_i = mask_i>0.5
+        mask_tar.append(mask_i)
+    mask_cur = []
+    for p_idx in range(len(x_cur)):
+        mask_i = torch.zeros(int(h//scale), int(w//scale)).cuda()
+        y_clip = int(np.clip(y_cur[p_idx]//scale, 1, mask_i.shape[0]-2))
+        x_clip = int(np.clip(x_cur[p_idx]//scale, 1, mask_i.shape[1]-2))
+        mask_i[y_clip-1:y_clip+2,x_clip-1:x_clip+2]=1
+        mask_i=mask_i>0.5
+        mask_cur.append(mask_i)
+
+    return {
+        "mask_tar":mask_tar,
+        "mask_cur":mask_cur,
+        "up_scale":up_scale,
+        "up_ft_index":up_ft_index,
+        "w_edit": w_edit,
+        "w_inpaint": w_inpaint,
+    }
+
+def process_drag(path_mask, h, w, x, y, x_cur, y_cur, scale, input_scale, up_scale, up_ft_index, w_edit, w_inpaint, w_content, precision, latent_in):
+    if isinstance(path_mask, str):
+        mask_x0 = cv2.imread(path_mask)
+    else:
+        mask_x0 = path_mask
+    mask_x0 = cv2.resize(mask_x0, (h, w))
+    mask_x0 = img2tensor(mask_x0)[0]
+    dict_mask = {}
+    dict_mask['base'] = mask_x0
+    mask_x0 = (mask_x0>0.5).float().to('cuda', dtype=precision)
+
+    mask_other = F.interpolate(mask_x0[None,None], (int(mask_x0.shape[-2]//scale), int(mask_x0.shape[-1]//scale)))<0.5
+    mask_tar = []
+    mask_cur = []
+    for p_idx in range(len(x)):
+        mask_tar_i = torch.zeros(int(mask_x0.shape[-2]//scale), int(mask_x0.shape[-1]//scale)).to('cuda', dtype=precision)
+        mask_cur_i = torch.zeros(int(mask_x0.shape[-2]//scale), int(mask_x0.shape[-1]//scale)).to('cuda', dtype=precision)
+        y_tar_clip = int(np.clip(y[p_idx]//scale, 1, mask_tar_i.shape[0]-2))
+        x_tar_clip = int(np.clip(x[p_idx]//scale, 1, mask_tar_i.shape[0]-2))
+        y_cur_clip = int(np.clip(y_cur[p_idx]//scale, 1, mask_cur_i.shape[0]-2))
+        x_cur_clip = int(np.clip(x_cur[p_idx]//scale, 1, mask_cur_i.shape[0]-2))
+        mask_tar_i[y_tar_clip-1:y_tar_clip+2,x_tar_clip-1:x_tar_clip+2]=1
+        mask_cur_i[y_cur_clip-1:y_cur_clip+2,x_cur_clip-1:x_cur_clip+2]=1
+        mask_tar_i = mask_tar_i>0.5
+        mask_cur_i=mask_cur_i>0.5
+        mask_tar.append(mask_tar_i)
+        mask_cur.append(mask_cur_i)
+        latent_in[:,:,y_cur_clip//up_scale-1:y_cur_clip//up_scale+2, x_cur_clip//up_scale-1:x_cur_clip//up_scale+2] = latent_in[:,:, y_tar_clip//up_scale-1:y_tar_clip//up_scale+2, x_tar_clip//up_scale-1:x_tar_clip//up_scale+2] 
+        
+
+    return {
+        "dict_mask":dict_mask,
+        "mask_x0":mask_x0,
+        "mask_tar":mask_tar,
+        "mask_cur":mask_cur,
+        "mask_other":mask_other,
+        "up_scale":up_scale,
+        "up_ft_index":up_ft_index,
+        "w_edit": w_edit,
+        "w_inpaint": w_inpaint,
+        "w_content": w_content,
+        "latent_in":latent_in,
+    }
+
+def process_appearance(path_mask, path_mask_replace, h, w, scale, input_scale, up_scale, up_ft_index, w_edit, w_content, precision):
+    if isinstance(path_mask, str):
+        mask_base = cv2.imread(path_mask)
+    else:
+        mask_base = path_mask
+    mask_base = cv2.resize(mask_base, (h, w))
+    if isinstance(path_mask_replace, str):
+        mask_replace = cv2.imread(path_mask_replace)
+    else:
+        mask_replace = path_mask_replace
+    mask_replace = cv2.resize(mask_replace, (h, w))
+
+    dict_mask = {}
+    mask_base = img2tensor(mask_base)[0]
+    dict_mask['base'] = mask_base
+    mask_base = (mask_base>0.5).to('cuda', dtype=precision)
+    mask_replace = img2tensor(mask_replace)[0]
+    dict_mask['replace'] = mask_replace
+    mask_replace = (mask_replace>0.5).to('cuda', dtype=precision)
+
+    mask_base_cur = F.interpolate(mask_base[None,None], (int(mask_base.shape[-2]//scale), int(mask_base.shape[-1]//scale)))>0.5
+    mask_replace_cur = F.interpolate(mask_replace[None,None], (int(mask_replace.shape[-2]//scale), int(mask_replace.shape[-1]//scale)))>0.5
+
+    return {
+        "dict_mask":dict_mask,
+        "mask_base_cur":mask_base_cur,
+        "mask_replace_cur":mask_replace_cur,
+        "up_scale":up_scale,
+        "up_ft_index":up_ft_index,
+        "w_edit":w_edit,
+        "w_content":w_content,
+    }
+
+def process_paste(path_mask, h, w, dx, dy, scale, input_scale, up_scale, up_ft_index, w_edit, w_content, precision, resize_scale=None):
+    dx, dy = dx*input_scale, dy*input_scale
+    if isinstance(path_mask, str):
+        mask_base = cv2.imread(path_mask)
+    else:
+        mask_base = path_mask
+    mask_base = cv2.resize(mask_base, (h, w))
+
+    dict_mask = {}
+    mask_base = img2tensor(mask_base)[0][None, None]
+    mask_base = (mask_base>0.5).to('cuda', dtype=precision)
+    if resize_scale is not None and resize_scale!=1:
+        hi, wi = mask_base.shape[-2], mask_base.shape[-1]
+        mask_base = F.interpolate(mask_base, (int(hi*resize_scale), int(wi*resize_scale)))
+        pad_size_x = np.abs(mask_base.shape[-1]-wi)//2
+        pad_size_y = np.abs(mask_base.shape[-2]-hi)//2
+        if resize_scale>1:
+            mask_base = mask_base[:,:,pad_size_y:pad_size_y+hi,pad_size_x:pad_size_x+wi]
+        else:
+            temp = torch.zeros(1,1,hi, wi).to(mask_base.device)
+            temp[:,:,pad_size_y:pad_size_y+mask_base.shape[-2],pad_size_x:pad_size_x+mask_base.shape[-1]]=mask_base
+            mask_base = temp
+    mask_replace = mask_base.clone()
+    mask_base = torch.roll(mask_base, (int(dy), int(dx)), (-2,-1))
+    dict_mask['base'] = mask_base[0,0]
+    dict_mask['replace'] = mask_replace[0,0]
+    mask_replace = (mask_replace>0.5).to('cuda', dtype=precision)
+
+    mask_base_cur = F.interpolate(mask_base, (int(mask_base.shape[-2]//scale), int(mask_base.shape[-1]//scale)))>0.5
+    mask_replace_cur = torch.roll(mask_base_cur, (-int(dy/scale), -int(dx/scale)), (-2,-1))
+
+    return {
+        "dict_mask":dict_mask,
+        "mask_base_cur":mask_base_cur,
+        "mask_replace_cur":mask_replace_cur,
+        "up_scale":up_scale,
+        "up_ft_index":up_ft_index,
+        "w_edit":w_edit,
+        "w_content":w_content,
+        "w_edit":w_edit,
+        "w_content":w_content,
+    }
\ No newline at end of file
diff --git a/utils/inversion.py b/utils/inversion.py
new file mode 100755
index 0000000000000000000000000000000000000000..16ce1796d3a905a1596300b39cdfa940d49d0e15
--- /dev/null
+++ b/utils/inversion.py
@@ -0,0 +1,265 @@
+import torch
+import numpy as np
+from PIL import Image
+from typing import Optional, Union, Tuple, List
+from tqdm import tqdm
+import os
+from diffusers import DDIMInverseScheduler,DPMSolverMultistepInverseScheduler
+import spaces
+
+class Inversion:
+
+    def next_step(self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
+                  sample: Union[torch.FloatTensor, np.ndarray]):
+        timestep, next_timestep = min(
+            timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999), timestep
+        alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
+        alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep]
+        beta_prod_t = 1 - alpha_prod_t
+        next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
+        next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
+        next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
+        return next_sample
+    
+    @torch.no_grad()
+    def get_noise_pred_single(self, latents, t, context,cond=True,both=False):
+        added_cond_id=1 if cond else 0
+        do_classifier_free_guidance=False
+        latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+        latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+        if both is False:
+            added_cond_kwargs = {"text_embeds": self.add_text_embeds[added_cond_id].unsqueeze(0).repeat(self.inv_batch_size,1), "time_ids": self.add_time_ids[added_cond_id].unsqueeze(0).repeat(self.inv_batch_size,1)}
+        else:
+            added_cond_kwargs = {"text_embeds": self.add_text_embeds, "time_ids": self.add_time_ids}
+        noise_pred = self.model.unet(
+            latent_model_input,
+            t,
+            encoder_hidden_states=context,
+            cross_attention_kwargs=None,
+            added_cond_kwargs=added_cond_kwargs,
+            return_dict=False,
+        )[0]
+        return noise_pred
+
+    @torch.no_grad()
+    def latent2image(self, latents, return_type='np'):
+        latents = 1 / self.model.vae.config.scaling_factor * latents.detach()
+        self.model.vae.to(dtype=torch.float32)
+        image = self.model.vae.decode(latents)['sample']
+        if return_type == 'np':
+            image = (image / 2 + 0.5).clamp(0, 1)
+            image = image.cpu().permute(0, 2, 3, 1).numpy()
+            image = (image * 255).astype(np.uint8)
+        return image
+
+    @torch.no_grad()
+    @spaces.GPU
+    def image2latent(self, image):
+        with torch.no_grad():
+            if type(image) is Image:
+                image = np.array(image)
+            else:
+                if image.ndim==3:
+                    image=np.expand_dims(image,0)
+                image = torch.from_numpy(image).float() / 127.5 - 1
+                image = image.permute(0, 3, 1, 2).to(self.device)
+                print(f"Running on device: {self.device}")
+                latents=[]
+                for i,_ in enumerate(image):
+                    latent=self.model.vae.encode(image[i:i+1])['latent_dist'].mean
+                    latents.append(latent)
+                latents=torch.stack(latents).squeeze(1)
+                latents = latents * self.model.vae.config.scaling_factor
+        return latents
+
+    @torch.no_grad()
+    def init_prompt(
+        self,
+        prompt:  Union[str, List[str]],
+        original_size: Optional[Tuple[int, int]] = None,
+        crops_coords_top_left: Tuple[int, int] = (0, 0),
+        target_size: Optional[Tuple[int, int]] = None,
+    ):
+        original_size = original_size or (1024, 1024)
+        target_size = target_size or (1024, 1024)
+        # 3. Encode input prompt
+        do_classifier_free_guidance=True
+        (
+            prompt_embeds,
+            negative_prompt_embeds,
+            pooled_prompt_embeds,
+            negative_pooled_prompt_embeds,
+        ) = self.model.encode_prompt_not_zero_uncond(
+            prompt,
+            self.model.device,
+            1,
+            do_classifier_free_guidance,
+            negative_prompt=None,
+            prompt_embeds=None,
+            negative_prompt_embeds=None,
+            pooled_prompt_embeds=None,
+            negative_pooled_prompt_embeds=None,
+            lora_scale=None,
+        )
+        prompt_embeds=prompt_embeds[:self.inv_batch_size]
+        negative_prompt_embeds=negative_prompt_embeds[:self.inv_batch_size]
+        pooled_prompt_embeds=pooled_prompt_embeds[:self.inv_batch_size]
+        negative_pooled_prompt_embeds=negative_pooled_prompt_embeds[:self.inv_batch_size]
+        # 7. Prepare added time ids & embeddings
+        add_text_embeds = pooled_prompt_embeds
+        add_time_ids = self.model._get_add_time_ids(
+            original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
+        )
+
+        if do_classifier_free_guidance:
+            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+            add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+            add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
+
+        prompt_embeds = prompt_embeds.to(self.device)
+        self.add_text_embeds = add_text_embeds.to(self.device)
+        self.add_time_ids = add_time_ids.to(self.device).repeat(self.inv_batch_size * 1, 1)
+
+        self.prompt_embeds=prompt_embeds
+        self.negative_prompt_embeds=negative_prompt_embeds
+        self.pooled_prompt_embeds=pooled_prompt_embeds
+        self.negative_pooled_prompt_embeds=negative_pooled_prompt_embeds
+        self.prompt = prompt
+        self.context=prompt_embeds
+
+    @torch.no_grad()
+    @spaces.GPU
+    def ddim_loop(self, latent):
+        uncond_embeddings, cond_embeddings = self.context.chunk(2)
+        all_latent = [latent]
+        latent = latent.clone().detach()
+        extra_step_kwargs = self.model.prepare_extra_step_kwargs(self.generator, self.eta)
+        if isinstance(self.inverse_scheduler,DDIMInverseScheduler):
+            extra_step_kwargs.pop("generator")
+        for i in tqdm(range(self.num_ddim_steps)):
+            use_inv_sc=False 
+            if use_inv_sc:
+                t = self.inverse_scheduler.timesteps[i]
+                noise_pred = self.get_noise_pred_single(latent, t, cond_embeddings,cond=True)
+                latent = self.inverse_scheduler.step(noise_pred, t, latent, **extra_step_kwargs, return_dict=False)[0]
+            else:
+                t = self.model.scheduler.timesteps[len(self.model.scheduler.timesteps) - i - 1]
+                noise_pred = self.get_noise_pred_single(latent, t, cond_embeddings,cond=True)
+                latent = self.next_step(noise_pred, t, latent)
+            all_latent.append(latent)
+        return all_latent
+
+    @property
+    def scheduler(self):
+        return self.model.scheduler
+
+    @torch.no_grad()
+    @spaces.GPU
+    def ddim_inversion(self, image):
+        latent = self.image2latent(image) 
+        image_rec = self.latent2image(latent) 
+        ddim_latents = self.ddim_loop(latent.to(self.model.unet.dtype)) 
+        return image_rec, ddim_latents
+
+    from typing import Union, List, Dict
+    import numpy as np
+
+    @spaces.GPU
+    def invert(self, image_gt, prompt: Union[str, List[str]], 
+            verbose=True, inv_output_pos=None, inv_batch_size=1):
+
+        self.inv_batch_size = inv_batch_size
+        self.init_prompt(prompt)
+        out_put_pos = 0 if inv_output_pos is None else inv_output_pos
+        self.out_put_pos = out_put_pos
+        if verbose:
+            print("DDIM inversion...")
+        image_rec, ddim_latents = self.ddim_inversion(image_gt)
+        if verbose:
+            print("Done.")
+        return (image_gt, image_rec), ddim_latents[-1], ddim_latents, self.prompt_embeds[self.prompt_embeds.shape[0]//2:], self.pooled_prompt_embeds
+
+    def __init__(self, model,num_ddim_steps,generator=None,scheduler_type="DDIM"):
+        self.model = model
+        self.tokenizer = self.model.tokenizer
+        self.num_ddim_steps=num_ddim_steps
+        if scheduler_type == "DDIM":
+            self.inverse_scheduler=DDIMInverseScheduler.from_config(self.model.scheduler.config)
+            self.inverse_scheduler.set_timesteps(num_ddim_steps)
+        elif scheduler_type=="DPMSolver":
+            self.inverse_scheduler=DPMSolverMultistepInverseScheduler.from_config(self.model.scheduler.config)
+            self.inverse_scheduler.set_timesteps(num_ddim_steps)
+        self.model.scheduler.set_timesteps(num_ddim_steps)
+        self.model.vae.to(dtype=torch.float32)
+        self.prompt = None
+        self.context = None
+        # self.device=self.model.unet.device
+        self.device = torch.device("cuda:0")
+        self.generator=generator
+        self.eta=0.0
+
+def load_512(image_path, left=0, right=0, top=0, bottom=0):
+    if type(image_path) is str:
+        image = np.array(Image.open(image_path))[:, :, :3]
+    else:
+        image = image_path
+    h, w, c = image.shape
+    left = min(left, w - 1)
+    right = min(right, w - left - 1)
+    top = min(top, h - left - 1)
+    bottom = min(bottom, h - top - 1)
+    image = image[top:h - bottom, left:w - right]
+    h, w, c = image.shape
+    if h < w:
+        offset = (w - h) // 2
+        image = image[:, offset:offset + h]
+    elif w < h:
+        offset = (h - w) // 2
+        image = image[offset:offset + w]
+    image = np.array(Image.fromarray(image).resize((512, 512)))
+    return image
+
+def load_1024_mask(image_path, left=0, right=0, top=0, bottom=0,target_H=128,target_W=128):
+    if type(image_path) is str:
+        image = np.array(Image.open(image_path))[:, :, np.newaxis]
+    else:
+        image = image_path
+    if len(image.shape) == 4:
+        image = image[:, :, :, 0]
+    h, w, c = image.shape
+    left = min(left, w - 1)
+    right = min(right, w - left - 1)
+    top = min(top, h - left - 1)
+    bottom = min(bottom, h - top - 1)
+    image = image[top:h - bottom, left:w - right]
+    h, w, c = image.shape
+    if h < w:
+        offset = (w - h) // 2
+        image = image[:, offset:offset + h]
+    elif w < h:
+        offset = (h - w) // 2
+        image = image[offset:offset + w]
+    image=image.squeeze()
+    image = np.array(Image.fromarray(image).resize((target_H, target_W)))
+    return image
+
+def load_1024(image_path, left=0, right=0, top=0, bottom=0):
+    if type(image_path) is str:
+        image = np.array(Image.open(image_path).resize((1024, 1024)))[:, :, :3]
+    else:
+        image = image_path
+    h, w, c = image.shape
+    left = min(left, w - 1)
+    right = min(right, w - left - 1)
+    top = min(top, h - left - 1)
+    bottom = min(bottom, h - top - 1)
+    image = image[top:h - bottom, left:w - right]
+    h, w, c = image.shape
+    if h < w:
+        offset = (w - h) // 2
+        image = image[:, offset:offset + h]
+    elif w < h:
+        offset = (h - w) // 2
+        image = image[offset:offset + w]
+    image = np.array(Image.fromarray(image).resize((1024, 1024)))
+    return image
\ No newline at end of file
diff --git a/utils/sdxl.py b/utils/sdxl.py
new file mode 100755
index 0000000000000000000000000000000000000000..30e3cee65b303fa704785adb6394a39eebb523a4
--- /dev/null
+++ b/utils/sdxl.py
@@ -0,0 +1,986 @@
+
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+# import seaborn as sns
+import matplotlib.pyplot as plt
+import torch
+from diffusers import StableDiffusionXLPipeline
+from typing import Optional, Union, Tuple, List, Callable, Dict
+import numpy as np
+import copy
+import torch.nn.functional as F
+from diffusers.loaders import  LoraLoaderMixin, TextualInversionLoaderMixin 
+from diffusers.models.attention_processor import ( AttnProcessor2_0, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, XFormersAttnProcessor, ) 
+from diffusers.utils import (  logging, randn_tensor, replace_example_docstring, ) 
+from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput 
+from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
+import os
+logger = logging.get_logger(__name__)
+
+EXAMPLE_DOC_STRING = """
+    Examples:
+        ```py
+        >>> import torch
+        >>> from diffusers import StableDiffusionXLPipeline
+
+        >>> pipe = StableDiffusionXLPipeline.from_pretrained(
+        ...     "stabilityai/stable-diffusion-xl-base-0.9", torch_dtype=torch.float16
+        ... )
+        >>> pipe = pipe.to("cuda")
+
+        >>> prompt = "a photo of an astronaut riding a horse on mars"
+        >>> image = pipe(prompt).images[0]
+        ```
+"""
+
+
+class sdxl(StableDiffusionXLPipeline): 
+    @replace_example_docstring(EXAMPLE_DOC_STRING)
+    @torch.no_grad()
+    def __call__(
+        self,
+        controller=None,
+        prompt: Union[str, List[str]] = None,
+        height: Optional[int] = None,
+        width: Optional[int] = None,
+        num_inference_steps: int = 50,
+        guidance_scale: float = 7.5,
+        negative_prompt: Optional[Union[str, List[str]]] = None,
+        num_images_per_prompt: Optional[int] = 1,
+        eta: float = 0.0,
+        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+        latents: Optional[torch.FloatTensor] = None,
+        prompt_embeds: Optional[torch.FloatTensor] = None,
+        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+        pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+        negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+        output_type: Optional[str] = "pil",
+        return_dict: bool = True,
+        callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+        callback_steps: int = 1,
+        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+        guidance_rescale: float = 0.0,
+        original_size: Optional[Tuple[int, int]] = None,
+        crops_coords_top_left: Tuple[int, int] = (0, 0),
+        target_size: Optional[Tuple[int, int]] = None,
+        same_init=False,
+        x_stars=None,
+        prox_guidance=True,
+        masa_control=False,
+        masa_mask=False,
+        masa_start_step=40,
+        masa_start_layer=55,
+        mask_file=None,
+        query_mask_time=[0, 10],
+        **kwargs
+    ):
+        r"""
+        Function invoked when calling the pipeline for generation.
+
+        Args:
+            prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+                instead.
+            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+                The height in pixels of the generated image.
+            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+                The width in pixels of the generated image.
+            num_inference_steps (`int`, *optional*, defaults to 50):
+                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+                expense of slower inference.
+            guidance_scale (`float`, *optional*, defaults to 7.5):
+                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+                `guidance_scale` is defined as `w` of equation 2. of [Imagen
+                Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+                usually at the expense of lower image quality.
+            negative_prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts not to guide the image generation. If not defined, one has to pass
+                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+                less than `1`).
+            num_images_per_prompt (`int`, *optional*, defaults to 1):
+                The number of images to generate per prompt.
+            eta (`float`, *optional*, defaults to 0.0):
+                Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+                [`schedulers.DDIMScheduler`], will be ignored for others.
+            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+                to make generation deterministic.
+            latents (`torch.FloatTensor`, *optional*):
+                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+                tensor will ge generated by sampling using the supplied random `generator`.
+            prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+                provided, text embeddings will be generated from `prompt` input argument.
+            negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+                argument.
+            pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+                If not provided, pooled text embeddings will be generated from `prompt` input argument.
+            negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+                input argument.
+            output_type (`str`, *optional*, defaults to `"pil"`):
+                The output format of the generate image. Choose between
+                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+            return_dict (`bool`, *optional*, defaults to `True`):
+                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a
+                plain tuple.
+            callback (`Callable`, *optional*):
+                A function that will be called every `callback_steps` steps during inference. The function will be
+                called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+            callback_steps (`int`, *optional*, defaults to 1):
+                The frequency at which the `callback` function will be called. If not specified, the callback will be
+                called at every step.
+            cross_attention_kwargs (`dict`, *optional*):
+                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+                `self.processor` in
+                [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
+            guidance_rescale (`float`, *optional*, defaults to 0.7):
+                Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
+                Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
+                [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+                Guidance rescale factor should fix overexposure when using zero terminal SNR.
+
+        Examples:
+
+        Returns:
+            [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`:
+            [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
+            `tuple. When returning a tuple, the first element is a list with the generated images, and the second
+            element is a list of `bool`s denoting whether the corresponding generated image likely represents
+            "not-safe-for-work" (nsfw) content, according to the `safety_checker`.
+        """
+        # 0. Default height and width to unet
+        height = height or self.default_sample_size * self.vae_scale_factor
+        width = width or self.default_sample_size * self.vae_scale_factor
+
+        original_size = original_size or (height, width)
+        target_size = target_size or (height, width)
+
+        inv_batch_size = len(latents) if latents is not None else 1
+        # 1. Check inputs. Raise error if not correct
+        self.check_inputs(
+            prompt,
+            height,
+            width,
+            callback_steps,
+            negative_prompt,
+            prompt_embeds,
+            negative_prompt_embeds,
+            pooled_prompt_embeds,
+            negative_pooled_prompt_embeds,
+        )
+
+
+        # 2. Define call parameters
+        if prompt is not None and isinstance(prompt, str):
+            batch_size = 1
+        elif prompt is not None and isinstance(prompt, list):
+            batch_size = len(prompt)
+        else:
+            batch_size = prompt_embeds.shape[0]
+
+        device = self._execution_device
+
+        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+        # corresponds to doing no classifier free guidance.
+        do_classifier_free_guidance = guidance_scale > 1.0
+
+        # 3. Encode input prompt
+        text_encoder_lora_scale = (
+            cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+        )
+        (
+            prompt_embeds,
+            negative_prompt_embeds,
+            pooled_prompt_embeds,
+            negative_pooled_prompt_embeds,
+        ) = self.encode_prompt(
+            prompt,
+            device,
+            num_images_per_prompt,
+            do_classifier_free_guidance,
+            negative_prompt,
+            prompt_embeds=prompt_embeds,
+            negative_prompt_embeds=negative_prompt_embeds,
+            pooled_prompt_embeds=pooled_prompt_embeds,
+            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+            lora_scale=text_encoder_lora_scale,
+            sample_ref_match=kwargs['sample_ref_match'] if 'sample_ref_match' in kwargs else None,
+        )
+
+        # 4. Prepare timesteps
+        self.scheduler.set_timesteps(num_inference_steps, device=device)
+
+        timesteps = self.scheduler.timesteps
+
+        # 5. Prepare latent variables
+        num_channels_latents = self.unet.config.in_channels
+        latents = self.prepare_latents(
+            batch_size * num_images_per_prompt,
+            num_channels_latents,
+            height,
+            width,
+            prompt_embeds.dtype,
+            device,
+            generator,
+            latents,
+            same_init=same_init, #ADD
+            sample_ref_match=kwargs['sample_ref_match'] if 'sample_ref_match' in kwargs else None,
+        )
+        # 6. Prepare extra step kwargs.
+        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+        # 7. Prepare added time ids & embeddings
+        add_text_embeds = pooled_prompt_embeds
+        add_time_ids = self._get_add_time_ids(
+            original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
+        )
+
+        if do_classifier_free_guidance:
+            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+            add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+            add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
+
+        prompt_embeds = prompt_embeds.to(device)
+        add_text_embeds = add_text_embeds.to(device)
+        add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
+
+        # 8. Denoising loop
+        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+        with self.progress_bar(total=num_inference_steps) as progress_bar:
+            for i, t in enumerate(timesteps):
+                # expand the latents if we are doing classifier free guidance
+                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+                # predict the noise residual
+                added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+                noise_pred = self.unet(
+                    latent_model_input,
+                    t,
+                    encoder_hidden_states=prompt_embeds,
+                    cross_attention_kwargs=cross_attention_kwargs,
+                    added_cond_kwargs=added_cond_kwargs,
+                    return_dict=False,
+                )[0]
+
+                # perform guidance
+                if do_classifier_free_guidance:
+                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+                    # CHANGE START
+                    score_delta,mask_edit=self.prox_regularization(
+                        noise_pred_uncond,
+                        noise_pred_text,
+                        i,
+                        t,
+                        prox_guidance=prox_guidance,
+                    )
+                    if mask_edit is not None:
+                        a = 1
+                    noise_pred = noise_pred_uncond + guidance_scale * score_delta
+                    # CHANGE END
+
+                if do_classifier_free_guidance and guidance_rescale > 0.0:
+                    # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+                    noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
+              
+                # compute the previous noisy sample x_t -> x_t-1
+                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+                # ADD START
+                latents = self.proximal_guidance(
+                    i,
+                    t,
+                    latents,
+                    mask_edit,
+                    prox_guidance=prox_guidance,
+                    dtype=self.unet.dtype,
+                    x_stars=x_stars,
+                    controller=controller,
+                    sample_ref_match=kwargs['sample_ref_match'] if 'sample_ref_match' in kwargs else None,
+                    inv_batch_size=inv_batch_size,
+                    only_inversion_align=kwargs['only_inversion_align'] if 'only_inversion_align' in kwargs else False,
+                )
+                # ADD END
+                if controller is not None:
+                    latents = controller.step_callback(latents)
+                # call the callback, if provided
+                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+                    progress_bar.update()
+                    if callback is not None and i % callback_steps == 0:
+                        callback(i, t, latents)
+
+        # make sure the VAE is in float32 mode, as it overflows in float16
+        self.vae.to(dtype=torch.float32)
+
+        use_torch_2_0_or_xformers = isinstance(
+            self.vae.decoder.mid_block.attentions[0].processor,
+            (
+                AttnProcessor2_0,
+                XFormersAttnProcessor,
+                LoRAXFormersAttnProcessor,
+                LoRAAttnProcessor2_0,
+            ),
+        )
+        # if xformers or torch_2_0 is used attention block does not need
+        # to be in float32 which can save lots of memory
+        if use_torch_2_0_or_xformers:
+            self.vae.post_quant_conv.to(latents.dtype)
+            self.vae.decoder.conv_in.to(latents.dtype)
+            self.vae.decoder.mid_block.to(latents.dtype)
+        else:
+            latents = latents.float()
+
+        if not output_type == "latent":
+            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+        else:
+            image = latents
+            return StableDiffusionXLPipelineOutput(images=image)
+
+        image = self.watermark.apply_watermark(image)
+        image = self.image_processor.postprocess(image, output_type=output_type)
+
+        # Offload last model to CPU
+        if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+            self.final_offload_hook.offload()
+
+        return image
+    
+    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+    def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None,same_init=False,sample_ref_match=None):
+        shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+        if isinstance(generator, list) and len(generator) != batch_size:
+            raise ValueError(
+                f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+                f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+            )
+        if sample_ref_match is not None:
+            new_latents=randn_tensor((batch_size,*shape[1:]), generator=generator, device=device, dtype=dtype)
+            for key,value in sample_ref_match.items():
+                new_latents[key]=latents[value].clone()
+            latents=new_latents
+        else:
+            if same_init is True:
+                if latents is None:
+                    latents = randn_tensor((1,*shape[1:]), generator=generator, device=device, dtype=dtype).expand(shape).to(device)
+                else:
+                    if batch_size>1 and latents.shape[0]==1:
+                        latents=latents.expand(shape).to(device)
+                    else:
+                        latents = latents.to(device)
+            else: 
+                if latents is None:
+                    latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+                else:
+                    latents = latents.to(device)
+
+        # scale the initial noise by the standard deviation required by the scheduler
+        latents = latents * self.scheduler.init_noise_sigma
+        return latents
+    
+    def encode_prompt(
+        self,
+        prompt,
+        device: Optional[torch.device] = None,
+        num_images_per_prompt: int = 1,
+        do_classifier_free_guidance: bool = True,
+        negative_prompt=None,
+        prompt_embeds: Optional[torch.FloatTensor] = None,
+        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+        pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+        negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+        lora_scale: Optional[float] = None,
+        sample_ref_match=None,
+    ):
+        r"""
+        Encodes the prompt into text encoder hidden states.
+
+        Args:
+             prompt (`str` or `List[str]`, *optional*):
+                prompt to be encoded
+            device: (`torch.device`):
+                torch device
+            num_images_per_prompt (`int`):
+                number of images that should be generated per prompt
+            do_classifier_free_guidance (`bool`):
+                whether to use classifier free guidance or not
+            negative_prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts not to guide the image generation. If not defined, one has to pass
+                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+                less than `1`).
+            prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+                provided, text embeddings will be generated from `prompt` input argument.
+            negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+                argument.
+            pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+                If not provided, pooled text embeddings will be generated from `prompt` input argument.
+            negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+                input argument.
+            lora_scale (`float`, *optional*):
+                A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+        """
+        device = device or self._execution_device
+
+        # set lora scale so that monkey patched LoRA
+        # function of text encoder can correctly access it
+        if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+            self._lora_scale = lora_scale
+
+        if prompt is not None and isinstance(prompt, str):
+            batch_size = 1
+        elif prompt is not None and isinstance(prompt, list):
+            batch_size = len(prompt)
+        else:
+            batch_size = prompt_embeds.shape[0]
+
+        # Define tokenizers and text encoders
+        tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+        text_encoders = (
+            [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+        )
+
+        if prompt_embeds is None:
+            # textual inversion: procecss multi-vector tokens if necessary
+            prompt_embeds_list = []
+            for tokenizer, text_encoder in zip(tokenizers, text_encoders):
+                if isinstance(self, TextualInversionLoaderMixin):
+                    prompt = self.maybe_convert_prompt(prompt, tokenizer)
+
+                text_inputs = tokenizer(
+                    prompt,
+                    padding="max_length",
+                    max_length=tokenizer.model_max_length,
+                    truncation=True,
+                    return_tensors="pt",
+                )
+                text_input_ids = text_inputs.input_ids
+                untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+                if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+                    text_input_ids, untruncated_ids
+                ):
+                    removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+                    logger.warning(
+                        "The following part of your input was truncated because CLIP can only handle sequences up to"
+                        f" {tokenizer.model_max_length} tokens: {removed_text}"
+                    )
+
+                prompt_embeds = text_encoder(
+                    text_input_ids.to(device),
+                    output_hidden_states=True,
+                )
+
+                # We are only ALWAYS interested in the pooled output of the final text encoder
+                pooled_prompt_embeds = prompt_embeds[0]
+                prompt_embeds = prompt_embeds.hidden_states[-2]
+
+                bs_embed, seq_len, _ = prompt_embeds.shape
+                # duplicate text embeddings for each generation per prompt, using mps friendly method
+                prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+                prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+                prompt_embeds_list.append(prompt_embeds)
+
+            prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+
+        # get unconditional embeddings for classifier free guidance
+        zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
+        if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
+            negative_prompt_embeds = torch.zeros_like(prompt_embeds)
+            negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
+        elif do_classifier_free_guidance and negative_prompt_embeds is None:
+            negative_prompt = negative_prompt or ""
+            uncond_tokens: List[str]
+            if prompt is not None and type(prompt) is not type(negative_prompt):
+                raise TypeError(
+                    f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+                    f" {type(prompt)}."
+                )
+            elif isinstance(negative_prompt, str):
+                uncond_tokens = [negative_prompt]
+            elif batch_size != len(negative_prompt):
+                raise ValueError(
+                    f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+                    f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+                    " the batch size of `prompt`."
+                )
+            else:
+                uncond_tokens = negative_prompt
+            
+
+            negative_prompt_embeds_list = []
+            for tokenizer, text_encoder in zip(tokenizers, text_encoders):
+                # textual inversion: procecss multi-vector tokens if necessary
+                if isinstance(self, TextualInversionLoaderMixin):
+                    uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer)
+
+                max_length = prompt_embeds.shape[1]
+                uncond_input = tokenizer(
+                    uncond_tokens,
+                    padding="max_length",
+                    max_length=max_length,
+                    truncation=True,
+                    return_tensors="pt",
+                )
+
+                negative_prompt_embeds = text_encoder(
+                    uncond_input.input_ids.to(device),
+                    output_hidden_states=True,
+                )
+                # We are only ALWAYS interested in the pooled output of the final text encoder
+                negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+                negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
+
+                if do_classifier_free_guidance:
+                    # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+                    seq_len = negative_prompt_embeds.shape[1]
+
+                    negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device)
+
+                    negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+                    negative_prompt_embeds = negative_prompt_embeds.view(
+                        batch_size * num_images_per_prompt, seq_len, -1
+                    )
+
+                    # For classifier free guidance, we need to do two forward passes.
+                    # Here we concatenate the unconditional and text embeddings into a single batch
+                    # to avoid doing two forward passes
+
+                negative_prompt_embeds_list.append(negative_prompt_embeds)
+
+            negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
+
+        bs_embed = pooled_prompt_embeds.shape[0]
+        # ADD START
+        if sample_ref_match is not None:
+            new_negative_prompt_embeds=torch.zeros_like(prompt_embeds)
+            new_negative_pooled_prompt_embeds=torch.zeros_like(pooled_prompt_embeds)
+            for key,value in sample_ref_match.items():
+                new_negative_prompt_embeds[key]=negative_prompt_embeds[value].clone()
+                new_negative_pooled_prompt_embeds[key]=negative_pooled_prompt_embeds[value].clone()
+            negative_prompt_embeds=new_negative_prompt_embeds
+            negative_pooled_prompt_embeds=new_negative_pooled_prompt_embeds
+        else:
+            if negative_pooled_prompt_embeds.shape[0]==1 and bs_embed!=1:
+                negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.repeat(bs_embed,1)
+            if negative_prompt_embeds.shape[0]==1 and bs_embed!=1:
+                negative_prompt_embeds=negative_prompt_embeds.repeat(bs_embed,1,1)
+        # ADD END
+        pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+            bs_embed * num_images_per_prompt, -1
+        )
+        negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+            bs_embed * num_images_per_prompt, -1
+        )
+          
+        return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+
+    def encode_prompt_not_zero_uncond(
+        self,
+        prompt,
+        device: Optional[torch.device] = None,
+        num_images_per_prompt: int = 1,
+        do_classifier_free_guidance: bool = True,
+        negative_prompt=None,
+        prompt_embeds: Optional[torch.FloatTensor] = None,
+        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+        pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+        negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+        lora_scale: Optional[float] = None,
+    ):
+        r"""
+        Encodes the prompt into text encoder hidden states.
+
+        Args:
+             prompt (`str` or `List[str]`, *optional*):
+                prompt to be encoded
+            device: (`torch.device`):
+                torch device
+            num_images_per_prompt (`int`):
+                number of images that should be generated per prompt
+            do_classifier_free_guidance (`bool`):
+                whether to use classifier free guidance or not
+            negative_prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts not to guide the image generation. If not defined, one has to pass
+                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+                less than `1`).
+            prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+                provided, text embeddings will be generated from `prompt` input argument.
+            negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+                argument.
+            pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+                If not provided, pooled text embeddings will be generated from `prompt` input argument.
+            negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+                input argument.
+            lora_scale (`float`, *optional*):
+                A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+        """
+        device = device or self._execution_device
+
+        # set lora scale so that monkey patched LoRA
+        # function of text encoder can correctly access it
+        if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+            self._lora_scale = lora_scale
+
+        if prompt is not None and isinstance(prompt, str):
+            batch_size = 1
+        elif prompt is not None and isinstance(prompt, list):
+            batch_size = len(prompt)
+        else:
+            batch_size = prompt_embeds.shape[0]
+
+        # Define tokenizers and text encoders
+        tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+        text_encoders = (
+            [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+        )
+
+        if prompt_embeds is None:
+            # textual inversion: procecss multi-vector tokens if necessary
+            prompt_embeds_list = []
+            for tokenizer, text_encoder in zip(tokenizers, text_encoders):
+                if isinstance(self, TextualInversionLoaderMixin):
+                    prompt = self.maybe_convert_prompt(prompt, tokenizer)
+
+                text_inputs = tokenizer(
+                    prompt,
+                    padding="max_length",
+                    max_length=tokenizer.model_max_length,
+                    truncation=True,
+                    return_tensors="pt",
+                )
+                text_input_ids = text_inputs.input_ids
+                untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+                if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+                    text_input_ids, untruncated_ids
+                ):
+                    removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+                    logger.warning(
+                        "The following part of your input was truncated because CLIP can only handle sequences up to"
+                        f" {tokenizer.model_max_length} tokens: {removed_text}"
+                    )
+
+                prompt_embeds = text_encoder(text_input_ids.to(device),output_hidden_states=True)
+
+                # We are only ALWAYS interested in the pooled output of the final text encoder
+                pooled_prompt_embeds = prompt_embeds[0]
+                prompt_embeds = prompt_embeds.hidden_states[-2]
+
+                bs_embed, seq_len, _ = prompt_embeds.shape
+                # duplicate text embeddings for each generation per prompt, using mps friendly method
+                prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+                prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+                prompt_embeds_list.append(prompt_embeds)
+
+            prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+
+        # get unconditional embeddings for classifier free guidance
+        if do_classifier_free_guidance and negative_prompt_embeds is None:
+            negative_prompt = negative_prompt or ""
+            uncond_tokens: List[str]
+            if prompt is not None and isinstance(prompt,List) and negative_prompt == "":
+                negative_prompt = ["" for i in range(len(prompt))]
+            if prompt is not None and type(prompt) is not type(negative_prompt):
+                raise TypeError(
+                    f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+                    f" {type(prompt)}."
+                )
+            elif isinstance(negative_prompt, str):
+                uncond_tokens = [negative_prompt]
+            elif batch_size != len(negative_prompt):
+                raise ValueError(
+                    f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+                    f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+                    " the batch size of `prompt`."
+                )
+            else:
+                uncond_tokens = negative_prompt
+
+            negative_prompt_embeds_list = []
+            for tokenizer, text_encoder in zip(tokenizers, text_encoders):
+                # textual inversion: procecss multi-vector tokens if necessary
+                if isinstance(self, TextualInversionLoaderMixin):
+                    uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer)
+
+                max_length = prompt_embeds.shape[1]
+                uncond_input = tokenizer(
+                    uncond_tokens,
+                    padding="max_length",
+                    max_length=max_length,
+                    truncation=True,
+                    return_tensors="pt",
+                )
+
+                negative_prompt_embeds = text_encoder(
+                    uncond_input.input_ids.to(device),
+                    output_hidden_states=True,
+                )
+                # We are only ALWAYS interested in the pooled output of the final text encoder
+                negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+                negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
+
+                if do_classifier_free_guidance:
+                    # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+                    seq_len = negative_prompt_embeds.shape[1]
+
+                    negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device)
+
+                    negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+                    negative_prompt_embeds = negative_prompt_embeds.view(
+                        batch_size * num_images_per_prompt, seq_len, -1
+                    )
+
+                    # For classifier free guidance, we need to do two forward passes.
+                    # Here we concatenate the unconditional and text embeddings into a single batch
+                    # to avoid doing two forward passes
+
+                negative_prompt_embeds_list.append(negative_prompt_embeds)
+
+            negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
+
+        bs_embed = pooled_prompt_embeds.shape[0]
+        pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+            bs_embed * num_images_per_prompt, -1
+        )
+        negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+            bs_embed * num_images_per_prompt, -1
+        )
+
+        return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+
+    def prox_regularization(
+        self,
+        noise_pred_uncond,
+        noise_pred_text,
+        i,
+        t,
+        prox_guidance=False,
+        prox=None,
+        quantile=0.75,
+        recon_t=400,
+        dilate_radius=2,
+    ):
+        if prox_guidance is True:
+            mask_edit = None
+            if prox == 'l1':
+                score_delta = (noise_pred_text - noise_pred_uncond).float()
+                if quantile > 0:
+                    threshold = score_delta.abs().quantile(quantile)
+                else:
+                    threshold = -quantile  # if quantile is negative, use it as a fixed threshold
+                score_delta -= score_delta.clamp(-threshold, threshold)
+                score_delta = torch.where(score_delta > 0, score_delta-threshold, score_delta)
+                score_delta = torch.where(score_delta < 0, score_delta+threshold, score_delta)
+                if (recon_t > 0 and t < recon_t) or (recon_t < 0 and t > -recon_t):
+                    mask_edit = (score_delta.abs() > threshold).float()
+                    if dilate_radius > 0:
+                        radius = int(dilate_radius)
+                        mask_edit = dilate(mask_edit.float(), kernel_size=2*radius+1, padding=radius)
+            elif prox == 'l0':
+                score_delta = (noise_pred_text - noise_pred_uncond).float()
+                if quantile > 0:
+                    threshold = score_delta.abs().quantile(quantile)
+                else:
+                    threshold = -quantile  # if quantile is negative, use it as a fixed threshold
+                score_delta -= score_delta.clamp(-threshold, threshold)
+                if (recon_t > 0 and t < recon_t) or (recon_t < 0 and t > -recon_t):
+                    mask_edit = (score_delta.abs() > threshold).float()
+                    if dilate_radius > 0:
+                        radius = int(dilate_radius)
+                        mask_edit = dilate(mask_edit.float(), kernel_size=2*radius+1, padding=radius)
+            elif prox==None:
+                score_delta = (noise_pred_text - noise_pred_uncond).float()
+                if quantile > 0:
+                    threshold = score_delta.abs().quantile(quantile)
+                else:
+                    threshold = -quantile  # if quantile is negative, use it as a fixed threshold
+                if (recon_t > 0 and t < recon_t) or (recon_t < 0 and t > -recon_t):
+                    mask_edit = (score_delta.abs() > threshold).float()
+                    if dilate_radius > 0:
+                        radius = int(dilate_radius)
+                        mask_edit = dilate(mask_edit.float(), kernel_size=2*radius+1, padding=radius)
+            else:
+                raise NotImplementedError
+            return score_delta,mask_edit
+        else:
+            return noise_pred_text - noise_pred_uncond,None
+
+    def proximal_guidance(
+        self,
+        i,
+        t,
+        latents,
+        mask_edit,
+        dtype,
+        prox_guidance=False,
+        recon_t=400,
+        recon_end=0,
+        recon_lr=0.1,
+        x_stars=None, 
+        controller=None,
+        sample_ref_match=None,
+        inv_batch_size=1,
+        only_inversion_align=False,
+    ):
+        if mask_edit is not None and prox_guidance and (recon_t > recon_end and t < recon_t) or (recon_t < -recon_end and t > -recon_t):
+            if controller.layer_fusion.remove_mask is not None:
+                fix_mask = copy.deepcopy(controller.layer_fusion.remove_mask)
+                mask_edit[1] = (mask_edit[1]+fix_mask).clamp(0,1) 
+                if mask_edit.shape[0] > 2:
+                    mask_edit[2].fill_(1) 
+            recon_mask = 1 - mask_edit
+            target_latents=x_stars[len(x_stars)-i-2]
+            new_target_latents=torch.zeros_like(latents)
+            for key,value in sample_ref_match.items():
+                new_target_latents[key]=target_latents[value].clone() 
+            latents = latents - recon_lr * (latents - new_target_latents) * recon_mask
+        return latents.to(dtype)  
+    
+def slerp(val, low, high):
+    """ taken from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/4
+    """
+    low_norm = low/torch.norm(low, dim=1, keepdim=True)
+    high_norm = high/torch.norm(high, dim=1, keepdim=True)
+    omega = torch.acos((low_norm*high_norm).sum(1))
+    so = torch.sin(omega)
+    res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
+    return res
+
+
+def slerp_tensor(val, low, high):
+    shape = low.shape
+    res = slerp(val, low.flatten(1), high.flatten(1))
+    return res.reshape(shape)
+
+
+def dilate(image, kernel_size, stride=1, padding=0):
+    """
+    Perform dilation on a binary image using a square kernel.
+    """
+    # Ensure the image is binary
+    assert image.max() <= 1 and image.min() >= 0
+    
+    # Get the maximum value in each neighborhood
+    dilated_image = F.max_pool2d(image, kernel_size, stride, padding)
+    
+    return dilated_image
+
+def exec_classifier_free_guidance(model,latents,controller,t,guidance_scale,
+                                  do_classifier_free_guidance,noise_pred,guidance_rescale,
+                                  prox=None, quantile=0.75,image_enc=None, recon_lr=0.1, recon_t=400,recon_end_t=0,
+                                  inversion_guidance=False, reconstruction_guidance=False,x_stars=None, i=0,
+                                    use_localblend_mask=False,
+                                  save_heatmap=False,**kwargs):
+    # perform guidance
+    if do_classifier_free_guidance:
+        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+        #noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+        if prox is None and inversion_guidance is True:
+            prox = 'l1'
+        step_kwargs = {
+            'ref_image': None,
+            'recon_lr': 0,
+            'recon_mask': None,
+        }
+        mask_edit = None
+        if prox is not None:
+            if prox == 'l1':
+                score_delta = (noise_pred_text - noise_pred_uncond).float()
+                if quantile > 0:
+                    threshold = score_delta.abs().quantile(quantile)
+                else:
+                    threshold = -quantile  # if quantile is negative, use it as a fixed threshold
+                score_delta -= score_delta.clamp(-threshold, threshold)
+                score_delta = torch.where(score_delta > 0, score_delta-threshold, score_delta)
+                score_delta = torch.where(score_delta < 0, score_delta+threshold, score_delta)
+                if (recon_t > 0 and t < recon_t) or (recon_t < 0 and t > -recon_t):
+                    step_kwargs['ref_image'] = image_enc
+                    step_kwargs['recon_lr'] = recon_lr
+                    score_delta_norm=score_delta.abs()
+                    score_delta_norm=(score_delta_norm - score_delta_norm.min ()) / (score_delta_norm.max () - score_delta_norm.min ())
+                    mask_edit = (score_delta.abs() > threshold).float()
+                    if save_heatmap and i%10==0:
+                        for kk in range(4):
+                            sns.heatmap(mask_edit[1][kk].clone().cpu(), cmap='coolwarm')
+                            plt.savefig(f'./vis/prox_inv/heatmap1_mask_{i}_{kk}.png')
+                            plt.clf()
+                    if kwargs.get('dilate_mask', 2) > 0:
+                        radius = int(kwargs.get('dilate_mask', 2))
+                        mask_edit = dilate(mask_edit.float(), kernel_size=2*radius+1, padding=radius)
+                    if save_heatmap and i%10==0:
+                        for kk in range(4):
+                            sns.heatmap(mask_edit[1][kk].clone().cpu(), cmap='coolwarm')
+                            plt.savefig(f'./vis/prox_inv/heatmap1_mask_dilate_{i}_{kk}.png')
+                            plt.clf()
+                    step_kwargs['recon_mask'] = 1 - mask_edit
+            elif prox == 'l0':
+                score_delta = (noise_pred_text - noise_pred_uncond).float()
+                if quantile > 0:
+                    threshold = score_delta.abs().quantile(quantile)
+                else:
+                    threshold = -quantile  # if quantile is negative, use it as a fixed threshold
+                score_delta -= score_delta.clamp(-threshold, threshold)
+                if (recon_t > 0 and t < recon_t) or (recon_t < 0 and t > -recon_t):
+                    step_kwargs['ref_image'] = image_enc
+                    step_kwargs['recon_lr'] = recon_lr
+                    mask_edit = (score_delta.abs() > threshold).float()
+                    if kwargs.get('dilate_mask', 2) > 0:
+                        radius = int(kwargs.get('dilate_mask', 2))
+                        mask_edit = dilate(mask_edit.float(), kernel_size=2*radius+1, padding=radius)
+                    step_kwargs['recon_mask'] = 1 - mask_edit
+            else:
+                raise NotImplementedError
+            noise_pred = (noise_pred_uncond + guidance_scale * score_delta).to(model.unet.dtype)
+        else:
+            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+            
+    if do_classifier_free_guidance and guidance_rescale > 0.0:
+    # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+        noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
+    if reconstruction_guidance:
+        kwargs.update(step_kwargs)
+    latents = model.scheduler.step(noise_pred, t, latents, **kwargs, return_dict=False)[0]
+    if mask_edit is not None and inversion_guidance and (recon_t > recon_end_t and t < recon_t) or (recon_t < recon_end_t and t > -recon_t):
+        if use_localblend_mask:
+            assert hasattr(controller,"layer_fusion")
+            if save_heatmap and i%10==0:
+                sns.heatmap(controller.layer_fusion.mask[0][0].clone().cpu(), cmap='coolwarm')
+                plt.savefig(f'./vis/prox_inv/heatmap0_localblendmask_{i}.png')
+                plt.clf()
+                sns.heatmap(controller.layer_fusion.mask[1][0].clone().cpu(), cmap='coolwarm')
+                plt.savefig(f'./vis/prox_inv/heatmap1_localblendmask_{i}.png')
+                plt.clf()
+            layer_fusion_mask=controller.layer_fusion.mask.float()
+            layer_fusion_mask[0]=layer_fusion_mask[1]
+            recon_mask=1-layer_fusion_mask.expand_as(latents)
+        else:
+            recon_mask = 1 - mask_edit
+        target_latents=x_stars[len(x_stars)-i-2].expand_as(latents)
+        # if target_latents有四维
+        if len(target_latents.shape)==4:
+            target_latents=target_latents[0]
+        latents = latents - recon_lr * (latents - target_latents) * recon_mask
+    # controller
+    if controller is not None:
+        latents = controller.step_callback(latents)
+    return latents.to(model.unet.dtype)
diff --git a/utils/utils.py b/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd71b6a15f9b9376b0ef161d40d00a78df790726
--- /dev/null
+++ b/utils/utils.py
@@ -0,0 +1,363 @@
+import cv2
+from matplotlib import pyplot as plt
+import numpy as np
+import torch
+from PIL import Image, ImageDraw, ImageFont
+from datetime import datetime
+import os
+from typing import List, Dict
+
+def convert_and_resize_mask(mask):
+    if mask.ndim == 3:
+        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
+    resized_mask = cv2.resize(mask, (1024, 1024))       
+    return resized_mask
+
+def add_masks_resized(masks):
+    final_mask = np.zeros((1024, 1024), dtype=np.uint8)         
+    for mask in masks:
+        if mask is not None:
+            resized_mask = convert_and_resize_mask(mask)
+            resized_mask = resized_mask.astype(np.uint8)
+            final_mask = cv2.add(final_mask, resized_mask)
+    return final_mask
+
+def attend_mask(mask_file, attend_scale=10, save=False):
+    if isinstance(mask_file, str):
+        if mask_file == '':
+            return torch.zeros([1, 1, 128, 128], dtype=torch.float32).cuda()
+        else:
+            image_with_mask = cv2.imread(mask_file, cv2.IMREAD_GRAYSCALE)
+    elif len(mask_file.shape) == 3: # convert RGB to gray
+        image_with_mask = cv2.cvtColor(mask_file, cv2.COLOR_BGR2GRAY)
+    
+    else:
+        image_with_mask = mask_file
+
+    if attend_scale != 0:
+        kernel = np.ones((abs(attend_scale), abs(attend_scale)), np.uint8)        
+        if attend_scale > 0:
+            image_with_mask = cv2.dilate(image_with_mask, kernel, iterations=1)
+        else:
+            image_with_mask = cv2.erode(image_with_mask, kernel, iterations=1)
+        
+        if save and isinstance(mask_file, str):
+            new_mask_file_name = mask_file[:-4]+'_'+str(attend_scale)+'.jpg'
+            cv2.imwrite(new_mask_file_name, image_with_mask)
+            print("new_mask is saved in ", new_mask_file_name)
+
+    dilated_image= cv2.resize(image_with_mask, (128, 128), interpolation=cv2.INTER_NEAREST)
+    dilated_image = torch.from_numpy(dilated_image).to(torch.float32).unsqueeze(0).unsqueeze(0).cuda() / 255 
+    return dilated_image
+
+
+def panning(img_path=None, op_list=[['left', 0.2]], save=False, save_dir=None):
+    if isinstance(img_path, str):
+        img = cv2.imread(img_path)
+    else:
+        img = img_path
+    img_new = img.copy()
+    img_height, img_width, _ = img.shape
+    w_mask = 255 * np.ones((img_height, img_width), dtype=np.uint8)
+    h_mask = 255 * np.ones((img_height, img_width), dtype=np.uint8)
+
+    for op in op_list:
+        scale = op[1]
+        if op[0] in ['right', 'left']:
+            K = int(scale*img_width)
+        elif op[0] in ['up', 'down']:
+            K = int(scale*img_height)
+      
+        if op[0] == 'right':
+            img_new[:, K:, :] = img[:, 0:img_width-K, :]
+            w_mask[:, K:] = 0
+        elif op[0] == 'left':
+            img_new[:, 0:img_width-K, :] = img[:, K:, :]
+            w_mask[:, 0:img_width-K] = 0
+        elif op[0] == 'down':
+            img_new[K:, :, :] = img[0:img_height-K, :, :]
+            h_mask[K:, :] = 0
+        elif op[0] == 'up':
+            img_new[0:img_height-K, :, :] = img[K:, :, :]
+            h_mask[0:img_height-K, :] = 0
+        img = img_new
+    
+    mask = w_mask + h_mask
+    mask[mask>0] = 255
+    
+    if save:
+        if save_dir is None:
+            base_dir = os.path.dirname(img_path)
+            save_dir = os.path.join(base_dir, 'preprocess')
+        elif not os.path.exists(save_dir):
+            os.makedirs(save_dir)
+        resized_img_name = f"{save_dir}/resized_image.png"
+        resized_mask_name = f"{save_dir}/resized_mask.png"
+        cv2.imwrite(resized_img_name, img_new)
+        cv2.imwrite(resized_mask_name, mask)
+        return resized_img_name, resized_mask_name
+    else:
+        return img_new, mask
+
+def zooming(img_path=None, scale=[0.8, 0.8], save=False, save_dir=None):
+    if isinstance(img_path, str):
+        img = cv2.imread(img_path)
+    else:
+        img = img_path
+    img_new = img.copy()
+    img_height, img_width, _ = img.shape
+    mask = 255 * np.ones((img_height, img_width), dtype=np.uint8)
+
+    new_height = int(img_height*scale[0])
+    new_width = int(img_width*scale[1])
+    resized_img = cv2.resize(img, (new_width, new_height), interpolation=cv2.INTER_AREA)
+    x_offset = (img_width - new_width) // 2
+    y_offset = (img_height - new_height) // 2
+
+    img_new[y_offset:y_offset + new_height, x_offset:x_offset + new_width] = resized_img
+    mask[y_offset:y_offset + new_height, x_offset:x_offset + new_width] = 0
+
+    if save:
+        if save_dir is None:
+            base_dir = os.path.dirname(img_path)
+            save_dir = os.path.join(base_dir, 'preprocess')
+        elif not os.path.exists(save_dir):
+            os.makedirs(save_dir)
+
+        resized_img_name = f"{save_dir}/resized_image.png"
+        resized_mask_name = f"{save_dir}/resized_mask.png"
+        cv2.imwrite(resized_img_name, img_new)
+        cv2.imwrite(resized_mask_name, mask)
+        return resized_img_name, resized_mask_name
+    else:
+        return img_new, mask
+
+def get_box(mask, bias = 2):
+    nonzero_indices = torch.nonzero(mask)
+    H, W = mask.shape[-2:]
+    min_x = max(min(nonzero_indices[:, 1]) - bias, 0)
+    min_y = max(min(nonzero_indices[:, 0]) - bias, 0)
+    max_x = min(max(nonzero_indices[:, 1]) + bias, W)
+    max_y = min(max(nonzero_indices[:, 0]) + bias, H)
+    return (min_x, min_y, max_x, max_y)
+
+
+def draw_axis(img,grid_dict,x_len,y_len):
+    if grid_dict is not None and grid_dict is not False:
+        assert isinstance(grid_dict,Dict)
+        assert "x_title" in grid_dict
+        assert "y_title" in grid_dict
+        assert "x_text_list" in grid_dict
+        assert "y_text_list" in grid_dict
+        x_title=grid_dict["x_title"]
+        y_title=grid_dict["y_title"]
+        x_text_list=grid_dict['x_text_list']
+        y_text_list=grid_dict['y_text_list']
+        assert len(y_text_list)==y_len
+        assert len(x_text_list)==x_len
+        assert "font_size" in grid_dict
+        font_size=grid_dict["font_size"]
+        if "x_color" in grid_dict:
+            color_x=grid_dict['x_color']
+        else:
+            color_x="black"
+        if "y_color" in grid_dict:
+            color_y=grid_dict['y_color']
+        else:
+            color_y="black"
+        if "num_decimals" in grid_dict:
+            num_decimals=grid_dict['num_decimals']
+        else:
+            num_decimals=2
+        if "shift_x" in grid_dict:
+            shift_x_x,shift_x_y=grid_dict['shift_x']
+        else:
+            shift_x_x=shift_x_y=0
+        if "shift_y" in grid_dict:
+            shift_y_x,shift_y_y=grid_dict['shift_y']
+        else:
+            shift_y_x=shift_y_y=0
+        if "title" in grid_dict:
+            title=grid_dict['title']
+            if isinstance(title,List):
+                all_title=""
+                for s in title:
+                    all_title=all_title+s+"\n"
+                title=all_title
+        else:
+            title=''
+        width, height = img.size
+        num_x=x_len
+        num_y=y_len
+
+        new_img = Image.new("RGB", (width + width // num_x+width // (num_x*2), height + height // num_y+height // (num_y*2)), color=(255, 255, 255))
+        width,height=(width + width // num_x, height + height // num_y)
+        num_x=num_x+1
+        num_y=num_y+1
+        new_img.paste(img, (width // num_x, height // num_y))
+
+        draw = ImageDraw.Draw(new_img)
+
+        font = ImageFont.truetype("DejaVuSansMono.ttf", font_size)
+        for i in range(2, num_x+1):
+            x = (i - 1) * width // num_x + width // (num_x * 2)-width *0.2// num_x+shift_x_x
+            y = height // (num_y * 2)+shift_x_y
+            k=i-1
+            if  isinstance(x_text_list[i-2],str):
+                draw.text((x, y), x_text_list[i-2], font=font,fill=color_x,align="center")
+            else:
+                draw.text((x, y), "{:.{}f}".format(x_text_list[i-2],num_decimals), font=font,fill=color_x,align="center")
+
+        for i in range(2, num_y+1):
+            x = width // (num_x * 2)-width *0.1// num_x+shift_y_x
+            y = (i - 1) * height // num_y + height // (num_y * 2)-height*0.1//num_y+shift_y_y
+            k = i - 1
+            if isinstance(y_text_list[i-2],str):
+                draw.text((x, y), y_text_list[i-2], font=font,fill=color_y,align="center")
+            else:
+                draw.text((x, y), "{:.{}f}".format(y_text_list[i-2],num_decimals), font=font,fill=color_y,align="center")
+        i=1
+        x = (i - 1) * width // num_x + width // (num_x * 2)-height*0.1//num_y+shift_y_x
+        y = height // (num_y * 2)+width *0.2// num_x+shift_y_y
+        draw.text((x, y), y_title, font=font, fill=color_y,align="center")
+        x = width // (num_x * 2)+width *0.2// num_x+shift_x_x
+        y = (i - 1) * height // num_y + height // (num_y * 2)+shift_x_y
+        draw.text((x, y), x_title, font=font, fill=color_x,align="left")
+        x = width // 4
+        y = (i - 1) * height // num_y + height // (num_y * 10)
+        draw.text((x, y), title, font=font, fill='blue',align="left")
+    else:
+
+        new_img=img
+    return new_img
+
+def view_images(images, num_rows=1, offset_ratio=0.02,text="",folder=None,Notimestamp=False,
+grid_dict=None,subfolder=None,verbose=True,output_dir=None,timestamp=None,**kwargs):
+    if type(images) is list:
+        num_empty = len(images) % num_rows
+    elif images.ndim == 4:
+        num_empty = images.shape[0] % num_rows
+    else:
+        images = [images]
+        num_empty = 0
+    origin_size=kwargs.get("origin_size",None)
+    images_copy=images.copy()
+    for i, per_image in enumerate(images_copy):
+        if isinstance(per_image, Image.Image) and origin_size is not None:
+            images[i] = np.array(per_image.resize((origin_size[1],origin_size[0])))
+        else:
+            images[i] = np.array(per_image)
+        
+    empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
+    images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty
+    num_items = len(images)
+
+    h, w, c = images[0].shape
+    offset = int(h * offset_ratio)
+    num_cols = num_items // num_rows
+    image_ = np.ones((h * num_rows + offset * (num_rows - 1),
+                      w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255
+    for i in range(num_rows):
+        for j in range(num_cols):
+            image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[
+                i * num_cols + j]
+
+    pil_img = Image.fromarray(image_)
+
+    pil_img_=draw_axis(pil_img,grid_dict,num_cols,num_rows)
+    if pil_img_.size[0]==pil_img_.size[1]:
+        pil_img_.resize((2048,2048))
+    else:
+        longer_side = max(pil_img.size)
+        ratio = 2048/longer_side
+        new_size = tuple([int(x*ratio) for x in pil_img.size])
+        pil_img = pil_img.resize(new_size)
+
+    if verbose is False:
+        return pil_img
+    now = datetime.now()
+    if timestamp is None:
+        if Notimestamp is False:
+            timestamp = now.strftime("%Y-%m-%d_%H-%M-%S")
+        else:
+            timestamp=""
+    if output_dir is None:
+        if timestamp != "":
+            date, time = timestamp.split('_')
+        else:
+            date, time = "",""
+        if folder is not None:
+            dirname="./"+folder
+            filename = text+f"img_{timestamp}.jpg"
+        else:
+            if subfolder is not None:
+                dirname=os.path.join("./img", subfolder,date)
+                dirname=os.path.join(dirname,time)            
+                filename =text+f"img_{timestamp}.jpg"
+            else:
+                dirname=os.path.join("./img",date)
+                dirname=os.path.join(dirname,time)
+                filename =text+f"img_{timestamp}.jpg"
+    else:
+        dirname=output_dir
+        filename =text+f"img_{timestamp}.jpg"
+    if not os.path.exists(dirname):
+        os.makedirs(dirname)
+    if verbose is True:
+        for i, img in enumerate(images):
+            im = Image.fromarray(img)
+            im.save(os.path.join(dirname,f"{i}.jpg"))
+    print(f"Output dir: {dirname}")
+    pil_img.save(os.path.join(dirname, filename))
+    if grid_dict is not None and grid_dict is not False:
+        if not os.path.exists(dirname):
+            os.makedirs(dirname)
+        pil_img_.save(os.path.join(dirname, filename[:-4]+"_2048x.jpg"))
+
+def resize_image_with_mask(img, mask, scale):
+    if scale == 1:
+        return img, mask, None
+    img_blackboard = img.copy() # canvas
+    mask_blackboard = np.zeros_like(mask)
+
+    M = cv2.moments(mask)
+    cx = int(M["m10"] / M["m00"])
+    cy = int(M["m01"] / M["m00"])
+
+    scale_factor = [scale, scale]
+    resized_img = cv2.resize(img, None, fx=scale_factor[0], fy=scale_factor[1], interpolation=cv2.INTER_AREA)
+    resized_mask = cv2.resize(mask, None, fx=scale_factor[0], fy=scale_factor[1], interpolation=cv2.INTER_AREA)
+    new_cx, new_cy = cx * scale_factor[0], cy * scale_factor[1]
+
+    for y in range(resized_mask.shape[0]):
+        for x in range(resized_mask.shape[1]):
+            if 0 <= cy - (new_cy - y) < img.shape[0] and 0 <= cx - (new_cx - x) < img.shape[1]:
+                mask_blackboard[int(cy - (new_cy - y)), int(cx - (new_cx - x))] = resized_mask[y, x]
+                img_blackboard[int(cy - (new_cy - y)), int(cx - (new_cx - x))] = resized_img[y, x]
+    return img_blackboard, mask_blackboard, (cx, cy)
+
+def flip_image_with_mask(img, mask, flip_code=None):
+    if flip_code is None:
+        return img, mask, None
+    M = cv2.moments(mask)
+    if M["m00"] == 0:  
+        return img, mask
+    cx = int(M["m10"] / M["m00"])
+    cy = int(M["m01"] / M["m00"])
+    
+    h, w = img.shape[:2]
+    img_center = (w // 2, h // 2)
+
+    tx = img_center[0] - cx
+    ty = img_center[1] - cy
+
+    M_translate = np.float32([[1, 0, tx], [0, 1, ty]])
+    img_translated = cv2.warpAffine(img, M_translate, (w, h))
+    mask_translated = cv2.warpAffine(mask, M_translate, (w, h))
+    flipped_img = cv2.flip(img_translated, flip_code)
+    flipped_mask = cv2.flip(mask_translated, flip_code)
+    M_translate_back = np.float32([[1, 0, -tx], [0, 1, -ty]])
+    flipped_img_back = cv2.warpAffine(flipped_img, M_translate_back, (w, h))
+    flipped_mask_back = cv2.warpAffine(flipped_mask, M_translate_back, (w, h))
+
+    return flipped_img_back, flipped_mask_back, (cx, cy)