Spaces:
Runtime error
Runtime error
add: files.
Browse files- 1.png +0 -0
- 111.png +0 -0
- 748.png +0 -0
- a4541-DSC_0040-2.png +0 -0
- app.py +35 -74
- create_maxim_model.py +37 -0
- maxim/__init__.py +0 -0
- maxim/blocks/__init__.py +0 -0
- maxim/blocks/attentions.py +143 -0
- maxim/blocks/block_gating.py +67 -0
- maxim/blocks/bottleneck.py +54 -0
- maxim/blocks/grid_gating.py +68 -0
- maxim/blocks/misc_gating.py +213 -0
- maxim/blocks/others.py +56 -0
- maxim/blocks/unet.py +133 -0
- maxim/configs.py +80 -0
- maxim/layers.py +101 -0
- maxim/maxim.py +320 -0
1.png
ADDED
111.png
ADDED
748.png
ADDED
a4541-DSC_0040-2.png
ADDED
app.py
CHANGED
@@ -1,13 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from huggingface_hub.keras_mixin import from_pretrained_keras
|
2 |
-
|
3 |
from PIL import Image
|
4 |
|
5 |
-
import numpy as np
|
6 |
-
|
7 |
from create_maxim_model import Model
|
8 |
from maxim.configs import MAXIM_CONFIGS
|
9 |
|
10 |
-
|
11 |
_MODEL = from_pretrained_keras("sayakpaul/S-2_enhancement_lol")
|
12 |
|
13 |
|
@@ -23,63 +26,6 @@ def mod_padding_symmetric(image, factor=64):
|
|
23 |
image, [(padh // 2, padh // 2), (padw // 2, padw // 2), (0, 0)], mode="REFLECT"
|
24 |
)
|
25 |
return image
|
26 |
-
|
27 |
-
def _convert_input_type_range(img):
|
28 |
-
"""Convert the type and range of the input image.
|
29 |
-
|
30 |
-
It converts the input image to np.float32 type and range of [0, 1].
|
31 |
-
It is mainly used for pre-processing the input image in colorspace
|
32 |
-
convertion functions such as rgb2ycbcr and ycbcr2rgb.
|
33 |
-
Args:
|
34 |
-
img (ndarray): The input image. It accepts:
|
35 |
-
1. np.uint8 type with range [0, 255];
|
36 |
-
2. np.float32 type with range [0, 1].
|
37 |
-
Returns:
|
38 |
-
(ndarray): The converted image with type of np.float32 and range of
|
39 |
-
[0, 1].
|
40 |
-
"""
|
41 |
-
img_type = img.dtype
|
42 |
-
img = img.astype(np.float32)
|
43 |
-
if img_type == np.float32:
|
44 |
-
pass
|
45 |
-
elif img_type == np.uint8:
|
46 |
-
img /= 255.0
|
47 |
-
else:
|
48 |
-
raise TypeError(
|
49 |
-
"The img type should be np.float32 or np.uint8, " f"but got {img_type}"
|
50 |
-
)
|
51 |
-
return img
|
52 |
-
|
53 |
-
|
54 |
-
def _convert_output_type_range(img, dst_type):
|
55 |
-
"""Convert the type and range of the image according to dst_type.
|
56 |
-
|
57 |
-
It converts the image to desired type and range. If `dst_type` is np.uint8,
|
58 |
-
images will be converted to np.uint8 type with range [0, 255]. If
|
59 |
-
`dst_type` is np.float32, it converts the image to np.float32 type with
|
60 |
-
range [0, 1].
|
61 |
-
It is mainly used for post-processing images in colorspace convertion
|
62 |
-
functions such as rgb2ycbcr and ycbcr2rgb.
|
63 |
-
Args:
|
64 |
-
img (ndarray): The image to be converted with np.float32 type and
|
65 |
-
range [0, 255].
|
66 |
-
dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
|
67 |
-
converts the image to np.uint8 type with range [0, 255]. If
|
68 |
-
dst_type is np.float32, it converts the image to np.float32 type
|
69 |
-
with range [0, 1].
|
70 |
-
Returns:
|
71 |
-
(ndarray): The converted image with desired type and range.
|
72 |
-
"""
|
73 |
-
if dst_type not in (np.uint8, np.float32):
|
74 |
-
raise TypeError(
|
75 |
-
"The dst_type should be np.float32 or np.uint8, " f"but got {dst_type}"
|
76 |
-
)
|
77 |
-
if dst_type == np.uint8:
|
78 |
-
img = img.round()
|
79 |
-
else:
|
80 |
-
img /= 255.0
|
81 |
-
|
82 |
-
return img.astype(dst_type)
|
83 |
|
84 |
|
85 |
def make_shape_even(image):
|
@@ -89,8 +35,8 @@ def make_shape_even(image):
|
|
89 |
padw = 1 if width % 2 != 0 else 0
|
90 |
image = tf.pad(image, [(0, padh), (0, padw), (0, 0)], mode="REFLECT")
|
91 |
return image
|
92 |
-
|
93 |
-
|
94 |
def process_image(image: Image):
|
95 |
input_img = np.asarray(image) / 255.0
|
96 |
height, width = input_img.shape[0], input_img.shape[1]
|
@@ -102,9 +48,9 @@ def process_image(image: Image):
|
|
102 |
# padding images to be multiplies of 64
|
103 |
input_img = mod_padding_symmetric(input_img, factor=64)
|
104 |
input_img = tf.expand_dims(input_img, axis=0)
|
105 |
-
return input_img, height_even, width_even
|
106 |
-
|
107 |
-
|
108 |
def init_new_model(input_img):
|
109 |
configs = MAXIM_CONFIGS.get("S-2")
|
110 |
configs.update(
|
@@ -120,25 +66,40 @@ def init_new_model(input_img):
|
|
120 |
new_model = Model(**configs)
|
121 |
new_model.set_weights(_MODEL.get_weights())
|
122 |
return new_model
|
123 |
-
|
124 |
-
|
125 |
def infer(image):
|
126 |
-
preprocessed_image, height_even, width_even = process_image(image)
|
127 |
new_model = init_new_model(preprocessed_image)
|
128 |
-
|
129 |
preds = new_model.predict(preprocessed_image)
|
130 |
if isinstance(preds, list):
|
131 |
preds = preds[-1]
|
132 |
if isinstance(preds, list):
|
133 |
preds = preds[-1]
|
134 |
-
|
135 |
preds = np.array(preds[0], np.float32)
|
136 |
-
|
137 |
new_height, new_width = preds.shape[0], preds.shape[1]
|
138 |
h_start = new_height // 2 - height_even // 2
|
139 |
h_end = h_start + height
|
140 |
w_start = new_width // 2 - width_even // 2
|
141 |
w_end = w_start + width
|
142 |
preds = preds[h_start:h_end, w_start:w_end, :]
|
143 |
-
|
144 |
-
return Image.fromarray(np.array((np.clip(preds, 0.0, 1.0) * 255.0).astype(np.uint8)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Some preprocessing utilities have been taken from:
|
3 |
+
https://github.com/google-research/maxim/blob/main/maxim/run_eval.py
|
4 |
+
"""
|
5 |
+
import gradio as gr
|
6 |
+
import numpy as np
|
7 |
+
import tensorflow as tf
|
8 |
from huggingface_hub.keras_mixin import from_pretrained_keras
|
|
|
9 |
from PIL import Image
|
10 |
|
|
|
|
|
11 |
from create_maxim_model import Model
|
12 |
from maxim.configs import MAXIM_CONFIGS
|
13 |
|
|
|
14 |
_MODEL = from_pretrained_keras("sayakpaul/S-2_enhancement_lol")
|
15 |
|
16 |
|
|
|
26 |
image, [(padh // 2, padh // 2), (padw // 2, padw // 2), (0, 0)], mode="REFLECT"
|
27 |
)
|
28 |
return image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
|
31 |
def make_shape_even(image):
|
|
|
35 |
padw = 1 if width % 2 != 0 else 0
|
36 |
image = tf.pad(image, [(0, padh), (0, padw), (0, 0)], mode="REFLECT")
|
37 |
return image
|
38 |
+
|
39 |
+
|
40 |
def process_image(image: Image):
|
41 |
input_img = np.asarray(image) / 255.0
|
42 |
height, width = input_img.shape[0], input_img.shape[1]
|
|
|
48 |
# padding images to be multiplies of 64
|
49 |
input_img = mod_padding_symmetric(input_img, factor=64)
|
50 |
input_img = tf.expand_dims(input_img, axis=0)
|
51 |
+
return input_img, height, width, height_even, width_even
|
52 |
+
|
53 |
+
|
54 |
def init_new_model(input_img):
|
55 |
configs = MAXIM_CONFIGS.get("S-2")
|
56 |
configs.update(
|
|
|
66 |
new_model = Model(**configs)
|
67 |
new_model.set_weights(_MODEL.get_weights())
|
68 |
return new_model
|
69 |
+
|
70 |
+
|
71 |
def infer(image):
|
72 |
+
preprocessed_image, height, width, height_even, width_even = process_image(image)
|
73 |
new_model = init_new_model(preprocessed_image)
|
74 |
+
|
75 |
preds = new_model.predict(preprocessed_image)
|
76 |
if isinstance(preds, list):
|
77 |
preds = preds[-1]
|
78 |
if isinstance(preds, list):
|
79 |
preds = preds[-1]
|
80 |
+
|
81 |
preds = np.array(preds[0], np.float32)
|
82 |
+
|
83 |
new_height, new_width = preds.shape[0], preds.shape[1]
|
84 |
h_start = new_height // 2 - height_even // 2
|
85 |
h_end = h_start + height
|
86 |
w_start = new_width // 2 - width_even // 2
|
87 |
w_end = w_start + width
|
88 |
preds = preds[h_start:h_end, w_start:w_end, :]
|
89 |
+
|
90 |
+
return Image.fromarray(np.array((np.clip(preds, 0.0, 1.0) * 255.0).astype(np.uint8)))
|
91 |
+
|
92 |
+
|
93 |
+
title = "Enhance low-light images."
|
94 |
+
article = "Model based on [this](https://huggingface.co/sayakpaul/S-2_enhancement_lol)."
|
95 |
+
|
96 |
+
iface = gr.Interface(
|
97 |
+
infer,
|
98 |
+
inputs="image",
|
99 |
+
outputs="image",
|
100 |
+
title=title,
|
101 |
+
article=article,
|
102 |
+
allow_flagging="never",
|
103 |
+
examples=[["1.png"], ["111.png"], ["748.png"], ["a4541-DSC_0040-2.png"]],
|
104 |
+
)
|
105 |
+
iface.launch(debug=True)
|
create_maxim_model.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tensorflow import keras
|
2 |
+
|
3 |
+
from maxim import maxim
|
4 |
+
from maxim.configs import MAXIM_CONFIGS
|
5 |
+
|
6 |
+
|
7 |
+
def Model(variant=None, input_resolution=(256, 256), **kw) -> keras.Model:
|
8 |
+
"""Factory function to easily create a Model variant like "S".
|
9 |
+
|
10 |
+
Args:
|
11 |
+
variant: UNet model variants. Options: 'S-1' | 'S-2' | 'S-3'
|
12 |
+
| 'M-1' | 'M-2' | 'M-3'
|
13 |
+
input_resolution: Size of the input images.
|
14 |
+
**kw: Other UNet config dicts.
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
The MAXIM model.
|
18 |
+
"""
|
19 |
+
|
20 |
+
if variant is not None:
|
21 |
+
config = MAXIM_CONFIGS[variant]
|
22 |
+
for k, v in config.items():
|
23 |
+
kw.setdefault(k, v)
|
24 |
+
|
25 |
+
if "variant" in kw:
|
26 |
+
_ = kw.pop("variant")
|
27 |
+
if "input_resolution" in kw:
|
28 |
+
_ = kw.pop("input_resolution")
|
29 |
+
model_name = kw.pop("name")
|
30 |
+
|
31 |
+
maxim_model = maxim.MAXIM(**kw)
|
32 |
+
|
33 |
+
inputs = keras.Input((*input_resolution, 3))
|
34 |
+
outputs = maxim_model(inputs)
|
35 |
+
final_model = keras.Model(inputs, outputs, name=f"{model_name}_model")
|
36 |
+
|
37 |
+
return final_model
|
maxim/__init__.py
ADDED
File without changes
|
maxim/blocks/__init__.py
ADDED
File without changes
|
maxim/blocks/attentions.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
|
3 |
+
import tensorflow as tf
|
4 |
+
from tensorflow.keras import layers
|
5 |
+
|
6 |
+
from .others import MlpBlock
|
7 |
+
|
8 |
+
Conv3x3 = functools.partial(layers.Conv2D, kernel_size=(3, 3), padding="same")
|
9 |
+
Conv1x1 = functools.partial(layers.Conv2D, kernel_size=(1, 1), padding="same")
|
10 |
+
|
11 |
+
|
12 |
+
def CALayer(
|
13 |
+
num_channels: int,
|
14 |
+
reduction: int = 4,
|
15 |
+
use_bias: bool = True,
|
16 |
+
name: str = "channel_attention",
|
17 |
+
):
|
18 |
+
"""Squeeze-and-excitation block for channel attention.
|
19 |
+
|
20 |
+
ref: https://arxiv.org/abs/1709.01507
|
21 |
+
"""
|
22 |
+
|
23 |
+
def apply(x):
|
24 |
+
# 2D global average pooling
|
25 |
+
y = layers.GlobalAvgPool2D(keepdims=True)(x)
|
26 |
+
# Squeeze (in Squeeze-Excitation)
|
27 |
+
y = Conv1x1(
|
28 |
+
filters=num_channels // reduction, use_bias=use_bias, name=f"{name}_Conv_0"
|
29 |
+
)(y)
|
30 |
+
y = tf.nn.relu(y)
|
31 |
+
# Excitation (in Squeeze-Excitation)
|
32 |
+
y = Conv1x1(filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_1")(y)
|
33 |
+
y = tf.nn.sigmoid(y)
|
34 |
+
return x * y
|
35 |
+
|
36 |
+
return apply
|
37 |
+
|
38 |
+
|
39 |
+
def RCAB(
|
40 |
+
num_channels: int,
|
41 |
+
reduction: int = 4,
|
42 |
+
lrelu_slope: float = 0.2,
|
43 |
+
use_bias: bool = True,
|
44 |
+
name: str = "residual_ca",
|
45 |
+
):
|
46 |
+
"""Residual channel attention block. Contains LN,Conv,lRelu,Conv,SELayer."""
|
47 |
+
|
48 |
+
def apply(x):
|
49 |
+
shortcut = x
|
50 |
+
x = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm")(x)
|
51 |
+
x = Conv3x3(filters=num_channels, use_bias=use_bias, name=f"{name}_conv1")(x)
|
52 |
+
x = tf.nn.leaky_relu(x, alpha=lrelu_slope)
|
53 |
+
x = Conv3x3(filters=num_channels, use_bias=use_bias, name=f"{name}_conv2")(x)
|
54 |
+
x = CALayer(
|
55 |
+
num_channels=num_channels,
|
56 |
+
reduction=reduction,
|
57 |
+
use_bias=use_bias,
|
58 |
+
name=f"{name}_channel_attention",
|
59 |
+
)(x)
|
60 |
+
return x + shortcut
|
61 |
+
|
62 |
+
return apply
|
63 |
+
|
64 |
+
|
65 |
+
def RDCAB(
|
66 |
+
num_channels: int,
|
67 |
+
reduction: int = 16,
|
68 |
+
use_bias: bool = True,
|
69 |
+
dropout_rate: float = 0.0,
|
70 |
+
name: str = "rdcab",
|
71 |
+
):
|
72 |
+
"""Residual dense channel attention block. Used in Bottlenecks."""
|
73 |
+
|
74 |
+
def apply(x):
|
75 |
+
y = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm")(x)
|
76 |
+
y = MlpBlock(
|
77 |
+
mlp_dim=num_channels,
|
78 |
+
dropout_rate=dropout_rate,
|
79 |
+
use_bias=use_bias,
|
80 |
+
name=f"{name}_channel_mixing",
|
81 |
+
)(y)
|
82 |
+
y = CALayer(
|
83 |
+
num_channels=num_channels,
|
84 |
+
reduction=reduction,
|
85 |
+
use_bias=use_bias,
|
86 |
+
name=f"{name}_channel_attention",
|
87 |
+
)(y)
|
88 |
+
x = x + y
|
89 |
+
return x
|
90 |
+
|
91 |
+
return apply
|
92 |
+
|
93 |
+
|
94 |
+
def SAM(
|
95 |
+
num_channels: int,
|
96 |
+
output_channels: int = 3,
|
97 |
+
use_bias: bool = True,
|
98 |
+
name: str = "sam",
|
99 |
+
):
|
100 |
+
|
101 |
+
"""Supervised attention module for multi-stage training.
|
102 |
+
|
103 |
+
Introduced by MPRNet [CVPR2021]: https://github.com/swz30/MPRNet
|
104 |
+
"""
|
105 |
+
|
106 |
+
def apply(x, x_image):
|
107 |
+
"""Apply the SAM module to the input and num_channels.
|
108 |
+
Args:
|
109 |
+
x: the output num_channels from UNet decoder with shape (h, w, c)
|
110 |
+
x_image: the input image with shape (h, w, 3)
|
111 |
+
Returns:
|
112 |
+
A tuple of tensors (x1, image) where (x1) is the sam num_channels used for the
|
113 |
+
next stage, and (image) is the output restored image at current stage.
|
114 |
+
"""
|
115 |
+
# Get num_channels
|
116 |
+
x1 = Conv3x3(filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_0")(x)
|
117 |
+
|
118 |
+
# Output restored image X_s
|
119 |
+
if output_channels == 3:
|
120 |
+
image = (
|
121 |
+
Conv3x3(
|
122 |
+
filters=output_channels, use_bias=use_bias, name=f"{name}_Conv_1"
|
123 |
+
)(x)
|
124 |
+
+ x_image
|
125 |
+
)
|
126 |
+
else:
|
127 |
+
image = Conv3x3(
|
128 |
+
filters=output_channels, use_bias=use_bias, name=f"{name}_Conv_1"
|
129 |
+
)(x)
|
130 |
+
|
131 |
+
# Get attention maps for num_channels
|
132 |
+
x2 = tf.nn.sigmoid(
|
133 |
+
Conv3x3(filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_2")(image)
|
134 |
+
)
|
135 |
+
|
136 |
+
# Get attended feature maps
|
137 |
+
x1 = x1 * x2
|
138 |
+
|
139 |
+
# Residual connection
|
140 |
+
x1 = x1 + x
|
141 |
+
return x1, image
|
142 |
+
|
143 |
+
return apply
|
maxim/blocks/block_gating.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
from tensorflow.keras import backend as K
|
3 |
+
from tensorflow.keras import layers
|
4 |
+
|
5 |
+
from ..layers import BlockImages, SwapAxes, UnblockImages
|
6 |
+
|
7 |
+
|
8 |
+
def BlockGatingUnit(use_bias: bool = True, name: str = "block_gating_unit"):
|
9 |
+
"""A SpatialGatingUnit as defined in the gMLP paper.
|
10 |
+
|
11 |
+
The 'spatial' dim is defined as the **second last**.
|
12 |
+
If applied on other dims, you should swapaxes first.
|
13 |
+
"""
|
14 |
+
|
15 |
+
def apply(x):
|
16 |
+
u, v = tf.split(x, 2, axis=-1)
|
17 |
+
v = layers.LayerNormalization(
|
18 |
+
epsilon=1e-06, name=f"{name}_intermediate_layernorm"
|
19 |
+
)(v)
|
20 |
+
n = K.int_shape(x)[-2] # get spatial dim
|
21 |
+
v = SwapAxes()(v, -1, -2)
|
22 |
+
v = layers.Dense(n, use_bias=use_bias, name=f"{name}_Dense_0")(v)
|
23 |
+
v = SwapAxes()(v, -1, -2)
|
24 |
+
return u * (v + 1.0)
|
25 |
+
|
26 |
+
return apply
|
27 |
+
|
28 |
+
|
29 |
+
def BlockGmlpLayer(
|
30 |
+
block_size,
|
31 |
+
use_bias: bool = True,
|
32 |
+
factor: int = 2,
|
33 |
+
dropout_rate: float = 0.0,
|
34 |
+
name: str = "block_gmlp",
|
35 |
+
):
|
36 |
+
"""Block gMLP layer that performs local mixing of tokens."""
|
37 |
+
|
38 |
+
def apply(x):
|
39 |
+
n, h, w, num_channels = (
|
40 |
+
K.int_shape(x)[0],
|
41 |
+
K.int_shape(x)[1],
|
42 |
+
K.int_shape(x)[2],
|
43 |
+
K.int_shape(x)[3],
|
44 |
+
)
|
45 |
+
fh, fw = block_size
|
46 |
+
gh, gw = h // fh, w // fw
|
47 |
+
x = BlockImages()(x, patch_size=(fh, fw))
|
48 |
+
# MLP2: Local (block) mixing part, provides within-block communication.
|
49 |
+
y = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm")(x)
|
50 |
+
y = layers.Dense(
|
51 |
+
num_channels * factor,
|
52 |
+
use_bias=use_bias,
|
53 |
+
name=f"{name}_in_project",
|
54 |
+
)(y)
|
55 |
+
y = tf.nn.gelu(y, approximate=True)
|
56 |
+
y = BlockGatingUnit(use_bias=use_bias, name=f"{name}_BlockGatingUnit")(y)
|
57 |
+
y = layers.Dense(
|
58 |
+
num_channels,
|
59 |
+
use_bias=use_bias,
|
60 |
+
name=f"{name}_out_project",
|
61 |
+
)(y)
|
62 |
+
y = layers.Dropout(dropout_rate)(y)
|
63 |
+
x = x + y
|
64 |
+
x = UnblockImages()(x, grid_size=(gh, gw), patch_size=(fh, fw))
|
65 |
+
return x
|
66 |
+
|
67 |
+
return apply
|
maxim/blocks/bottleneck.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
|
3 |
+
from tensorflow.keras import layers
|
4 |
+
|
5 |
+
from .attentions import RDCAB
|
6 |
+
from .misc_gating import ResidualSplitHeadMultiAxisGmlpLayer
|
7 |
+
|
8 |
+
Conv1x1 = functools.partial(layers.Conv2D, kernel_size=(1, 1), padding="same")
|
9 |
+
|
10 |
+
|
11 |
+
def BottleneckBlock(
|
12 |
+
features: int,
|
13 |
+
block_size,
|
14 |
+
grid_size,
|
15 |
+
num_groups: int = 1,
|
16 |
+
block_gmlp_factor: int = 2,
|
17 |
+
grid_gmlp_factor: int = 2,
|
18 |
+
input_proj_factor: int = 2,
|
19 |
+
channels_reduction: int = 4,
|
20 |
+
dropout_rate: float = 0.0,
|
21 |
+
use_bias: bool = True,
|
22 |
+
name: str = "bottleneck_block",
|
23 |
+
):
|
24 |
+
"""The bottleneck block consisting of multi-axis gMLP block and RDCAB."""
|
25 |
+
|
26 |
+
def apply(x):
|
27 |
+
# input projection
|
28 |
+
x = Conv1x1(filters=features, use_bias=use_bias, name=f"{name}_input_proj")(x)
|
29 |
+
shortcut_long = x
|
30 |
+
|
31 |
+
for i in range(num_groups):
|
32 |
+
x = ResidualSplitHeadMultiAxisGmlpLayer(
|
33 |
+
grid_size=grid_size,
|
34 |
+
block_size=block_size,
|
35 |
+
grid_gmlp_factor=grid_gmlp_factor,
|
36 |
+
block_gmlp_factor=block_gmlp_factor,
|
37 |
+
input_proj_factor=input_proj_factor,
|
38 |
+
use_bias=use_bias,
|
39 |
+
dropout_rate=dropout_rate,
|
40 |
+
name=f"{name}_SplitHeadMultiAxisGmlpLayer_{i}",
|
41 |
+
)(x)
|
42 |
+
# Channel-mixing part, which provides within-patch communication.
|
43 |
+
x = RDCAB(
|
44 |
+
num_channels=features,
|
45 |
+
reduction=channels_reduction,
|
46 |
+
use_bias=use_bias,
|
47 |
+
name=f"{name}_channel_attention_block_1_{i}",
|
48 |
+
)(x)
|
49 |
+
|
50 |
+
# long skip-connect
|
51 |
+
x = x + shortcut_long
|
52 |
+
return x
|
53 |
+
|
54 |
+
return apply
|
maxim/blocks/grid_gating.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
from tensorflow.keras import backend as K
|
3 |
+
from tensorflow.keras import layers
|
4 |
+
|
5 |
+
from ..layers import BlockImages, SwapAxes, UnblockImages
|
6 |
+
|
7 |
+
|
8 |
+
def GridGatingUnit(use_bias: bool = True, name: str = "grid_gating_unit"):
|
9 |
+
"""A SpatialGatingUnit as defined in the gMLP paper.
|
10 |
+
|
11 |
+
The 'spatial' dim is defined as the second last.
|
12 |
+
If applied on other dims, you should swapaxes first.
|
13 |
+
"""
|
14 |
+
|
15 |
+
def apply(x):
|
16 |
+
u, v = tf.split(x, 2, axis=-1)
|
17 |
+
v = layers.LayerNormalization(
|
18 |
+
epsilon=1e-06, name=f"{name}_intermediate_layernorm"
|
19 |
+
)(v)
|
20 |
+
n = K.int_shape(x)[-3] # get spatial dim
|
21 |
+
v = SwapAxes()(v, -1, -3)
|
22 |
+
v = layers.Dense(n, use_bias=use_bias, name=f"{name}_Dense_0")(v)
|
23 |
+
v = SwapAxes()(v, -1, -3)
|
24 |
+
return u * (v + 1.0)
|
25 |
+
|
26 |
+
return apply
|
27 |
+
|
28 |
+
|
29 |
+
def GridGmlpLayer(
|
30 |
+
grid_size,
|
31 |
+
use_bias: bool = True,
|
32 |
+
factor: int = 2,
|
33 |
+
dropout_rate: float = 0.0,
|
34 |
+
name: str = "grid_gmlp",
|
35 |
+
):
|
36 |
+
"""Grid gMLP layer that performs global mixing of tokens."""
|
37 |
+
|
38 |
+
def apply(x):
|
39 |
+
n, h, w, num_channels = (
|
40 |
+
K.int_shape(x)[0],
|
41 |
+
K.int_shape(x)[1],
|
42 |
+
K.int_shape(x)[2],
|
43 |
+
K.int_shape(x)[3],
|
44 |
+
)
|
45 |
+
gh, gw = grid_size
|
46 |
+
fh, fw = h // gh, w // gw
|
47 |
+
|
48 |
+
x = BlockImages()(x, patch_size=(fh, fw))
|
49 |
+
# gMLP1: Global (grid) mixing part, provides global grid communication.
|
50 |
+
y = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm")(x)
|
51 |
+
y = layers.Dense(
|
52 |
+
num_channels * factor,
|
53 |
+
use_bias=use_bias,
|
54 |
+
name=f"{name}_in_project",
|
55 |
+
)(y)
|
56 |
+
y = tf.nn.gelu(y, approximate=True)
|
57 |
+
y = GridGatingUnit(use_bias=use_bias, name=f"{name}_GridGatingUnit")(y)
|
58 |
+
y = layers.Dense(
|
59 |
+
num_channels,
|
60 |
+
use_bias=use_bias,
|
61 |
+
name=f"{name}_out_project",
|
62 |
+
)(y)
|
63 |
+
y = layers.Dropout(dropout_rate)(y)
|
64 |
+
x = x + y
|
65 |
+
x = UnblockImages()(x, grid_size=(gh, gw), patch_size=(fh, fw))
|
66 |
+
return x
|
67 |
+
|
68 |
+
return apply
|
maxim/blocks/misc_gating.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
|
3 |
+
import tensorflow as tf
|
4 |
+
from tensorflow.keras import backend as K
|
5 |
+
from tensorflow.keras import layers
|
6 |
+
|
7 |
+
from ..layers import BlockImages, SwapAxes, UnblockImages
|
8 |
+
from .block_gating import BlockGmlpLayer
|
9 |
+
from .grid_gating import GridGmlpLayer
|
10 |
+
|
11 |
+
Conv1x1 = functools.partial(layers.Conv2D, kernel_size=(1, 1), padding="same")
|
12 |
+
Conv3x3 = functools.partial(layers.Conv2D, kernel_size=(3, 3), padding="same")
|
13 |
+
ConvT_up = functools.partial(
|
14 |
+
layers.Conv2DTranspose, kernel_size=(2, 2), strides=(2, 2), padding="same"
|
15 |
+
)
|
16 |
+
Conv_down = functools.partial(
|
17 |
+
layers.Conv2D, kernel_size=(4, 4), strides=(2, 2), padding="same"
|
18 |
+
)
|
19 |
+
|
20 |
+
|
21 |
+
def ResidualSplitHeadMultiAxisGmlpLayer(
|
22 |
+
block_size,
|
23 |
+
grid_size,
|
24 |
+
block_gmlp_factor: int = 2,
|
25 |
+
grid_gmlp_factor: int = 2,
|
26 |
+
input_proj_factor: int = 2,
|
27 |
+
use_bias: bool = True,
|
28 |
+
dropout_rate: float = 0.0,
|
29 |
+
name: str = "residual_split_head_maxim",
|
30 |
+
):
|
31 |
+
"""The multi-axis gated MLP block."""
|
32 |
+
|
33 |
+
def apply(x):
|
34 |
+
shortcut = x
|
35 |
+
n, h, w, num_channels = (
|
36 |
+
K.int_shape(x)[0],
|
37 |
+
K.int_shape(x)[1],
|
38 |
+
K.int_shape(x)[2],
|
39 |
+
K.int_shape(x)[3],
|
40 |
+
)
|
41 |
+
x = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm_in")(x)
|
42 |
+
|
43 |
+
x = layers.Dense(
|
44 |
+
int(num_channels) * input_proj_factor,
|
45 |
+
use_bias=use_bias,
|
46 |
+
name=f"{name}_in_project",
|
47 |
+
)(x)
|
48 |
+
x = tf.nn.gelu(x, approximate=True)
|
49 |
+
|
50 |
+
u, v = tf.split(x, 2, axis=-1)
|
51 |
+
|
52 |
+
# GridGMLPLayer
|
53 |
+
u = GridGmlpLayer(
|
54 |
+
grid_size=grid_size,
|
55 |
+
factor=grid_gmlp_factor,
|
56 |
+
use_bias=use_bias,
|
57 |
+
dropout_rate=dropout_rate,
|
58 |
+
name=f"{name}_GridGmlpLayer",
|
59 |
+
)(u)
|
60 |
+
|
61 |
+
# BlockGMLPLayer
|
62 |
+
v = BlockGmlpLayer(
|
63 |
+
block_size=block_size,
|
64 |
+
factor=block_gmlp_factor,
|
65 |
+
use_bias=use_bias,
|
66 |
+
dropout_rate=dropout_rate,
|
67 |
+
name=f"{name}_BlockGmlpLayer",
|
68 |
+
)(v)
|
69 |
+
|
70 |
+
x = tf.concat([u, v], axis=-1)
|
71 |
+
|
72 |
+
x = layers.Dense(
|
73 |
+
num_channels,
|
74 |
+
use_bias=use_bias,
|
75 |
+
name=f"{name}_out_project",
|
76 |
+
)(x)
|
77 |
+
x = layers.Dropout(dropout_rate)(x)
|
78 |
+
x = x + shortcut
|
79 |
+
return x
|
80 |
+
|
81 |
+
return apply
|
82 |
+
|
83 |
+
|
84 |
+
def GetSpatialGatingWeights(
|
85 |
+
features: int,
|
86 |
+
block_size,
|
87 |
+
grid_size,
|
88 |
+
input_proj_factor: int = 2,
|
89 |
+
dropout_rate: float = 0.0,
|
90 |
+
use_bias: bool = True,
|
91 |
+
name: str = "spatial_gating",
|
92 |
+
):
|
93 |
+
|
94 |
+
"""Get gating weights for cross-gating MLP block."""
|
95 |
+
|
96 |
+
def apply(x):
|
97 |
+
n, h, w, num_channels = (
|
98 |
+
K.int_shape(x)[0],
|
99 |
+
K.int_shape(x)[1],
|
100 |
+
K.int_shape(x)[2],
|
101 |
+
K.int_shape(x)[3],
|
102 |
+
)
|
103 |
+
|
104 |
+
# input projection
|
105 |
+
x = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm_in")(x)
|
106 |
+
x = layers.Dense(
|
107 |
+
num_channels * input_proj_factor,
|
108 |
+
use_bias=use_bias,
|
109 |
+
name=f"{name}_in_project",
|
110 |
+
)(x)
|
111 |
+
x = tf.nn.gelu(x, approximate=True)
|
112 |
+
u, v = tf.split(x, 2, axis=-1)
|
113 |
+
|
114 |
+
# Get grid MLP weights
|
115 |
+
gh, gw = grid_size
|
116 |
+
fh, fw = h // gh, w // gw
|
117 |
+
u = BlockImages()(u, patch_size=(fh, fw))
|
118 |
+
dim_u = K.int_shape(u)[-3]
|
119 |
+
u = SwapAxes()(u, -1, -3)
|
120 |
+
u = layers.Dense(dim_u, use_bias=use_bias, name=f"{name}_Dense_0")(u)
|
121 |
+
u = SwapAxes()(u, -1, -3)
|
122 |
+
u = UnblockImages()(u, grid_size=(gh, gw), patch_size=(fh, fw))
|
123 |
+
|
124 |
+
# Get Block MLP weights
|
125 |
+
fh, fw = block_size
|
126 |
+
gh, gw = h // fh, w // fw
|
127 |
+
v = BlockImages()(v, patch_size=(fh, fw))
|
128 |
+
dim_v = K.int_shape(v)[-2]
|
129 |
+
v = SwapAxes()(v, -1, -2)
|
130 |
+
v = layers.Dense(dim_v, use_bias=use_bias, name=f"{name}_Dense_1")(v)
|
131 |
+
v = SwapAxes()(v, -1, -2)
|
132 |
+
v = UnblockImages()(v, grid_size=(gh, gw), patch_size=(fh, fw))
|
133 |
+
|
134 |
+
x = tf.concat([u, v], axis=-1)
|
135 |
+
x = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_out_project")(x)
|
136 |
+
x = layers.Dropout(dropout_rate)(x)
|
137 |
+
return x
|
138 |
+
|
139 |
+
return apply
|
140 |
+
|
141 |
+
|
142 |
+
def CrossGatingBlock(
|
143 |
+
features: int,
|
144 |
+
block_size,
|
145 |
+
grid_size,
|
146 |
+
dropout_rate: float = 0.0,
|
147 |
+
input_proj_factor: int = 2,
|
148 |
+
upsample_y: bool = True,
|
149 |
+
use_bias: bool = True,
|
150 |
+
name: str = "cross_gating",
|
151 |
+
):
|
152 |
+
|
153 |
+
"""Cross-gating MLP block."""
|
154 |
+
|
155 |
+
def apply(x, y):
|
156 |
+
# Upscale Y signal, y is the gating signal.
|
157 |
+
if upsample_y:
|
158 |
+
y = ConvT_up(
|
159 |
+
filters=features, use_bias=use_bias, name=f"{name}_ConvTranspose_0"
|
160 |
+
)(y)
|
161 |
+
|
162 |
+
x = Conv1x1(filters=features, use_bias=use_bias, name=f"{name}_Conv_0")(x)
|
163 |
+
n, h, w, num_channels = (
|
164 |
+
K.int_shape(x)[0],
|
165 |
+
K.int_shape(x)[1],
|
166 |
+
K.int_shape(x)[2],
|
167 |
+
K.int_shape(x)[3],
|
168 |
+
)
|
169 |
+
|
170 |
+
y = Conv1x1(filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_1")(y)
|
171 |
+
|
172 |
+
shortcut_x = x
|
173 |
+
shortcut_y = y
|
174 |
+
|
175 |
+
# Get gating weights from X
|
176 |
+
x = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm_x")(x)
|
177 |
+
x = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_in_project_x")(x)
|
178 |
+
x = tf.nn.gelu(x, approximate=True)
|
179 |
+
gx = GetSpatialGatingWeights(
|
180 |
+
features=num_channels,
|
181 |
+
block_size=block_size,
|
182 |
+
grid_size=grid_size,
|
183 |
+
dropout_rate=dropout_rate,
|
184 |
+
use_bias=use_bias,
|
185 |
+
name=f"{name}_SplitHeadMultiAxisGating_x",
|
186 |
+
)(x)
|
187 |
+
|
188 |
+
# Get gating weights from Y
|
189 |
+
y = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm_y")(y)
|
190 |
+
y = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_in_project_y")(y)
|
191 |
+
y = tf.nn.gelu(y, approximate=True)
|
192 |
+
gy = GetSpatialGatingWeights(
|
193 |
+
features=num_channels,
|
194 |
+
block_size=block_size,
|
195 |
+
grid_size=grid_size,
|
196 |
+
dropout_rate=dropout_rate,
|
197 |
+
use_bias=use_bias,
|
198 |
+
name=f"{name}_SplitHeadMultiAxisGating_y",
|
199 |
+
)(y)
|
200 |
+
|
201 |
+
# Apply cross gating: X = X * GY, Y = Y * GX
|
202 |
+
y = y * gx
|
203 |
+
y = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_out_project_y")(y)
|
204 |
+
y = layers.Dropout(dropout_rate)(y)
|
205 |
+
y = y + shortcut_y
|
206 |
+
|
207 |
+
x = x * gy # gating x using y
|
208 |
+
x = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_out_project_x")(x)
|
209 |
+
x = layers.Dropout(dropout_rate)(x)
|
210 |
+
x = x + y + shortcut_x # get all aggregated signals
|
211 |
+
return x, y
|
212 |
+
|
213 |
+
return apply
|
maxim/blocks/others.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
|
3 |
+
import tensorflow as tf
|
4 |
+
from tensorflow.keras import backend as K
|
5 |
+
from tensorflow.keras import layers
|
6 |
+
|
7 |
+
from ..layers import Resizing
|
8 |
+
|
9 |
+
Conv1x1 = functools.partial(layers.Conv2D, kernel_size=(1, 1), padding="same")
|
10 |
+
|
11 |
+
|
12 |
+
def MlpBlock(
|
13 |
+
mlp_dim: int,
|
14 |
+
dropout_rate: float = 0.0,
|
15 |
+
use_bias: bool = True,
|
16 |
+
name: str = "mlp_block",
|
17 |
+
):
|
18 |
+
"""A 1-hidden-layer MLP block, applied over the last dimension."""
|
19 |
+
|
20 |
+
def apply(x):
|
21 |
+
d = K.int_shape(x)[-1]
|
22 |
+
x = layers.Dense(mlp_dim, use_bias=use_bias, name=f"{name}_Dense_0")(x)
|
23 |
+
x = tf.nn.gelu(x, approximate=True)
|
24 |
+
x = layers.Dropout(dropout_rate)(x)
|
25 |
+
x = layers.Dense(d, use_bias=use_bias, name=f"{name}_Dense_1")(x)
|
26 |
+
return x
|
27 |
+
|
28 |
+
return apply
|
29 |
+
|
30 |
+
|
31 |
+
def UpSampleRatio(
|
32 |
+
num_channels: int, ratio: float, use_bias: bool = True, name: str = "upsample"
|
33 |
+
):
|
34 |
+
"""Upsample features given a ratio > 0."""
|
35 |
+
|
36 |
+
def apply(x):
|
37 |
+
n, h, w, c = (
|
38 |
+
K.int_shape(x)[0],
|
39 |
+
K.int_shape(x)[1],
|
40 |
+
K.int_shape(x)[2],
|
41 |
+
K.int_shape(x)[3],
|
42 |
+
)
|
43 |
+
|
44 |
+
# Following `jax.image.resize()`
|
45 |
+
x = Resizing(
|
46 |
+
height=int(h * ratio),
|
47 |
+
width=int(w * ratio),
|
48 |
+
method="bilinear",
|
49 |
+
antialias=True,
|
50 |
+
name=f"{name}_resizing_{K.get_uid('Resizing')}",
|
51 |
+
)(x)
|
52 |
+
|
53 |
+
x = Conv1x1(filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_0")(x)
|
54 |
+
return x
|
55 |
+
|
56 |
+
return apply
|
maxim/blocks/unet.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
|
3 |
+
import tensorflow as tf
|
4 |
+
from tensorflow.keras import layers
|
5 |
+
|
6 |
+
from .attentions import RCAB
|
7 |
+
from .misc_gating import CrossGatingBlock, ResidualSplitHeadMultiAxisGmlpLayer
|
8 |
+
|
9 |
+
Conv1x1 = functools.partial(layers.Conv2D, kernel_size=(1, 1), padding="same")
|
10 |
+
Conv3x3 = functools.partial(layers.Conv2D, kernel_size=(3, 3), padding="same")
|
11 |
+
ConvT_up = functools.partial(
|
12 |
+
layers.Conv2DTranspose, kernel_size=(2, 2), strides=(2, 2), padding="same"
|
13 |
+
)
|
14 |
+
Conv_down = functools.partial(
|
15 |
+
layers.Conv2D, kernel_size=(4, 4), strides=(2, 2), padding="same"
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
def UNetEncoderBlock(
|
20 |
+
num_channels: int,
|
21 |
+
block_size,
|
22 |
+
grid_size,
|
23 |
+
num_groups: int = 1,
|
24 |
+
lrelu_slope: float = 0.2,
|
25 |
+
block_gmlp_factor: int = 2,
|
26 |
+
grid_gmlp_factor: int = 2,
|
27 |
+
input_proj_factor: int = 2,
|
28 |
+
channels_reduction: int = 4,
|
29 |
+
dropout_rate: float = 0.0,
|
30 |
+
downsample: bool = True,
|
31 |
+
use_global_mlp: bool = True,
|
32 |
+
use_bias: bool = True,
|
33 |
+
use_cross_gating: bool = False,
|
34 |
+
name: str = "unet_encoder",
|
35 |
+
):
|
36 |
+
"""Encoder block in MAXIM."""
|
37 |
+
|
38 |
+
def apply(x, skip=None, enc=None, dec=None):
|
39 |
+
if skip is not None:
|
40 |
+
x = tf.concat([x, skip], axis=-1)
|
41 |
+
|
42 |
+
# convolution-in
|
43 |
+
x = Conv1x1(filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_0")(x)
|
44 |
+
shortcut_long = x
|
45 |
+
|
46 |
+
for i in range(num_groups):
|
47 |
+
if use_global_mlp:
|
48 |
+
x = ResidualSplitHeadMultiAxisGmlpLayer(
|
49 |
+
grid_size=grid_size,
|
50 |
+
block_size=block_size,
|
51 |
+
grid_gmlp_factor=grid_gmlp_factor,
|
52 |
+
block_gmlp_factor=block_gmlp_factor,
|
53 |
+
input_proj_factor=input_proj_factor,
|
54 |
+
use_bias=use_bias,
|
55 |
+
dropout_rate=dropout_rate,
|
56 |
+
name=f"{name}_SplitHeadMultiAxisGmlpLayer_{i}",
|
57 |
+
)(x)
|
58 |
+
x = RCAB(
|
59 |
+
num_channels=num_channels,
|
60 |
+
reduction=channels_reduction,
|
61 |
+
lrelu_slope=lrelu_slope,
|
62 |
+
use_bias=use_bias,
|
63 |
+
name=f"{name}_channel_attention_block_1{i}",
|
64 |
+
)(x)
|
65 |
+
|
66 |
+
x = x + shortcut_long
|
67 |
+
|
68 |
+
if enc is not None and dec is not None:
|
69 |
+
assert use_cross_gating
|
70 |
+
x, _ = CrossGatingBlock(
|
71 |
+
features=num_channels,
|
72 |
+
block_size=block_size,
|
73 |
+
grid_size=grid_size,
|
74 |
+
dropout_rate=dropout_rate,
|
75 |
+
input_proj_factor=input_proj_factor,
|
76 |
+
upsample_y=False,
|
77 |
+
use_bias=use_bias,
|
78 |
+
name=f"{name}_cross_gating_block",
|
79 |
+
)(x, enc + dec)
|
80 |
+
|
81 |
+
if downsample:
|
82 |
+
x_down = Conv_down(
|
83 |
+
filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_1"
|
84 |
+
)(x)
|
85 |
+
return x_down, x
|
86 |
+
else:
|
87 |
+
return x
|
88 |
+
|
89 |
+
return apply
|
90 |
+
|
91 |
+
|
92 |
+
def UNetDecoderBlock(
|
93 |
+
num_channels: int,
|
94 |
+
block_size,
|
95 |
+
grid_size,
|
96 |
+
num_groups: int = 1,
|
97 |
+
lrelu_slope: float = 0.2,
|
98 |
+
block_gmlp_factor: int = 2,
|
99 |
+
grid_gmlp_factor: int = 2,
|
100 |
+
input_proj_factor: int = 2,
|
101 |
+
channels_reduction: int = 4,
|
102 |
+
dropout_rate: float = 0.0,
|
103 |
+
downsample: bool = True,
|
104 |
+
use_global_mlp: bool = True,
|
105 |
+
use_bias: bool = True,
|
106 |
+
name: str = "unet_decoder",
|
107 |
+
):
|
108 |
+
|
109 |
+
"""Decoder block in MAXIM."""
|
110 |
+
|
111 |
+
def apply(x, bridge=None):
|
112 |
+
x = ConvT_up(
|
113 |
+
filters=num_channels, use_bias=use_bias, name=f"{name}_ConvTranspose_0"
|
114 |
+
)(x)
|
115 |
+
x = UNetEncoderBlock(
|
116 |
+
num_channels=num_channels,
|
117 |
+
num_groups=num_groups,
|
118 |
+
lrelu_slope=lrelu_slope,
|
119 |
+
block_size=block_size,
|
120 |
+
grid_size=grid_size,
|
121 |
+
block_gmlp_factor=block_gmlp_factor,
|
122 |
+
grid_gmlp_factor=grid_gmlp_factor,
|
123 |
+
channels_reduction=channels_reduction,
|
124 |
+
use_global_mlp=use_global_mlp,
|
125 |
+
dropout_rate=dropout_rate,
|
126 |
+
downsample=False,
|
127 |
+
use_bias=use_bias,
|
128 |
+
name=f"{name}_UNetEncoderBlock_0",
|
129 |
+
)(x, skip=bridge)
|
130 |
+
|
131 |
+
return x
|
132 |
+
|
133 |
+
return apply
|
maxim/configs.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MAXIM_CONFIGS = {
|
2 |
+
# params: 6.108515000000001 M, GFLOPS: 93.163716608
|
3 |
+
"S-1": {
|
4 |
+
"features": 32,
|
5 |
+
"depth": 3,
|
6 |
+
"num_stages": 1,
|
7 |
+
"num_groups": 2,
|
8 |
+
"num_bottleneck_blocks": 2,
|
9 |
+
"block_gmlp_factor": 2,
|
10 |
+
"grid_gmlp_factor": 2,
|
11 |
+
"input_proj_factor": 2,
|
12 |
+
"channels_reduction": 4,
|
13 |
+
"name": "s1",
|
14 |
+
},
|
15 |
+
# params: 13.35383 M, GFLOPS: 206.743273472
|
16 |
+
"S-2": {
|
17 |
+
"features": 32,
|
18 |
+
"depth": 3,
|
19 |
+
"num_stages": 2,
|
20 |
+
"num_groups": 2,
|
21 |
+
"num_bottleneck_blocks": 2,
|
22 |
+
"block_gmlp_factor": 2,
|
23 |
+
"grid_gmlp_factor": 2,
|
24 |
+
"input_proj_factor": 2,
|
25 |
+
"channels_reduction": 4,
|
26 |
+
"name": "s2",
|
27 |
+
},
|
28 |
+
# params: 20.599145 M, GFLOPS: 320.32194560000005
|
29 |
+
"S-3": {
|
30 |
+
"features": 32,
|
31 |
+
"depth": 3,
|
32 |
+
"num_stages": 3,
|
33 |
+
"num_groups": 2,
|
34 |
+
"num_bottleneck_blocks": 2,
|
35 |
+
"block_gmlp_factor": 2,
|
36 |
+
"grid_gmlp_factor": 2,
|
37 |
+
"input_proj_factor": 2,
|
38 |
+
"channels_reduction": 4,
|
39 |
+
"name": "s3",
|
40 |
+
},
|
41 |
+
# params: 19.361219000000002 M, 308.495712256 GFLOPs
|
42 |
+
"M-1": {
|
43 |
+
"features": 64,
|
44 |
+
"depth": 3,
|
45 |
+
"num_stages": 1,
|
46 |
+
"num_groups": 2,
|
47 |
+
"num_bottleneck_blocks": 2,
|
48 |
+
"block_gmlp_factor": 2,
|
49 |
+
"grid_gmlp_factor": 2,
|
50 |
+
"input_proj_factor": 2,
|
51 |
+
"channels_reduction": 4,
|
52 |
+
"name": "m1",
|
53 |
+
},
|
54 |
+
# params: 40.83911 M, 675.25541888 GFLOPs
|
55 |
+
"M-2": {
|
56 |
+
"features": 64,
|
57 |
+
"depth": 3,
|
58 |
+
"num_stages": 2,
|
59 |
+
"num_groups": 2,
|
60 |
+
"num_bottleneck_blocks": 2,
|
61 |
+
"block_gmlp_factor": 2,
|
62 |
+
"grid_gmlp_factor": 2,
|
63 |
+
"input_proj_factor": 2,
|
64 |
+
"channels_reduction": 4,
|
65 |
+
"name": "m2",
|
66 |
+
},
|
67 |
+
# params: 62.317001 M, 1042.014666752 GFLOPs
|
68 |
+
"M-3": {
|
69 |
+
"features": 64,
|
70 |
+
"depth": 3,
|
71 |
+
"num_stages": 3,
|
72 |
+
"num_groups": 2,
|
73 |
+
"num_bottleneck_blocks": 2,
|
74 |
+
"block_gmlp_factor": 2,
|
75 |
+
"grid_gmlp_factor": 2,
|
76 |
+
"input_proj_factor": 2,
|
77 |
+
"channels_reduction": 4,
|
78 |
+
"name": "m3",
|
79 |
+
},
|
80 |
+
}
|
maxim/layers.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import einops
|
2 |
+
import tensorflow as tf
|
3 |
+
from tensorflow.experimental import numpy as tnp
|
4 |
+
from tensorflow.keras import backend as K
|
5 |
+
from tensorflow.keras import layers
|
6 |
+
|
7 |
+
|
8 |
+
@tf.keras.utils.register_keras_serializable("maxim")
|
9 |
+
class BlockImages(layers.Layer):
|
10 |
+
def __init__(self, **kwargs):
|
11 |
+
super().__init__(**kwargs)
|
12 |
+
|
13 |
+
def call(self, x, patch_size):
|
14 |
+
bs, h, w, num_channels = (
|
15 |
+
K.int_shape(x)[0],
|
16 |
+
K.int_shape(x)[1],
|
17 |
+
K.int_shape(x)[2],
|
18 |
+
K.int_shape(x)[3],
|
19 |
+
)
|
20 |
+
|
21 |
+
grid_height, grid_width = h // patch_size[0], w // patch_size[1]
|
22 |
+
|
23 |
+
x = einops.rearrange(
|
24 |
+
x,
|
25 |
+
"n (gh fh) (gw fw) c -> n (gh gw) (fh fw) c",
|
26 |
+
gh=grid_height,
|
27 |
+
gw=grid_width,
|
28 |
+
fh=patch_size[0],
|
29 |
+
fw=patch_size[1],
|
30 |
+
)
|
31 |
+
|
32 |
+
return x
|
33 |
+
|
34 |
+
def get_config(self):
|
35 |
+
config = super().get_config().copy()
|
36 |
+
return config
|
37 |
+
|
38 |
+
|
39 |
+
@tf.keras.utils.register_keras_serializable("maxim")
|
40 |
+
class UnblockImages(layers.Layer):
|
41 |
+
def __init__(self, **kwargs):
|
42 |
+
super().__init__(**kwargs)
|
43 |
+
|
44 |
+
def call(self, x, grid_size, patch_size):
|
45 |
+
x = einops.rearrange(
|
46 |
+
x,
|
47 |
+
"n (gh gw) (fh fw) c -> n (gh fh) (gw fw) c",
|
48 |
+
gh=grid_size[0],
|
49 |
+
gw=grid_size[1],
|
50 |
+
fh=patch_size[0],
|
51 |
+
fw=patch_size[1],
|
52 |
+
)
|
53 |
+
|
54 |
+
return x
|
55 |
+
|
56 |
+
def get_config(self):
|
57 |
+
config = super().get_config().copy()
|
58 |
+
return config
|
59 |
+
|
60 |
+
|
61 |
+
@tf.keras.utils.register_keras_serializable("maxim")
|
62 |
+
class SwapAxes(layers.Layer):
|
63 |
+
def __init__(self, **kwargs):
|
64 |
+
super().__init__(**kwargs)
|
65 |
+
|
66 |
+
def call(self, x, axis_one, axis_two):
|
67 |
+
return tnp.swapaxes(x, axis_one, axis_two)
|
68 |
+
|
69 |
+
def get_config(self):
|
70 |
+
config = super().get_config().copy()
|
71 |
+
return config
|
72 |
+
|
73 |
+
|
74 |
+
@tf.keras.utils.register_keras_serializable("maxim")
|
75 |
+
class Resizing(layers.Layer):
|
76 |
+
def __init__(self, height, width, antialias=True, method="bilinear", **kwargs):
|
77 |
+
super().__init__(**kwargs)
|
78 |
+
self.height = height
|
79 |
+
self.width = width
|
80 |
+
self.antialias = antialias
|
81 |
+
self.method = method
|
82 |
+
|
83 |
+
def call(self, x):
|
84 |
+
return tf.image.resize(
|
85 |
+
x,
|
86 |
+
size=(self.height, self.width),
|
87 |
+
antialias=self.antialias,
|
88 |
+
method=self.method,
|
89 |
+
)
|
90 |
+
|
91 |
+
def get_config(self):
|
92 |
+
config = super().get_config().copy()
|
93 |
+
config.update(
|
94 |
+
{
|
95 |
+
"height": self.height,
|
96 |
+
"width": self.width,
|
97 |
+
"antialias": self.antialias,
|
98 |
+
"method": self.method,
|
99 |
+
}
|
100 |
+
)
|
101 |
+
return config
|
maxim/maxim.py
ADDED
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
|
3 |
+
import tensorflow as tf
|
4 |
+
from tensorflow.keras import backend as K
|
5 |
+
from tensorflow.keras import layers
|
6 |
+
|
7 |
+
from .blocks.attentions import SAM
|
8 |
+
from .blocks.bottleneck import BottleneckBlock
|
9 |
+
from .blocks.misc_gating import CrossGatingBlock
|
10 |
+
from .blocks.others import UpSampleRatio
|
11 |
+
from .blocks.unet import UNetDecoderBlock, UNetEncoderBlock
|
12 |
+
from .layers import Resizing
|
13 |
+
|
14 |
+
Conv1x1 = functools.partial(layers.Conv2D, kernel_size=(1, 1), padding="same")
|
15 |
+
Conv3x3 = functools.partial(layers.Conv2D, kernel_size=(3, 3), padding="same")
|
16 |
+
ConvT_up = functools.partial(
|
17 |
+
layers.Conv2DTranspose, kernel_size=(2, 2), strides=(2, 2), padding="same"
|
18 |
+
)
|
19 |
+
Conv_down = functools.partial(
|
20 |
+
layers.Conv2D, kernel_size=(4, 4), strides=(2, 2), padding="same"
|
21 |
+
)
|
22 |
+
|
23 |
+
|
24 |
+
def MAXIM(
|
25 |
+
features: int = 64,
|
26 |
+
depth: int = 3,
|
27 |
+
num_stages: int = 2,
|
28 |
+
num_groups: int = 1,
|
29 |
+
use_bias: bool = True,
|
30 |
+
num_supervision_scales: int = 1,
|
31 |
+
lrelu_slope: float = 0.2,
|
32 |
+
use_global_mlp: bool = True,
|
33 |
+
use_cross_gating: bool = True,
|
34 |
+
high_res_stages: int = 2,
|
35 |
+
block_size_hr=(16, 16),
|
36 |
+
block_size_lr=(8, 8),
|
37 |
+
grid_size_hr=(16, 16),
|
38 |
+
grid_size_lr=(8, 8),
|
39 |
+
num_bottleneck_blocks: int = 1,
|
40 |
+
block_gmlp_factor: int = 2,
|
41 |
+
grid_gmlp_factor: int = 2,
|
42 |
+
input_proj_factor: int = 2,
|
43 |
+
channels_reduction: int = 4,
|
44 |
+
num_outputs: int = 3,
|
45 |
+
dropout_rate: float = 0.0,
|
46 |
+
):
|
47 |
+
"""The MAXIM model function with multi-stage and multi-scale supervision.
|
48 |
+
|
49 |
+
For more model details, please check the CVPR paper:
|
50 |
+
MAXIM: MUlti-Axis MLP for Image Processing (https://arxiv.org/abs/2201.02973)
|
51 |
+
|
52 |
+
Attributes:
|
53 |
+
features: initial hidden dimension for the input resolution.
|
54 |
+
depth: the number of downsampling depth for the model.
|
55 |
+
num_stages: how many stages to use. It will also affects the output list.
|
56 |
+
num_groups: how many blocks each stage contains.
|
57 |
+
use_bias: whether to use bias in all the conv/mlp layers.
|
58 |
+
num_supervision_scales: the number of desired supervision scales.
|
59 |
+
lrelu_slope: the negative slope parameter in leaky_relu layers.
|
60 |
+
use_global_mlp: whether to use the multi-axis gated MLP block (MAB) in each
|
61 |
+
layer.
|
62 |
+
use_cross_gating: whether to use the cross-gating MLP block (CGB) in the
|
63 |
+
skip connections and multi-stage feature fusion layers.
|
64 |
+
high_res_stages: how many stages are specificied as high-res stages. The
|
65 |
+
rest (depth - high_res_stages) are called low_res_stages.
|
66 |
+
block_size_hr: the block_size parameter for high-res stages.
|
67 |
+
block_size_lr: the block_size parameter for low-res stages.
|
68 |
+
grid_size_hr: the grid_size parameter for high-res stages.
|
69 |
+
grid_size_lr: the grid_size parameter for low-res stages.
|
70 |
+
num_bottleneck_blocks: how many bottleneck blocks.
|
71 |
+
block_gmlp_factor: the input projection factor for block_gMLP layers.
|
72 |
+
grid_gmlp_factor: the input projection factor for grid_gMLP layers.
|
73 |
+
input_proj_factor: the input projection factor for the MAB block.
|
74 |
+
channels_reduction: the channel reduction factor for SE layer.
|
75 |
+
num_outputs: the output channels.
|
76 |
+
dropout_rate: Dropout rate.
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
The output contains a list of arrays consisting of multi-stage multi-scale
|
80 |
+
outputs. For example, if num_stages = num_supervision_scales = 3 (the
|
81 |
+
model used in the paper), the output specs are: outputs =
|
82 |
+
[[output_stage1_scale1, output_stage1_scale2, output_stage1_scale3],
|
83 |
+
[output_stage2_scale1, output_stage2_scale2, output_stage2_scale3],
|
84 |
+
[output_stage3_scale1, output_stage3_scale2, output_stage3_scale3],]
|
85 |
+
The final output can be retrieved by outputs[-1][-1].
|
86 |
+
"""
|
87 |
+
|
88 |
+
def apply(x):
|
89 |
+
n, h, w, c = (
|
90 |
+
K.int_shape(x)[0],
|
91 |
+
K.int_shape(x)[1],
|
92 |
+
K.int_shape(x)[2],
|
93 |
+
K.int_shape(x)[3],
|
94 |
+
) # input image shape
|
95 |
+
|
96 |
+
shortcuts = []
|
97 |
+
shortcuts.append(x)
|
98 |
+
|
99 |
+
# Get multi-scale input images
|
100 |
+
for i in range(1, num_supervision_scales):
|
101 |
+
resizing_layer = Resizing(
|
102 |
+
height=h // (2 ** i),
|
103 |
+
width=w // (2 ** i),
|
104 |
+
method="nearest",
|
105 |
+
antialias=True, # Following `jax.image.resize()`.
|
106 |
+
name=f"initial_resizing_{K.get_uid('Resizing')}",
|
107 |
+
)
|
108 |
+
shortcuts.append(resizing_layer(x))
|
109 |
+
|
110 |
+
# store outputs from all stages and all scales
|
111 |
+
# Eg, [[(64, 64, 3), (128, 128, 3), (256, 256, 3)], # Stage-1 outputs
|
112 |
+
# [(64, 64, 3), (128, 128, 3), (256, 256, 3)],] # Stage-2 outputs
|
113 |
+
outputs_all = []
|
114 |
+
sam_features, encs_prev, decs_prev = [], [], []
|
115 |
+
|
116 |
+
for idx_stage in range(num_stages):
|
117 |
+
# Input convolution, get multi-scale input features
|
118 |
+
x_scales = []
|
119 |
+
for i in range(num_supervision_scales):
|
120 |
+
x_scale = Conv3x3(
|
121 |
+
filters=(2 ** i) * features,
|
122 |
+
use_bias=use_bias,
|
123 |
+
name=f"stage_{idx_stage}_input_conv_{i}",
|
124 |
+
)(shortcuts[i])
|
125 |
+
|
126 |
+
# If later stages, fuse input features with SAM features from prev stage
|
127 |
+
if idx_stage > 0:
|
128 |
+
# use larger blocksize at high-res stages
|
129 |
+
if use_cross_gating:
|
130 |
+
block_size = (
|
131 |
+
block_size_hr if i < high_res_stages else block_size_lr
|
132 |
+
)
|
133 |
+
grid_size = grid_size_hr if i < high_res_stages else block_size_lr
|
134 |
+
x_scale, _ = CrossGatingBlock(
|
135 |
+
features=(2 ** i) * features,
|
136 |
+
block_size=block_size,
|
137 |
+
grid_size=grid_size,
|
138 |
+
dropout_rate=dropout_rate,
|
139 |
+
input_proj_factor=input_proj_factor,
|
140 |
+
upsample_y=False,
|
141 |
+
use_bias=use_bias,
|
142 |
+
name=f"stage_{idx_stage}_input_fuse_sam_{i}",
|
143 |
+
)(x_scale, sam_features.pop())
|
144 |
+
else:
|
145 |
+
x_scale = Conv1x1(
|
146 |
+
filters=(2 ** i) * features,
|
147 |
+
use_bias=use_bias,
|
148 |
+
name=f"stage_{idx_stage}_input_catconv_{i}",
|
149 |
+
)(tf.concat([x_scale, sam_features.pop()], axis=-1))
|
150 |
+
|
151 |
+
x_scales.append(x_scale)
|
152 |
+
|
153 |
+
# start encoder blocks
|
154 |
+
encs = []
|
155 |
+
x = x_scales[0] # First full-scale input feature
|
156 |
+
|
157 |
+
for i in range(depth): # 0, 1, 2
|
158 |
+
# use larger blocksize at high-res stages, vice versa.
|
159 |
+
block_size = block_size_hr if i < high_res_stages else block_size_lr
|
160 |
+
grid_size = grid_size_hr if i < high_res_stages else block_size_lr
|
161 |
+
use_cross_gating_layer = True if idx_stage > 0 else False
|
162 |
+
|
163 |
+
# Multi-scale input if multi-scale supervision
|
164 |
+
x_scale = x_scales[i] if i < num_supervision_scales else None
|
165 |
+
|
166 |
+
# UNet Encoder block
|
167 |
+
enc_prev = encs_prev.pop() if idx_stage > 0 else None
|
168 |
+
dec_prev = decs_prev.pop() if idx_stage > 0 else None
|
169 |
+
|
170 |
+
x, bridge = UNetEncoderBlock(
|
171 |
+
num_channels=(2 ** i) * features,
|
172 |
+
num_groups=num_groups,
|
173 |
+
downsample=True,
|
174 |
+
lrelu_slope=lrelu_slope,
|
175 |
+
block_size=block_size,
|
176 |
+
grid_size=grid_size,
|
177 |
+
block_gmlp_factor=block_gmlp_factor,
|
178 |
+
grid_gmlp_factor=grid_gmlp_factor,
|
179 |
+
input_proj_factor=input_proj_factor,
|
180 |
+
channels_reduction=channels_reduction,
|
181 |
+
use_global_mlp=use_global_mlp,
|
182 |
+
dropout_rate=dropout_rate,
|
183 |
+
use_bias=use_bias,
|
184 |
+
use_cross_gating=use_cross_gating_layer,
|
185 |
+
name=f"stage_{idx_stage}_encoder_block_{i}",
|
186 |
+
)(x, skip=x_scale, enc=enc_prev, dec=dec_prev)
|
187 |
+
|
188 |
+
# Cache skip signals
|
189 |
+
encs.append(bridge)
|
190 |
+
|
191 |
+
# Global MLP bottleneck blocks
|
192 |
+
for i in range(num_bottleneck_blocks):
|
193 |
+
x = BottleneckBlock(
|
194 |
+
block_size=block_size_lr,
|
195 |
+
grid_size=block_size_lr,
|
196 |
+
features=(2 ** (depth - 1)) * features,
|
197 |
+
num_groups=num_groups,
|
198 |
+
block_gmlp_factor=block_gmlp_factor,
|
199 |
+
grid_gmlp_factor=grid_gmlp_factor,
|
200 |
+
input_proj_factor=input_proj_factor,
|
201 |
+
dropout_rate=dropout_rate,
|
202 |
+
use_bias=use_bias,
|
203 |
+
channels_reduction=channels_reduction,
|
204 |
+
name=f"stage_{idx_stage}_global_block_{i}",
|
205 |
+
)(x)
|
206 |
+
# cache global feature for cross-gating
|
207 |
+
global_feature = x
|
208 |
+
|
209 |
+
# start cross gating. Use multi-scale feature fusion
|
210 |
+
skip_features = []
|
211 |
+
for i in reversed(range(depth)): # 2, 1, 0
|
212 |
+
# use larger blocksize at high-res stages
|
213 |
+
block_size = block_size_hr if i < high_res_stages else block_size_lr
|
214 |
+
grid_size = grid_size_hr if i < high_res_stages else block_size_lr
|
215 |
+
|
216 |
+
# get additional multi-scale signals
|
217 |
+
signal = tf.concat(
|
218 |
+
[
|
219 |
+
UpSampleRatio(
|
220 |
+
num_channels=(2 ** i) * features,
|
221 |
+
ratio=2 ** (j - i),
|
222 |
+
use_bias=use_bias,
|
223 |
+
name=f"UpSampleRatio_{K.get_uid('UpSampleRatio')}",
|
224 |
+
)(enc)
|
225 |
+
for j, enc in enumerate(encs)
|
226 |
+
],
|
227 |
+
axis=-1,
|
228 |
+
)
|
229 |
+
|
230 |
+
# Use cross-gating to cross modulate features
|
231 |
+
if use_cross_gating:
|
232 |
+
skips, global_feature = CrossGatingBlock(
|
233 |
+
features=(2 ** i) * features,
|
234 |
+
block_size=block_size,
|
235 |
+
grid_size=grid_size,
|
236 |
+
input_proj_factor=input_proj_factor,
|
237 |
+
dropout_rate=dropout_rate,
|
238 |
+
upsample_y=True,
|
239 |
+
use_bias=use_bias,
|
240 |
+
name=f"stage_{idx_stage}_cross_gating_block_{i}",
|
241 |
+
)(signal, global_feature)
|
242 |
+
else:
|
243 |
+
skips = Conv1x1(
|
244 |
+
filters=(2 ** i) * features, use_bias=use_bias, name="Conv_0"
|
245 |
+
)(signal)
|
246 |
+
skips = Conv3x3(
|
247 |
+
filters=(2 ** i) * features, use_bias=use_bias, name="Conv_1"
|
248 |
+
)(skips)
|
249 |
+
|
250 |
+
skip_features.append(skips)
|
251 |
+
|
252 |
+
# start decoder. Multi-scale feature fusion of cross-gated features
|
253 |
+
outputs, decs, sam_features = [], [], []
|
254 |
+
for i in reversed(range(depth)):
|
255 |
+
# use larger blocksize at high-res stages
|
256 |
+
block_size = block_size_hr if i < high_res_stages else block_size_lr
|
257 |
+
grid_size = grid_size_hr if i < high_res_stages else block_size_lr
|
258 |
+
|
259 |
+
# get multi-scale skip signals from cross-gating block
|
260 |
+
signal = tf.concat(
|
261 |
+
[
|
262 |
+
UpSampleRatio(
|
263 |
+
num_channels=(2 ** i) * features,
|
264 |
+
ratio=2 ** (depth - j - 1 - i),
|
265 |
+
use_bias=use_bias,
|
266 |
+
name=f"UpSampleRatio_{K.get_uid('UpSampleRatio')}",
|
267 |
+
)(skip)
|
268 |
+
for j, skip in enumerate(skip_features)
|
269 |
+
],
|
270 |
+
axis=-1,
|
271 |
+
)
|
272 |
+
|
273 |
+
# Decoder block
|
274 |
+
x = UNetDecoderBlock(
|
275 |
+
num_channels=(2 ** i) * features,
|
276 |
+
num_groups=num_groups,
|
277 |
+
lrelu_slope=lrelu_slope,
|
278 |
+
block_size=block_size,
|
279 |
+
grid_size=grid_size,
|
280 |
+
block_gmlp_factor=block_gmlp_factor,
|
281 |
+
grid_gmlp_factor=grid_gmlp_factor,
|
282 |
+
input_proj_factor=input_proj_factor,
|
283 |
+
channels_reduction=channels_reduction,
|
284 |
+
use_global_mlp=use_global_mlp,
|
285 |
+
dropout_rate=dropout_rate,
|
286 |
+
use_bias=use_bias,
|
287 |
+
name=f"stage_{idx_stage}_decoder_block_{i}",
|
288 |
+
)(x, bridge=signal)
|
289 |
+
|
290 |
+
# Cache decoder features for later-stage's usage
|
291 |
+
decs.append(x)
|
292 |
+
|
293 |
+
# output conv, if not final stage, use supervised-attention-block.
|
294 |
+
if i < num_supervision_scales:
|
295 |
+
if idx_stage < num_stages - 1: # not last stage, apply SAM
|
296 |
+
sam, output = SAM(
|
297 |
+
num_channels=(2 ** i) * features,
|
298 |
+
output_channels=num_outputs,
|
299 |
+
use_bias=use_bias,
|
300 |
+
name=f"stage_{idx_stage}_supervised_attention_module_{i}",
|
301 |
+
)(x, shortcuts[i])
|
302 |
+
outputs.append(output)
|
303 |
+
sam_features.append(sam)
|
304 |
+
else: # Last stage, apply output convolutions
|
305 |
+
output = Conv3x3(
|
306 |
+
num_outputs,
|
307 |
+
use_bias=use_bias,
|
308 |
+
name=f"stage_{idx_stage}_output_conv_{i}",
|
309 |
+
)(x)
|
310 |
+
output = output + shortcuts[i]
|
311 |
+
outputs.append(output)
|
312 |
+
# Cache encoder and decoder features for later-stage's usage
|
313 |
+
encs_prev = encs[::-1]
|
314 |
+
decs_prev = decs
|
315 |
+
|
316 |
+
# Store outputs
|
317 |
+
outputs_all.append(outputs)
|
318 |
+
return outputs_all
|
319 |
+
|
320 |
+
return apply
|