Spaces:
Runtime error
Runtime error
File size: 1,662 Bytes
3f9d71f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
import functools
from tensorflow.keras import layers
from .attentions import RDCAB
from .misc_gating import ResidualSplitHeadMultiAxisGmlpLayer
Conv1x1 = functools.partial(layers.Conv2D, kernel_size=(1, 1), padding="same")
def BottleneckBlock(
features: int,
block_size,
grid_size,
num_groups: int = 1,
block_gmlp_factor: int = 2,
grid_gmlp_factor: int = 2,
input_proj_factor: int = 2,
channels_reduction: int = 4,
dropout_rate: float = 0.0,
use_bias: bool = True,
name: str = "bottleneck_block",
):
"""The bottleneck block consisting of multi-axis gMLP block and RDCAB."""
def apply(x):
# input projection
x = Conv1x1(filters=features, use_bias=use_bias, name=f"{name}_input_proj")(x)
shortcut_long = x
for i in range(num_groups):
x = ResidualSplitHeadMultiAxisGmlpLayer(
grid_size=grid_size,
block_size=block_size,
grid_gmlp_factor=grid_gmlp_factor,
block_gmlp_factor=block_gmlp_factor,
input_proj_factor=input_proj_factor,
use_bias=use_bias,
dropout_rate=dropout_rate,
name=f"{name}_SplitHeadMultiAxisGmlpLayer_{i}",
)(x)
# Channel-mixing part, which provides within-patch communication.
x = RDCAB(
num_channels=features,
reduction=channels_reduction,
use_bias=use_bias,
name=f"{name}_channel_attention_block_1_{i}",
)(x)
# long skip-connect
x = x + shortcut_long
return x
return apply
|