sayakpaul HF staff commited on
Commit
3f9d71f
·
1 Parent(s): 399866a

add: files.

Browse files
1.png ADDED
111.png ADDED
748.png ADDED
a4541-DSC_0040-2.png ADDED
app.py CHANGED
@@ -1,13 +1,16 @@
 
 
 
 
 
 
 
1
  from huggingface_hub.keras_mixin import from_pretrained_keras
2
-
3
  from PIL import Image
4
 
5
- import numpy as np
6
-
7
  from create_maxim_model import Model
8
  from maxim.configs import MAXIM_CONFIGS
9
 
10
-
11
  _MODEL = from_pretrained_keras("sayakpaul/S-2_enhancement_lol")
12
 
13
 
@@ -23,63 +26,6 @@ def mod_padding_symmetric(image, factor=64):
23
  image, [(padh // 2, padh // 2), (padw // 2, padw // 2), (0, 0)], mode="REFLECT"
24
  )
25
  return image
26
-
27
- def _convert_input_type_range(img):
28
- """Convert the type and range of the input image.
29
-
30
- It converts the input image to np.float32 type and range of [0, 1].
31
- It is mainly used for pre-processing the input image in colorspace
32
- convertion functions such as rgb2ycbcr and ycbcr2rgb.
33
- Args:
34
- img (ndarray): The input image. It accepts:
35
- 1. np.uint8 type with range [0, 255];
36
- 2. np.float32 type with range [0, 1].
37
- Returns:
38
- (ndarray): The converted image with type of np.float32 and range of
39
- [0, 1].
40
- """
41
- img_type = img.dtype
42
- img = img.astype(np.float32)
43
- if img_type == np.float32:
44
- pass
45
- elif img_type == np.uint8:
46
- img /= 255.0
47
- else:
48
- raise TypeError(
49
- "The img type should be np.float32 or np.uint8, " f"but got {img_type}"
50
- )
51
- return img
52
-
53
-
54
- def _convert_output_type_range(img, dst_type):
55
- """Convert the type and range of the image according to dst_type.
56
-
57
- It converts the image to desired type and range. If `dst_type` is np.uint8,
58
- images will be converted to np.uint8 type with range [0, 255]. If
59
- `dst_type` is np.float32, it converts the image to np.float32 type with
60
- range [0, 1].
61
- It is mainly used for post-processing images in colorspace convertion
62
- functions such as rgb2ycbcr and ycbcr2rgb.
63
- Args:
64
- img (ndarray): The image to be converted with np.float32 type and
65
- range [0, 255].
66
- dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
67
- converts the image to np.uint8 type with range [0, 255]. If
68
- dst_type is np.float32, it converts the image to np.float32 type
69
- with range [0, 1].
70
- Returns:
71
- (ndarray): The converted image with desired type and range.
72
- """
73
- if dst_type not in (np.uint8, np.float32):
74
- raise TypeError(
75
- "The dst_type should be np.float32 or np.uint8, " f"but got {dst_type}"
76
- )
77
- if dst_type == np.uint8:
78
- img = img.round()
79
- else:
80
- img /= 255.0
81
-
82
- return img.astype(dst_type)
83
 
84
 
85
  def make_shape_even(image):
@@ -89,8 +35,8 @@ def make_shape_even(image):
89
  padw = 1 if width % 2 != 0 else 0
90
  image = tf.pad(image, [(0, padh), (0, padw), (0, 0)], mode="REFLECT")
91
  return image
92
-
93
-
94
  def process_image(image: Image):
95
  input_img = np.asarray(image) / 255.0
96
  height, width = input_img.shape[0], input_img.shape[1]
@@ -102,9 +48,9 @@ def process_image(image: Image):
102
  # padding images to be multiplies of 64
103
  input_img = mod_padding_symmetric(input_img, factor=64)
104
  input_img = tf.expand_dims(input_img, axis=0)
105
- return input_img, height_even, width_even
106
-
107
-
108
  def init_new_model(input_img):
109
  configs = MAXIM_CONFIGS.get("S-2")
110
  configs.update(
@@ -120,25 +66,40 @@ def init_new_model(input_img):
120
  new_model = Model(**configs)
121
  new_model.set_weights(_MODEL.get_weights())
122
  return new_model
123
-
124
-
125
  def infer(image):
126
- preprocessed_image, height_even, width_even = process_image(image)
127
  new_model = init_new_model(preprocessed_image)
128
-
129
  preds = new_model.predict(preprocessed_image)
130
  if isinstance(preds, list):
131
  preds = preds[-1]
132
  if isinstance(preds, list):
133
  preds = preds[-1]
134
-
135
  preds = np.array(preds[0], np.float32)
136
-
137
  new_height, new_width = preds.shape[0], preds.shape[1]
138
  h_start = new_height // 2 - height_even // 2
139
  h_end = h_start + height
140
  w_start = new_width // 2 - width_even // 2
141
  w_end = w_start + width
142
  preds = preds[h_start:h_end, w_start:w_end, :]
143
-
144
- return Image.fromarray(np.array((np.clip(preds, 0.0, 1.0) * 255.0).astype(np.uint8)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Some preprocessing utilities have been taken from:
3
+ https://github.com/google-research/maxim/blob/main/maxim/run_eval.py
4
+ """
5
+ import gradio as gr
6
+ import numpy as np
7
+ import tensorflow as tf
8
  from huggingface_hub.keras_mixin import from_pretrained_keras
 
9
  from PIL import Image
10
 
 
 
11
  from create_maxim_model import Model
12
  from maxim.configs import MAXIM_CONFIGS
13
 
 
14
  _MODEL = from_pretrained_keras("sayakpaul/S-2_enhancement_lol")
15
 
16
 
 
26
  image, [(padh // 2, padh // 2), (padw // 2, padw // 2), (0, 0)], mode="REFLECT"
27
  )
28
  return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
 
31
  def make_shape_even(image):
 
35
  padw = 1 if width % 2 != 0 else 0
36
  image = tf.pad(image, [(0, padh), (0, padw), (0, 0)], mode="REFLECT")
37
  return image
38
+
39
+
40
  def process_image(image: Image):
41
  input_img = np.asarray(image) / 255.0
42
  height, width = input_img.shape[0], input_img.shape[1]
 
48
  # padding images to be multiplies of 64
49
  input_img = mod_padding_symmetric(input_img, factor=64)
50
  input_img = tf.expand_dims(input_img, axis=0)
51
+ return input_img, height, width, height_even, width_even
52
+
53
+
54
  def init_new_model(input_img):
55
  configs = MAXIM_CONFIGS.get("S-2")
56
  configs.update(
 
66
  new_model = Model(**configs)
67
  new_model.set_weights(_MODEL.get_weights())
68
  return new_model
69
+
70
+
71
  def infer(image):
72
+ preprocessed_image, height, width, height_even, width_even = process_image(image)
73
  new_model = init_new_model(preprocessed_image)
74
+
75
  preds = new_model.predict(preprocessed_image)
76
  if isinstance(preds, list):
77
  preds = preds[-1]
78
  if isinstance(preds, list):
79
  preds = preds[-1]
80
+
81
  preds = np.array(preds[0], np.float32)
82
+
83
  new_height, new_width = preds.shape[0], preds.shape[1]
84
  h_start = new_height // 2 - height_even // 2
85
  h_end = h_start + height
86
  w_start = new_width // 2 - width_even // 2
87
  w_end = w_start + width
88
  preds = preds[h_start:h_end, w_start:w_end, :]
89
+
90
+ return Image.fromarray(np.array((np.clip(preds, 0.0, 1.0) * 255.0).astype(np.uint8)))
91
+
92
+
93
+ title = "Enhance low-light images."
94
+ article = "Model based on [this](https://huggingface.co/sayakpaul/S-2_enhancement_lol)."
95
+
96
+ iface = gr.Interface(
97
+ infer,
98
+ inputs="image",
99
+ outputs="image",
100
+ title=title,
101
+ article=article,
102
+ allow_flagging="never",
103
+ examples=[["1.png"], ["111.png"], ["748.png"], ["a4541-DSC_0040-2.png"]],
104
+ )
105
+ iface.launch(debug=True)
create_maxim_model.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tensorflow import keras
2
+
3
+ from maxim import maxim
4
+ from maxim.configs import MAXIM_CONFIGS
5
+
6
+
7
+ def Model(variant=None, input_resolution=(256, 256), **kw) -> keras.Model:
8
+ """Factory function to easily create a Model variant like "S".
9
+
10
+ Args:
11
+ variant: UNet model variants. Options: 'S-1' | 'S-2' | 'S-3'
12
+ | 'M-1' | 'M-2' | 'M-3'
13
+ input_resolution: Size of the input images.
14
+ **kw: Other UNet config dicts.
15
+
16
+ Returns:
17
+ The MAXIM model.
18
+ """
19
+
20
+ if variant is not None:
21
+ config = MAXIM_CONFIGS[variant]
22
+ for k, v in config.items():
23
+ kw.setdefault(k, v)
24
+
25
+ if "variant" in kw:
26
+ _ = kw.pop("variant")
27
+ if "input_resolution" in kw:
28
+ _ = kw.pop("input_resolution")
29
+ model_name = kw.pop("name")
30
+
31
+ maxim_model = maxim.MAXIM(**kw)
32
+
33
+ inputs = keras.Input((*input_resolution, 3))
34
+ outputs = maxim_model(inputs)
35
+ final_model = keras.Model(inputs, outputs, name=f"{model_name}_model")
36
+
37
+ return final_model
maxim/__init__.py ADDED
File without changes
maxim/blocks/__init__.py ADDED
File without changes
maxim/blocks/attentions.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+
3
+ import tensorflow as tf
4
+ from tensorflow.keras import layers
5
+
6
+ from .others import MlpBlock
7
+
8
+ Conv3x3 = functools.partial(layers.Conv2D, kernel_size=(3, 3), padding="same")
9
+ Conv1x1 = functools.partial(layers.Conv2D, kernel_size=(1, 1), padding="same")
10
+
11
+
12
+ def CALayer(
13
+ num_channels: int,
14
+ reduction: int = 4,
15
+ use_bias: bool = True,
16
+ name: str = "channel_attention",
17
+ ):
18
+ """Squeeze-and-excitation block for channel attention.
19
+
20
+ ref: https://arxiv.org/abs/1709.01507
21
+ """
22
+
23
+ def apply(x):
24
+ # 2D global average pooling
25
+ y = layers.GlobalAvgPool2D(keepdims=True)(x)
26
+ # Squeeze (in Squeeze-Excitation)
27
+ y = Conv1x1(
28
+ filters=num_channels // reduction, use_bias=use_bias, name=f"{name}_Conv_0"
29
+ )(y)
30
+ y = tf.nn.relu(y)
31
+ # Excitation (in Squeeze-Excitation)
32
+ y = Conv1x1(filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_1")(y)
33
+ y = tf.nn.sigmoid(y)
34
+ return x * y
35
+
36
+ return apply
37
+
38
+
39
+ def RCAB(
40
+ num_channels: int,
41
+ reduction: int = 4,
42
+ lrelu_slope: float = 0.2,
43
+ use_bias: bool = True,
44
+ name: str = "residual_ca",
45
+ ):
46
+ """Residual channel attention block. Contains LN,Conv,lRelu,Conv,SELayer."""
47
+
48
+ def apply(x):
49
+ shortcut = x
50
+ x = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm")(x)
51
+ x = Conv3x3(filters=num_channels, use_bias=use_bias, name=f"{name}_conv1")(x)
52
+ x = tf.nn.leaky_relu(x, alpha=lrelu_slope)
53
+ x = Conv3x3(filters=num_channels, use_bias=use_bias, name=f"{name}_conv2")(x)
54
+ x = CALayer(
55
+ num_channels=num_channels,
56
+ reduction=reduction,
57
+ use_bias=use_bias,
58
+ name=f"{name}_channel_attention",
59
+ )(x)
60
+ return x + shortcut
61
+
62
+ return apply
63
+
64
+
65
+ def RDCAB(
66
+ num_channels: int,
67
+ reduction: int = 16,
68
+ use_bias: bool = True,
69
+ dropout_rate: float = 0.0,
70
+ name: str = "rdcab",
71
+ ):
72
+ """Residual dense channel attention block. Used in Bottlenecks."""
73
+
74
+ def apply(x):
75
+ y = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm")(x)
76
+ y = MlpBlock(
77
+ mlp_dim=num_channels,
78
+ dropout_rate=dropout_rate,
79
+ use_bias=use_bias,
80
+ name=f"{name}_channel_mixing",
81
+ )(y)
82
+ y = CALayer(
83
+ num_channels=num_channels,
84
+ reduction=reduction,
85
+ use_bias=use_bias,
86
+ name=f"{name}_channel_attention",
87
+ )(y)
88
+ x = x + y
89
+ return x
90
+
91
+ return apply
92
+
93
+
94
+ def SAM(
95
+ num_channels: int,
96
+ output_channels: int = 3,
97
+ use_bias: bool = True,
98
+ name: str = "sam",
99
+ ):
100
+
101
+ """Supervised attention module for multi-stage training.
102
+
103
+ Introduced by MPRNet [CVPR2021]: https://github.com/swz30/MPRNet
104
+ """
105
+
106
+ def apply(x, x_image):
107
+ """Apply the SAM module to the input and num_channels.
108
+ Args:
109
+ x: the output num_channels from UNet decoder with shape (h, w, c)
110
+ x_image: the input image with shape (h, w, 3)
111
+ Returns:
112
+ A tuple of tensors (x1, image) where (x1) is the sam num_channels used for the
113
+ next stage, and (image) is the output restored image at current stage.
114
+ """
115
+ # Get num_channels
116
+ x1 = Conv3x3(filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_0")(x)
117
+
118
+ # Output restored image X_s
119
+ if output_channels == 3:
120
+ image = (
121
+ Conv3x3(
122
+ filters=output_channels, use_bias=use_bias, name=f"{name}_Conv_1"
123
+ )(x)
124
+ + x_image
125
+ )
126
+ else:
127
+ image = Conv3x3(
128
+ filters=output_channels, use_bias=use_bias, name=f"{name}_Conv_1"
129
+ )(x)
130
+
131
+ # Get attention maps for num_channels
132
+ x2 = tf.nn.sigmoid(
133
+ Conv3x3(filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_2")(image)
134
+ )
135
+
136
+ # Get attended feature maps
137
+ x1 = x1 * x2
138
+
139
+ # Residual connection
140
+ x1 = x1 + x
141
+ return x1, image
142
+
143
+ return apply
maxim/blocks/block_gating.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow.keras import backend as K
3
+ from tensorflow.keras import layers
4
+
5
+ from ..layers import BlockImages, SwapAxes, UnblockImages
6
+
7
+
8
+ def BlockGatingUnit(use_bias: bool = True, name: str = "block_gating_unit"):
9
+ """A SpatialGatingUnit as defined in the gMLP paper.
10
+
11
+ The 'spatial' dim is defined as the **second last**.
12
+ If applied on other dims, you should swapaxes first.
13
+ """
14
+
15
+ def apply(x):
16
+ u, v = tf.split(x, 2, axis=-1)
17
+ v = layers.LayerNormalization(
18
+ epsilon=1e-06, name=f"{name}_intermediate_layernorm"
19
+ )(v)
20
+ n = K.int_shape(x)[-2] # get spatial dim
21
+ v = SwapAxes()(v, -1, -2)
22
+ v = layers.Dense(n, use_bias=use_bias, name=f"{name}_Dense_0")(v)
23
+ v = SwapAxes()(v, -1, -2)
24
+ return u * (v + 1.0)
25
+
26
+ return apply
27
+
28
+
29
+ def BlockGmlpLayer(
30
+ block_size,
31
+ use_bias: bool = True,
32
+ factor: int = 2,
33
+ dropout_rate: float = 0.0,
34
+ name: str = "block_gmlp",
35
+ ):
36
+ """Block gMLP layer that performs local mixing of tokens."""
37
+
38
+ def apply(x):
39
+ n, h, w, num_channels = (
40
+ K.int_shape(x)[0],
41
+ K.int_shape(x)[1],
42
+ K.int_shape(x)[2],
43
+ K.int_shape(x)[3],
44
+ )
45
+ fh, fw = block_size
46
+ gh, gw = h // fh, w // fw
47
+ x = BlockImages()(x, patch_size=(fh, fw))
48
+ # MLP2: Local (block) mixing part, provides within-block communication.
49
+ y = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm")(x)
50
+ y = layers.Dense(
51
+ num_channels * factor,
52
+ use_bias=use_bias,
53
+ name=f"{name}_in_project",
54
+ )(y)
55
+ y = tf.nn.gelu(y, approximate=True)
56
+ y = BlockGatingUnit(use_bias=use_bias, name=f"{name}_BlockGatingUnit")(y)
57
+ y = layers.Dense(
58
+ num_channels,
59
+ use_bias=use_bias,
60
+ name=f"{name}_out_project",
61
+ )(y)
62
+ y = layers.Dropout(dropout_rate)(y)
63
+ x = x + y
64
+ x = UnblockImages()(x, grid_size=(gh, gw), patch_size=(fh, fw))
65
+ return x
66
+
67
+ return apply
maxim/blocks/bottleneck.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+
3
+ from tensorflow.keras import layers
4
+
5
+ from .attentions import RDCAB
6
+ from .misc_gating import ResidualSplitHeadMultiAxisGmlpLayer
7
+
8
+ Conv1x1 = functools.partial(layers.Conv2D, kernel_size=(1, 1), padding="same")
9
+
10
+
11
+ def BottleneckBlock(
12
+ features: int,
13
+ block_size,
14
+ grid_size,
15
+ num_groups: int = 1,
16
+ block_gmlp_factor: int = 2,
17
+ grid_gmlp_factor: int = 2,
18
+ input_proj_factor: int = 2,
19
+ channels_reduction: int = 4,
20
+ dropout_rate: float = 0.0,
21
+ use_bias: bool = True,
22
+ name: str = "bottleneck_block",
23
+ ):
24
+ """The bottleneck block consisting of multi-axis gMLP block and RDCAB."""
25
+
26
+ def apply(x):
27
+ # input projection
28
+ x = Conv1x1(filters=features, use_bias=use_bias, name=f"{name}_input_proj")(x)
29
+ shortcut_long = x
30
+
31
+ for i in range(num_groups):
32
+ x = ResidualSplitHeadMultiAxisGmlpLayer(
33
+ grid_size=grid_size,
34
+ block_size=block_size,
35
+ grid_gmlp_factor=grid_gmlp_factor,
36
+ block_gmlp_factor=block_gmlp_factor,
37
+ input_proj_factor=input_proj_factor,
38
+ use_bias=use_bias,
39
+ dropout_rate=dropout_rate,
40
+ name=f"{name}_SplitHeadMultiAxisGmlpLayer_{i}",
41
+ )(x)
42
+ # Channel-mixing part, which provides within-patch communication.
43
+ x = RDCAB(
44
+ num_channels=features,
45
+ reduction=channels_reduction,
46
+ use_bias=use_bias,
47
+ name=f"{name}_channel_attention_block_1_{i}",
48
+ )(x)
49
+
50
+ # long skip-connect
51
+ x = x + shortcut_long
52
+ return x
53
+
54
+ return apply
maxim/blocks/grid_gating.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow.keras import backend as K
3
+ from tensorflow.keras import layers
4
+
5
+ from ..layers import BlockImages, SwapAxes, UnblockImages
6
+
7
+
8
+ def GridGatingUnit(use_bias: bool = True, name: str = "grid_gating_unit"):
9
+ """A SpatialGatingUnit as defined in the gMLP paper.
10
+
11
+ The 'spatial' dim is defined as the second last.
12
+ If applied on other dims, you should swapaxes first.
13
+ """
14
+
15
+ def apply(x):
16
+ u, v = tf.split(x, 2, axis=-1)
17
+ v = layers.LayerNormalization(
18
+ epsilon=1e-06, name=f"{name}_intermediate_layernorm"
19
+ )(v)
20
+ n = K.int_shape(x)[-3] # get spatial dim
21
+ v = SwapAxes()(v, -1, -3)
22
+ v = layers.Dense(n, use_bias=use_bias, name=f"{name}_Dense_0")(v)
23
+ v = SwapAxes()(v, -1, -3)
24
+ return u * (v + 1.0)
25
+
26
+ return apply
27
+
28
+
29
+ def GridGmlpLayer(
30
+ grid_size,
31
+ use_bias: bool = True,
32
+ factor: int = 2,
33
+ dropout_rate: float = 0.0,
34
+ name: str = "grid_gmlp",
35
+ ):
36
+ """Grid gMLP layer that performs global mixing of tokens."""
37
+
38
+ def apply(x):
39
+ n, h, w, num_channels = (
40
+ K.int_shape(x)[0],
41
+ K.int_shape(x)[1],
42
+ K.int_shape(x)[2],
43
+ K.int_shape(x)[3],
44
+ )
45
+ gh, gw = grid_size
46
+ fh, fw = h // gh, w // gw
47
+
48
+ x = BlockImages()(x, patch_size=(fh, fw))
49
+ # gMLP1: Global (grid) mixing part, provides global grid communication.
50
+ y = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm")(x)
51
+ y = layers.Dense(
52
+ num_channels * factor,
53
+ use_bias=use_bias,
54
+ name=f"{name}_in_project",
55
+ )(y)
56
+ y = tf.nn.gelu(y, approximate=True)
57
+ y = GridGatingUnit(use_bias=use_bias, name=f"{name}_GridGatingUnit")(y)
58
+ y = layers.Dense(
59
+ num_channels,
60
+ use_bias=use_bias,
61
+ name=f"{name}_out_project",
62
+ )(y)
63
+ y = layers.Dropout(dropout_rate)(y)
64
+ x = x + y
65
+ x = UnblockImages()(x, grid_size=(gh, gw), patch_size=(fh, fw))
66
+ return x
67
+
68
+ return apply
maxim/blocks/misc_gating.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+
3
+ import tensorflow as tf
4
+ from tensorflow.keras import backend as K
5
+ from tensorflow.keras import layers
6
+
7
+ from ..layers import BlockImages, SwapAxes, UnblockImages
8
+ from .block_gating import BlockGmlpLayer
9
+ from .grid_gating import GridGmlpLayer
10
+
11
+ Conv1x1 = functools.partial(layers.Conv2D, kernel_size=(1, 1), padding="same")
12
+ Conv3x3 = functools.partial(layers.Conv2D, kernel_size=(3, 3), padding="same")
13
+ ConvT_up = functools.partial(
14
+ layers.Conv2DTranspose, kernel_size=(2, 2), strides=(2, 2), padding="same"
15
+ )
16
+ Conv_down = functools.partial(
17
+ layers.Conv2D, kernel_size=(4, 4), strides=(2, 2), padding="same"
18
+ )
19
+
20
+
21
+ def ResidualSplitHeadMultiAxisGmlpLayer(
22
+ block_size,
23
+ grid_size,
24
+ block_gmlp_factor: int = 2,
25
+ grid_gmlp_factor: int = 2,
26
+ input_proj_factor: int = 2,
27
+ use_bias: bool = True,
28
+ dropout_rate: float = 0.0,
29
+ name: str = "residual_split_head_maxim",
30
+ ):
31
+ """The multi-axis gated MLP block."""
32
+
33
+ def apply(x):
34
+ shortcut = x
35
+ n, h, w, num_channels = (
36
+ K.int_shape(x)[0],
37
+ K.int_shape(x)[1],
38
+ K.int_shape(x)[2],
39
+ K.int_shape(x)[3],
40
+ )
41
+ x = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm_in")(x)
42
+
43
+ x = layers.Dense(
44
+ int(num_channels) * input_proj_factor,
45
+ use_bias=use_bias,
46
+ name=f"{name}_in_project",
47
+ )(x)
48
+ x = tf.nn.gelu(x, approximate=True)
49
+
50
+ u, v = tf.split(x, 2, axis=-1)
51
+
52
+ # GridGMLPLayer
53
+ u = GridGmlpLayer(
54
+ grid_size=grid_size,
55
+ factor=grid_gmlp_factor,
56
+ use_bias=use_bias,
57
+ dropout_rate=dropout_rate,
58
+ name=f"{name}_GridGmlpLayer",
59
+ )(u)
60
+
61
+ # BlockGMLPLayer
62
+ v = BlockGmlpLayer(
63
+ block_size=block_size,
64
+ factor=block_gmlp_factor,
65
+ use_bias=use_bias,
66
+ dropout_rate=dropout_rate,
67
+ name=f"{name}_BlockGmlpLayer",
68
+ )(v)
69
+
70
+ x = tf.concat([u, v], axis=-1)
71
+
72
+ x = layers.Dense(
73
+ num_channels,
74
+ use_bias=use_bias,
75
+ name=f"{name}_out_project",
76
+ )(x)
77
+ x = layers.Dropout(dropout_rate)(x)
78
+ x = x + shortcut
79
+ return x
80
+
81
+ return apply
82
+
83
+
84
+ def GetSpatialGatingWeights(
85
+ features: int,
86
+ block_size,
87
+ grid_size,
88
+ input_proj_factor: int = 2,
89
+ dropout_rate: float = 0.0,
90
+ use_bias: bool = True,
91
+ name: str = "spatial_gating",
92
+ ):
93
+
94
+ """Get gating weights for cross-gating MLP block."""
95
+
96
+ def apply(x):
97
+ n, h, w, num_channels = (
98
+ K.int_shape(x)[0],
99
+ K.int_shape(x)[1],
100
+ K.int_shape(x)[2],
101
+ K.int_shape(x)[3],
102
+ )
103
+
104
+ # input projection
105
+ x = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm_in")(x)
106
+ x = layers.Dense(
107
+ num_channels * input_proj_factor,
108
+ use_bias=use_bias,
109
+ name=f"{name}_in_project",
110
+ )(x)
111
+ x = tf.nn.gelu(x, approximate=True)
112
+ u, v = tf.split(x, 2, axis=-1)
113
+
114
+ # Get grid MLP weights
115
+ gh, gw = grid_size
116
+ fh, fw = h // gh, w // gw
117
+ u = BlockImages()(u, patch_size=(fh, fw))
118
+ dim_u = K.int_shape(u)[-3]
119
+ u = SwapAxes()(u, -1, -3)
120
+ u = layers.Dense(dim_u, use_bias=use_bias, name=f"{name}_Dense_0")(u)
121
+ u = SwapAxes()(u, -1, -3)
122
+ u = UnblockImages()(u, grid_size=(gh, gw), patch_size=(fh, fw))
123
+
124
+ # Get Block MLP weights
125
+ fh, fw = block_size
126
+ gh, gw = h // fh, w // fw
127
+ v = BlockImages()(v, patch_size=(fh, fw))
128
+ dim_v = K.int_shape(v)[-2]
129
+ v = SwapAxes()(v, -1, -2)
130
+ v = layers.Dense(dim_v, use_bias=use_bias, name=f"{name}_Dense_1")(v)
131
+ v = SwapAxes()(v, -1, -2)
132
+ v = UnblockImages()(v, grid_size=(gh, gw), patch_size=(fh, fw))
133
+
134
+ x = tf.concat([u, v], axis=-1)
135
+ x = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_out_project")(x)
136
+ x = layers.Dropout(dropout_rate)(x)
137
+ return x
138
+
139
+ return apply
140
+
141
+
142
+ def CrossGatingBlock(
143
+ features: int,
144
+ block_size,
145
+ grid_size,
146
+ dropout_rate: float = 0.0,
147
+ input_proj_factor: int = 2,
148
+ upsample_y: bool = True,
149
+ use_bias: bool = True,
150
+ name: str = "cross_gating",
151
+ ):
152
+
153
+ """Cross-gating MLP block."""
154
+
155
+ def apply(x, y):
156
+ # Upscale Y signal, y is the gating signal.
157
+ if upsample_y:
158
+ y = ConvT_up(
159
+ filters=features, use_bias=use_bias, name=f"{name}_ConvTranspose_0"
160
+ )(y)
161
+
162
+ x = Conv1x1(filters=features, use_bias=use_bias, name=f"{name}_Conv_0")(x)
163
+ n, h, w, num_channels = (
164
+ K.int_shape(x)[0],
165
+ K.int_shape(x)[1],
166
+ K.int_shape(x)[2],
167
+ K.int_shape(x)[3],
168
+ )
169
+
170
+ y = Conv1x1(filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_1")(y)
171
+
172
+ shortcut_x = x
173
+ shortcut_y = y
174
+
175
+ # Get gating weights from X
176
+ x = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm_x")(x)
177
+ x = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_in_project_x")(x)
178
+ x = tf.nn.gelu(x, approximate=True)
179
+ gx = GetSpatialGatingWeights(
180
+ features=num_channels,
181
+ block_size=block_size,
182
+ grid_size=grid_size,
183
+ dropout_rate=dropout_rate,
184
+ use_bias=use_bias,
185
+ name=f"{name}_SplitHeadMultiAxisGating_x",
186
+ )(x)
187
+
188
+ # Get gating weights from Y
189
+ y = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm_y")(y)
190
+ y = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_in_project_y")(y)
191
+ y = tf.nn.gelu(y, approximate=True)
192
+ gy = GetSpatialGatingWeights(
193
+ features=num_channels,
194
+ block_size=block_size,
195
+ grid_size=grid_size,
196
+ dropout_rate=dropout_rate,
197
+ use_bias=use_bias,
198
+ name=f"{name}_SplitHeadMultiAxisGating_y",
199
+ )(y)
200
+
201
+ # Apply cross gating: X = X * GY, Y = Y * GX
202
+ y = y * gx
203
+ y = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_out_project_y")(y)
204
+ y = layers.Dropout(dropout_rate)(y)
205
+ y = y + shortcut_y
206
+
207
+ x = x * gy # gating x using y
208
+ x = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_out_project_x")(x)
209
+ x = layers.Dropout(dropout_rate)(x)
210
+ x = x + y + shortcut_x # get all aggregated signals
211
+ return x, y
212
+
213
+ return apply
maxim/blocks/others.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+
3
+ import tensorflow as tf
4
+ from tensorflow.keras import backend as K
5
+ from tensorflow.keras import layers
6
+
7
+ from ..layers import Resizing
8
+
9
+ Conv1x1 = functools.partial(layers.Conv2D, kernel_size=(1, 1), padding="same")
10
+
11
+
12
+ def MlpBlock(
13
+ mlp_dim: int,
14
+ dropout_rate: float = 0.0,
15
+ use_bias: bool = True,
16
+ name: str = "mlp_block",
17
+ ):
18
+ """A 1-hidden-layer MLP block, applied over the last dimension."""
19
+
20
+ def apply(x):
21
+ d = K.int_shape(x)[-1]
22
+ x = layers.Dense(mlp_dim, use_bias=use_bias, name=f"{name}_Dense_0")(x)
23
+ x = tf.nn.gelu(x, approximate=True)
24
+ x = layers.Dropout(dropout_rate)(x)
25
+ x = layers.Dense(d, use_bias=use_bias, name=f"{name}_Dense_1")(x)
26
+ return x
27
+
28
+ return apply
29
+
30
+
31
+ def UpSampleRatio(
32
+ num_channels: int, ratio: float, use_bias: bool = True, name: str = "upsample"
33
+ ):
34
+ """Upsample features given a ratio > 0."""
35
+
36
+ def apply(x):
37
+ n, h, w, c = (
38
+ K.int_shape(x)[0],
39
+ K.int_shape(x)[1],
40
+ K.int_shape(x)[2],
41
+ K.int_shape(x)[3],
42
+ )
43
+
44
+ # Following `jax.image.resize()`
45
+ x = Resizing(
46
+ height=int(h * ratio),
47
+ width=int(w * ratio),
48
+ method="bilinear",
49
+ antialias=True,
50
+ name=f"{name}_resizing_{K.get_uid('Resizing')}",
51
+ )(x)
52
+
53
+ x = Conv1x1(filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_0")(x)
54
+ return x
55
+
56
+ return apply
maxim/blocks/unet.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+
3
+ import tensorflow as tf
4
+ from tensorflow.keras import layers
5
+
6
+ from .attentions import RCAB
7
+ from .misc_gating import CrossGatingBlock, ResidualSplitHeadMultiAxisGmlpLayer
8
+
9
+ Conv1x1 = functools.partial(layers.Conv2D, kernel_size=(1, 1), padding="same")
10
+ Conv3x3 = functools.partial(layers.Conv2D, kernel_size=(3, 3), padding="same")
11
+ ConvT_up = functools.partial(
12
+ layers.Conv2DTranspose, kernel_size=(2, 2), strides=(2, 2), padding="same"
13
+ )
14
+ Conv_down = functools.partial(
15
+ layers.Conv2D, kernel_size=(4, 4), strides=(2, 2), padding="same"
16
+ )
17
+
18
+
19
+ def UNetEncoderBlock(
20
+ num_channels: int,
21
+ block_size,
22
+ grid_size,
23
+ num_groups: int = 1,
24
+ lrelu_slope: float = 0.2,
25
+ block_gmlp_factor: int = 2,
26
+ grid_gmlp_factor: int = 2,
27
+ input_proj_factor: int = 2,
28
+ channels_reduction: int = 4,
29
+ dropout_rate: float = 0.0,
30
+ downsample: bool = True,
31
+ use_global_mlp: bool = True,
32
+ use_bias: bool = True,
33
+ use_cross_gating: bool = False,
34
+ name: str = "unet_encoder",
35
+ ):
36
+ """Encoder block in MAXIM."""
37
+
38
+ def apply(x, skip=None, enc=None, dec=None):
39
+ if skip is not None:
40
+ x = tf.concat([x, skip], axis=-1)
41
+
42
+ # convolution-in
43
+ x = Conv1x1(filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_0")(x)
44
+ shortcut_long = x
45
+
46
+ for i in range(num_groups):
47
+ if use_global_mlp:
48
+ x = ResidualSplitHeadMultiAxisGmlpLayer(
49
+ grid_size=grid_size,
50
+ block_size=block_size,
51
+ grid_gmlp_factor=grid_gmlp_factor,
52
+ block_gmlp_factor=block_gmlp_factor,
53
+ input_proj_factor=input_proj_factor,
54
+ use_bias=use_bias,
55
+ dropout_rate=dropout_rate,
56
+ name=f"{name}_SplitHeadMultiAxisGmlpLayer_{i}",
57
+ )(x)
58
+ x = RCAB(
59
+ num_channels=num_channels,
60
+ reduction=channels_reduction,
61
+ lrelu_slope=lrelu_slope,
62
+ use_bias=use_bias,
63
+ name=f"{name}_channel_attention_block_1{i}",
64
+ )(x)
65
+
66
+ x = x + shortcut_long
67
+
68
+ if enc is not None and dec is not None:
69
+ assert use_cross_gating
70
+ x, _ = CrossGatingBlock(
71
+ features=num_channels,
72
+ block_size=block_size,
73
+ grid_size=grid_size,
74
+ dropout_rate=dropout_rate,
75
+ input_proj_factor=input_proj_factor,
76
+ upsample_y=False,
77
+ use_bias=use_bias,
78
+ name=f"{name}_cross_gating_block",
79
+ )(x, enc + dec)
80
+
81
+ if downsample:
82
+ x_down = Conv_down(
83
+ filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_1"
84
+ )(x)
85
+ return x_down, x
86
+ else:
87
+ return x
88
+
89
+ return apply
90
+
91
+
92
+ def UNetDecoderBlock(
93
+ num_channels: int,
94
+ block_size,
95
+ grid_size,
96
+ num_groups: int = 1,
97
+ lrelu_slope: float = 0.2,
98
+ block_gmlp_factor: int = 2,
99
+ grid_gmlp_factor: int = 2,
100
+ input_proj_factor: int = 2,
101
+ channels_reduction: int = 4,
102
+ dropout_rate: float = 0.0,
103
+ downsample: bool = True,
104
+ use_global_mlp: bool = True,
105
+ use_bias: bool = True,
106
+ name: str = "unet_decoder",
107
+ ):
108
+
109
+ """Decoder block in MAXIM."""
110
+
111
+ def apply(x, bridge=None):
112
+ x = ConvT_up(
113
+ filters=num_channels, use_bias=use_bias, name=f"{name}_ConvTranspose_0"
114
+ )(x)
115
+ x = UNetEncoderBlock(
116
+ num_channels=num_channels,
117
+ num_groups=num_groups,
118
+ lrelu_slope=lrelu_slope,
119
+ block_size=block_size,
120
+ grid_size=grid_size,
121
+ block_gmlp_factor=block_gmlp_factor,
122
+ grid_gmlp_factor=grid_gmlp_factor,
123
+ channels_reduction=channels_reduction,
124
+ use_global_mlp=use_global_mlp,
125
+ dropout_rate=dropout_rate,
126
+ downsample=False,
127
+ use_bias=use_bias,
128
+ name=f"{name}_UNetEncoderBlock_0",
129
+ )(x, skip=bridge)
130
+
131
+ return x
132
+
133
+ return apply
maxim/configs.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MAXIM_CONFIGS = {
2
+ # params: 6.108515000000001 M, GFLOPS: 93.163716608
3
+ "S-1": {
4
+ "features": 32,
5
+ "depth": 3,
6
+ "num_stages": 1,
7
+ "num_groups": 2,
8
+ "num_bottleneck_blocks": 2,
9
+ "block_gmlp_factor": 2,
10
+ "grid_gmlp_factor": 2,
11
+ "input_proj_factor": 2,
12
+ "channels_reduction": 4,
13
+ "name": "s1",
14
+ },
15
+ # params: 13.35383 M, GFLOPS: 206.743273472
16
+ "S-2": {
17
+ "features": 32,
18
+ "depth": 3,
19
+ "num_stages": 2,
20
+ "num_groups": 2,
21
+ "num_bottleneck_blocks": 2,
22
+ "block_gmlp_factor": 2,
23
+ "grid_gmlp_factor": 2,
24
+ "input_proj_factor": 2,
25
+ "channels_reduction": 4,
26
+ "name": "s2",
27
+ },
28
+ # params: 20.599145 M, GFLOPS: 320.32194560000005
29
+ "S-3": {
30
+ "features": 32,
31
+ "depth": 3,
32
+ "num_stages": 3,
33
+ "num_groups": 2,
34
+ "num_bottleneck_blocks": 2,
35
+ "block_gmlp_factor": 2,
36
+ "grid_gmlp_factor": 2,
37
+ "input_proj_factor": 2,
38
+ "channels_reduction": 4,
39
+ "name": "s3",
40
+ },
41
+ # params: 19.361219000000002 M, 308.495712256 GFLOPs
42
+ "M-1": {
43
+ "features": 64,
44
+ "depth": 3,
45
+ "num_stages": 1,
46
+ "num_groups": 2,
47
+ "num_bottleneck_blocks": 2,
48
+ "block_gmlp_factor": 2,
49
+ "grid_gmlp_factor": 2,
50
+ "input_proj_factor": 2,
51
+ "channels_reduction": 4,
52
+ "name": "m1",
53
+ },
54
+ # params: 40.83911 M, 675.25541888 GFLOPs
55
+ "M-2": {
56
+ "features": 64,
57
+ "depth": 3,
58
+ "num_stages": 2,
59
+ "num_groups": 2,
60
+ "num_bottleneck_blocks": 2,
61
+ "block_gmlp_factor": 2,
62
+ "grid_gmlp_factor": 2,
63
+ "input_proj_factor": 2,
64
+ "channels_reduction": 4,
65
+ "name": "m2",
66
+ },
67
+ # params: 62.317001 M, 1042.014666752 GFLOPs
68
+ "M-3": {
69
+ "features": 64,
70
+ "depth": 3,
71
+ "num_stages": 3,
72
+ "num_groups": 2,
73
+ "num_bottleneck_blocks": 2,
74
+ "block_gmlp_factor": 2,
75
+ "grid_gmlp_factor": 2,
76
+ "input_proj_factor": 2,
77
+ "channels_reduction": 4,
78
+ "name": "m3",
79
+ },
80
+ }
maxim/layers.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import einops
2
+ import tensorflow as tf
3
+ from tensorflow.experimental import numpy as tnp
4
+ from tensorflow.keras import backend as K
5
+ from tensorflow.keras import layers
6
+
7
+
8
+ @tf.keras.utils.register_keras_serializable("maxim")
9
+ class BlockImages(layers.Layer):
10
+ def __init__(self, **kwargs):
11
+ super().__init__(**kwargs)
12
+
13
+ def call(self, x, patch_size):
14
+ bs, h, w, num_channels = (
15
+ K.int_shape(x)[0],
16
+ K.int_shape(x)[1],
17
+ K.int_shape(x)[2],
18
+ K.int_shape(x)[3],
19
+ )
20
+
21
+ grid_height, grid_width = h // patch_size[0], w // patch_size[1]
22
+
23
+ x = einops.rearrange(
24
+ x,
25
+ "n (gh fh) (gw fw) c -> n (gh gw) (fh fw) c",
26
+ gh=grid_height,
27
+ gw=grid_width,
28
+ fh=patch_size[0],
29
+ fw=patch_size[1],
30
+ )
31
+
32
+ return x
33
+
34
+ def get_config(self):
35
+ config = super().get_config().copy()
36
+ return config
37
+
38
+
39
+ @tf.keras.utils.register_keras_serializable("maxim")
40
+ class UnblockImages(layers.Layer):
41
+ def __init__(self, **kwargs):
42
+ super().__init__(**kwargs)
43
+
44
+ def call(self, x, grid_size, patch_size):
45
+ x = einops.rearrange(
46
+ x,
47
+ "n (gh gw) (fh fw) c -> n (gh fh) (gw fw) c",
48
+ gh=grid_size[0],
49
+ gw=grid_size[1],
50
+ fh=patch_size[0],
51
+ fw=patch_size[1],
52
+ )
53
+
54
+ return x
55
+
56
+ def get_config(self):
57
+ config = super().get_config().copy()
58
+ return config
59
+
60
+
61
+ @tf.keras.utils.register_keras_serializable("maxim")
62
+ class SwapAxes(layers.Layer):
63
+ def __init__(self, **kwargs):
64
+ super().__init__(**kwargs)
65
+
66
+ def call(self, x, axis_one, axis_two):
67
+ return tnp.swapaxes(x, axis_one, axis_two)
68
+
69
+ def get_config(self):
70
+ config = super().get_config().copy()
71
+ return config
72
+
73
+
74
+ @tf.keras.utils.register_keras_serializable("maxim")
75
+ class Resizing(layers.Layer):
76
+ def __init__(self, height, width, antialias=True, method="bilinear", **kwargs):
77
+ super().__init__(**kwargs)
78
+ self.height = height
79
+ self.width = width
80
+ self.antialias = antialias
81
+ self.method = method
82
+
83
+ def call(self, x):
84
+ return tf.image.resize(
85
+ x,
86
+ size=(self.height, self.width),
87
+ antialias=self.antialias,
88
+ method=self.method,
89
+ )
90
+
91
+ def get_config(self):
92
+ config = super().get_config().copy()
93
+ config.update(
94
+ {
95
+ "height": self.height,
96
+ "width": self.width,
97
+ "antialias": self.antialias,
98
+ "method": self.method,
99
+ }
100
+ )
101
+ return config
maxim/maxim.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+
3
+ import tensorflow as tf
4
+ from tensorflow.keras import backend as K
5
+ from tensorflow.keras import layers
6
+
7
+ from .blocks.attentions import SAM
8
+ from .blocks.bottleneck import BottleneckBlock
9
+ from .blocks.misc_gating import CrossGatingBlock
10
+ from .blocks.others import UpSampleRatio
11
+ from .blocks.unet import UNetDecoderBlock, UNetEncoderBlock
12
+ from .layers import Resizing
13
+
14
+ Conv1x1 = functools.partial(layers.Conv2D, kernel_size=(1, 1), padding="same")
15
+ Conv3x3 = functools.partial(layers.Conv2D, kernel_size=(3, 3), padding="same")
16
+ ConvT_up = functools.partial(
17
+ layers.Conv2DTranspose, kernel_size=(2, 2), strides=(2, 2), padding="same"
18
+ )
19
+ Conv_down = functools.partial(
20
+ layers.Conv2D, kernel_size=(4, 4), strides=(2, 2), padding="same"
21
+ )
22
+
23
+
24
+ def MAXIM(
25
+ features: int = 64,
26
+ depth: int = 3,
27
+ num_stages: int = 2,
28
+ num_groups: int = 1,
29
+ use_bias: bool = True,
30
+ num_supervision_scales: int = 1,
31
+ lrelu_slope: float = 0.2,
32
+ use_global_mlp: bool = True,
33
+ use_cross_gating: bool = True,
34
+ high_res_stages: int = 2,
35
+ block_size_hr=(16, 16),
36
+ block_size_lr=(8, 8),
37
+ grid_size_hr=(16, 16),
38
+ grid_size_lr=(8, 8),
39
+ num_bottleneck_blocks: int = 1,
40
+ block_gmlp_factor: int = 2,
41
+ grid_gmlp_factor: int = 2,
42
+ input_proj_factor: int = 2,
43
+ channels_reduction: int = 4,
44
+ num_outputs: int = 3,
45
+ dropout_rate: float = 0.0,
46
+ ):
47
+ """The MAXIM model function with multi-stage and multi-scale supervision.
48
+
49
+ For more model details, please check the CVPR paper:
50
+ MAXIM: MUlti-Axis MLP for Image Processing (https://arxiv.org/abs/2201.02973)
51
+
52
+ Attributes:
53
+ features: initial hidden dimension for the input resolution.
54
+ depth: the number of downsampling depth for the model.
55
+ num_stages: how many stages to use. It will also affects the output list.
56
+ num_groups: how many blocks each stage contains.
57
+ use_bias: whether to use bias in all the conv/mlp layers.
58
+ num_supervision_scales: the number of desired supervision scales.
59
+ lrelu_slope: the negative slope parameter in leaky_relu layers.
60
+ use_global_mlp: whether to use the multi-axis gated MLP block (MAB) in each
61
+ layer.
62
+ use_cross_gating: whether to use the cross-gating MLP block (CGB) in the
63
+ skip connections and multi-stage feature fusion layers.
64
+ high_res_stages: how many stages are specificied as high-res stages. The
65
+ rest (depth - high_res_stages) are called low_res_stages.
66
+ block_size_hr: the block_size parameter for high-res stages.
67
+ block_size_lr: the block_size parameter for low-res stages.
68
+ grid_size_hr: the grid_size parameter for high-res stages.
69
+ grid_size_lr: the grid_size parameter for low-res stages.
70
+ num_bottleneck_blocks: how many bottleneck blocks.
71
+ block_gmlp_factor: the input projection factor for block_gMLP layers.
72
+ grid_gmlp_factor: the input projection factor for grid_gMLP layers.
73
+ input_proj_factor: the input projection factor for the MAB block.
74
+ channels_reduction: the channel reduction factor for SE layer.
75
+ num_outputs: the output channels.
76
+ dropout_rate: Dropout rate.
77
+
78
+ Returns:
79
+ The output contains a list of arrays consisting of multi-stage multi-scale
80
+ outputs. For example, if num_stages = num_supervision_scales = 3 (the
81
+ model used in the paper), the output specs are: outputs =
82
+ [[output_stage1_scale1, output_stage1_scale2, output_stage1_scale3],
83
+ [output_stage2_scale1, output_stage2_scale2, output_stage2_scale3],
84
+ [output_stage3_scale1, output_stage3_scale2, output_stage3_scale3],]
85
+ The final output can be retrieved by outputs[-1][-1].
86
+ """
87
+
88
+ def apply(x):
89
+ n, h, w, c = (
90
+ K.int_shape(x)[0],
91
+ K.int_shape(x)[1],
92
+ K.int_shape(x)[2],
93
+ K.int_shape(x)[3],
94
+ ) # input image shape
95
+
96
+ shortcuts = []
97
+ shortcuts.append(x)
98
+
99
+ # Get multi-scale input images
100
+ for i in range(1, num_supervision_scales):
101
+ resizing_layer = Resizing(
102
+ height=h // (2 ** i),
103
+ width=w // (2 ** i),
104
+ method="nearest",
105
+ antialias=True, # Following `jax.image.resize()`.
106
+ name=f"initial_resizing_{K.get_uid('Resizing')}",
107
+ )
108
+ shortcuts.append(resizing_layer(x))
109
+
110
+ # store outputs from all stages and all scales
111
+ # Eg, [[(64, 64, 3), (128, 128, 3), (256, 256, 3)], # Stage-1 outputs
112
+ # [(64, 64, 3), (128, 128, 3), (256, 256, 3)],] # Stage-2 outputs
113
+ outputs_all = []
114
+ sam_features, encs_prev, decs_prev = [], [], []
115
+
116
+ for idx_stage in range(num_stages):
117
+ # Input convolution, get multi-scale input features
118
+ x_scales = []
119
+ for i in range(num_supervision_scales):
120
+ x_scale = Conv3x3(
121
+ filters=(2 ** i) * features,
122
+ use_bias=use_bias,
123
+ name=f"stage_{idx_stage}_input_conv_{i}",
124
+ )(shortcuts[i])
125
+
126
+ # If later stages, fuse input features with SAM features from prev stage
127
+ if idx_stage > 0:
128
+ # use larger blocksize at high-res stages
129
+ if use_cross_gating:
130
+ block_size = (
131
+ block_size_hr if i < high_res_stages else block_size_lr
132
+ )
133
+ grid_size = grid_size_hr if i < high_res_stages else block_size_lr
134
+ x_scale, _ = CrossGatingBlock(
135
+ features=(2 ** i) * features,
136
+ block_size=block_size,
137
+ grid_size=grid_size,
138
+ dropout_rate=dropout_rate,
139
+ input_proj_factor=input_proj_factor,
140
+ upsample_y=False,
141
+ use_bias=use_bias,
142
+ name=f"stage_{idx_stage}_input_fuse_sam_{i}",
143
+ )(x_scale, sam_features.pop())
144
+ else:
145
+ x_scale = Conv1x1(
146
+ filters=(2 ** i) * features,
147
+ use_bias=use_bias,
148
+ name=f"stage_{idx_stage}_input_catconv_{i}",
149
+ )(tf.concat([x_scale, sam_features.pop()], axis=-1))
150
+
151
+ x_scales.append(x_scale)
152
+
153
+ # start encoder blocks
154
+ encs = []
155
+ x = x_scales[0] # First full-scale input feature
156
+
157
+ for i in range(depth): # 0, 1, 2
158
+ # use larger blocksize at high-res stages, vice versa.
159
+ block_size = block_size_hr if i < high_res_stages else block_size_lr
160
+ grid_size = grid_size_hr if i < high_res_stages else block_size_lr
161
+ use_cross_gating_layer = True if idx_stage > 0 else False
162
+
163
+ # Multi-scale input if multi-scale supervision
164
+ x_scale = x_scales[i] if i < num_supervision_scales else None
165
+
166
+ # UNet Encoder block
167
+ enc_prev = encs_prev.pop() if idx_stage > 0 else None
168
+ dec_prev = decs_prev.pop() if idx_stage > 0 else None
169
+
170
+ x, bridge = UNetEncoderBlock(
171
+ num_channels=(2 ** i) * features,
172
+ num_groups=num_groups,
173
+ downsample=True,
174
+ lrelu_slope=lrelu_slope,
175
+ block_size=block_size,
176
+ grid_size=grid_size,
177
+ block_gmlp_factor=block_gmlp_factor,
178
+ grid_gmlp_factor=grid_gmlp_factor,
179
+ input_proj_factor=input_proj_factor,
180
+ channels_reduction=channels_reduction,
181
+ use_global_mlp=use_global_mlp,
182
+ dropout_rate=dropout_rate,
183
+ use_bias=use_bias,
184
+ use_cross_gating=use_cross_gating_layer,
185
+ name=f"stage_{idx_stage}_encoder_block_{i}",
186
+ )(x, skip=x_scale, enc=enc_prev, dec=dec_prev)
187
+
188
+ # Cache skip signals
189
+ encs.append(bridge)
190
+
191
+ # Global MLP bottleneck blocks
192
+ for i in range(num_bottleneck_blocks):
193
+ x = BottleneckBlock(
194
+ block_size=block_size_lr,
195
+ grid_size=block_size_lr,
196
+ features=(2 ** (depth - 1)) * features,
197
+ num_groups=num_groups,
198
+ block_gmlp_factor=block_gmlp_factor,
199
+ grid_gmlp_factor=grid_gmlp_factor,
200
+ input_proj_factor=input_proj_factor,
201
+ dropout_rate=dropout_rate,
202
+ use_bias=use_bias,
203
+ channels_reduction=channels_reduction,
204
+ name=f"stage_{idx_stage}_global_block_{i}",
205
+ )(x)
206
+ # cache global feature for cross-gating
207
+ global_feature = x
208
+
209
+ # start cross gating. Use multi-scale feature fusion
210
+ skip_features = []
211
+ for i in reversed(range(depth)): # 2, 1, 0
212
+ # use larger blocksize at high-res stages
213
+ block_size = block_size_hr if i < high_res_stages else block_size_lr
214
+ grid_size = grid_size_hr if i < high_res_stages else block_size_lr
215
+
216
+ # get additional multi-scale signals
217
+ signal = tf.concat(
218
+ [
219
+ UpSampleRatio(
220
+ num_channels=(2 ** i) * features,
221
+ ratio=2 ** (j - i),
222
+ use_bias=use_bias,
223
+ name=f"UpSampleRatio_{K.get_uid('UpSampleRatio')}",
224
+ )(enc)
225
+ for j, enc in enumerate(encs)
226
+ ],
227
+ axis=-1,
228
+ )
229
+
230
+ # Use cross-gating to cross modulate features
231
+ if use_cross_gating:
232
+ skips, global_feature = CrossGatingBlock(
233
+ features=(2 ** i) * features,
234
+ block_size=block_size,
235
+ grid_size=grid_size,
236
+ input_proj_factor=input_proj_factor,
237
+ dropout_rate=dropout_rate,
238
+ upsample_y=True,
239
+ use_bias=use_bias,
240
+ name=f"stage_{idx_stage}_cross_gating_block_{i}",
241
+ )(signal, global_feature)
242
+ else:
243
+ skips = Conv1x1(
244
+ filters=(2 ** i) * features, use_bias=use_bias, name="Conv_0"
245
+ )(signal)
246
+ skips = Conv3x3(
247
+ filters=(2 ** i) * features, use_bias=use_bias, name="Conv_1"
248
+ )(skips)
249
+
250
+ skip_features.append(skips)
251
+
252
+ # start decoder. Multi-scale feature fusion of cross-gated features
253
+ outputs, decs, sam_features = [], [], []
254
+ for i in reversed(range(depth)):
255
+ # use larger blocksize at high-res stages
256
+ block_size = block_size_hr if i < high_res_stages else block_size_lr
257
+ grid_size = grid_size_hr if i < high_res_stages else block_size_lr
258
+
259
+ # get multi-scale skip signals from cross-gating block
260
+ signal = tf.concat(
261
+ [
262
+ UpSampleRatio(
263
+ num_channels=(2 ** i) * features,
264
+ ratio=2 ** (depth - j - 1 - i),
265
+ use_bias=use_bias,
266
+ name=f"UpSampleRatio_{K.get_uid('UpSampleRatio')}",
267
+ )(skip)
268
+ for j, skip in enumerate(skip_features)
269
+ ],
270
+ axis=-1,
271
+ )
272
+
273
+ # Decoder block
274
+ x = UNetDecoderBlock(
275
+ num_channels=(2 ** i) * features,
276
+ num_groups=num_groups,
277
+ lrelu_slope=lrelu_slope,
278
+ block_size=block_size,
279
+ grid_size=grid_size,
280
+ block_gmlp_factor=block_gmlp_factor,
281
+ grid_gmlp_factor=grid_gmlp_factor,
282
+ input_proj_factor=input_proj_factor,
283
+ channels_reduction=channels_reduction,
284
+ use_global_mlp=use_global_mlp,
285
+ dropout_rate=dropout_rate,
286
+ use_bias=use_bias,
287
+ name=f"stage_{idx_stage}_decoder_block_{i}",
288
+ )(x, bridge=signal)
289
+
290
+ # Cache decoder features for later-stage's usage
291
+ decs.append(x)
292
+
293
+ # output conv, if not final stage, use supervised-attention-block.
294
+ if i < num_supervision_scales:
295
+ if idx_stage < num_stages - 1: # not last stage, apply SAM
296
+ sam, output = SAM(
297
+ num_channels=(2 ** i) * features,
298
+ output_channels=num_outputs,
299
+ use_bias=use_bias,
300
+ name=f"stage_{idx_stage}_supervised_attention_module_{i}",
301
+ )(x, shortcuts[i])
302
+ outputs.append(output)
303
+ sam_features.append(sam)
304
+ else: # Last stage, apply output convolutions
305
+ output = Conv3x3(
306
+ num_outputs,
307
+ use_bias=use_bias,
308
+ name=f"stage_{idx_stage}_output_conv_{i}",
309
+ )(x)
310
+ output = output + shortcuts[i]
311
+ outputs.append(output)
312
+ # Cache encoder and decoder features for later-stage's usage
313
+ encs_prev = encs[::-1]
314
+ decs_prev = decs
315
+
316
+ # Store outputs
317
+ outputs_all.append(outputs)
318
+ return outputs_all
319
+
320
+ return apply