|
import functools |
|
|
|
import tensorflow as tf |
|
from tensorflow.keras import layers |
|
|
|
from .others import MlpBlock |
|
|
|
Conv3x3 = functools.partial(layers.Conv2D, kernel_size=(3, 3), padding="same") |
|
Conv1x1 = functools.partial(layers.Conv2D, kernel_size=(1, 1), padding="same") |
|
|
|
|
|
def CALayer( |
|
num_channels: int, |
|
reduction: int = 4, |
|
use_bias: bool = True, |
|
name: str = "channel_attention", |
|
): |
|
"""Squeeze-and-excitation block for channel attention. |
|
|
|
ref: https://arxiv.org/abs/1709.01507 |
|
""" |
|
|
|
def apply(x): |
|
|
|
y = layers.GlobalAvgPool2D(keepdims=True)(x) |
|
|
|
y = Conv1x1( |
|
filters=num_channels // reduction, use_bias=use_bias, name=f"{name}_Conv_0" |
|
)(y) |
|
y = tf.nn.relu(y) |
|
|
|
y = Conv1x1(filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_1")(y) |
|
y = tf.nn.sigmoid(y) |
|
return x * y |
|
|
|
return apply |
|
|
|
|
|
def RCAB( |
|
num_channels: int, |
|
reduction: int = 4, |
|
lrelu_slope: float = 0.2, |
|
use_bias: bool = True, |
|
name: str = "residual_ca", |
|
): |
|
"""Residual channel attention block. Contains LN,Conv,lRelu,Conv,SELayer.""" |
|
|
|
def apply(x): |
|
shortcut = x |
|
x = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm")(x) |
|
x = Conv3x3(filters=num_channels, use_bias=use_bias, name=f"{name}_conv1")(x) |
|
x = tf.nn.leaky_relu(x, alpha=lrelu_slope) |
|
x = Conv3x3(filters=num_channels, use_bias=use_bias, name=f"{name}_conv2")(x) |
|
x = CALayer( |
|
num_channels=num_channels, |
|
reduction=reduction, |
|
use_bias=use_bias, |
|
name=f"{name}_channel_attention", |
|
)(x) |
|
return x + shortcut |
|
|
|
return apply |
|
|
|
|
|
def RDCAB( |
|
num_channels: int, |
|
reduction: int = 16, |
|
use_bias: bool = True, |
|
dropout_rate: float = 0.0, |
|
name: str = "rdcab", |
|
): |
|
"""Residual dense channel attention block. Used in Bottlenecks.""" |
|
|
|
def apply(x): |
|
y = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm")(x) |
|
y = MlpBlock( |
|
mlp_dim=num_channels, |
|
dropout_rate=dropout_rate, |
|
use_bias=use_bias, |
|
name=f"{name}_channel_mixing", |
|
)(y) |
|
y = CALayer( |
|
num_channels=num_channels, |
|
reduction=reduction, |
|
use_bias=use_bias, |
|
name=f"{name}_channel_attention", |
|
)(y) |
|
x = x + y |
|
return x |
|
|
|
return apply |
|
|
|
|
|
def SAM( |
|
num_channels: int, |
|
output_channels: int = 3, |
|
use_bias: bool = True, |
|
name: str = "sam", |
|
): |
|
|
|
"""Supervised attention module for multi-stage training. |
|
|
|
Introduced by MPRNet [CVPR2021]: https://github.com/swz30/MPRNet |
|
""" |
|
|
|
def apply(x, x_image): |
|
"""Apply the SAM module to the input and num_channels. |
|
Args: |
|
x: the output num_channels from UNet decoder with shape (h, w, c) |
|
x_image: the input image with shape (h, w, 3) |
|
Returns: |
|
A tuple of tensors (x1, image) where (x1) is the sam num_channels used for the |
|
next stage, and (image) is the output restored image at current stage. |
|
""" |
|
|
|
x1 = Conv3x3(filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_0")(x) |
|
|
|
|
|
if output_channels == 3: |
|
image = ( |
|
Conv3x3( |
|
filters=output_channels, use_bias=use_bias, name=f"{name}_Conv_1" |
|
)(x) |
|
+ x_image |
|
) |
|
else: |
|
image = Conv3x3( |
|
filters=output_channels, use_bias=use_bias, name=f"{name}_Conv_1" |
|
)(x) |
|
|
|
|
|
x2 = tf.nn.sigmoid( |
|
Conv3x3(filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_2")(image) |
|
) |
|
|
|
|
|
x1 = x1 * x2 |
|
|
|
|
|
x1 = x1 + x |
|
return x1, image |
|
|
|
return apply |
|
|