Spaces:
Sleeping
Sleeping
Duplicate from Towsif7/genious_bgremover
Browse filesCo-authored-by: Towsif Labib <Towsif7@users.noreply.huggingface.co>
This view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +35 -0
- README.md +13 -0
- app.py +36 -0
- carvekit/__init__.py +1 -0
- carvekit/__main__.py +149 -0
- carvekit/__pycache__/__init__.cpython-38.pyc +0 -0
- carvekit/api/__init__.py +0 -0
- carvekit/api/__pycache__/__init__.cpython-38.pyc +0 -0
- carvekit/api/__pycache__/high.cpython-38.pyc +0 -0
- carvekit/api/__pycache__/interface.cpython-38.pyc +0 -0
- carvekit/api/high.py +100 -0
- carvekit/api/interface.py +77 -0
- carvekit/ml/__init__.py +4 -0
- carvekit/ml/__pycache__/__init__.cpython-38.pyc +0 -0
- carvekit/ml/arch/__init__.py +0 -0
- carvekit/ml/arch/__pycache__/__init__.cpython-38.pyc +0 -0
- carvekit/ml/arch/basnet/__init__.py +0 -0
- carvekit/ml/arch/basnet/__pycache__/__init__.cpython-38.pyc +0 -0
- carvekit/ml/arch/basnet/__pycache__/basnet.cpython-38.pyc +0 -0
- carvekit/ml/arch/basnet/basnet.py +478 -0
- carvekit/ml/arch/fba_matting/__init__.py +0 -0
- carvekit/ml/arch/fba_matting/__pycache__/__init__.cpython-38.pyc +0 -0
- carvekit/ml/arch/fba_matting/__pycache__/layers_WS.cpython-38.pyc +0 -0
- carvekit/ml/arch/fba_matting/__pycache__/models.cpython-38.pyc +0 -0
- carvekit/ml/arch/fba_matting/__pycache__/resnet_GN_WS.cpython-38.pyc +0 -0
- carvekit/ml/arch/fba_matting/__pycache__/resnet_bn.cpython-38.pyc +0 -0
- carvekit/ml/arch/fba_matting/__pycache__/transforms.cpython-38.pyc +0 -0
- carvekit/ml/arch/fba_matting/layers_WS.py +57 -0
- carvekit/ml/arch/fba_matting/models.py +341 -0
- carvekit/ml/arch/fba_matting/resnet_GN_WS.py +151 -0
- carvekit/ml/arch/fba_matting/resnet_bn.py +169 -0
- carvekit/ml/arch/fba_matting/transforms.py +45 -0
- carvekit/ml/arch/tracerb7/__init__.py +0 -0
- carvekit/ml/arch/tracerb7/__pycache__/__init__.cpython-38.pyc +0 -0
- carvekit/ml/arch/tracerb7/__pycache__/att_modules.cpython-38.pyc +0 -0
- carvekit/ml/arch/tracerb7/__pycache__/conv_modules.cpython-38.pyc +0 -0
- carvekit/ml/arch/tracerb7/__pycache__/effi_utils.cpython-38.pyc +0 -0
- carvekit/ml/arch/tracerb7/__pycache__/efficientnet.cpython-38.pyc +0 -0
- carvekit/ml/arch/tracerb7/__pycache__/tracer.cpython-38.pyc +0 -0
- carvekit/ml/arch/tracerb7/att_modules.py +290 -0
- carvekit/ml/arch/tracerb7/conv_modules.py +88 -0
- carvekit/ml/arch/tracerb7/effi_utils.py +579 -0
- carvekit/ml/arch/tracerb7/efficientnet.py +325 -0
- carvekit/ml/arch/tracerb7/tracer.py +97 -0
- carvekit/ml/arch/u2net/__init__.py +0 -0
- carvekit/ml/arch/u2net/__pycache__/__init__.cpython-38.pyc +0 -0
- carvekit/ml/arch/u2net/__pycache__/u2net.cpython-38.pyc +0 -0
- carvekit/ml/arch/u2net/u2net.py +172 -0
- carvekit/ml/files/__init__.py +7 -0
- carvekit/ml/files/__pycache__/__init__.cpython-38.pyc +0 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Genious Bgremover
|
3 |
+
emoji: 👁
|
4 |
+
colorFrom: green
|
5 |
+
colorTo: blue
|
6 |
+
sdk: streamlit
|
7 |
+
sdk_version: 1.25.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
duplicated_from: Towsif7/genious_bgremover
|
11 |
+
---
|
12 |
+
|
13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from carvekit.api.interface import Interface
|
3 |
+
from carvekit.ml.wrap.fba_matting import FBAMatting
|
4 |
+
from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7
|
5 |
+
from carvekit.pipelines.postprocessing import MattingMethod
|
6 |
+
from carvekit.pipelines.preprocessing import PreprocessingStub
|
7 |
+
from carvekit.trimap.generator import TrimapGenerator
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
# Create Streamlit app title
|
11 |
+
st.title("Image Background Remover")
|
12 |
+
|
13 |
+
# Create a file uploader
|
14 |
+
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png"])
|
15 |
+
|
16 |
+
if uploaded_file is not None:
|
17 |
+
# Load the image
|
18 |
+
image = Image.open(uploaded_file)
|
19 |
+
|
20 |
+
# Set up ML pipeline
|
21 |
+
seg_net = TracerUniversalB7(device='cpu', batch_size=1)
|
22 |
+
fba = FBAMatting(device='cpu', input_tensor_size=2048, batch_size=1)
|
23 |
+
trimap = TrimapGenerator()
|
24 |
+
preprocessing = PreprocessingStub()
|
25 |
+
postprocessing = MattingMethod(matting_module=fba, trimap_generator=trimap, device='cpu')
|
26 |
+
interface = Interface(pre_pipe=preprocessing, post_pipe=postprocessing, seg_pipe=seg_net)
|
27 |
+
|
28 |
+
# Process the image
|
29 |
+
processed_bg = interface([image])[0]
|
30 |
+
|
31 |
+
# Display original and processed images
|
32 |
+
col1, col2 = st.columns(2)
|
33 |
+
with col1:
|
34 |
+
st.image(image, caption='Original Image', use_column_width=True)
|
35 |
+
with col2:
|
36 |
+
st.image(processed_bg, caption='Background Removed', use_column_width=True)
|
carvekit/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
version = "4.1.0"
|
carvekit/__main__.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
import click
|
4 |
+
import tqdm
|
5 |
+
|
6 |
+
from carvekit.utils.image_utils import ALLOWED_SUFFIXES
|
7 |
+
from carvekit.utils.pool_utils import batch_generator, thread_pool_processing
|
8 |
+
from carvekit.web.schemas.config import MLConfig
|
9 |
+
from carvekit.web.utils.init_utils import init_interface
|
10 |
+
from carvekit.utils.fs_utils import save_file
|
11 |
+
|
12 |
+
|
13 |
+
@click.command(
|
14 |
+
"removebg",
|
15 |
+
help="Performs background removal on specified photos using console interface.",
|
16 |
+
)
|
17 |
+
@click.option("-i", required=True, type=str, help="Path to input file or dir")
|
18 |
+
@click.option("-o", default="none", type=str, help="Path to output file or dir")
|
19 |
+
@click.option("--pre", default="none", type=str, help="Preprocessing method")
|
20 |
+
@click.option("--post", default="fba", type=str, help="Postprocessing method.")
|
21 |
+
@click.option("--net", default="tracer_b7", type=str, help="Segmentation Network")
|
22 |
+
@click.option(
|
23 |
+
"--recursive",
|
24 |
+
default=False,
|
25 |
+
type=bool,
|
26 |
+
help="Enables recursive search for images in a folder",
|
27 |
+
)
|
28 |
+
@click.option(
|
29 |
+
"--batch_size",
|
30 |
+
default=10,
|
31 |
+
type=int,
|
32 |
+
help="Batch Size for list of images to be loaded to RAM",
|
33 |
+
)
|
34 |
+
@click.option(
|
35 |
+
"--batch_size_seg",
|
36 |
+
default=5,
|
37 |
+
type=int,
|
38 |
+
help="Batch size for list of images to be processed by segmentation " "network",
|
39 |
+
)
|
40 |
+
@click.option(
|
41 |
+
"--batch_size_mat",
|
42 |
+
default=1,
|
43 |
+
type=int,
|
44 |
+
help="Batch size for list of images to be processed by matting " "network",
|
45 |
+
)
|
46 |
+
@click.option(
|
47 |
+
"--seg_mask_size",
|
48 |
+
default=640,
|
49 |
+
type=int,
|
50 |
+
help="The size of the input image for the segmentation neural network.",
|
51 |
+
)
|
52 |
+
@click.option(
|
53 |
+
"--matting_mask_size",
|
54 |
+
default=2048,
|
55 |
+
type=int,
|
56 |
+
help="The size of the input image for the matting neural network.",
|
57 |
+
)
|
58 |
+
@click.option(
|
59 |
+
"--trimap_dilation",
|
60 |
+
default=30,
|
61 |
+
type=int,
|
62 |
+
help="The size of the offset radius from the object mask in "
|
63 |
+
"pixels when forming an unknown area",
|
64 |
+
)
|
65 |
+
@click.option(
|
66 |
+
"--trimap_erosion",
|
67 |
+
default=5,
|
68 |
+
type=int,
|
69 |
+
help="The number of iterations of erosion that the object's "
|
70 |
+
"mask will be subjected to before forming an unknown area",
|
71 |
+
)
|
72 |
+
@click.option(
|
73 |
+
"--trimap_prob_threshold",
|
74 |
+
default=231,
|
75 |
+
type=int,
|
76 |
+
help="Probability threshold at which the prob_filter "
|
77 |
+
"and prob_as_unknown_area operations will be "
|
78 |
+
"applied",
|
79 |
+
)
|
80 |
+
@click.option("--device", default="cpu", type=str, help="Processing Device.")
|
81 |
+
@click.option(
|
82 |
+
"--fp16", default=False, type=bool, help="Enables mixed precision processing."
|
83 |
+
)
|
84 |
+
def removebg(
|
85 |
+
i: str,
|
86 |
+
o: str,
|
87 |
+
pre: str,
|
88 |
+
post: str,
|
89 |
+
net: str,
|
90 |
+
recursive: bool,
|
91 |
+
batch_size: int,
|
92 |
+
batch_size_seg: int,
|
93 |
+
batch_size_mat: int,
|
94 |
+
seg_mask_size: int,
|
95 |
+
matting_mask_size: int,
|
96 |
+
device: str,
|
97 |
+
fp16: bool,
|
98 |
+
trimap_dilation: int,
|
99 |
+
trimap_erosion: int,
|
100 |
+
trimap_prob_threshold: int,
|
101 |
+
):
|
102 |
+
out_path = Path(o)
|
103 |
+
input_path = Path(i)
|
104 |
+
if input_path.is_dir():
|
105 |
+
if recursive:
|
106 |
+
all_images = input_path.rglob("*.*")
|
107 |
+
else:
|
108 |
+
all_images = input_path.glob("*.*")
|
109 |
+
all_images = [
|
110 |
+
i
|
111 |
+
for i in all_images
|
112 |
+
if i.suffix.lower() in ALLOWED_SUFFIXES and "_bg_removed" not in i.name
|
113 |
+
]
|
114 |
+
else:
|
115 |
+
all_images = [input_path]
|
116 |
+
|
117 |
+
interface_config = MLConfig(
|
118 |
+
segmentation_network=net,
|
119 |
+
preprocessing_method=pre,
|
120 |
+
postprocessing_method=post,
|
121 |
+
device=device,
|
122 |
+
batch_size_seg=batch_size_seg,
|
123 |
+
batch_size_matting=batch_size_mat,
|
124 |
+
seg_mask_size=seg_mask_size,
|
125 |
+
matting_mask_size=matting_mask_size,
|
126 |
+
fp16=fp16,
|
127 |
+
trimap_dilation=trimap_dilation,
|
128 |
+
trimap_erosion=trimap_erosion,
|
129 |
+
trimap_prob_threshold=trimap_prob_threshold,
|
130 |
+
)
|
131 |
+
|
132 |
+
interface = init_interface(interface_config)
|
133 |
+
|
134 |
+
for image_batch in tqdm.tqdm(
|
135 |
+
batch_generator(all_images, n=batch_size),
|
136 |
+
total=int(len(all_images) / batch_size),
|
137 |
+
desc="Removing background",
|
138 |
+
unit=" image batch",
|
139 |
+
colour="blue",
|
140 |
+
):
|
141 |
+
images_without_background = interface(image_batch) # Remove background
|
142 |
+
thread_pool_processing(
|
143 |
+
lambda x: save_file(out_path, image_batch[x], images_without_background[x]),
|
144 |
+
range((len(image_batch))),
|
145 |
+
) # Drop images to fs
|
146 |
+
|
147 |
+
|
148 |
+
if __name__ == "__main__":
|
149 |
+
removebg()
|
carvekit/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (187 Bytes). View file
|
|
carvekit/api/__init__.py
ADDED
File without changes
|
carvekit/api/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (174 Bytes). View file
|
|
carvekit/api/__pycache__/high.cpython-38.pyc
ADDED
Binary file (3.71 kB). View file
|
|
carvekit/api/__pycache__/interface.cpython-38.pyc
ADDED
Binary file (2.87 kB). View file
|
|
carvekit/api/high.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Source url: https://github.com/OPHoperHPO/image-background-remove-tool
|
3 |
+
Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
|
4 |
+
License: Apache License 2.0
|
5 |
+
"""
|
6 |
+
import warnings
|
7 |
+
|
8 |
+
from carvekit.api.interface import Interface
|
9 |
+
from carvekit.ml.wrap.fba_matting import FBAMatting
|
10 |
+
from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7
|
11 |
+
from carvekit.ml.wrap.u2net import U2NET
|
12 |
+
from carvekit.pipelines.postprocessing import MattingMethod
|
13 |
+
from carvekit.trimap.generator import TrimapGenerator
|
14 |
+
|
15 |
+
|
16 |
+
class HiInterface(Interface):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
object_type: str = "object",
|
20 |
+
batch_size_seg=2,
|
21 |
+
batch_size_matting=1,
|
22 |
+
device="cpu",
|
23 |
+
seg_mask_size=640,
|
24 |
+
matting_mask_size=2048,
|
25 |
+
trimap_prob_threshold=231,
|
26 |
+
trimap_dilation=30,
|
27 |
+
trimap_erosion_iters=5,
|
28 |
+
fp16=False,
|
29 |
+
):
|
30 |
+
"""
|
31 |
+
Initializes High Level interface.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
object_type: Interest object type. Can be "object" or "hairs-like".
|
35 |
+
matting_mask_size: The size of the input image for the matting neural network.
|
36 |
+
seg_mask_size: The size of the input image for the segmentation neural network.
|
37 |
+
batch_size_seg: Number of images processed per one segmentation neural network call.
|
38 |
+
batch_size_matting: Number of images processed per one matting neural network call.
|
39 |
+
device: Processing device
|
40 |
+
fp16: Use half precision. Reduce memory usage and increase speed. Experimental support
|
41 |
+
trimap_prob_threshold: Probability threshold at which the prob_filter and prob_as_unknown_area operations will be applied
|
42 |
+
trimap_dilation: The size of the offset radius from the object mask in pixels when forming an unknown area
|
43 |
+
trimap_erosion_iters: The number of iterations of erosion that the object's mask will be subjected to before forming an unknown area
|
44 |
+
|
45 |
+
Notes:
|
46 |
+
1. Changing seg_mask_size may cause an out-of-memory error if the value is too large, and it may also
|
47 |
+
result in reduced precision. I do not recommend changing this value. You can change matting_mask_size in
|
48 |
+
range from (1024 to 4096) to improve object edge refining quality, but it will cause extra large RAM and
|
49 |
+
video memory consume. Also, you can change batch size to accelerate background removal, but it also causes
|
50 |
+
extra large video memory consume, if value is too big.
|
51 |
+
|
52 |
+
2. Changing trimap_prob_threshold, trimap_kernel_size, trimap_erosion_iters may improve object edge
|
53 |
+
refining quality,
|
54 |
+
"""
|
55 |
+
if object_type == "object":
|
56 |
+
self.u2net = TracerUniversalB7(
|
57 |
+
device=device,
|
58 |
+
batch_size=batch_size_seg,
|
59 |
+
input_image_size=seg_mask_size,
|
60 |
+
fp16=fp16,
|
61 |
+
)
|
62 |
+
elif object_type == "hairs-like":
|
63 |
+
self.u2net = U2NET(
|
64 |
+
device=device,
|
65 |
+
batch_size=batch_size_seg,
|
66 |
+
input_image_size=seg_mask_size,
|
67 |
+
fp16=fp16,
|
68 |
+
)
|
69 |
+
else:
|
70 |
+
warnings.warn(
|
71 |
+
f"Unknown object type: {object_type}. Using default object type: object"
|
72 |
+
)
|
73 |
+
self.u2net = TracerUniversalB7(
|
74 |
+
device=device,
|
75 |
+
batch_size=batch_size_seg,
|
76 |
+
input_image_size=seg_mask_size,
|
77 |
+
fp16=fp16,
|
78 |
+
)
|
79 |
+
|
80 |
+
self.fba = FBAMatting(
|
81 |
+
batch_size=batch_size_matting,
|
82 |
+
device=device,
|
83 |
+
input_tensor_size=matting_mask_size,
|
84 |
+
fp16=fp16,
|
85 |
+
)
|
86 |
+
self.trimap_generator = TrimapGenerator(
|
87 |
+
prob_threshold=trimap_prob_threshold,
|
88 |
+
kernel_size=trimap_dilation,
|
89 |
+
erosion_iters=trimap_erosion_iters,
|
90 |
+
)
|
91 |
+
super(HiInterface, self).__init__(
|
92 |
+
pre_pipe=None,
|
93 |
+
seg_pipe=self.u2net,
|
94 |
+
post_pipe=MattingMethod(
|
95 |
+
matting_module=self.fba,
|
96 |
+
trimap_generator=self.trimap_generator,
|
97 |
+
device=device,
|
98 |
+
),
|
99 |
+
device=device,
|
100 |
+
)
|
carvekit/api/interface.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Source url: https://github.com/OPHoperHPO/image-background-remove-tool
|
3 |
+
Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
|
4 |
+
License: Apache License 2.0
|
5 |
+
"""
|
6 |
+
from pathlib import Path
|
7 |
+
from typing import Union, List, Optional
|
8 |
+
|
9 |
+
from PIL import Image
|
10 |
+
|
11 |
+
from carvekit.ml.wrap.basnet import BASNET
|
12 |
+
from carvekit.ml.wrap.deeplab_v3 import DeepLabV3
|
13 |
+
from carvekit.ml.wrap.u2net import U2NET
|
14 |
+
from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7
|
15 |
+
from carvekit.pipelines.preprocessing import PreprocessingStub
|
16 |
+
from carvekit.pipelines.postprocessing import MattingMethod
|
17 |
+
from carvekit.utils.image_utils import load_image
|
18 |
+
from carvekit.utils.mask_utils import apply_mask
|
19 |
+
from carvekit.utils.pool_utils import thread_pool_processing
|
20 |
+
|
21 |
+
|
22 |
+
class Interface:
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
seg_pipe: Union[U2NET, BASNET, DeepLabV3, TracerUniversalB7],
|
26 |
+
pre_pipe: Optional[Union[PreprocessingStub]] = None,
|
27 |
+
post_pipe: Optional[Union[MattingMethod]] = None,
|
28 |
+
device="cpu",
|
29 |
+
):
|
30 |
+
"""
|
31 |
+
Initializes an object for interacting with pipelines and other components of the CarveKit framework.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
pre_pipe: Initialized pre-processing pipeline object
|
35 |
+
seg_pipe: Initialized segmentation network object
|
36 |
+
post_pipe: Initialized postprocessing pipeline object
|
37 |
+
device: The processing device that will be used to apply the masks to the images.
|
38 |
+
"""
|
39 |
+
self.device = device
|
40 |
+
self.preprocessing_pipeline = pre_pipe
|
41 |
+
self.segmentation_pipeline = seg_pipe
|
42 |
+
self.postprocessing_pipeline = post_pipe
|
43 |
+
|
44 |
+
def __call__(
|
45 |
+
self, images: List[Union[str, Path, Image.Image]]
|
46 |
+
) -> List[Image.Image]:
|
47 |
+
"""
|
48 |
+
Removes the background from the specified images.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
images: list of input images
|
52 |
+
|
53 |
+
Returns:
|
54 |
+
List of images without background as PIL.Image.Image instances
|
55 |
+
"""
|
56 |
+
images = thread_pool_processing(load_image, images)
|
57 |
+
if self.preprocessing_pipeline is not None:
|
58 |
+
masks: List[Image.Image] = self.preprocessing_pipeline(
|
59 |
+
interface=self, images=images
|
60 |
+
)
|
61 |
+
else:
|
62 |
+
masks: List[Image.Image] = self.segmentation_pipeline(images=images)
|
63 |
+
|
64 |
+
if self.postprocessing_pipeline is not None:
|
65 |
+
images: List[Image.Image] = self.postprocessing_pipeline(
|
66 |
+
images=images, masks=masks
|
67 |
+
)
|
68 |
+
else:
|
69 |
+
images = list(
|
70 |
+
map(
|
71 |
+
lambda x: apply_mask(
|
72 |
+
image=images[x], mask=masks[x], device=self.device
|
73 |
+
),
|
74 |
+
range(len(images)),
|
75 |
+
)
|
76 |
+
)
|
77 |
+
return images
|
carvekit/ml/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from carvekit.utils.models_utils import fix_seed, suppress_warnings
|
2 |
+
|
3 |
+
fix_seed()
|
4 |
+
suppress_warnings()
|
carvekit/ml/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (277 Bytes). View file
|
|
carvekit/ml/arch/__init__.py
ADDED
File without changes
|
carvekit/ml/arch/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (178 Bytes). View file
|
|
carvekit/ml/arch/basnet/__init__.py
ADDED
File without changes
|
carvekit/ml/arch/basnet/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (185 Bytes). View file
|
|
carvekit/ml/arch/basnet/__pycache__/basnet.cpython-38.pyc
ADDED
Binary file (10 kB). View file
|
|
carvekit/ml/arch/basnet/basnet.py
ADDED
@@ -0,0 +1,478 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Source url: https://github.com/NathanUA/BASNet
|
3 |
+
Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
|
4 |
+
License: MIT License
|
5 |
+
"""
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from torchvision import models
|
9 |
+
|
10 |
+
|
11 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
12 |
+
"""3x3 convolution with padding"""
|
13 |
+
return nn.Conv2d(
|
14 |
+
in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
|
15 |
+
)
|
16 |
+
|
17 |
+
|
18 |
+
class BasicBlock(nn.Module):
|
19 |
+
expansion = 1
|
20 |
+
|
21 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
22 |
+
super(BasicBlock, self).__init__()
|
23 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
24 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
25 |
+
self.relu = nn.ReLU(inplace=True)
|
26 |
+
self.conv2 = conv3x3(planes, planes)
|
27 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
28 |
+
self.downsample = downsample
|
29 |
+
self.stride = stride
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
residual = x
|
33 |
+
|
34 |
+
out = self.conv1(x)
|
35 |
+
out = self.bn1(out)
|
36 |
+
out = self.relu(out)
|
37 |
+
|
38 |
+
out = self.conv2(out)
|
39 |
+
out = self.bn2(out)
|
40 |
+
|
41 |
+
if self.downsample is not None:
|
42 |
+
residual = self.downsample(x)
|
43 |
+
|
44 |
+
out += residual
|
45 |
+
out = self.relu(out)
|
46 |
+
|
47 |
+
return out
|
48 |
+
|
49 |
+
|
50 |
+
class BasicBlockDe(nn.Module):
|
51 |
+
expansion = 1
|
52 |
+
|
53 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
54 |
+
super(BasicBlockDe, self).__init__()
|
55 |
+
|
56 |
+
self.convRes = conv3x3(inplanes, planes, stride)
|
57 |
+
self.bnRes = nn.BatchNorm2d(planes)
|
58 |
+
self.reluRes = nn.ReLU(inplace=True)
|
59 |
+
|
60 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
61 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
62 |
+
self.relu = nn.ReLU(inplace=True)
|
63 |
+
self.conv2 = conv3x3(planes, planes)
|
64 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
65 |
+
self.downsample = downsample
|
66 |
+
self.stride = stride
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
residual = self.convRes(x)
|
70 |
+
residual = self.bnRes(residual)
|
71 |
+
residual = self.reluRes(residual)
|
72 |
+
|
73 |
+
out = self.conv1(x)
|
74 |
+
out = self.bn1(out)
|
75 |
+
out = self.relu(out)
|
76 |
+
|
77 |
+
out = self.conv2(out)
|
78 |
+
out = self.bn2(out)
|
79 |
+
|
80 |
+
if self.downsample is not None:
|
81 |
+
residual = self.downsample(x)
|
82 |
+
|
83 |
+
out += residual
|
84 |
+
out = self.relu(out)
|
85 |
+
|
86 |
+
return out
|
87 |
+
|
88 |
+
|
89 |
+
class Bottleneck(nn.Module):
|
90 |
+
expansion = 4
|
91 |
+
|
92 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
93 |
+
super(Bottleneck, self).__init__()
|
94 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
95 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
96 |
+
self.conv2 = nn.Conv2d(
|
97 |
+
planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
|
98 |
+
)
|
99 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
100 |
+
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
101 |
+
self.bn3 = nn.BatchNorm2d(planes * 4)
|
102 |
+
self.relu = nn.ReLU(inplace=True)
|
103 |
+
self.downsample = downsample
|
104 |
+
self.stride = stride
|
105 |
+
|
106 |
+
def forward(self, x):
|
107 |
+
residual = x
|
108 |
+
|
109 |
+
out = self.conv1(x)
|
110 |
+
out = self.bn1(out)
|
111 |
+
out = self.relu(out)
|
112 |
+
|
113 |
+
out = self.conv2(out)
|
114 |
+
out = self.bn2(out)
|
115 |
+
out = self.relu(out)
|
116 |
+
|
117 |
+
out = self.conv3(out)
|
118 |
+
out = self.bn3(out)
|
119 |
+
|
120 |
+
if self.downsample is not None:
|
121 |
+
residual = self.downsample(x)
|
122 |
+
|
123 |
+
out += residual
|
124 |
+
out = self.relu(out)
|
125 |
+
|
126 |
+
return out
|
127 |
+
|
128 |
+
|
129 |
+
class RefUnet(nn.Module):
|
130 |
+
def __init__(self, in_ch, inc_ch):
|
131 |
+
super(RefUnet, self).__init__()
|
132 |
+
|
133 |
+
self.conv0 = nn.Conv2d(in_ch, inc_ch, 3, padding=1)
|
134 |
+
|
135 |
+
self.conv1 = nn.Conv2d(inc_ch, 64, 3, padding=1)
|
136 |
+
self.bn1 = nn.BatchNorm2d(64)
|
137 |
+
self.relu1 = nn.ReLU(inplace=True)
|
138 |
+
|
139 |
+
self.pool1 = nn.MaxPool2d(2, 2, ceil_mode=True)
|
140 |
+
|
141 |
+
self.conv2 = nn.Conv2d(64, 64, 3, padding=1)
|
142 |
+
self.bn2 = nn.BatchNorm2d(64)
|
143 |
+
self.relu2 = nn.ReLU(inplace=True)
|
144 |
+
|
145 |
+
self.pool2 = nn.MaxPool2d(2, 2, ceil_mode=True)
|
146 |
+
|
147 |
+
self.conv3 = nn.Conv2d(64, 64, 3, padding=1)
|
148 |
+
self.bn3 = nn.BatchNorm2d(64)
|
149 |
+
self.relu3 = nn.ReLU(inplace=True)
|
150 |
+
|
151 |
+
self.pool3 = nn.MaxPool2d(2, 2, ceil_mode=True)
|
152 |
+
|
153 |
+
self.conv4 = nn.Conv2d(64, 64, 3, padding=1)
|
154 |
+
self.bn4 = nn.BatchNorm2d(64)
|
155 |
+
self.relu4 = nn.ReLU(inplace=True)
|
156 |
+
|
157 |
+
self.pool4 = nn.MaxPool2d(2, 2, ceil_mode=True)
|
158 |
+
|
159 |
+
self.conv5 = nn.Conv2d(64, 64, 3, padding=1)
|
160 |
+
self.bn5 = nn.BatchNorm2d(64)
|
161 |
+
self.relu5 = nn.ReLU(inplace=True)
|
162 |
+
|
163 |
+
self.conv_d4 = nn.Conv2d(128, 64, 3, padding=1)
|
164 |
+
self.bn_d4 = nn.BatchNorm2d(64)
|
165 |
+
self.relu_d4 = nn.ReLU(inplace=True)
|
166 |
+
|
167 |
+
self.conv_d3 = nn.Conv2d(128, 64, 3, padding=1)
|
168 |
+
self.bn_d3 = nn.BatchNorm2d(64)
|
169 |
+
self.relu_d3 = nn.ReLU(inplace=True)
|
170 |
+
|
171 |
+
self.conv_d2 = nn.Conv2d(128, 64, 3, padding=1)
|
172 |
+
self.bn_d2 = nn.BatchNorm2d(64)
|
173 |
+
self.relu_d2 = nn.ReLU(inplace=True)
|
174 |
+
|
175 |
+
self.conv_d1 = nn.Conv2d(128, 64, 3, padding=1)
|
176 |
+
self.bn_d1 = nn.BatchNorm2d(64)
|
177 |
+
self.relu_d1 = nn.ReLU(inplace=True)
|
178 |
+
|
179 |
+
self.conv_d0 = nn.Conv2d(64, 1, 3, padding=1)
|
180 |
+
|
181 |
+
self.upscore2 = nn.Upsample(
|
182 |
+
scale_factor=2, mode="bilinear", align_corners=False
|
183 |
+
)
|
184 |
+
|
185 |
+
def forward(self, x):
|
186 |
+
hx = x
|
187 |
+
hx = self.conv0(hx)
|
188 |
+
|
189 |
+
hx1 = self.relu1(self.bn1(self.conv1(hx)))
|
190 |
+
hx = self.pool1(hx1)
|
191 |
+
|
192 |
+
hx2 = self.relu2(self.bn2(self.conv2(hx)))
|
193 |
+
hx = self.pool2(hx2)
|
194 |
+
|
195 |
+
hx3 = self.relu3(self.bn3(self.conv3(hx)))
|
196 |
+
hx = self.pool3(hx3)
|
197 |
+
|
198 |
+
hx4 = self.relu4(self.bn4(self.conv4(hx)))
|
199 |
+
hx = self.pool4(hx4)
|
200 |
+
|
201 |
+
hx5 = self.relu5(self.bn5(self.conv5(hx)))
|
202 |
+
|
203 |
+
hx = self.upscore2(hx5)
|
204 |
+
|
205 |
+
d4 = self.relu_d4(self.bn_d4(self.conv_d4(torch.cat((hx, hx4), 1))))
|
206 |
+
hx = self.upscore2(d4)
|
207 |
+
|
208 |
+
d3 = self.relu_d3(self.bn_d3(self.conv_d3(torch.cat((hx, hx3), 1))))
|
209 |
+
hx = self.upscore2(d3)
|
210 |
+
|
211 |
+
d2 = self.relu_d2(self.bn_d2(self.conv_d2(torch.cat((hx, hx2), 1))))
|
212 |
+
hx = self.upscore2(d2)
|
213 |
+
|
214 |
+
d1 = self.relu_d1(self.bn_d1(self.conv_d1(torch.cat((hx, hx1), 1))))
|
215 |
+
|
216 |
+
residual = self.conv_d0(d1)
|
217 |
+
|
218 |
+
return x + residual
|
219 |
+
|
220 |
+
|
221 |
+
class BASNet(nn.Module):
|
222 |
+
def __init__(self, n_channels, n_classes):
|
223 |
+
super(BASNet, self).__init__()
|
224 |
+
|
225 |
+
resnet = models.resnet34(pretrained=False)
|
226 |
+
|
227 |
+
# -------------Encoder--------------
|
228 |
+
|
229 |
+
self.inconv = nn.Conv2d(n_channels, 64, 3, padding=1)
|
230 |
+
self.inbn = nn.BatchNorm2d(64)
|
231 |
+
self.inrelu = nn.ReLU(inplace=True)
|
232 |
+
|
233 |
+
# stage 1
|
234 |
+
self.encoder1 = resnet.layer1 # 224
|
235 |
+
# stage 2
|
236 |
+
self.encoder2 = resnet.layer2 # 112
|
237 |
+
# stage 3
|
238 |
+
self.encoder3 = resnet.layer3 # 56
|
239 |
+
# stage 4
|
240 |
+
self.encoder4 = resnet.layer4 # 28
|
241 |
+
|
242 |
+
self.pool4 = nn.MaxPool2d(2, 2, ceil_mode=True)
|
243 |
+
|
244 |
+
# stage 5
|
245 |
+
self.resb5_1 = BasicBlock(512, 512)
|
246 |
+
self.resb5_2 = BasicBlock(512, 512)
|
247 |
+
self.resb5_3 = BasicBlock(512, 512) # 14
|
248 |
+
|
249 |
+
self.pool5 = nn.MaxPool2d(2, 2, ceil_mode=True)
|
250 |
+
|
251 |
+
# stage 6
|
252 |
+
self.resb6_1 = BasicBlock(512, 512)
|
253 |
+
self.resb6_2 = BasicBlock(512, 512)
|
254 |
+
self.resb6_3 = BasicBlock(512, 512) # 7
|
255 |
+
|
256 |
+
# -------------Bridge--------------
|
257 |
+
|
258 |
+
# stage Bridge
|
259 |
+
self.convbg_1 = nn.Conv2d(512, 512, 3, dilation=2, padding=2) # 7
|
260 |
+
self.bnbg_1 = nn.BatchNorm2d(512)
|
261 |
+
self.relubg_1 = nn.ReLU(inplace=True)
|
262 |
+
self.convbg_m = nn.Conv2d(512, 512, 3, dilation=2, padding=2)
|
263 |
+
self.bnbg_m = nn.BatchNorm2d(512)
|
264 |
+
self.relubg_m = nn.ReLU(inplace=True)
|
265 |
+
self.convbg_2 = nn.Conv2d(512, 512, 3, dilation=2, padding=2)
|
266 |
+
self.bnbg_2 = nn.BatchNorm2d(512)
|
267 |
+
self.relubg_2 = nn.ReLU(inplace=True)
|
268 |
+
|
269 |
+
# -------------Decoder--------------
|
270 |
+
|
271 |
+
# stage 6d
|
272 |
+
self.conv6d_1 = nn.Conv2d(1024, 512, 3, padding=1) # 16
|
273 |
+
self.bn6d_1 = nn.BatchNorm2d(512)
|
274 |
+
self.relu6d_1 = nn.ReLU(inplace=True)
|
275 |
+
|
276 |
+
self.conv6d_m = nn.Conv2d(512, 512, 3, dilation=2, padding=2)
|
277 |
+
self.bn6d_m = nn.BatchNorm2d(512)
|
278 |
+
self.relu6d_m = nn.ReLU(inplace=True)
|
279 |
+
|
280 |
+
self.conv6d_2 = nn.Conv2d(512, 512, 3, dilation=2, padding=2)
|
281 |
+
self.bn6d_2 = nn.BatchNorm2d(512)
|
282 |
+
self.relu6d_2 = nn.ReLU(inplace=True)
|
283 |
+
|
284 |
+
# stage 5d
|
285 |
+
self.conv5d_1 = nn.Conv2d(1024, 512, 3, padding=1) # 16
|
286 |
+
self.bn5d_1 = nn.BatchNorm2d(512)
|
287 |
+
self.relu5d_1 = nn.ReLU(inplace=True)
|
288 |
+
|
289 |
+
self.conv5d_m = nn.Conv2d(512, 512, 3, padding=1)
|
290 |
+
self.bn5d_m = nn.BatchNorm2d(512)
|
291 |
+
self.relu5d_m = nn.ReLU(inplace=True)
|
292 |
+
|
293 |
+
self.conv5d_2 = nn.Conv2d(512, 512, 3, padding=1)
|
294 |
+
self.bn5d_2 = nn.BatchNorm2d(512)
|
295 |
+
self.relu5d_2 = nn.ReLU(inplace=True)
|
296 |
+
|
297 |
+
# stage 4d
|
298 |
+
self.conv4d_1 = nn.Conv2d(1024, 512, 3, padding=1) # 32
|
299 |
+
self.bn4d_1 = nn.BatchNorm2d(512)
|
300 |
+
self.relu4d_1 = nn.ReLU(inplace=True)
|
301 |
+
|
302 |
+
self.conv4d_m = nn.Conv2d(512, 512, 3, padding=1)
|
303 |
+
self.bn4d_m = nn.BatchNorm2d(512)
|
304 |
+
self.relu4d_m = nn.ReLU(inplace=True)
|
305 |
+
|
306 |
+
self.conv4d_2 = nn.Conv2d(512, 256, 3, padding=1)
|
307 |
+
self.bn4d_2 = nn.BatchNorm2d(256)
|
308 |
+
self.relu4d_2 = nn.ReLU(inplace=True)
|
309 |
+
|
310 |
+
# stage 3d
|
311 |
+
self.conv3d_1 = nn.Conv2d(512, 256, 3, padding=1) # 64
|
312 |
+
self.bn3d_1 = nn.BatchNorm2d(256)
|
313 |
+
self.relu3d_1 = nn.ReLU(inplace=True)
|
314 |
+
|
315 |
+
self.conv3d_m = nn.Conv2d(256, 256, 3, padding=1)
|
316 |
+
self.bn3d_m = nn.BatchNorm2d(256)
|
317 |
+
self.relu3d_m = nn.ReLU(inplace=True)
|
318 |
+
|
319 |
+
self.conv3d_2 = nn.Conv2d(256, 128, 3, padding=1)
|
320 |
+
self.bn3d_2 = nn.BatchNorm2d(128)
|
321 |
+
self.relu3d_2 = nn.ReLU(inplace=True)
|
322 |
+
|
323 |
+
# stage 2d
|
324 |
+
|
325 |
+
self.conv2d_1 = nn.Conv2d(256, 128, 3, padding=1) # 128
|
326 |
+
self.bn2d_1 = nn.BatchNorm2d(128)
|
327 |
+
self.relu2d_1 = nn.ReLU(inplace=True)
|
328 |
+
|
329 |
+
self.conv2d_m = nn.Conv2d(128, 128, 3, padding=1)
|
330 |
+
self.bn2d_m = nn.BatchNorm2d(128)
|
331 |
+
self.relu2d_m = nn.ReLU(inplace=True)
|
332 |
+
|
333 |
+
self.conv2d_2 = nn.Conv2d(128, 64, 3, padding=1)
|
334 |
+
self.bn2d_2 = nn.BatchNorm2d(64)
|
335 |
+
self.relu2d_2 = nn.ReLU(inplace=True)
|
336 |
+
|
337 |
+
# stage 1d
|
338 |
+
self.conv1d_1 = nn.Conv2d(128, 64, 3, padding=1) # 256
|
339 |
+
self.bn1d_1 = nn.BatchNorm2d(64)
|
340 |
+
self.relu1d_1 = nn.ReLU(inplace=True)
|
341 |
+
|
342 |
+
self.conv1d_m = nn.Conv2d(64, 64, 3, padding=1)
|
343 |
+
self.bn1d_m = nn.BatchNorm2d(64)
|
344 |
+
self.relu1d_m = nn.ReLU(inplace=True)
|
345 |
+
|
346 |
+
self.conv1d_2 = nn.Conv2d(64, 64, 3, padding=1)
|
347 |
+
self.bn1d_2 = nn.BatchNorm2d(64)
|
348 |
+
self.relu1d_2 = nn.ReLU(inplace=True)
|
349 |
+
|
350 |
+
# -------------Bilinear Upsampling--------------
|
351 |
+
self.upscore6 = nn.Upsample(
|
352 |
+
scale_factor=32, mode="bilinear", align_corners=False
|
353 |
+
)
|
354 |
+
self.upscore5 = nn.Upsample(
|
355 |
+
scale_factor=16, mode="bilinear", align_corners=False
|
356 |
+
)
|
357 |
+
self.upscore4 = nn.Upsample(
|
358 |
+
scale_factor=8, mode="bilinear", align_corners=False
|
359 |
+
)
|
360 |
+
self.upscore3 = nn.Upsample(
|
361 |
+
scale_factor=4, mode="bilinear", align_corners=False
|
362 |
+
)
|
363 |
+
self.upscore2 = nn.Upsample(
|
364 |
+
scale_factor=2, mode="bilinear", align_corners=False
|
365 |
+
)
|
366 |
+
|
367 |
+
# -------------Side Output--------------
|
368 |
+
self.outconvb = nn.Conv2d(512, 1, 3, padding=1)
|
369 |
+
self.outconv6 = nn.Conv2d(512, 1, 3, padding=1)
|
370 |
+
self.outconv5 = nn.Conv2d(512, 1, 3, padding=1)
|
371 |
+
self.outconv4 = nn.Conv2d(256, 1, 3, padding=1)
|
372 |
+
self.outconv3 = nn.Conv2d(128, 1, 3, padding=1)
|
373 |
+
self.outconv2 = nn.Conv2d(64, 1, 3, padding=1)
|
374 |
+
self.outconv1 = nn.Conv2d(64, 1, 3, padding=1)
|
375 |
+
|
376 |
+
# -------------Refine Module-------------
|
377 |
+
self.refunet = RefUnet(1, 64)
|
378 |
+
|
379 |
+
def forward(self, x):
|
380 |
+
hx = x
|
381 |
+
|
382 |
+
# -------------Encoder-------------
|
383 |
+
hx = self.inconv(hx)
|
384 |
+
hx = self.inbn(hx)
|
385 |
+
hx = self.inrelu(hx)
|
386 |
+
|
387 |
+
h1 = self.encoder1(hx) # 256
|
388 |
+
h2 = self.encoder2(h1) # 128
|
389 |
+
h3 = self.encoder3(h2) # 64
|
390 |
+
h4 = self.encoder4(h3) # 32
|
391 |
+
|
392 |
+
hx = self.pool4(h4) # 16
|
393 |
+
|
394 |
+
hx = self.resb5_1(hx)
|
395 |
+
hx = self.resb5_2(hx)
|
396 |
+
h5 = self.resb5_3(hx)
|
397 |
+
|
398 |
+
hx = self.pool5(h5) # 8
|
399 |
+
|
400 |
+
hx = self.resb6_1(hx)
|
401 |
+
hx = self.resb6_2(hx)
|
402 |
+
h6 = self.resb6_3(hx)
|
403 |
+
|
404 |
+
# -------------Bridge-------------
|
405 |
+
hx = self.relubg_1(self.bnbg_1(self.convbg_1(h6))) # 8
|
406 |
+
hx = self.relubg_m(self.bnbg_m(self.convbg_m(hx)))
|
407 |
+
hbg = self.relubg_2(self.bnbg_2(self.convbg_2(hx)))
|
408 |
+
|
409 |
+
# -------------Decoder-------------
|
410 |
+
|
411 |
+
hx = self.relu6d_1(self.bn6d_1(self.conv6d_1(torch.cat((hbg, h6), 1))))
|
412 |
+
hx = self.relu6d_m(self.bn6d_m(self.conv6d_m(hx)))
|
413 |
+
hd6 = self.relu6d_2(self.bn6d_2(self.conv6d_2(hx)))
|
414 |
+
|
415 |
+
hx = self.upscore2(hd6) # 8 -> 16
|
416 |
+
|
417 |
+
hx = self.relu5d_1(self.bn5d_1(self.conv5d_1(torch.cat((hx, h5), 1))))
|
418 |
+
hx = self.relu5d_m(self.bn5d_m(self.conv5d_m(hx)))
|
419 |
+
hd5 = self.relu5d_2(self.bn5d_2(self.conv5d_2(hx)))
|
420 |
+
|
421 |
+
hx = self.upscore2(hd5) # 16 -> 32
|
422 |
+
|
423 |
+
hx = self.relu4d_1(self.bn4d_1(self.conv4d_1(torch.cat((hx, h4), 1))))
|
424 |
+
hx = self.relu4d_m(self.bn4d_m(self.conv4d_m(hx)))
|
425 |
+
hd4 = self.relu4d_2(self.bn4d_2(self.conv4d_2(hx)))
|
426 |
+
|
427 |
+
hx = self.upscore2(hd4) # 32 -> 64
|
428 |
+
|
429 |
+
hx = self.relu3d_1(self.bn3d_1(self.conv3d_1(torch.cat((hx, h3), 1))))
|
430 |
+
hx = self.relu3d_m(self.bn3d_m(self.conv3d_m(hx)))
|
431 |
+
hd3 = self.relu3d_2(self.bn3d_2(self.conv3d_2(hx)))
|
432 |
+
|
433 |
+
hx = self.upscore2(hd3) # 64 -> 128
|
434 |
+
|
435 |
+
hx = self.relu2d_1(self.bn2d_1(self.conv2d_1(torch.cat((hx, h2), 1))))
|
436 |
+
hx = self.relu2d_m(self.bn2d_m(self.conv2d_m(hx)))
|
437 |
+
hd2 = self.relu2d_2(self.bn2d_2(self.conv2d_2(hx)))
|
438 |
+
|
439 |
+
hx = self.upscore2(hd2) # 128 -> 256
|
440 |
+
|
441 |
+
hx = self.relu1d_1(self.bn1d_1(self.conv1d_1(torch.cat((hx, h1), 1))))
|
442 |
+
hx = self.relu1d_m(self.bn1d_m(self.conv1d_m(hx)))
|
443 |
+
hd1 = self.relu1d_2(self.bn1d_2(self.conv1d_2(hx)))
|
444 |
+
|
445 |
+
# -------------Side Output-------------
|
446 |
+
db = self.outconvb(hbg)
|
447 |
+
db = self.upscore6(db) # 8->256
|
448 |
+
|
449 |
+
d6 = self.outconv6(hd6)
|
450 |
+
d6 = self.upscore6(d6) # 8->256
|
451 |
+
|
452 |
+
d5 = self.outconv5(hd5)
|
453 |
+
d5 = self.upscore5(d5) # 16->256
|
454 |
+
|
455 |
+
d4 = self.outconv4(hd4)
|
456 |
+
d4 = self.upscore4(d4) # 32->256
|
457 |
+
|
458 |
+
d3 = self.outconv3(hd3)
|
459 |
+
d3 = self.upscore3(d3) # 64->256
|
460 |
+
|
461 |
+
d2 = self.outconv2(hd2)
|
462 |
+
d2 = self.upscore2(d2) # 128->256
|
463 |
+
|
464 |
+
d1 = self.outconv1(hd1) # 256
|
465 |
+
|
466 |
+
# -------------Refine Module-------------
|
467 |
+
dout = self.refunet(d1) # 256
|
468 |
+
|
469 |
+
return (
|
470 |
+
torch.sigmoid(dout),
|
471 |
+
torch.sigmoid(d1),
|
472 |
+
torch.sigmoid(d2),
|
473 |
+
torch.sigmoid(d3),
|
474 |
+
torch.sigmoid(d4),
|
475 |
+
torch.sigmoid(d5),
|
476 |
+
torch.sigmoid(d6),
|
477 |
+
torch.sigmoid(db),
|
478 |
+
)
|
carvekit/ml/arch/fba_matting/__init__.py
ADDED
File without changes
|
carvekit/ml/arch/fba_matting/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (190 Bytes). View file
|
|
carvekit/ml/arch/fba_matting/__pycache__/layers_WS.cpython-38.pyc
ADDED
Binary file (1.6 kB). View file
|
|
carvekit/ml/arch/fba_matting/__pycache__/models.cpython-38.pyc
ADDED
Binary file (8.24 kB). View file
|
|
carvekit/ml/arch/fba_matting/__pycache__/resnet_GN_WS.cpython-38.pyc
ADDED
Binary file (4.45 kB). View file
|
|
carvekit/ml/arch/fba_matting/__pycache__/resnet_bn.cpython-38.pyc
ADDED
Binary file (4.69 kB). View file
|
|
carvekit/ml/arch/fba_matting/__pycache__/transforms.cpython-38.pyc
ADDED
Binary file (1.58 kB). View file
|
|
carvekit/ml/arch/fba_matting/layers_WS.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
|
3 |
+
Source url: https://github.com/MarcoForte/FBA_Matting
|
4 |
+
License: MIT License
|
5 |
+
"""
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from torch.nn import functional as F
|
9 |
+
|
10 |
+
|
11 |
+
class Conv2d(nn.Conv2d):
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
in_channels,
|
15 |
+
out_channels,
|
16 |
+
kernel_size,
|
17 |
+
stride=1,
|
18 |
+
padding=0,
|
19 |
+
dilation=1,
|
20 |
+
groups=1,
|
21 |
+
bias=True,
|
22 |
+
):
|
23 |
+
super(Conv2d, self).__init__(
|
24 |
+
in_channels,
|
25 |
+
out_channels,
|
26 |
+
kernel_size,
|
27 |
+
stride,
|
28 |
+
padding,
|
29 |
+
dilation,
|
30 |
+
groups,
|
31 |
+
bias,
|
32 |
+
)
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
# return super(Conv2d, self).forward(x)
|
36 |
+
weight = self.weight
|
37 |
+
weight_mean = (
|
38 |
+
weight.mean(dim=1, keepdim=True)
|
39 |
+
.mean(dim=2, keepdim=True)
|
40 |
+
.mean(dim=3, keepdim=True)
|
41 |
+
)
|
42 |
+
weight = weight - weight_mean
|
43 |
+
# std = (weight).view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5
|
44 |
+
std = (
|
45 |
+
torch.sqrt(torch.var(weight.view(weight.size(0), -1), dim=1) + 1e-12).view(
|
46 |
+
-1, 1, 1, 1
|
47 |
+
)
|
48 |
+
+ 1e-5
|
49 |
+
)
|
50 |
+
weight = weight / std.expand_as(weight)
|
51 |
+
return F.conv2d(
|
52 |
+
x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups
|
53 |
+
)
|
54 |
+
|
55 |
+
|
56 |
+
def BatchNorm2d(num_features):
|
57 |
+
return nn.GroupNorm(num_channels=num_features, num_groups=32)
|
carvekit/ml/arch/fba_matting/models.py
ADDED
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
|
3 |
+
Source url: https://github.com/MarcoForte/FBA_Matting
|
4 |
+
License: MIT License
|
5 |
+
"""
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import carvekit.ml.arch.fba_matting.resnet_GN_WS as resnet_GN_WS
|
9 |
+
import carvekit.ml.arch.fba_matting.layers_WS as L
|
10 |
+
import carvekit.ml.arch.fba_matting.resnet_bn as resnet_bn
|
11 |
+
from functools import partial
|
12 |
+
|
13 |
+
|
14 |
+
class FBA(nn.Module):
|
15 |
+
def __init__(self, encoder: str):
|
16 |
+
super(FBA, self).__init__()
|
17 |
+
self.encoder = build_encoder(arch=encoder)
|
18 |
+
self.decoder = fba_decoder(batch_norm=True if "BN" in encoder else False)
|
19 |
+
|
20 |
+
def forward(self, image, two_chan_trimap, image_n, trimap_transformed):
|
21 |
+
resnet_input = torch.cat((image_n, trimap_transformed, two_chan_trimap), 1)
|
22 |
+
conv_out, indices = self.encoder(resnet_input, return_feature_maps=True)
|
23 |
+
return self.decoder(conv_out, image, indices, two_chan_trimap)
|
24 |
+
|
25 |
+
|
26 |
+
class ResnetDilatedBN(nn.Module):
|
27 |
+
def __init__(self, orig_resnet, dilate_scale=8):
|
28 |
+
super(ResnetDilatedBN, self).__init__()
|
29 |
+
|
30 |
+
if dilate_scale == 8:
|
31 |
+
orig_resnet.layer3.apply(partial(self._nostride_dilate, dilate=2))
|
32 |
+
orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=4))
|
33 |
+
elif dilate_scale == 16:
|
34 |
+
orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=2))
|
35 |
+
|
36 |
+
# take pretrained resnet, except AvgPool and FC
|
37 |
+
self.conv1 = orig_resnet.conv1
|
38 |
+
self.bn1 = orig_resnet.bn1
|
39 |
+
self.relu1 = orig_resnet.relu1
|
40 |
+
self.conv2 = orig_resnet.conv2
|
41 |
+
self.bn2 = orig_resnet.bn2
|
42 |
+
self.relu2 = orig_resnet.relu2
|
43 |
+
self.conv3 = orig_resnet.conv3
|
44 |
+
self.bn3 = orig_resnet.bn3
|
45 |
+
self.relu3 = orig_resnet.relu3
|
46 |
+
self.maxpool = orig_resnet.maxpool
|
47 |
+
self.layer1 = orig_resnet.layer1
|
48 |
+
self.layer2 = orig_resnet.layer2
|
49 |
+
self.layer3 = orig_resnet.layer3
|
50 |
+
self.layer4 = orig_resnet.layer4
|
51 |
+
|
52 |
+
def _nostride_dilate(self, m, dilate):
|
53 |
+
classname = m.__class__.__name__
|
54 |
+
if classname.find("Conv") != -1:
|
55 |
+
# the convolution with stride
|
56 |
+
if m.stride == (2, 2):
|
57 |
+
m.stride = (1, 1)
|
58 |
+
if m.kernel_size == (3, 3):
|
59 |
+
m.dilation = (dilate // 2, dilate // 2)
|
60 |
+
m.padding = (dilate // 2, dilate // 2)
|
61 |
+
# other convoluions
|
62 |
+
else:
|
63 |
+
if m.kernel_size == (3, 3):
|
64 |
+
m.dilation = (dilate, dilate)
|
65 |
+
m.padding = (dilate, dilate)
|
66 |
+
|
67 |
+
def forward(self, x, return_feature_maps=False):
|
68 |
+
conv_out = [x]
|
69 |
+
x = self.relu1(self.bn1(self.conv1(x)))
|
70 |
+
x = self.relu2(self.bn2(self.conv2(x)))
|
71 |
+
x = self.relu3(self.bn3(self.conv3(x)))
|
72 |
+
conv_out.append(x)
|
73 |
+
x, indices = self.maxpool(x)
|
74 |
+
x = self.layer1(x)
|
75 |
+
conv_out.append(x)
|
76 |
+
x = self.layer2(x)
|
77 |
+
conv_out.append(x)
|
78 |
+
x = self.layer3(x)
|
79 |
+
conv_out.append(x)
|
80 |
+
x = self.layer4(x)
|
81 |
+
conv_out.append(x)
|
82 |
+
|
83 |
+
if return_feature_maps:
|
84 |
+
return conv_out, indices
|
85 |
+
return [x]
|
86 |
+
|
87 |
+
|
88 |
+
class Resnet(nn.Module):
|
89 |
+
def __init__(self, orig_resnet):
|
90 |
+
super(Resnet, self).__init__()
|
91 |
+
|
92 |
+
# take pretrained resnet, except AvgPool and FC
|
93 |
+
self.conv1 = orig_resnet.conv1
|
94 |
+
self.bn1 = orig_resnet.bn1
|
95 |
+
self.relu1 = orig_resnet.relu1
|
96 |
+
self.conv2 = orig_resnet.conv2
|
97 |
+
self.bn2 = orig_resnet.bn2
|
98 |
+
self.relu2 = orig_resnet.relu2
|
99 |
+
self.conv3 = orig_resnet.conv3
|
100 |
+
self.bn3 = orig_resnet.bn3
|
101 |
+
self.relu3 = orig_resnet.relu3
|
102 |
+
self.maxpool = orig_resnet.maxpool
|
103 |
+
self.layer1 = orig_resnet.layer1
|
104 |
+
self.layer2 = orig_resnet.layer2
|
105 |
+
self.layer3 = orig_resnet.layer3
|
106 |
+
self.layer4 = orig_resnet.layer4
|
107 |
+
|
108 |
+
def forward(self, x, return_feature_maps=False):
|
109 |
+
conv_out = []
|
110 |
+
|
111 |
+
x = self.relu1(self.bn1(self.conv1(x)))
|
112 |
+
x = self.relu2(self.bn2(self.conv2(x)))
|
113 |
+
x = self.relu3(self.bn3(self.conv3(x)))
|
114 |
+
conv_out.append(x)
|
115 |
+
x, indices = self.maxpool(x)
|
116 |
+
|
117 |
+
x = self.layer1(x)
|
118 |
+
conv_out.append(x)
|
119 |
+
x = self.layer2(x)
|
120 |
+
conv_out.append(x)
|
121 |
+
x = self.layer3(x)
|
122 |
+
conv_out.append(x)
|
123 |
+
x = self.layer4(x)
|
124 |
+
conv_out.append(x)
|
125 |
+
|
126 |
+
if return_feature_maps:
|
127 |
+
return conv_out
|
128 |
+
return [x]
|
129 |
+
|
130 |
+
|
131 |
+
class ResnetDilated(nn.Module):
|
132 |
+
def __init__(self, orig_resnet, dilate_scale=8):
|
133 |
+
super(ResnetDilated, self).__init__()
|
134 |
+
|
135 |
+
if dilate_scale == 8:
|
136 |
+
orig_resnet.layer3.apply(partial(self._nostride_dilate, dilate=2))
|
137 |
+
orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=4))
|
138 |
+
elif dilate_scale == 16:
|
139 |
+
orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=2))
|
140 |
+
|
141 |
+
# take pretrained resnet, except AvgPool and FC
|
142 |
+
self.conv1 = orig_resnet.conv1
|
143 |
+
self.bn1 = orig_resnet.bn1
|
144 |
+
self.relu = orig_resnet.relu
|
145 |
+
self.maxpool = orig_resnet.maxpool
|
146 |
+
self.layer1 = orig_resnet.layer1
|
147 |
+
self.layer2 = orig_resnet.layer2
|
148 |
+
self.layer3 = orig_resnet.layer3
|
149 |
+
self.layer4 = orig_resnet.layer4
|
150 |
+
|
151 |
+
def _nostride_dilate(self, m, dilate):
|
152 |
+
classname = m.__class__.__name__
|
153 |
+
if classname.find("Conv") != -1:
|
154 |
+
# the convolution with stride
|
155 |
+
if m.stride == (2, 2):
|
156 |
+
m.stride = (1, 1)
|
157 |
+
if m.kernel_size == (3, 3):
|
158 |
+
m.dilation = (dilate // 2, dilate // 2)
|
159 |
+
m.padding = (dilate // 2, dilate // 2)
|
160 |
+
# other convoluions
|
161 |
+
else:
|
162 |
+
if m.kernel_size == (3, 3):
|
163 |
+
m.dilation = (dilate, dilate)
|
164 |
+
m.padding = (dilate, dilate)
|
165 |
+
|
166 |
+
def forward(self, x, return_feature_maps=False):
|
167 |
+
conv_out = [x]
|
168 |
+
x = self.relu(self.bn1(self.conv1(x)))
|
169 |
+
conv_out.append(x)
|
170 |
+
x, indices = self.maxpool(x)
|
171 |
+
x = self.layer1(x)
|
172 |
+
conv_out.append(x)
|
173 |
+
x = self.layer2(x)
|
174 |
+
conv_out.append(x)
|
175 |
+
x = self.layer3(x)
|
176 |
+
conv_out.append(x)
|
177 |
+
x = self.layer4(x)
|
178 |
+
conv_out.append(x)
|
179 |
+
|
180 |
+
if return_feature_maps:
|
181 |
+
return conv_out, indices
|
182 |
+
return [x]
|
183 |
+
|
184 |
+
|
185 |
+
def norm(dim, bn=False):
|
186 |
+
if bn is False:
|
187 |
+
return nn.GroupNorm(32, dim)
|
188 |
+
else:
|
189 |
+
return nn.BatchNorm2d(dim)
|
190 |
+
|
191 |
+
|
192 |
+
def fba_fusion(alpha, img, F, B):
|
193 |
+
F = alpha * img + (1 - alpha**2) * F - alpha * (1 - alpha) * B
|
194 |
+
B = (1 - alpha) * img + (2 * alpha - alpha**2) * B - alpha * (1 - alpha) * F
|
195 |
+
|
196 |
+
F = torch.clamp(F, 0, 1)
|
197 |
+
B = torch.clamp(B, 0, 1)
|
198 |
+
la = 0.1
|
199 |
+
alpha = (alpha * la + torch.sum((img - B) * (F - B), 1, keepdim=True)) / (
|
200 |
+
torch.sum((F - B) * (F - B), 1, keepdim=True) + la
|
201 |
+
)
|
202 |
+
alpha = torch.clamp(alpha, 0, 1)
|
203 |
+
return alpha, F, B
|
204 |
+
|
205 |
+
|
206 |
+
class fba_decoder(nn.Module):
|
207 |
+
def __init__(self, batch_norm=False):
|
208 |
+
super(fba_decoder, self).__init__()
|
209 |
+
pool_scales = (1, 2, 3, 6)
|
210 |
+
self.batch_norm = batch_norm
|
211 |
+
|
212 |
+
self.ppm = []
|
213 |
+
|
214 |
+
for scale in pool_scales:
|
215 |
+
self.ppm.append(
|
216 |
+
nn.Sequential(
|
217 |
+
nn.AdaptiveAvgPool2d(scale),
|
218 |
+
L.Conv2d(2048, 256, kernel_size=1, bias=True),
|
219 |
+
norm(256, self.batch_norm),
|
220 |
+
nn.LeakyReLU(),
|
221 |
+
)
|
222 |
+
)
|
223 |
+
self.ppm = nn.ModuleList(self.ppm)
|
224 |
+
|
225 |
+
self.conv_up1 = nn.Sequential(
|
226 |
+
L.Conv2d(
|
227 |
+
2048 + len(pool_scales) * 256, 256, kernel_size=3, padding=1, bias=True
|
228 |
+
),
|
229 |
+
norm(256, self.batch_norm),
|
230 |
+
nn.LeakyReLU(),
|
231 |
+
L.Conv2d(256, 256, kernel_size=3, padding=1),
|
232 |
+
norm(256, self.batch_norm),
|
233 |
+
nn.LeakyReLU(),
|
234 |
+
)
|
235 |
+
|
236 |
+
self.conv_up2 = nn.Sequential(
|
237 |
+
L.Conv2d(256 + 256, 256, kernel_size=3, padding=1, bias=True),
|
238 |
+
norm(256, self.batch_norm),
|
239 |
+
nn.LeakyReLU(),
|
240 |
+
)
|
241 |
+
if self.batch_norm:
|
242 |
+
d_up3 = 128
|
243 |
+
else:
|
244 |
+
d_up3 = 64
|
245 |
+
self.conv_up3 = nn.Sequential(
|
246 |
+
L.Conv2d(256 + d_up3, 64, kernel_size=3, padding=1, bias=True),
|
247 |
+
norm(64, self.batch_norm),
|
248 |
+
nn.LeakyReLU(),
|
249 |
+
)
|
250 |
+
|
251 |
+
self.unpool = nn.MaxUnpool2d(2, stride=2)
|
252 |
+
|
253 |
+
self.conv_up4 = nn.Sequential(
|
254 |
+
nn.Conv2d(64 + 3 + 3 + 2, 32, kernel_size=3, padding=1, bias=True),
|
255 |
+
nn.LeakyReLU(),
|
256 |
+
nn.Conv2d(32, 16, kernel_size=3, padding=1, bias=True),
|
257 |
+
nn.LeakyReLU(),
|
258 |
+
nn.Conv2d(16, 7, kernel_size=1, padding=0, bias=True),
|
259 |
+
)
|
260 |
+
|
261 |
+
def forward(self, conv_out, img, indices, two_chan_trimap):
|
262 |
+
conv5 = conv_out[-1]
|
263 |
+
|
264 |
+
input_size = conv5.size()
|
265 |
+
ppm_out = [conv5]
|
266 |
+
for pool_scale in self.ppm:
|
267 |
+
ppm_out.append(
|
268 |
+
nn.functional.interpolate(
|
269 |
+
pool_scale(conv5),
|
270 |
+
(input_size[2], input_size[3]),
|
271 |
+
mode="bilinear",
|
272 |
+
align_corners=False,
|
273 |
+
)
|
274 |
+
)
|
275 |
+
ppm_out = torch.cat(ppm_out, 1)
|
276 |
+
x = self.conv_up1(ppm_out)
|
277 |
+
|
278 |
+
x = torch.nn.functional.interpolate(
|
279 |
+
x, scale_factor=2, mode="bilinear", align_corners=False
|
280 |
+
)
|
281 |
+
|
282 |
+
x = torch.cat((x, conv_out[-4]), 1)
|
283 |
+
|
284 |
+
x = self.conv_up2(x)
|
285 |
+
x = torch.nn.functional.interpolate(
|
286 |
+
x, scale_factor=2, mode="bilinear", align_corners=False
|
287 |
+
)
|
288 |
+
|
289 |
+
x = torch.cat((x, conv_out[-5]), 1)
|
290 |
+
x = self.conv_up3(x)
|
291 |
+
|
292 |
+
x = torch.nn.functional.interpolate(
|
293 |
+
x, scale_factor=2, mode="bilinear", align_corners=False
|
294 |
+
)
|
295 |
+
x = torch.cat((x, conv_out[-6][:, :3], img, two_chan_trimap), 1)
|
296 |
+
|
297 |
+
output = self.conv_up4(x)
|
298 |
+
|
299 |
+
alpha = torch.clamp(output[:, 0][:, None], 0, 1)
|
300 |
+
F = torch.sigmoid(output[:, 1:4])
|
301 |
+
B = torch.sigmoid(output[:, 4:7])
|
302 |
+
|
303 |
+
# FBA Fusion
|
304 |
+
alpha, F, B = fba_fusion(alpha, img, F, B)
|
305 |
+
|
306 |
+
output = torch.cat((alpha, F, B), 1)
|
307 |
+
|
308 |
+
return output
|
309 |
+
|
310 |
+
|
311 |
+
def build_encoder(arch="resnet50_GN"):
|
312 |
+
if arch == "resnet50_GN_WS":
|
313 |
+
orig_resnet = resnet_GN_WS.__dict__["l_resnet50"]()
|
314 |
+
net_encoder = ResnetDilated(orig_resnet, dilate_scale=8)
|
315 |
+
elif arch == "resnet50_BN":
|
316 |
+
orig_resnet = resnet_bn.__dict__["l_resnet50"]()
|
317 |
+
net_encoder = ResnetDilatedBN(orig_resnet, dilate_scale=8)
|
318 |
+
|
319 |
+
else:
|
320 |
+
raise ValueError("Architecture undefined!")
|
321 |
+
|
322 |
+
num_channels = 3 + 6 + 2
|
323 |
+
|
324 |
+
if num_channels > 3:
|
325 |
+
net_encoder_sd = net_encoder.state_dict()
|
326 |
+
conv1_weights = net_encoder_sd["conv1.weight"]
|
327 |
+
|
328 |
+
c_out, c_in, h, w = conv1_weights.size()
|
329 |
+
conv1_mod = torch.zeros(c_out, num_channels, h, w)
|
330 |
+
conv1_mod[:, :3, :, :] = conv1_weights
|
331 |
+
|
332 |
+
conv1 = net_encoder.conv1
|
333 |
+
conv1.in_channels = num_channels
|
334 |
+
conv1.weight = torch.nn.Parameter(conv1_mod)
|
335 |
+
|
336 |
+
net_encoder.conv1 = conv1
|
337 |
+
|
338 |
+
net_encoder_sd["conv1.weight"] = conv1_mod
|
339 |
+
|
340 |
+
net_encoder.load_state_dict(net_encoder_sd)
|
341 |
+
return net_encoder
|
carvekit/ml/arch/fba_matting/resnet_GN_WS.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
|
3 |
+
Source url: https://github.com/MarcoForte/FBA_Matting
|
4 |
+
License: MIT License
|
5 |
+
"""
|
6 |
+
import torch.nn as nn
|
7 |
+
import carvekit.ml.arch.fba_matting.layers_WS as L
|
8 |
+
|
9 |
+
__all__ = ["ResNet", "l_resnet50"]
|
10 |
+
|
11 |
+
|
12 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
13 |
+
"""3x3 convolution with padding"""
|
14 |
+
return L.Conv2d(
|
15 |
+
in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
20 |
+
"""1x1 convolution"""
|
21 |
+
return L.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
22 |
+
|
23 |
+
|
24 |
+
class BasicBlock(nn.Module):
|
25 |
+
expansion = 1
|
26 |
+
|
27 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
28 |
+
super(BasicBlock, self).__init__()
|
29 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
30 |
+
self.bn1 = L.BatchNorm2d(planes)
|
31 |
+
self.relu = nn.ReLU(inplace=True)
|
32 |
+
self.conv2 = conv3x3(planes, planes)
|
33 |
+
self.bn2 = L.BatchNorm2d(planes)
|
34 |
+
self.downsample = downsample
|
35 |
+
self.stride = stride
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
identity = x
|
39 |
+
|
40 |
+
out = self.conv1(x)
|
41 |
+
out = self.bn1(out)
|
42 |
+
out = self.relu(out)
|
43 |
+
|
44 |
+
out = self.conv2(out)
|
45 |
+
out = self.bn2(out)
|
46 |
+
|
47 |
+
if self.downsample is not None:
|
48 |
+
identity = self.downsample(x)
|
49 |
+
|
50 |
+
out += identity
|
51 |
+
out = self.relu(out)
|
52 |
+
|
53 |
+
return out
|
54 |
+
|
55 |
+
|
56 |
+
class Bottleneck(nn.Module):
|
57 |
+
expansion = 4
|
58 |
+
|
59 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
60 |
+
super(Bottleneck, self).__init__()
|
61 |
+
self.conv1 = conv1x1(inplanes, planes)
|
62 |
+
self.bn1 = L.BatchNorm2d(planes)
|
63 |
+
self.conv2 = conv3x3(planes, planes, stride)
|
64 |
+
self.bn2 = L.BatchNorm2d(planes)
|
65 |
+
self.conv3 = conv1x1(planes, planes * self.expansion)
|
66 |
+
self.bn3 = L.BatchNorm2d(planes * self.expansion)
|
67 |
+
self.relu = nn.ReLU(inplace=True)
|
68 |
+
self.downsample = downsample
|
69 |
+
self.stride = stride
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
identity = x
|
73 |
+
|
74 |
+
out = self.conv1(x)
|
75 |
+
out = self.bn1(out)
|
76 |
+
out = self.relu(out)
|
77 |
+
|
78 |
+
out = self.conv2(out)
|
79 |
+
out = self.bn2(out)
|
80 |
+
out = self.relu(out)
|
81 |
+
|
82 |
+
out = self.conv3(out)
|
83 |
+
out = self.bn3(out)
|
84 |
+
|
85 |
+
if self.downsample is not None:
|
86 |
+
identity = self.downsample(x)
|
87 |
+
|
88 |
+
out += identity
|
89 |
+
out = self.relu(out)
|
90 |
+
|
91 |
+
return out
|
92 |
+
|
93 |
+
|
94 |
+
class ResNet(nn.Module):
|
95 |
+
def __init__(self, block, layers, num_classes=1000):
|
96 |
+
super(ResNet, self).__init__()
|
97 |
+
self.inplanes = 64
|
98 |
+
self.conv1 = L.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
99 |
+
self.bn1 = L.BatchNorm2d(64)
|
100 |
+
self.relu = nn.ReLU(inplace=True)
|
101 |
+
self.maxpool = nn.MaxPool2d(
|
102 |
+
kernel_size=3, stride=2, padding=1, return_indices=True
|
103 |
+
)
|
104 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
105 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
106 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
107 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
108 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
109 |
+
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
110 |
+
|
111 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
112 |
+
downsample = None
|
113 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
114 |
+
downsample = nn.Sequential(
|
115 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
116 |
+
L.BatchNorm2d(planes * block.expansion),
|
117 |
+
)
|
118 |
+
|
119 |
+
layers = []
|
120 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
121 |
+
self.inplanes = planes * block.expansion
|
122 |
+
for _ in range(1, blocks):
|
123 |
+
layers.append(block(self.inplanes, planes))
|
124 |
+
|
125 |
+
return nn.Sequential(*layers)
|
126 |
+
|
127 |
+
def forward(self, x):
|
128 |
+
x = self.conv1(x)
|
129 |
+
x = self.bn1(x)
|
130 |
+
x = self.relu(x)
|
131 |
+
x = self.maxpool(x)
|
132 |
+
|
133 |
+
x = self.layer1(x)
|
134 |
+
x = self.layer2(x)
|
135 |
+
x = self.layer3(x)
|
136 |
+
x = self.layer4(x)
|
137 |
+
|
138 |
+
x = self.avgpool(x)
|
139 |
+
x = x.view(x.size(0), -1)
|
140 |
+
x = self.fc(x)
|
141 |
+
|
142 |
+
return x
|
143 |
+
|
144 |
+
|
145 |
+
def l_resnet50(pretrained=False, **kwargs):
|
146 |
+
"""Constructs a ResNet-50 model.
|
147 |
+
Args:
|
148 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
149 |
+
"""
|
150 |
+
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
|
151 |
+
return model
|
carvekit/ml/arch/fba_matting/resnet_bn.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
|
3 |
+
Source url: https://github.com/MarcoForte/FBA_Matting
|
4 |
+
License: MIT License
|
5 |
+
"""
|
6 |
+
import torch.nn as nn
|
7 |
+
import math
|
8 |
+
from torch.nn import BatchNorm2d
|
9 |
+
|
10 |
+
__all__ = ["ResNet"]
|
11 |
+
|
12 |
+
|
13 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
14 |
+
"3x3 convolution with padding"
|
15 |
+
return nn.Conv2d(
|
16 |
+
in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
|
17 |
+
)
|
18 |
+
|
19 |
+
|
20 |
+
class BasicBlock(nn.Module):
|
21 |
+
expansion = 1
|
22 |
+
|
23 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
24 |
+
super(BasicBlock, self).__init__()
|
25 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
26 |
+
self.bn1 = BatchNorm2d(planes)
|
27 |
+
self.relu = nn.ReLU(inplace=True)
|
28 |
+
self.conv2 = conv3x3(planes, planes)
|
29 |
+
self.bn2 = BatchNorm2d(planes)
|
30 |
+
self.downsample = downsample
|
31 |
+
self.stride = stride
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
residual = x
|
35 |
+
|
36 |
+
out = self.conv1(x)
|
37 |
+
out = self.bn1(out)
|
38 |
+
out = self.relu(out)
|
39 |
+
|
40 |
+
out = self.conv2(out)
|
41 |
+
out = self.bn2(out)
|
42 |
+
|
43 |
+
if self.downsample is not None:
|
44 |
+
residual = self.downsample(x)
|
45 |
+
|
46 |
+
out += residual
|
47 |
+
out = self.relu(out)
|
48 |
+
|
49 |
+
return out
|
50 |
+
|
51 |
+
|
52 |
+
class Bottleneck(nn.Module):
|
53 |
+
expansion = 4
|
54 |
+
|
55 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
56 |
+
super(Bottleneck, self).__init__()
|
57 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
58 |
+
self.bn1 = BatchNorm2d(planes)
|
59 |
+
self.conv2 = nn.Conv2d(
|
60 |
+
planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
|
61 |
+
)
|
62 |
+
self.bn2 = BatchNorm2d(planes, momentum=0.01)
|
63 |
+
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
64 |
+
self.bn3 = BatchNorm2d(planes * 4)
|
65 |
+
self.relu = nn.ReLU(inplace=True)
|
66 |
+
self.downsample = downsample
|
67 |
+
self.stride = stride
|
68 |
+
|
69 |
+
def forward(self, x):
|
70 |
+
residual = x
|
71 |
+
|
72 |
+
out = self.conv1(x)
|
73 |
+
out = self.bn1(out)
|
74 |
+
out = self.relu(out)
|
75 |
+
|
76 |
+
out = self.conv2(out)
|
77 |
+
out = self.bn2(out)
|
78 |
+
out = self.relu(out)
|
79 |
+
|
80 |
+
out = self.conv3(out)
|
81 |
+
out = self.bn3(out)
|
82 |
+
|
83 |
+
if self.downsample is not None:
|
84 |
+
residual = self.downsample(x)
|
85 |
+
|
86 |
+
out += residual
|
87 |
+
out = self.relu(out)
|
88 |
+
|
89 |
+
return out
|
90 |
+
|
91 |
+
|
92 |
+
class ResNet(nn.Module):
|
93 |
+
def __init__(self, block, layers, num_classes=1000):
|
94 |
+
self.inplanes = 128
|
95 |
+
super(ResNet, self).__init__()
|
96 |
+
self.conv1 = conv3x3(3, 64, stride=2)
|
97 |
+
self.bn1 = BatchNorm2d(64)
|
98 |
+
self.relu1 = nn.ReLU(inplace=True)
|
99 |
+
self.conv2 = conv3x3(64, 64)
|
100 |
+
self.bn2 = BatchNorm2d(64)
|
101 |
+
self.relu2 = nn.ReLU(inplace=True)
|
102 |
+
self.conv3 = conv3x3(64, 128)
|
103 |
+
self.bn3 = BatchNorm2d(128)
|
104 |
+
self.relu3 = nn.ReLU(inplace=True)
|
105 |
+
self.maxpool = nn.MaxPool2d(
|
106 |
+
kernel_size=3, stride=2, padding=1, return_indices=True
|
107 |
+
)
|
108 |
+
|
109 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
110 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
111 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
112 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
113 |
+
self.avgpool = nn.AvgPool2d(7, stride=1)
|
114 |
+
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
115 |
+
|
116 |
+
for m in self.modules():
|
117 |
+
if isinstance(m, nn.Conv2d):
|
118 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
119 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / n))
|
120 |
+
elif isinstance(m, BatchNorm2d):
|
121 |
+
m.weight.data.fill_(1)
|
122 |
+
m.bias.data.zero_()
|
123 |
+
|
124 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
125 |
+
downsample = None
|
126 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
127 |
+
downsample = nn.Sequential(
|
128 |
+
nn.Conv2d(
|
129 |
+
self.inplanes,
|
130 |
+
planes * block.expansion,
|
131 |
+
kernel_size=1,
|
132 |
+
stride=stride,
|
133 |
+
bias=False,
|
134 |
+
),
|
135 |
+
BatchNorm2d(planes * block.expansion),
|
136 |
+
)
|
137 |
+
|
138 |
+
layers = []
|
139 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
140 |
+
self.inplanes = planes * block.expansion
|
141 |
+
for i in range(1, blocks):
|
142 |
+
layers.append(block(self.inplanes, planes))
|
143 |
+
|
144 |
+
return nn.Sequential(*layers)
|
145 |
+
|
146 |
+
def forward(self, x):
|
147 |
+
x = self.relu1(self.bn1(self.conv1(x)))
|
148 |
+
x = self.relu2(self.bn2(self.conv2(x)))
|
149 |
+
x = self.relu3(self.bn3(self.conv3(x)))
|
150 |
+
x, indices = self.maxpool(x)
|
151 |
+
|
152 |
+
x = self.layer1(x)
|
153 |
+
x = self.layer2(x)
|
154 |
+
x = self.layer3(x)
|
155 |
+
x = self.layer4(x)
|
156 |
+
|
157 |
+
x = self.avgpool(x)
|
158 |
+
x = x.view(x.size(0), -1)
|
159 |
+
x = self.fc(x)
|
160 |
+
return x
|
161 |
+
|
162 |
+
|
163 |
+
def l_resnet50():
|
164 |
+
"""Constructs a ResNet-50 model.
|
165 |
+
Args:
|
166 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
167 |
+
"""
|
168 |
+
model = ResNet(Bottleneck, [3, 4, 6, 3])
|
169 |
+
return model
|
carvekit/ml/arch/fba_matting/transforms.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
|
3 |
+
Source url: https://github.com/MarcoForte/FBA_Matting
|
4 |
+
License: MIT License
|
5 |
+
"""
|
6 |
+
import cv2
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
group_norm_std = [0.229, 0.224, 0.225]
|
10 |
+
group_norm_mean = [0.485, 0.456, 0.406]
|
11 |
+
|
12 |
+
|
13 |
+
def dt(a):
|
14 |
+
return cv2.distanceTransform((a * 255).astype(np.uint8), cv2.DIST_L2, 0)
|
15 |
+
|
16 |
+
|
17 |
+
def trimap_transform(trimap):
|
18 |
+
h, w = trimap.shape[0], trimap.shape[1]
|
19 |
+
|
20 |
+
clicks = np.zeros((h, w, 6))
|
21 |
+
for k in range(2):
|
22 |
+
if np.count_nonzero(trimap[:, :, k]) > 0:
|
23 |
+
dt_mask = -dt(1 - trimap[:, :, k]) ** 2
|
24 |
+
L = 320
|
25 |
+
clicks[:, :, 3 * k] = np.exp(dt_mask / (2 * ((0.02 * L) ** 2)))
|
26 |
+
clicks[:, :, 3 * k + 1] = np.exp(dt_mask / (2 * ((0.08 * L) ** 2)))
|
27 |
+
clicks[:, :, 3 * k + 2] = np.exp(dt_mask / (2 * ((0.16 * L) ** 2)))
|
28 |
+
|
29 |
+
return clicks
|
30 |
+
|
31 |
+
|
32 |
+
def groupnorm_normalise_image(img, format="nhwc"):
|
33 |
+
"""
|
34 |
+
Accept rgb in range 0,1
|
35 |
+
"""
|
36 |
+
if format == "nhwc":
|
37 |
+
for i in range(3):
|
38 |
+
img[..., i] = (img[..., i] - group_norm_mean[i]) / group_norm_std[i]
|
39 |
+
else:
|
40 |
+
for i in range(3):
|
41 |
+
img[..., i, :, :] = (
|
42 |
+
img[..., i, :, :] - group_norm_mean[i]
|
43 |
+
) / group_norm_std[i]
|
44 |
+
|
45 |
+
return img
|
carvekit/ml/arch/tracerb7/__init__.py
ADDED
File without changes
|
carvekit/ml/arch/tracerb7/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (187 Bytes). View file
|
|
carvekit/ml/arch/tracerb7/__pycache__/att_modules.cpython-38.pyc
ADDED
Binary file (7.42 kB). View file
|
|
carvekit/ml/arch/tracerb7/__pycache__/conv_modules.cpython-38.pyc
ADDED
Binary file (2.44 kB). View file
|
|
carvekit/ml/arch/tracerb7/__pycache__/effi_utils.cpython-38.pyc
ADDED
Binary file (14.9 kB). View file
|
|
carvekit/ml/arch/tracerb7/__pycache__/efficientnet.cpython-38.pyc
ADDED
Binary file (8.02 kB). View file
|
|
carvekit/ml/arch/tracerb7/__pycache__/tracer.cpython-38.pyc
ADDED
Binary file (2.83 kB). View file
|
|
carvekit/ml/arch/tracerb7/att_modules.py
ADDED
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Source url: https://github.com/Karel911/TRACER
|
3 |
+
Author: Min Seok Lee and Wooseok Shin
|
4 |
+
License: Apache License 2.0
|
5 |
+
"""
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
from carvekit.ml.arch.tracerb7.conv_modules import BasicConv2d, DWConv, DWSConv
|
11 |
+
|
12 |
+
|
13 |
+
class RFB_Block(nn.Module):
|
14 |
+
def __init__(self, in_channel, out_channel):
|
15 |
+
super(RFB_Block, self).__init__()
|
16 |
+
self.relu = nn.ReLU(True)
|
17 |
+
self.branch0 = nn.Sequential(
|
18 |
+
BasicConv2d(in_channel, out_channel, 1),
|
19 |
+
)
|
20 |
+
self.branch1 = nn.Sequential(
|
21 |
+
BasicConv2d(in_channel, out_channel, 1),
|
22 |
+
BasicConv2d(out_channel, out_channel, kernel_size=(1, 3), padding=(0, 1)),
|
23 |
+
BasicConv2d(out_channel, out_channel, kernel_size=(3, 1), padding=(1, 0)),
|
24 |
+
BasicConv2d(out_channel, out_channel, 3, padding=3, dilation=3),
|
25 |
+
)
|
26 |
+
self.branch2 = nn.Sequential(
|
27 |
+
BasicConv2d(in_channel, out_channel, 1),
|
28 |
+
BasicConv2d(out_channel, out_channel, kernel_size=(1, 5), padding=(0, 2)),
|
29 |
+
BasicConv2d(out_channel, out_channel, kernel_size=(5, 1), padding=(2, 0)),
|
30 |
+
BasicConv2d(out_channel, out_channel, 3, padding=5, dilation=5),
|
31 |
+
)
|
32 |
+
self.branch3 = nn.Sequential(
|
33 |
+
BasicConv2d(in_channel, out_channel, 1),
|
34 |
+
BasicConv2d(out_channel, out_channel, kernel_size=(1, 7), padding=(0, 3)),
|
35 |
+
BasicConv2d(out_channel, out_channel, kernel_size=(7, 1), padding=(3, 0)),
|
36 |
+
BasicConv2d(out_channel, out_channel, 3, padding=7, dilation=7),
|
37 |
+
)
|
38 |
+
self.conv_cat = BasicConv2d(4 * out_channel, out_channel, 3, padding=1)
|
39 |
+
self.conv_res = BasicConv2d(in_channel, out_channel, 1)
|
40 |
+
|
41 |
+
def forward(self, x):
|
42 |
+
x0 = self.branch0(x)
|
43 |
+
x1 = self.branch1(x)
|
44 |
+
x2 = self.branch2(x)
|
45 |
+
x3 = self.branch3(x)
|
46 |
+
x_cat = torch.cat((x0, x1, x2, x3), 1)
|
47 |
+
x_cat = self.conv_cat(x_cat)
|
48 |
+
|
49 |
+
x = self.relu(x_cat + self.conv_res(x))
|
50 |
+
return x
|
51 |
+
|
52 |
+
|
53 |
+
class GlobalAvgPool(nn.Module):
|
54 |
+
def __init__(self, flatten=False):
|
55 |
+
super(GlobalAvgPool, self).__init__()
|
56 |
+
self.flatten = flatten
|
57 |
+
|
58 |
+
def forward(self, x):
|
59 |
+
if self.flatten:
|
60 |
+
in_size = x.size()
|
61 |
+
return x.view((in_size[0], in_size[1], -1)).mean(dim=2)
|
62 |
+
else:
|
63 |
+
return (
|
64 |
+
x.view(x.size(0), x.size(1), -1)
|
65 |
+
.mean(-1)
|
66 |
+
.view(x.size(0), x.size(1), 1, 1)
|
67 |
+
)
|
68 |
+
|
69 |
+
|
70 |
+
class UnionAttentionModule(nn.Module):
|
71 |
+
def __init__(self, n_channels, only_channel_tracing=False):
|
72 |
+
super(UnionAttentionModule, self).__init__()
|
73 |
+
self.GAP = GlobalAvgPool()
|
74 |
+
self.confidence_ratio = 0.1
|
75 |
+
self.bn = nn.BatchNorm2d(n_channels)
|
76 |
+
self.norm = nn.Sequential(
|
77 |
+
nn.BatchNorm2d(n_channels), nn.Dropout3d(self.confidence_ratio)
|
78 |
+
)
|
79 |
+
self.channel_q = nn.Conv2d(
|
80 |
+
in_channels=n_channels,
|
81 |
+
out_channels=n_channels,
|
82 |
+
kernel_size=1,
|
83 |
+
stride=1,
|
84 |
+
padding=0,
|
85 |
+
bias=False,
|
86 |
+
)
|
87 |
+
self.channel_k = nn.Conv2d(
|
88 |
+
in_channels=n_channels,
|
89 |
+
out_channels=n_channels,
|
90 |
+
kernel_size=1,
|
91 |
+
stride=1,
|
92 |
+
padding=0,
|
93 |
+
bias=False,
|
94 |
+
)
|
95 |
+
self.channel_v = nn.Conv2d(
|
96 |
+
in_channels=n_channels,
|
97 |
+
out_channels=n_channels,
|
98 |
+
kernel_size=1,
|
99 |
+
stride=1,
|
100 |
+
padding=0,
|
101 |
+
bias=False,
|
102 |
+
)
|
103 |
+
|
104 |
+
self.fc = nn.Conv2d(
|
105 |
+
in_channels=n_channels,
|
106 |
+
out_channels=n_channels,
|
107 |
+
kernel_size=1,
|
108 |
+
stride=1,
|
109 |
+
padding=0,
|
110 |
+
bias=False,
|
111 |
+
)
|
112 |
+
|
113 |
+
if only_channel_tracing is False:
|
114 |
+
self.spatial_q = nn.Conv2d(
|
115 |
+
in_channels=n_channels,
|
116 |
+
out_channels=1,
|
117 |
+
kernel_size=1,
|
118 |
+
stride=1,
|
119 |
+
padding=0,
|
120 |
+
bias=False,
|
121 |
+
)
|
122 |
+
self.spatial_k = nn.Conv2d(
|
123 |
+
in_channels=n_channels,
|
124 |
+
out_channels=1,
|
125 |
+
kernel_size=1,
|
126 |
+
stride=1,
|
127 |
+
padding=0,
|
128 |
+
bias=False,
|
129 |
+
)
|
130 |
+
self.spatial_v = nn.Conv2d(
|
131 |
+
in_channels=n_channels,
|
132 |
+
out_channels=1,
|
133 |
+
kernel_size=1,
|
134 |
+
stride=1,
|
135 |
+
padding=0,
|
136 |
+
bias=False,
|
137 |
+
)
|
138 |
+
self.sigmoid = nn.Sigmoid()
|
139 |
+
|
140 |
+
def masking(self, x, mask):
|
141 |
+
mask = mask.squeeze(3).squeeze(2)
|
142 |
+
threshold = torch.quantile(
|
143 |
+
mask.float(), self.confidence_ratio, dim=-1, keepdim=True
|
144 |
+
)
|
145 |
+
mask[mask <= threshold] = 0.0
|
146 |
+
mask = mask.unsqueeze(2).unsqueeze(3)
|
147 |
+
mask = mask.expand(-1, x.shape[1], x.shape[2], x.shape[3]).contiguous()
|
148 |
+
masked_x = x * mask
|
149 |
+
|
150 |
+
return masked_x
|
151 |
+
|
152 |
+
def Channel_Tracer(self, x):
|
153 |
+
avg_pool = self.GAP(x)
|
154 |
+
x_norm = self.norm(avg_pool)
|
155 |
+
|
156 |
+
q = self.channel_q(x_norm).squeeze(-1)
|
157 |
+
k = self.channel_k(x_norm).squeeze(-1)
|
158 |
+
v = self.channel_v(x_norm).squeeze(-1)
|
159 |
+
|
160 |
+
# softmax(Q*K^T)
|
161 |
+
QK_T = torch.matmul(q, k.transpose(1, 2))
|
162 |
+
alpha = F.softmax(QK_T, dim=-1)
|
163 |
+
|
164 |
+
# a*v
|
165 |
+
att = torch.matmul(alpha, v).unsqueeze(-1)
|
166 |
+
att = self.fc(att)
|
167 |
+
att = self.sigmoid(att)
|
168 |
+
|
169 |
+
output = (x * att) + x
|
170 |
+
alpha_mask = att.clone()
|
171 |
+
|
172 |
+
return output, alpha_mask
|
173 |
+
|
174 |
+
def forward(self, x):
|
175 |
+
X_c, alpha_mask = self.Channel_Tracer(x)
|
176 |
+
X_c = self.bn(X_c)
|
177 |
+
x_drop = self.masking(X_c, alpha_mask)
|
178 |
+
|
179 |
+
q = self.spatial_q(x_drop).squeeze(1)
|
180 |
+
k = self.spatial_k(x_drop).squeeze(1)
|
181 |
+
v = self.spatial_v(x_drop).squeeze(1)
|
182 |
+
|
183 |
+
# softmax(Q*K^T)
|
184 |
+
QK_T = torch.matmul(q, k.transpose(1, 2))
|
185 |
+
alpha = F.softmax(QK_T, dim=-1)
|
186 |
+
|
187 |
+
output = torch.matmul(alpha, v).unsqueeze(1) + v.unsqueeze(1)
|
188 |
+
|
189 |
+
return output
|
190 |
+
|
191 |
+
|
192 |
+
class aggregation(nn.Module):
|
193 |
+
def __init__(self, channel):
|
194 |
+
super(aggregation, self).__init__()
|
195 |
+
self.relu = nn.ReLU(True)
|
196 |
+
|
197 |
+
self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
|
198 |
+
self.conv_upsample1 = BasicConv2d(channel[2], channel[1], 3, padding=1)
|
199 |
+
self.conv_upsample2 = BasicConv2d(channel[2], channel[0], 3, padding=1)
|
200 |
+
self.conv_upsample3 = BasicConv2d(channel[1], channel[0], 3, padding=1)
|
201 |
+
self.conv_upsample4 = BasicConv2d(channel[2], channel[2], 3, padding=1)
|
202 |
+
self.conv_upsample5 = BasicConv2d(
|
203 |
+
channel[2] + channel[1], channel[2] + channel[1], 3, padding=1
|
204 |
+
)
|
205 |
+
|
206 |
+
self.conv_concat2 = BasicConv2d(
|
207 |
+
(channel[2] + channel[1]), (channel[2] + channel[1]), 3, padding=1
|
208 |
+
)
|
209 |
+
self.conv_concat3 = BasicConv2d(
|
210 |
+
(channel[0] + channel[1] + channel[2]),
|
211 |
+
(channel[0] + channel[1] + channel[2]),
|
212 |
+
3,
|
213 |
+
padding=1,
|
214 |
+
)
|
215 |
+
|
216 |
+
self.UAM = UnionAttentionModule(channel[0] + channel[1] + channel[2])
|
217 |
+
|
218 |
+
def forward(self, e4, e3, e2):
|
219 |
+
e4_1 = e4
|
220 |
+
e3_1 = self.conv_upsample1(self.upsample(e4)) * e3
|
221 |
+
e2_1 = (
|
222 |
+
self.conv_upsample2(self.upsample(self.upsample(e4)))
|
223 |
+
* self.conv_upsample3(self.upsample(e3))
|
224 |
+
* e2
|
225 |
+
)
|
226 |
+
|
227 |
+
e3_2 = torch.cat((e3_1, self.conv_upsample4(self.upsample(e4_1))), 1)
|
228 |
+
e3_2 = self.conv_concat2(e3_2)
|
229 |
+
|
230 |
+
e2_2 = torch.cat((e2_1, self.conv_upsample5(self.upsample(e3_2))), 1)
|
231 |
+
x = self.conv_concat3(e2_2)
|
232 |
+
|
233 |
+
output = self.UAM(x)
|
234 |
+
|
235 |
+
return output
|
236 |
+
|
237 |
+
|
238 |
+
class ObjectAttention(nn.Module):
|
239 |
+
def __init__(self, channel, kernel_size):
|
240 |
+
super(ObjectAttention, self).__init__()
|
241 |
+
self.channel = channel
|
242 |
+
self.DWSConv = DWSConv(
|
243 |
+
channel, channel // 2, kernel=kernel_size, padding=1, kernels_per_layer=1
|
244 |
+
)
|
245 |
+
self.DWConv1 = nn.Sequential(
|
246 |
+
DWConv(channel // 2, channel // 2, kernel=1, padding=0, dilation=1),
|
247 |
+
BasicConv2d(channel // 2, channel // 8, 1),
|
248 |
+
)
|
249 |
+
self.DWConv2 = nn.Sequential(
|
250 |
+
DWConv(channel // 2, channel // 2, kernel=3, padding=1, dilation=1),
|
251 |
+
BasicConv2d(channel // 2, channel // 8, 1),
|
252 |
+
)
|
253 |
+
self.DWConv3 = nn.Sequential(
|
254 |
+
DWConv(channel // 2, channel // 2, kernel=3, padding=3, dilation=3),
|
255 |
+
BasicConv2d(channel // 2, channel // 8, 1),
|
256 |
+
)
|
257 |
+
self.DWConv4 = nn.Sequential(
|
258 |
+
DWConv(channel // 2, channel // 2, kernel=3, padding=5, dilation=5),
|
259 |
+
BasicConv2d(channel // 2, channel // 8, 1),
|
260 |
+
)
|
261 |
+
self.conv1 = BasicConv2d(channel // 2, 1, 1)
|
262 |
+
|
263 |
+
def forward(self, decoder_map, encoder_map):
|
264 |
+
"""
|
265 |
+
Args:
|
266 |
+
decoder_map: decoder representation (B, 1, H, W).
|
267 |
+
encoder_map: encoder block output (B, C, H, W).
|
268 |
+
Returns:
|
269 |
+
decoder representation: (B, 1, H, W)
|
270 |
+
"""
|
271 |
+
mask_bg = -1 * torch.sigmoid(decoder_map) + 1 # Sigmoid & Reverse
|
272 |
+
mask_ob = torch.sigmoid(decoder_map) # object attention
|
273 |
+
x = mask_ob.expand(-1, self.channel, -1, -1).mul(encoder_map)
|
274 |
+
|
275 |
+
edge = mask_bg.clone()
|
276 |
+
edge[edge > 0.93] = 0
|
277 |
+
x = x + (edge * encoder_map)
|
278 |
+
|
279 |
+
x = self.DWSConv(x)
|
280 |
+
skip = x.clone()
|
281 |
+
x = (
|
282 |
+
torch.cat(
|
283 |
+
[self.DWConv1(x), self.DWConv2(x), self.DWConv3(x), self.DWConv4(x)],
|
284 |
+
dim=1,
|
285 |
+
)
|
286 |
+
+ skip
|
287 |
+
)
|
288 |
+
x = torch.relu(self.conv1(x))
|
289 |
+
|
290 |
+
return x + decoder_map
|
carvekit/ml/arch/tracerb7/conv_modules.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Source url: https://github.com/Karel911/TRACER
|
3 |
+
Author: Min Seok Lee and Wooseok Shin
|
4 |
+
License: Apache License 2.0
|
5 |
+
"""
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
|
9 |
+
class BasicConv2d(nn.Module):
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
in_channel,
|
13 |
+
out_channel,
|
14 |
+
kernel_size,
|
15 |
+
stride=(1, 1),
|
16 |
+
padding=(0, 0),
|
17 |
+
dilation=(1, 1),
|
18 |
+
):
|
19 |
+
super(BasicConv2d, self).__init__()
|
20 |
+
self.conv = nn.Conv2d(
|
21 |
+
in_channel,
|
22 |
+
out_channel,
|
23 |
+
kernel_size=kernel_size,
|
24 |
+
stride=stride,
|
25 |
+
padding=padding,
|
26 |
+
dilation=dilation,
|
27 |
+
bias=False,
|
28 |
+
)
|
29 |
+
self.bn = nn.BatchNorm2d(out_channel)
|
30 |
+
self.selu = nn.SELU()
|
31 |
+
|
32 |
+
def forward(self, x):
|
33 |
+
x = self.conv(x)
|
34 |
+
x = self.bn(x)
|
35 |
+
x = self.selu(x)
|
36 |
+
|
37 |
+
return x
|
38 |
+
|
39 |
+
|
40 |
+
class DWConv(nn.Module):
|
41 |
+
def __init__(self, in_channel, out_channel, kernel, dilation, padding):
|
42 |
+
super(DWConv, self).__init__()
|
43 |
+
self.out_channel = out_channel
|
44 |
+
self.DWConv = nn.Conv2d(
|
45 |
+
in_channel,
|
46 |
+
out_channel,
|
47 |
+
kernel_size=kernel,
|
48 |
+
padding=padding,
|
49 |
+
groups=in_channel,
|
50 |
+
dilation=dilation,
|
51 |
+
bias=False,
|
52 |
+
)
|
53 |
+
self.bn = nn.BatchNorm2d(out_channel)
|
54 |
+
self.selu = nn.SELU()
|
55 |
+
|
56 |
+
def forward(self, x):
|
57 |
+
x = self.DWConv(x)
|
58 |
+
out = self.selu(self.bn(x))
|
59 |
+
|
60 |
+
return out
|
61 |
+
|
62 |
+
|
63 |
+
class DWSConv(nn.Module):
|
64 |
+
def __init__(self, in_channel, out_channel, kernel, padding, kernels_per_layer):
|
65 |
+
super(DWSConv, self).__init__()
|
66 |
+
self.out_channel = out_channel
|
67 |
+
self.DWConv = nn.Conv2d(
|
68 |
+
in_channel,
|
69 |
+
in_channel * kernels_per_layer,
|
70 |
+
kernel_size=kernel,
|
71 |
+
padding=padding,
|
72 |
+
groups=in_channel,
|
73 |
+
bias=False,
|
74 |
+
)
|
75 |
+
self.bn = nn.BatchNorm2d(in_channel * kernels_per_layer)
|
76 |
+
self.selu = nn.SELU()
|
77 |
+
self.PWConv = nn.Conv2d(
|
78 |
+
in_channel * kernels_per_layer, out_channel, kernel_size=1, bias=False
|
79 |
+
)
|
80 |
+
self.bn2 = nn.BatchNorm2d(out_channel)
|
81 |
+
|
82 |
+
def forward(self, x):
|
83 |
+
x = self.DWConv(x)
|
84 |
+
x = self.selu(self.bn(x))
|
85 |
+
out = self.PWConv(x)
|
86 |
+
out = self.selu(self.bn2(out))
|
87 |
+
|
88 |
+
return out
|
carvekit/ml/arch/tracerb7/effi_utils.py
ADDED
@@ -0,0 +1,579 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Original author: lukemelas (github username)
|
3 |
+
Github repo: https://github.com/lukemelas/EfficientNet-PyTorch
|
4 |
+
With adjustments and added comments by workingcoder (github username).
|
5 |
+
License: Apache License 2.0
|
6 |
+
Reimplemented: Min Seok Lee and Wooseok Shin
|
7 |
+
"""
|
8 |
+
|
9 |
+
import collections
|
10 |
+
import re
|
11 |
+
from functools import partial
|
12 |
+
|
13 |
+
import math
|
14 |
+
import torch
|
15 |
+
from torch import nn
|
16 |
+
from torch.nn import functional as F
|
17 |
+
|
18 |
+
# Parameters for the entire model (stem, all blocks, and head)
|
19 |
+
GlobalParams = collections.namedtuple(
|
20 |
+
"GlobalParams",
|
21 |
+
[
|
22 |
+
"width_coefficient",
|
23 |
+
"depth_coefficient",
|
24 |
+
"image_size",
|
25 |
+
"dropout_rate",
|
26 |
+
"num_classes",
|
27 |
+
"batch_norm_momentum",
|
28 |
+
"batch_norm_epsilon",
|
29 |
+
"drop_connect_rate",
|
30 |
+
"depth_divisor",
|
31 |
+
"min_depth",
|
32 |
+
"include_top",
|
33 |
+
],
|
34 |
+
)
|
35 |
+
|
36 |
+
# Parameters for an individual model block
|
37 |
+
BlockArgs = collections.namedtuple(
|
38 |
+
"BlockArgs",
|
39 |
+
[
|
40 |
+
"num_repeat",
|
41 |
+
"kernel_size",
|
42 |
+
"stride",
|
43 |
+
"expand_ratio",
|
44 |
+
"input_filters",
|
45 |
+
"output_filters",
|
46 |
+
"se_ratio",
|
47 |
+
"id_skip",
|
48 |
+
],
|
49 |
+
)
|
50 |
+
|
51 |
+
# Set GlobalParams and BlockArgs's defaults
|
52 |
+
GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields)
|
53 |
+
BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields)
|
54 |
+
|
55 |
+
|
56 |
+
# An ordinary implementation of Swish function
|
57 |
+
class Swish(nn.Module):
|
58 |
+
def forward(self, x):
|
59 |
+
return x * torch.sigmoid(x)
|
60 |
+
|
61 |
+
|
62 |
+
# A memory-efficient implementation of Swish function
|
63 |
+
class SwishImplementation(torch.autograd.Function):
|
64 |
+
@staticmethod
|
65 |
+
def forward(ctx, i):
|
66 |
+
result = i * torch.sigmoid(i)
|
67 |
+
ctx.save_for_backward(i)
|
68 |
+
return result
|
69 |
+
|
70 |
+
@staticmethod
|
71 |
+
def backward(ctx, grad_output):
|
72 |
+
i = ctx.saved_tensors[0]
|
73 |
+
sigmoid_i = torch.sigmoid(i)
|
74 |
+
return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
|
75 |
+
|
76 |
+
|
77 |
+
class MemoryEfficientSwish(nn.Module):
|
78 |
+
def forward(self, x):
|
79 |
+
return SwishImplementation.apply(x)
|
80 |
+
|
81 |
+
|
82 |
+
def round_filters(filters, global_params):
|
83 |
+
"""Calculate and round number of filters based on width multiplier.
|
84 |
+
Use width_coefficient, depth_divisor and min_depth of global_params.
|
85 |
+
|
86 |
+
Args:
|
87 |
+
filters (int): Filters number to be calculated.
|
88 |
+
global_params (namedtuple): Global params of the model.
|
89 |
+
|
90 |
+
Returns:
|
91 |
+
new_filters: New filters number after calculating.
|
92 |
+
"""
|
93 |
+
multiplier = global_params.width_coefficient
|
94 |
+
if not multiplier:
|
95 |
+
return filters
|
96 |
+
divisor = global_params.depth_divisor
|
97 |
+
min_depth = global_params.min_depth
|
98 |
+
filters *= multiplier
|
99 |
+
min_depth = min_depth or divisor # pay attention to this line when using min_depth
|
100 |
+
# follow the formula transferred from official TensorFlow implementation
|
101 |
+
new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor)
|
102 |
+
if new_filters < 0.9 * filters: # prevent rounding by more than 10%
|
103 |
+
new_filters += divisor
|
104 |
+
return int(new_filters)
|
105 |
+
|
106 |
+
|
107 |
+
def round_repeats(repeats, global_params):
|
108 |
+
"""Calculate module's repeat number of a block based on depth multiplier.
|
109 |
+
Use depth_coefficient of global_params.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
repeats (int): num_repeat to be calculated.
|
113 |
+
global_params (namedtuple): Global params of the model.
|
114 |
+
|
115 |
+
Returns:
|
116 |
+
new repeat: New repeat number after calculating.
|
117 |
+
"""
|
118 |
+
multiplier = global_params.depth_coefficient
|
119 |
+
if not multiplier:
|
120 |
+
return repeats
|
121 |
+
# follow the formula transferred from official TensorFlow implementation
|
122 |
+
return int(math.ceil(multiplier * repeats))
|
123 |
+
|
124 |
+
|
125 |
+
def drop_connect(inputs, p, training):
|
126 |
+
"""Drop connect.
|
127 |
+
|
128 |
+
Args:
|
129 |
+
input (tensor: BCWH): Input of this structure.
|
130 |
+
p (float: 0.0~1.0): Probability of drop connection.
|
131 |
+
training (bool): The running mode.
|
132 |
+
|
133 |
+
Returns:
|
134 |
+
output: Output after drop connection.
|
135 |
+
"""
|
136 |
+
assert 0 <= p <= 1, "p must be in range of [0,1]"
|
137 |
+
|
138 |
+
if not training:
|
139 |
+
return inputs
|
140 |
+
|
141 |
+
batch_size = inputs.shape[0]
|
142 |
+
keep_prob = 1 - p
|
143 |
+
|
144 |
+
# generate binary_tensor mask according to probability (p for 0, 1-p for 1)
|
145 |
+
random_tensor = keep_prob
|
146 |
+
random_tensor += torch.rand(
|
147 |
+
[batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device
|
148 |
+
)
|
149 |
+
binary_tensor = torch.floor(random_tensor)
|
150 |
+
|
151 |
+
output = inputs / keep_prob * binary_tensor
|
152 |
+
return output
|
153 |
+
|
154 |
+
|
155 |
+
def get_width_and_height_from_size(x):
|
156 |
+
"""Obtain height and width from x.
|
157 |
+
|
158 |
+
Args:
|
159 |
+
x (int, tuple or list): Data size.
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
size: A tuple or list (H,W).
|
163 |
+
"""
|
164 |
+
if isinstance(x, int):
|
165 |
+
return x, x
|
166 |
+
if isinstance(x, list) or isinstance(x, tuple):
|
167 |
+
return x
|
168 |
+
else:
|
169 |
+
raise TypeError()
|
170 |
+
|
171 |
+
|
172 |
+
def calculate_output_image_size(input_image_size, stride):
|
173 |
+
"""Calculates the output image size when using Conv2dSamePadding with a stride.
|
174 |
+
Necessary for static padding. Thanks to mannatsingh for pointing this out.
|
175 |
+
|
176 |
+
Args:
|
177 |
+
input_image_size (int, tuple or list): Size of input image.
|
178 |
+
stride (int, tuple or list): Conv2d operation's stride.
|
179 |
+
|
180 |
+
Returns:
|
181 |
+
output_image_size: A list [H,W].
|
182 |
+
"""
|
183 |
+
if input_image_size is None:
|
184 |
+
return None
|
185 |
+
image_height, image_width = get_width_and_height_from_size(input_image_size)
|
186 |
+
stride = stride if isinstance(stride, int) else stride[0]
|
187 |
+
image_height = int(math.ceil(image_height / stride))
|
188 |
+
image_width = int(math.ceil(image_width / stride))
|
189 |
+
return [image_height, image_width]
|
190 |
+
|
191 |
+
|
192 |
+
# Note:
|
193 |
+
# The following 'SamePadding' functions make output size equal ceil(input size/stride).
|
194 |
+
# Only when stride equals 1, can the output size be the same as input size.
|
195 |
+
# Don't be confused by their function names ! ! !
|
196 |
+
|
197 |
+
|
198 |
+
def get_same_padding_conv2d(image_size=None):
|
199 |
+
"""Chooses static padding if you have specified an image size, and dynamic padding otherwise.
|
200 |
+
Static padding is necessary for ONNX exporting of models.
|
201 |
+
|
202 |
+
Args:
|
203 |
+
image_size (int or tuple): Size of the image.
|
204 |
+
|
205 |
+
Returns:
|
206 |
+
Conv2dDynamicSamePadding or Conv2dStaticSamePadding.
|
207 |
+
"""
|
208 |
+
if image_size is None:
|
209 |
+
return Conv2dDynamicSamePadding
|
210 |
+
else:
|
211 |
+
return partial(Conv2dStaticSamePadding, image_size=image_size)
|
212 |
+
|
213 |
+
|
214 |
+
class Conv2dDynamicSamePadding(nn.Conv2d):
|
215 |
+
"""2D Convolutions like TensorFlow, for a dynamic image size.
|
216 |
+
The padding is operated in forward function by calculating dynamically.
|
217 |
+
"""
|
218 |
+
|
219 |
+
# Tips for 'SAME' mode padding.
|
220 |
+
# Given the following:
|
221 |
+
# i: width or height
|
222 |
+
# s: stride
|
223 |
+
# k: kernel size
|
224 |
+
# d: dilation
|
225 |
+
# p: padding
|
226 |
+
# Output after Conv2d:
|
227 |
+
# o = floor((i+p-((k-1)*d+1))/s+1)
|
228 |
+
# If o equals i, i = floor((i+p-((k-1)*d+1))/s+1),
|
229 |
+
# => p = (i-1)*s+((k-1)*d+1)-i
|
230 |
+
|
231 |
+
def __init__(
|
232 |
+
self,
|
233 |
+
in_channels,
|
234 |
+
out_channels,
|
235 |
+
kernel_size,
|
236 |
+
stride=1,
|
237 |
+
dilation=1,
|
238 |
+
groups=1,
|
239 |
+
bias=True,
|
240 |
+
):
|
241 |
+
super().__init__(
|
242 |
+
in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias
|
243 |
+
)
|
244 |
+
self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
|
245 |
+
|
246 |
+
def forward(self, x):
|
247 |
+
ih, iw = x.size()[-2:]
|
248 |
+
kh, kw = self.weight.size()[-2:]
|
249 |
+
sh, sw = self.stride
|
250 |
+
oh, ow = math.ceil(ih / sh), math.ceil(
|
251 |
+
iw / sw
|
252 |
+
) # change the output size according to stride ! ! !
|
253 |
+
pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
|
254 |
+
pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
|
255 |
+
if pad_h > 0 or pad_w > 0:
|
256 |
+
x = F.pad(
|
257 |
+
x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]
|
258 |
+
)
|
259 |
+
return F.conv2d(
|
260 |
+
x,
|
261 |
+
self.weight,
|
262 |
+
self.bias,
|
263 |
+
self.stride,
|
264 |
+
self.padding,
|
265 |
+
self.dilation,
|
266 |
+
self.groups,
|
267 |
+
)
|
268 |
+
|
269 |
+
|
270 |
+
class Conv2dStaticSamePadding(nn.Conv2d):
|
271 |
+
"""2D Convolutions like TensorFlow's 'SAME' mode, with the given input image size.
|
272 |
+
The padding mudule is calculated in construction function, then used in forward.
|
273 |
+
"""
|
274 |
+
|
275 |
+
# With the same calculation as Conv2dDynamicSamePadding
|
276 |
+
|
277 |
+
def __init__(
|
278 |
+
self,
|
279 |
+
in_channels,
|
280 |
+
out_channels,
|
281 |
+
kernel_size,
|
282 |
+
stride=1,
|
283 |
+
image_size=None,
|
284 |
+
**kwargs
|
285 |
+
):
|
286 |
+
super().__init__(in_channels, out_channels, kernel_size, stride, **kwargs)
|
287 |
+
self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
|
288 |
+
|
289 |
+
# Calculate padding based on image size and save it
|
290 |
+
assert image_size is not None
|
291 |
+
ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size
|
292 |
+
kh, kw = self.weight.size()[-2:]
|
293 |
+
sh, sw = self.stride
|
294 |
+
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
|
295 |
+
pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
|
296 |
+
pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
|
297 |
+
if pad_h > 0 or pad_w > 0:
|
298 |
+
self.static_padding = nn.ZeroPad2d(
|
299 |
+
(pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
|
300 |
+
)
|
301 |
+
else:
|
302 |
+
self.static_padding = nn.Identity()
|
303 |
+
|
304 |
+
def forward(self, x):
|
305 |
+
x = self.static_padding(x)
|
306 |
+
x = F.conv2d(
|
307 |
+
x,
|
308 |
+
self.weight,
|
309 |
+
self.bias,
|
310 |
+
self.stride,
|
311 |
+
self.padding,
|
312 |
+
self.dilation,
|
313 |
+
self.groups,
|
314 |
+
)
|
315 |
+
return x
|
316 |
+
|
317 |
+
|
318 |
+
def get_same_padding_maxPool2d(image_size=None):
|
319 |
+
"""Chooses static padding if you have specified an image size, and dynamic padding otherwise.
|
320 |
+
Static padding is necessary for ONNX exporting of models.
|
321 |
+
|
322 |
+
Args:
|
323 |
+
image_size (int or tuple): Size of the image.
|
324 |
+
|
325 |
+
Returns:
|
326 |
+
MaxPool2dDynamicSamePadding or MaxPool2dStaticSamePadding.
|
327 |
+
"""
|
328 |
+
if image_size is None:
|
329 |
+
return MaxPool2dDynamicSamePadding
|
330 |
+
else:
|
331 |
+
return partial(MaxPool2dStaticSamePadding, image_size=image_size)
|
332 |
+
|
333 |
+
|
334 |
+
class MaxPool2dDynamicSamePadding(nn.MaxPool2d):
|
335 |
+
"""2D MaxPooling like TensorFlow's 'SAME' mode, with a dynamic image size.
|
336 |
+
The padding is operated in forward function by calculating dynamically.
|
337 |
+
"""
|
338 |
+
|
339 |
+
def __init__(
|
340 |
+
self,
|
341 |
+
kernel_size,
|
342 |
+
stride,
|
343 |
+
padding=0,
|
344 |
+
dilation=1,
|
345 |
+
return_indices=False,
|
346 |
+
ceil_mode=False,
|
347 |
+
):
|
348 |
+
super().__init__(
|
349 |
+
kernel_size, stride, padding, dilation, return_indices, ceil_mode
|
350 |
+
)
|
351 |
+
self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride
|
352 |
+
self.kernel_size = (
|
353 |
+
[self.kernel_size] * 2
|
354 |
+
if isinstance(self.kernel_size, int)
|
355 |
+
else self.kernel_size
|
356 |
+
)
|
357 |
+
self.dilation = (
|
358 |
+
[self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation
|
359 |
+
)
|
360 |
+
|
361 |
+
def forward(self, x):
|
362 |
+
ih, iw = x.size()[-2:]
|
363 |
+
kh, kw = self.kernel_size
|
364 |
+
sh, sw = self.stride
|
365 |
+
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
|
366 |
+
pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
|
367 |
+
pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
|
368 |
+
if pad_h > 0 or pad_w > 0:
|
369 |
+
x = F.pad(
|
370 |
+
x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]
|
371 |
+
)
|
372 |
+
return F.max_pool2d(
|
373 |
+
x,
|
374 |
+
self.kernel_size,
|
375 |
+
self.stride,
|
376 |
+
self.padding,
|
377 |
+
self.dilation,
|
378 |
+
self.ceil_mode,
|
379 |
+
self.return_indices,
|
380 |
+
)
|
381 |
+
|
382 |
+
|
383 |
+
class MaxPool2dStaticSamePadding(nn.MaxPool2d):
|
384 |
+
"""2D MaxPooling like TensorFlow's 'SAME' mode, with the given input image size.
|
385 |
+
The padding mudule is calculated in construction function, then used in forward.
|
386 |
+
"""
|
387 |
+
|
388 |
+
def __init__(self, kernel_size, stride, image_size=None, **kwargs):
|
389 |
+
super().__init__(kernel_size, stride, **kwargs)
|
390 |
+
self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride
|
391 |
+
self.kernel_size = (
|
392 |
+
[self.kernel_size] * 2
|
393 |
+
if isinstance(self.kernel_size, int)
|
394 |
+
else self.kernel_size
|
395 |
+
)
|
396 |
+
self.dilation = (
|
397 |
+
[self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation
|
398 |
+
)
|
399 |
+
|
400 |
+
# Calculate padding based on image size and save it
|
401 |
+
assert image_size is not None
|
402 |
+
ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size
|
403 |
+
kh, kw = self.kernel_size
|
404 |
+
sh, sw = self.stride
|
405 |
+
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
|
406 |
+
pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
|
407 |
+
pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
|
408 |
+
if pad_h > 0 or pad_w > 0:
|
409 |
+
self.static_padding = nn.ZeroPad2d(
|
410 |
+
(pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
|
411 |
+
)
|
412 |
+
else:
|
413 |
+
self.static_padding = nn.Identity()
|
414 |
+
|
415 |
+
def forward(self, x):
|
416 |
+
x = self.static_padding(x)
|
417 |
+
x = F.max_pool2d(
|
418 |
+
x,
|
419 |
+
self.kernel_size,
|
420 |
+
self.stride,
|
421 |
+
self.padding,
|
422 |
+
self.dilation,
|
423 |
+
self.ceil_mode,
|
424 |
+
self.return_indices,
|
425 |
+
)
|
426 |
+
return x
|
427 |
+
|
428 |
+
|
429 |
+
class BlockDecoder(object):
|
430 |
+
"""Block Decoder for readability,
|
431 |
+
straight from the official TensorFlow repository.
|
432 |
+
"""
|
433 |
+
|
434 |
+
@staticmethod
|
435 |
+
def _decode_block_string(block_string):
|
436 |
+
"""Get a block through a string notation of arguments.
|
437 |
+
|
438 |
+
Args:
|
439 |
+
block_string (str): A string notation of arguments.
|
440 |
+
Examples: 'r1_k3_s11_e1_i32_o16_se0.25_noskip'.
|
441 |
+
|
442 |
+
Returns:
|
443 |
+
BlockArgs: The namedtuple defined at the top of this file.
|
444 |
+
"""
|
445 |
+
assert isinstance(block_string, str)
|
446 |
+
|
447 |
+
ops = block_string.split("_")
|
448 |
+
options = {}
|
449 |
+
for op in ops:
|
450 |
+
splits = re.split(r"(\d.*)", op)
|
451 |
+
if len(splits) >= 2:
|
452 |
+
key, value = splits[:2]
|
453 |
+
options[key] = value
|
454 |
+
|
455 |
+
# Check stride
|
456 |
+
assert ("s" in options and len(options["s"]) == 1) or (
|
457 |
+
len(options["s"]) == 2 and options["s"][0] == options["s"][1]
|
458 |
+
)
|
459 |
+
|
460 |
+
return BlockArgs(
|
461 |
+
num_repeat=int(options["r"]),
|
462 |
+
kernel_size=int(options["k"]),
|
463 |
+
stride=[int(options["s"][0])],
|
464 |
+
expand_ratio=int(options["e"]),
|
465 |
+
input_filters=int(options["i"]),
|
466 |
+
output_filters=int(options["o"]),
|
467 |
+
se_ratio=float(options["se"]) if "se" in options else None,
|
468 |
+
id_skip=("noskip" not in block_string),
|
469 |
+
)
|
470 |
+
|
471 |
+
@staticmethod
|
472 |
+
def _encode_block_string(block):
|
473 |
+
"""Encode a block to a string.
|
474 |
+
|
475 |
+
Args:
|
476 |
+
block (namedtuple): A BlockArgs type argument.
|
477 |
+
|
478 |
+
Returns:
|
479 |
+
block_string: A String form of BlockArgs.
|
480 |
+
"""
|
481 |
+
args = [
|
482 |
+
"r%d" % block.num_repeat,
|
483 |
+
"k%d" % block.kernel_size,
|
484 |
+
"s%d%d" % (block.strides[0], block.strides[1]),
|
485 |
+
"e%s" % block.expand_ratio,
|
486 |
+
"i%d" % block.input_filters,
|
487 |
+
"o%d" % block.output_filters,
|
488 |
+
]
|
489 |
+
if 0 < block.se_ratio <= 1:
|
490 |
+
args.append("se%s" % block.se_ratio)
|
491 |
+
if block.id_skip is False:
|
492 |
+
args.append("noskip")
|
493 |
+
return "_".join(args)
|
494 |
+
|
495 |
+
@staticmethod
|
496 |
+
def decode(string_list):
|
497 |
+
"""Decode a list of string notations to specify blocks inside the network.
|
498 |
+
|
499 |
+
Args:
|
500 |
+
string_list (list[str]): A list of strings, each string is a notation of block.
|
501 |
+
|
502 |
+
Returns:
|
503 |
+
blocks_args: A list of BlockArgs namedtuples of block args.
|
504 |
+
"""
|
505 |
+
assert isinstance(string_list, list)
|
506 |
+
blocks_args = []
|
507 |
+
for block_string in string_list:
|
508 |
+
blocks_args.append(BlockDecoder._decode_block_string(block_string))
|
509 |
+
return blocks_args
|
510 |
+
|
511 |
+
@staticmethod
|
512 |
+
def encode(blocks_args):
|
513 |
+
"""Encode a list of BlockArgs to a list of strings.
|
514 |
+
|
515 |
+
Args:
|
516 |
+
blocks_args (list[namedtuples]): A list of BlockArgs namedtuples of block args.
|
517 |
+
|
518 |
+
Returns:
|
519 |
+
block_strings: A list of strings, each string is a notation of block.
|
520 |
+
"""
|
521 |
+
block_strings = []
|
522 |
+
for block in blocks_args:
|
523 |
+
block_strings.append(BlockDecoder._encode_block_string(block))
|
524 |
+
return block_strings
|
525 |
+
|
526 |
+
|
527 |
+
def create_block_args(
|
528 |
+
width_coefficient=None,
|
529 |
+
depth_coefficient=None,
|
530 |
+
image_size=None,
|
531 |
+
dropout_rate=0.2,
|
532 |
+
drop_connect_rate=0.2,
|
533 |
+
num_classes=1000,
|
534 |
+
include_top=True,
|
535 |
+
):
|
536 |
+
"""Create BlockArgs and GlobalParams for efficientnet model.
|
537 |
+
|
538 |
+
Args:
|
539 |
+
width_coefficient (float)
|
540 |
+
depth_coefficient (float)
|
541 |
+
image_size (int)
|
542 |
+
dropout_rate (float)
|
543 |
+
drop_connect_rate (float)
|
544 |
+
num_classes (int)
|
545 |
+
|
546 |
+
Meaning as the name suggests.
|
547 |
+
|
548 |
+
Returns:
|
549 |
+
blocks_args, global_params.
|
550 |
+
"""
|
551 |
+
|
552 |
+
# Blocks args for the whole model(efficientnet-b0 by default)
|
553 |
+
# It will be modified in the construction of EfficientNet Class according to model
|
554 |
+
blocks_args = [
|
555 |
+
"r1_k3_s11_e1_i32_o16_se0.25",
|
556 |
+
"r2_k3_s22_e6_i16_o24_se0.25",
|
557 |
+
"r2_k5_s22_e6_i24_o40_se0.25",
|
558 |
+
"r3_k3_s22_e6_i40_o80_se0.25",
|
559 |
+
"r3_k5_s11_e6_i80_o112_se0.25",
|
560 |
+
"r4_k5_s22_e6_i112_o192_se0.25",
|
561 |
+
"r1_k3_s11_e6_i192_o320_se0.25",
|
562 |
+
]
|
563 |
+
blocks_args = BlockDecoder.decode(blocks_args)
|
564 |
+
|
565 |
+
global_params = GlobalParams(
|
566 |
+
width_coefficient=width_coefficient,
|
567 |
+
depth_coefficient=depth_coefficient,
|
568 |
+
image_size=image_size,
|
569 |
+
dropout_rate=dropout_rate,
|
570 |
+
num_classes=num_classes,
|
571 |
+
batch_norm_momentum=0.99,
|
572 |
+
batch_norm_epsilon=1e-3,
|
573 |
+
drop_connect_rate=drop_connect_rate,
|
574 |
+
depth_divisor=8,
|
575 |
+
min_depth=None,
|
576 |
+
include_top=include_top,
|
577 |
+
)
|
578 |
+
|
579 |
+
return blocks_args, global_params
|
carvekit/ml/arch/tracerb7/efficientnet.py
ADDED
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Source url: https://github.com/lukemelas/EfficientNet-PyTorch
|
3 |
+
Modified by Min Seok Lee, Wooseok Shin, Nikita Selin
|
4 |
+
License: Apache License 2.0
|
5 |
+
Changes:
|
6 |
+
- Added support for extracting edge features
|
7 |
+
- Added support for extracting object features at different levels
|
8 |
+
- Refactored the code
|
9 |
+
"""
|
10 |
+
from typing import Any, List
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from torch import nn
|
14 |
+
from torch.nn import functional as F
|
15 |
+
|
16 |
+
from carvekit.ml.arch.tracerb7.effi_utils import (
|
17 |
+
get_same_padding_conv2d,
|
18 |
+
calculate_output_image_size,
|
19 |
+
MemoryEfficientSwish,
|
20 |
+
drop_connect,
|
21 |
+
round_filters,
|
22 |
+
round_repeats,
|
23 |
+
Swish,
|
24 |
+
create_block_args,
|
25 |
+
)
|
26 |
+
|
27 |
+
|
28 |
+
class MBConvBlock(nn.Module):
|
29 |
+
"""Mobile Inverted Residual Bottleneck Block.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
block_args (namedtuple): BlockArgs, defined in utils.py.
|
33 |
+
global_params (namedtuple): GlobalParam, defined in utils.py.
|
34 |
+
image_size (tuple or list): [image_height, image_width].
|
35 |
+
|
36 |
+
References:
|
37 |
+
[1] https://arxiv.org/abs/1704.04861 (MobileNet v1)
|
38 |
+
[2] https://arxiv.org/abs/1801.04381 (MobileNet v2)
|
39 |
+
[3] https://arxiv.org/abs/1905.02244 (MobileNet v3)
|
40 |
+
"""
|
41 |
+
|
42 |
+
def __init__(self, block_args, global_params, image_size=None):
|
43 |
+
super().__init__()
|
44 |
+
self._block_args = block_args
|
45 |
+
self._bn_mom = (
|
46 |
+
1 - global_params.batch_norm_momentum
|
47 |
+
) # pytorch's difference from tensorflow
|
48 |
+
self._bn_eps = global_params.batch_norm_epsilon
|
49 |
+
self.has_se = (self._block_args.se_ratio is not None) and (
|
50 |
+
0 < self._block_args.se_ratio <= 1
|
51 |
+
)
|
52 |
+
self.id_skip = (
|
53 |
+
block_args.id_skip
|
54 |
+
) # whether to use skip connection and drop connect
|
55 |
+
|
56 |
+
# Expansion phase (Inverted Bottleneck)
|
57 |
+
inp = self._block_args.input_filters # number of input channels
|
58 |
+
oup = (
|
59 |
+
self._block_args.input_filters * self._block_args.expand_ratio
|
60 |
+
) # number of output channels
|
61 |
+
if self._block_args.expand_ratio != 1:
|
62 |
+
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
63 |
+
self._expand_conv = Conv2d(
|
64 |
+
in_channels=inp, out_channels=oup, kernel_size=1, bias=False
|
65 |
+
)
|
66 |
+
self._bn0 = nn.BatchNorm2d(
|
67 |
+
num_features=oup, momentum=self._bn_mom, eps=self._bn_eps
|
68 |
+
)
|
69 |
+
# image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size
|
70 |
+
|
71 |
+
# Depthwise convolution phase
|
72 |
+
k = self._block_args.kernel_size
|
73 |
+
s = self._block_args.stride
|
74 |
+
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
75 |
+
self._depthwise_conv = Conv2d(
|
76 |
+
in_channels=oup,
|
77 |
+
out_channels=oup,
|
78 |
+
groups=oup, # groups makes it depthwise
|
79 |
+
kernel_size=k,
|
80 |
+
stride=s,
|
81 |
+
bias=False,
|
82 |
+
)
|
83 |
+
self._bn1 = nn.BatchNorm2d(
|
84 |
+
num_features=oup, momentum=self._bn_mom, eps=self._bn_eps
|
85 |
+
)
|
86 |
+
image_size = calculate_output_image_size(image_size, s)
|
87 |
+
|
88 |
+
# Squeeze and Excitation layer, if desired
|
89 |
+
if self.has_se:
|
90 |
+
Conv2d = get_same_padding_conv2d(image_size=(1, 1))
|
91 |
+
num_squeezed_channels = max(
|
92 |
+
1, int(self._block_args.input_filters * self._block_args.se_ratio)
|
93 |
+
)
|
94 |
+
self._se_reduce = Conv2d(
|
95 |
+
in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1
|
96 |
+
)
|
97 |
+
self._se_expand = Conv2d(
|
98 |
+
in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1
|
99 |
+
)
|
100 |
+
|
101 |
+
# Pointwise convolution phase
|
102 |
+
final_oup = self._block_args.output_filters
|
103 |
+
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
104 |
+
self._project_conv = Conv2d(
|
105 |
+
in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False
|
106 |
+
)
|
107 |
+
self._bn2 = nn.BatchNorm2d(
|
108 |
+
num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps
|
109 |
+
)
|
110 |
+
self._swish = MemoryEfficientSwish()
|
111 |
+
|
112 |
+
def forward(self, inputs, drop_connect_rate=None):
|
113 |
+
"""MBConvBlock's forward function.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
inputs (tensor): Input tensor.
|
117 |
+
drop_connect_rate (bool): Drop connect rate (float, between 0 and 1).
|
118 |
+
|
119 |
+
Returns:
|
120 |
+
Output of this block after processing.
|
121 |
+
"""
|
122 |
+
|
123 |
+
# Expansion and Depthwise Convolution
|
124 |
+
x = inputs
|
125 |
+
if self._block_args.expand_ratio != 1:
|
126 |
+
x = self._expand_conv(inputs)
|
127 |
+
x = self._bn0(x)
|
128 |
+
x = self._swish(x)
|
129 |
+
|
130 |
+
x = self._depthwise_conv(x)
|
131 |
+
x = self._bn1(x)
|
132 |
+
x = self._swish(x)
|
133 |
+
|
134 |
+
# Squeeze and Excitation
|
135 |
+
if self.has_se:
|
136 |
+
x_squeezed = F.adaptive_avg_pool2d(x, 1)
|
137 |
+
x_squeezed = self._se_reduce(x_squeezed)
|
138 |
+
x_squeezed = self._swish(x_squeezed)
|
139 |
+
x_squeezed = self._se_expand(x_squeezed)
|
140 |
+
x = torch.sigmoid(x_squeezed) * x
|
141 |
+
|
142 |
+
# Pointwise Convolution
|
143 |
+
x = self._project_conv(x)
|
144 |
+
x = self._bn2(x)
|
145 |
+
|
146 |
+
# Skip connection and drop connect
|
147 |
+
input_filters, output_filters = (
|
148 |
+
self._block_args.input_filters,
|
149 |
+
self._block_args.output_filters,
|
150 |
+
)
|
151 |
+
if (
|
152 |
+
self.id_skip
|
153 |
+
and self._block_args.stride == 1
|
154 |
+
and input_filters == output_filters
|
155 |
+
):
|
156 |
+
# The combination of skip connection and drop connect brings about stochastic depth.
|
157 |
+
if drop_connect_rate:
|
158 |
+
x = drop_connect(x, p=drop_connect_rate, training=self.training)
|
159 |
+
x = x + inputs # skip connection
|
160 |
+
return x
|
161 |
+
|
162 |
+
def set_swish(self, memory_efficient=True):
|
163 |
+
"""Sets swish function as memory efficient (for training) or standard (for export).
|
164 |
+
|
165 |
+
Args:
|
166 |
+
memory_efficient (bool): Whether to use memory-efficient version of swish.
|
167 |
+
"""
|
168 |
+
self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
|
169 |
+
|
170 |
+
|
171 |
+
class EfficientNet(nn.Module):
|
172 |
+
def __init__(self, blocks_args=None, global_params=None):
|
173 |
+
super().__init__()
|
174 |
+
assert isinstance(blocks_args, list), "blocks_args should be a list"
|
175 |
+
assert len(blocks_args) > 0, "block args must be greater than 0"
|
176 |
+
self._global_params = global_params
|
177 |
+
self._blocks_args = blocks_args
|
178 |
+
|
179 |
+
# Batch norm parameters
|
180 |
+
bn_mom = 1 - self._global_params.batch_norm_momentum
|
181 |
+
bn_eps = self._global_params.batch_norm_epsilon
|
182 |
+
|
183 |
+
# Get stem static or dynamic convolution depending on image size
|
184 |
+
image_size = global_params.image_size
|
185 |
+
Conv2d = get_same_padding_conv2d(image_size=image_size)
|
186 |
+
|
187 |
+
# Stem
|
188 |
+
in_channels = 3 # rgb
|
189 |
+
out_channels = round_filters(
|
190 |
+
32, self._global_params
|
191 |
+
) # number of output channels
|
192 |
+
self._conv_stem = Conv2d(
|
193 |
+
in_channels, out_channels, kernel_size=3, stride=2, bias=False
|
194 |
+
)
|
195 |
+
self._bn0 = nn.BatchNorm2d(
|
196 |
+
num_features=out_channels, momentum=bn_mom, eps=bn_eps
|
197 |
+
)
|
198 |
+
image_size = calculate_output_image_size(image_size, 2)
|
199 |
+
|
200 |
+
# Build blocks
|
201 |
+
self._blocks = nn.ModuleList([])
|
202 |
+
for block_args in self._blocks_args:
|
203 |
+
|
204 |
+
# Update block input and output filters based on depth multiplier.
|
205 |
+
block_args = block_args._replace(
|
206 |
+
input_filters=round_filters(
|
207 |
+
block_args.input_filters, self._global_params
|
208 |
+
),
|
209 |
+
output_filters=round_filters(
|
210 |
+
block_args.output_filters, self._global_params
|
211 |
+
),
|
212 |
+
num_repeat=round_repeats(block_args.num_repeat, self._global_params),
|
213 |
+
)
|
214 |
+
|
215 |
+
# The first block needs to take care of stride and filter size increase.
|
216 |
+
self._blocks.append(
|
217 |
+
MBConvBlock(block_args, self._global_params, image_size=image_size)
|
218 |
+
)
|
219 |
+
image_size = calculate_output_image_size(image_size, block_args.stride)
|
220 |
+
if block_args.num_repeat > 1: # modify block_args to keep same output size
|
221 |
+
block_args = block_args._replace(
|
222 |
+
input_filters=block_args.output_filters, stride=1
|
223 |
+
)
|
224 |
+
for _ in range(block_args.num_repeat - 1):
|
225 |
+
self._blocks.append(
|
226 |
+
MBConvBlock(block_args, self._global_params, image_size=image_size)
|
227 |
+
)
|
228 |
+
# image_size = calculate_output_image_size(image_size, block_args.stride) # stride = 1
|
229 |
+
|
230 |
+
self._swish = MemoryEfficientSwish()
|
231 |
+
|
232 |
+
def set_swish(self, memory_efficient=True):
|
233 |
+
"""Sets swish function as memory efficient (for training) or standard (for export).
|
234 |
+
|
235 |
+
Args:
|
236 |
+
memory_efficient (bool): Whether to use memory-efficient version of swish.
|
237 |
+
|
238 |
+
"""
|
239 |
+
self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
|
240 |
+
for block in self._blocks:
|
241 |
+
block.set_swish(memory_efficient)
|
242 |
+
|
243 |
+
def extract_endpoints(self, inputs):
|
244 |
+
endpoints = dict()
|
245 |
+
|
246 |
+
# Stem
|
247 |
+
x = self._swish(self._bn0(self._conv_stem(inputs)))
|
248 |
+
prev_x = x
|
249 |
+
|
250 |
+
# Blocks
|
251 |
+
for idx, block in enumerate(self._blocks):
|
252 |
+
drop_connect_rate = self._global_params.drop_connect_rate
|
253 |
+
if drop_connect_rate:
|
254 |
+
drop_connect_rate *= float(idx) / len(
|
255 |
+
self._blocks
|
256 |
+
) # scale drop connect_rate
|
257 |
+
x = block(x, drop_connect_rate=drop_connect_rate)
|
258 |
+
if prev_x.size(2) > x.size(2):
|
259 |
+
endpoints["reduction_{}".format(len(endpoints) + 1)] = prev_x
|
260 |
+
prev_x = x
|
261 |
+
|
262 |
+
# Head
|
263 |
+
x = self._swish(self._bn1(self._conv_head(x)))
|
264 |
+
endpoints["reduction_{}".format(len(endpoints) + 1)] = x
|
265 |
+
|
266 |
+
return endpoints
|
267 |
+
|
268 |
+
def _change_in_channels(self, in_channels):
|
269 |
+
"""Adjust model's first convolution layer to in_channels, if in_channels not equals 3.
|
270 |
+
|
271 |
+
Args:
|
272 |
+
in_channels (int): Input data's channel number.
|
273 |
+
"""
|
274 |
+
if in_channels != 3:
|
275 |
+
Conv2d = get_same_padding_conv2d(image_size=self._global_params.image_size)
|
276 |
+
out_channels = round_filters(32, self._global_params)
|
277 |
+
self._conv_stem = Conv2d(
|
278 |
+
in_channels, out_channels, kernel_size=3, stride=2, bias=False
|
279 |
+
)
|
280 |
+
|
281 |
+
|
282 |
+
class EfficientEncoderB7(EfficientNet):
|
283 |
+
def __init__(self):
|
284 |
+
super().__init__(
|
285 |
+
*create_block_args(
|
286 |
+
width_coefficient=2.0,
|
287 |
+
depth_coefficient=3.1,
|
288 |
+
dropout_rate=0.5,
|
289 |
+
image_size=600,
|
290 |
+
)
|
291 |
+
)
|
292 |
+
self._change_in_channels(3)
|
293 |
+
self.block_idx = [10, 17, 37, 54]
|
294 |
+
self.channels = [48, 80, 224, 640]
|
295 |
+
|
296 |
+
def initial_conv(self, inputs):
|
297 |
+
x = self._swish(self._bn0(self._conv_stem(inputs)))
|
298 |
+
return x
|
299 |
+
|
300 |
+
def get_blocks(self, x, H, W, block_idx):
|
301 |
+
features = []
|
302 |
+
for idx, block in enumerate(self._blocks):
|
303 |
+
drop_connect_rate = self._global_params.drop_connect_rate
|
304 |
+
if drop_connect_rate:
|
305 |
+
drop_connect_rate *= float(idx) / len(
|
306 |
+
self._blocks
|
307 |
+
) # scale drop connect_rate
|
308 |
+
x = block(x, drop_connect_rate=drop_connect_rate)
|
309 |
+
if idx == block_idx[0]:
|
310 |
+
features.append(x.clone())
|
311 |
+
if idx == block_idx[1]:
|
312 |
+
features.append(x.clone())
|
313 |
+
if idx == block_idx[2]:
|
314 |
+
features.append(x.clone())
|
315 |
+
if idx == block_idx[3]:
|
316 |
+
features.append(x.clone())
|
317 |
+
|
318 |
+
return features
|
319 |
+
|
320 |
+
def forward(self, inputs: torch.Tensor) -> List[Any]:
|
321 |
+
B, C, H, W = inputs.size()
|
322 |
+
x = self.initial_conv(inputs) # Prepare input for the backbone
|
323 |
+
return self.get_blocks(
|
324 |
+
x, H, W, block_idx=self.block_idx
|
325 |
+
) # Get backbone features and edge maps
|
carvekit/ml/arch/tracerb7/tracer.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Source url: https://github.com/Karel911/TRACER
|
3 |
+
Author: Min Seok Lee and Wooseok Shin
|
4 |
+
Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
|
5 |
+
License: Apache License 2.0
|
6 |
+
Changes:
|
7 |
+
- Refactored code
|
8 |
+
- Removed unused code
|
9 |
+
- Added comments
|
10 |
+
"""
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.nn.functional as F
|
15 |
+
from typing import List, Optional, Tuple
|
16 |
+
|
17 |
+
from torch import Tensor
|
18 |
+
|
19 |
+
from carvekit.ml.arch.tracerb7.efficientnet import EfficientEncoderB7
|
20 |
+
from carvekit.ml.arch.tracerb7.att_modules import (
|
21 |
+
RFB_Block,
|
22 |
+
aggregation,
|
23 |
+
ObjectAttention,
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
class TracerDecoder(nn.Module):
|
28 |
+
"""Tracer Decoder"""
|
29 |
+
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
encoder: EfficientEncoderB7,
|
33 |
+
features_channels: Optional[List[int]] = None,
|
34 |
+
rfb_channel: Optional[List[int]] = None,
|
35 |
+
):
|
36 |
+
"""
|
37 |
+
Initialize the tracer decoder.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
encoder: The encoder to use.
|
41 |
+
features_channels: The channels of the backbone features at different stages. default: [48, 80, 224, 640]
|
42 |
+
rfb_channel: The channels of the RFB features. default: [32, 64, 128]
|
43 |
+
"""
|
44 |
+
super().__init__()
|
45 |
+
if rfb_channel is None:
|
46 |
+
rfb_channel = [32, 64, 128]
|
47 |
+
if features_channels is None:
|
48 |
+
features_channels = [48, 80, 224, 640]
|
49 |
+
self.encoder = encoder
|
50 |
+
self.features_channels = features_channels
|
51 |
+
|
52 |
+
# Receptive Field Blocks
|
53 |
+
features_channels = rfb_channel
|
54 |
+
self.rfb2 = RFB_Block(self.features_channels[1], features_channels[0])
|
55 |
+
self.rfb3 = RFB_Block(self.features_channels[2], features_channels[1])
|
56 |
+
self.rfb4 = RFB_Block(self.features_channels[3], features_channels[2])
|
57 |
+
|
58 |
+
# Multi-level aggregation
|
59 |
+
self.agg = aggregation(features_channels)
|
60 |
+
|
61 |
+
# Object Attention
|
62 |
+
self.ObjectAttention2 = ObjectAttention(
|
63 |
+
channel=self.features_channels[1], kernel_size=3
|
64 |
+
)
|
65 |
+
self.ObjectAttention1 = ObjectAttention(
|
66 |
+
channel=self.features_channels[0], kernel_size=3
|
67 |
+
)
|
68 |
+
|
69 |
+
def forward(self, inputs: torch.Tensor) -> Tensor:
|
70 |
+
"""
|
71 |
+
Forward pass of the tracer decoder.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
inputs: Preprocessed images.
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
Tensors of segmentation masks and mask of object edges.
|
78 |
+
"""
|
79 |
+
features = self.encoder(inputs)
|
80 |
+
x3_rfb = self.rfb2(features[1])
|
81 |
+
x4_rfb = self.rfb3(features[2])
|
82 |
+
x5_rfb = self.rfb4(features[3])
|
83 |
+
|
84 |
+
D_0 = self.agg(x5_rfb, x4_rfb, x3_rfb)
|
85 |
+
|
86 |
+
ds_map0 = F.interpolate(D_0, scale_factor=8, mode="bilinear")
|
87 |
+
|
88 |
+
D_1 = self.ObjectAttention2(D_0, features[1])
|
89 |
+
ds_map1 = F.interpolate(D_1, scale_factor=8, mode="bilinear")
|
90 |
+
|
91 |
+
ds_map = F.interpolate(D_1, scale_factor=2, mode="bilinear")
|
92 |
+
D_2 = self.ObjectAttention1(ds_map, features[0])
|
93 |
+
ds_map2 = F.interpolate(D_2, scale_factor=4, mode="bilinear")
|
94 |
+
|
95 |
+
final_map = (ds_map2 + ds_map1 + ds_map0) / 3
|
96 |
+
|
97 |
+
return torch.sigmoid(final_map)
|
carvekit/ml/arch/u2net/__init__.py
ADDED
File without changes
|
carvekit/ml/arch/u2net/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (184 Bytes). View file
|
|
carvekit/ml/arch/u2net/__pycache__/u2net.cpython-38.pyc
ADDED
Binary file (6.13 kB). View file
|
|
carvekit/ml/arch/u2net/u2net.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
|
3 |
+
Source url: https://github.com/xuebinqin/U-2-Net
|
4 |
+
License: Apache License 2.0
|
5 |
+
"""
|
6 |
+
from typing import Union
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
|
11 |
+
import math
|
12 |
+
|
13 |
+
__all__ = ["U2NETArchitecture"]
|
14 |
+
|
15 |
+
|
16 |
+
def _upsample_like(x, size):
|
17 |
+
return nn.Upsample(size=size, mode="bilinear", align_corners=False)(x)
|
18 |
+
|
19 |
+
|
20 |
+
def _size_map(x, height):
|
21 |
+
# {height: size} for Upsample
|
22 |
+
size = list(x.shape[-2:])
|
23 |
+
sizes = {}
|
24 |
+
for h in range(1, height):
|
25 |
+
sizes[h] = size
|
26 |
+
size = [math.ceil(w / 2) for w in size]
|
27 |
+
return sizes
|
28 |
+
|
29 |
+
|
30 |
+
class REBNCONV(nn.Module):
|
31 |
+
def __init__(self, in_ch=3, out_ch=3, dilate=1):
|
32 |
+
super(REBNCONV, self).__init__()
|
33 |
+
|
34 |
+
self.conv_s1 = nn.Conv2d(
|
35 |
+
in_ch, out_ch, 3, padding=1 * dilate, dilation=1 * dilate
|
36 |
+
)
|
37 |
+
self.bn_s1 = nn.BatchNorm2d(out_ch)
|
38 |
+
self.relu_s1 = nn.ReLU(inplace=True)
|
39 |
+
|
40 |
+
def forward(self, x):
|
41 |
+
return self.relu_s1(self.bn_s1(self.conv_s1(x)))
|
42 |
+
|
43 |
+
|
44 |
+
class RSU(nn.Module):
|
45 |
+
def __init__(self, name, height, in_ch, mid_ch, out_ch, dilated=False):
|
46 |
+
super(RSU, self).__init__()
|
47 |
+
self.name = name
|
48 |
+
self.height = height
|
49 |
+
self.dilated = dilated
|
50 |
+
self._make_layers(height, in_ch, mid_ch, out_ch, dilated)
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
sizes = _size_map(x, self.height)
|
54 |
+
x = self.rebnconvin(x)
|
55 |
+
|
56 |
+
# U-Net like symmetric encoder-decoder structure
|
57 |
+
def unet(x, height=1):
|
58 |
+
if height < self.height:
|
59 |
+
x1 = getattr(self, f"rebnconv{height}")(x)
|
60 |
+
if not self.dilated and height < self.height - 1:
|
61 |
+
x2 = unet(getattr(self, "downsample")(x1), height + 1)
|
62 |
+
else:
|
63 |
+
x2 = unet(x1, height + 1)
|
64 |
+
|
65 |
+
x = getattr(self, f"rebnconv{height}d")(torch.cat((x2, x1), 1))
|
66 |
+
return (
|
67 |
+
_upsample_like(x, sizes[height - 1])
|
68 |
+
if not self.dilated and height > 1
|
69 |
+
else x
|
70 |
+
)
|
71 |
+
else:
|
72 |
+
return getattr(self, f"rebnconv{height}")(x)
|
73 |
+
|
74 |
+
return x + unet(x)
|
75 |
+
|
76 |
+
def _make_layers(self, height, in_ch, mid_ch, out_ch, dilated=False):
|
77 |
+
self.add_module("rebnconvin", REBNCONV(in_ch, out_ch))
|
78 |
+
self.add_module("downsample", nn.MaxPool2d(2, stride=2, ceil_mode=True))
|
79 |
+
|
80 |
+
self.add_module("rebnconv1", REBNCONV(out_ch, mid_ch))
|
81 |
+
self.add_module("rebnconv1d", REBNCONV(mid_ch * 2, out_ch))
|
82 |
+
|
83 |
+
for i in range(2, height):
|
84 |
+
dilate = 1 if not dilated else 2 ** (i - 1)
|
85 |
+
self.add_module(f"rebnconv{i}", REBNCONV(mid_ch, mid_ch, dilate=dilate))
|
86 |
+
self.add_module(
|
87 |
+
f"rebnconv{i}d", REBNCONV(mid_ch * 2, mid_ch, dilate=dilate)
|
88 |
+
)
|
89 |
+
|
90 |
+
dilate = 2 if not dilated else 2 ** (height - 1)
|
91 |
+
self.add_module(f"rebnconv{height}", REBNCONV(mid_ch, mid_ch, dilate=dilate))
|
92 |
+
|
93 |
+
|
94 |
+
class U2NETArchitecture(nn.Module):
|
95 |
+
def __init__(self, cfg_type: Union[dict, str] = "full", out_ch: int = 1):
|
96 |
+
super(U2NETArchitecture, self).__init__()
|
97 |
+
if isinstance(cfg_type, str):
|
98 |
+
if cfg_type == "full":
|
99 |
+
layers_cfgs = {
|
100 |
+
# cfgs for building RSUs and sides
|
101 |
+
# {stage : [name, (height(L), in_ch, mid_ch, out_ch, dilated), side]}
|
102 |
+
"stage1": ["En_1", (7, 3, 32, 64), -1],
|
103 |
+
"stage2": ["En_2", (6, 64, 32, 128), -1],
|
104 |
+
"stage3": ["En_3", (5, 128, 64, 256), -1],
|
105 |
+
"stage4": ["En_4", (4, 256, 128, 512), -1],
|
106 |
+
"stage5": ["En_5", (4, 512, 256, 512, True), -1],
|
107 |
+
"stage6": ["En_6", (4, 512, 256, 512, True), 512],
|
108 |
+
"stage5d": ["De_5", (4, 1024, 256, 512, True), 512],
|
109 |
+
"stage4d": ["De_4", (4, 1024, 128, 256), 256],
|
110 |
+
"stage3d": ["De_3", (5, 512, 64, 128), 128],
|
111 |
+
"stage2d": ["De_2", (6, 256, 32, 64), 64],
|
112 |
+
"stage1d": ["De_1", (7, 128, 16, 64), 64],
|
113 |
+
}
|
114 |
+
else:
|
115 |
+
raise ValueError("Unknown U^2-Net architecture conf. name")
|
116 |
+
elif isinstance(cfg_type, dict):
|
117 |
+
layers_cfgs = cfg_type
|
118 |
+
else:
|
119 |
+
raise ValueError("Unknown U^2-Net architecture conf. type")
|
120 |
+
self.out_ch = out_ch
|
121 |
+
self._make_layers(layers_cfgs)
|
122 |
+
|
123 |
+
def forward(self, x):
|
124 |
+
sizes = _size_map(x, self.height)
|
125 |
+
maps = [] # storage for maps
|
126 |
+
|
127 |
+
# side saliency map
|
128 |
+
def unet(x, height=1):
|
129 |
+
if height < 6:
|
130 |
+
x1 = getattr(self, f"stage{height}")(x)
|
131 |
+
x2 = unet(getattr(self, "downsample")(x1), height + 1)
|
132 |
+
x = getattr(self, f"stage{height}d")(torch.cat((x2, x1), 1))
|
133 |
+
side(x, height)
|
134 |
+
return _upsample_like(x, sizes[height - 1]) if height > 1 else x
|
135 |
+
else:
|
136 |
+
x = getattr(self, f"stage{height}")(x)
|
137 |
+
side(x, height)
|
138 |
+
return _upsample_like(x, sizes[height - 1])
|
139 |
+
|
140 |
+
def side(x, h):
|
141 |
+
# side output saliency map (before sigmoid)
|
142 |
+
x = getattr(self, f"side{h}")(x)
|
143 |
+
x = _upsample_like(x, sizes[1])
|
144 |
+
maps.append(x)
|
145 |
+
|
146 |
+
def fuse():
|
147 |
+
# fuse saliency probability maps
|
148 |
+
maps.reverse()
|
149 |
+
x = torch.cat(maps, 1)
|
150 |
+
x = getattr(self, "outconv")(x)
|
151 |
+
maps.insert(0, x)
|
152 |
+
return [torch.sigmoid(x) for x in maps]
|
153 |
+
|
154 |
+
unet(x)
|
155 |
+
maps = fuse()
|
156 |
+
return maps
|
157 |
+
|
158 |
+
def _make_layers(self, cfgs):
|
159 |
+
self.height = int((len(cfgs) + 1) / 2)
|
160 |
+
self.add_module("downsample", nn.MaxPool2d(2, stride=2, ceil_mode=True))
|
161 |
+
for k, v in cfgs.items():
|
162 |
+
# build rsu block
|
163 |
+
self.add_module(k, RSU(v[0], *v[1]))
|
164 |
+
if v[2] > 0:
|
165 |
+
# build side layer
|
166 |
+
self.add_module(
|
167 |
+
f"side{v[0][-1]}", nn.Conv2d(v[2], self.out_ch, 3, padding=1)
|
168 |
+
)
|
169 |
+
# build fuse layer
|
170 |
+
self.add_module(
|
171 |
+
"outconv", nn.Conv2d(int(self.height * self.out_ch), self.out_ch, 1)
|
172 |
+
)
|
carvekit/ml/files/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
carvekit_dir = Path.home().joinpath(".cache/carvekit")
|
4 |
+
|
5 |
+
carvekit_dir.mkdir(parents=True, exist_ok=True)
|
6 |
+
|
7 |
+
checkpoints_dir = carvekit_dir.joinpath("checkpoints")
|
carvekit/ml/files/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (365 Bytes). View file
|
|