sayakpaul's picture
sayakpaul HF staff
add: files.
3f9d71f
raw
history blame
1.66 kB
import functools
from tensorflow.keras import layers
from .attentions import RDCAB
from .misc_gating import ResidualSplitHeadMultiAxisGmlpLayer
Conv1x1 = functools.partial(layers.Conv2D, kernel_size=(1, 1), padding="same")
def BottleneckBlock(
features: int,
block_size,
grid_size,
num_groups: int = 1,
block_gmlp_factor: int = 2,
grid_gmlp_factor: int = 2,
input_proj_factor: int = 2,
channels_reduction: int = 4,
dropout_rate: float = 0.0,
use_bias: bool = True,
name: str = "bottleneck_block",
):
"""The bottleneck block consisting of multi-axis gMLP block and RDCAB."""
def apply(x):
# input projection
x = Conv1x1(filters=features, use_bias=use_bias, name=f"{name}_input_proj")(x)
shortcut_long = x
for i in range(num_groups):
x = ResidualSplitHeadMultiAxisGmlpLayer(
grid_size=grid_size,
block_size=block_size,
grid_gmlp_factor=grid_gmlp_factor,
block_gmlp_factor=block_gmlp_factor,
input_proj_factor=input_proj_factor,
use_bias=use_bias,
dropout_rate=dropout_rate,
name=f"{name}_SplitHeadMultiAxisGmlpLayer_{i}",
)(x)
# Channel-mixing part, which provides within-patch communication.
x = RDCAB(
num_channels=features,
reduction=channels_reduction,
use_bias=use_bias,
name=f"{name}_channel_attention_block_1_{i}",
)(x)
# long skip-connect
x = x + shortcut_long
return x
return apply