sayakpaul's picture
sayakpaul HF staff
add: files.
3f9d71f
raw
history blame
6.53 kB
import functools
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras import layers
from ..layers import BlockImages, SwapAxes, UnblockImages
from .block_gating import BlockGmlpLayer
from .grid_gating import GridGmlpLayer
Conv1x1 = functools.partial(layers.Conv2D, kernel_size=(1, 1), padding="same")
Conv3x3 = functools.partial(layers.Conv2D, kernel_size=(3, 3), padding="same")
ConvT_up = functools.partial(
layers.Conv2DTranspose, kernel_size=(2, 2), strides=(2, 2), padding="same"
)
Conv_down = functools.partial(
layers.Conv2D, kernel_size=(4, 4), strides=(2, 2), padding="same"
)
def ResidualSplitHeadMultiAxisGmlpLayer(
block_size,
grid_size,
block_gmlp_factor: int = 2,
grid_gmlp_factor: int = 2,
input_proj_factor: int = 2,
use_bias: bool = True,
dropout_rate: float = 0.0,
name: str = "residual_split_head_maxim",
):
"""The multi-axis gated MLP block."""
def apply(x):
shortcut = x
n, h, w, num_channels = (
K.int_shape(x)[0],
K.int_shape(x)[1],
K.int_shape(x)[2],
K.int_shape(x)[3],
)
x = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm_in")(x)
x = layers.Dense(
int(num_channels) * input_proj_factor,
use_bias=use_bias,
name=f"{name}_in_project",
)(x)
x = tf.nn.gelu(x, approximate=True)
u, v = tf.split(x, 2, axis=-1)
# GridGMLPLayer
u = GridGmlpLayer(
grid_size=grid_size,
factor=grid_gmlp_factor,
use_bias=use_bias,
dropout_rate=dropout_rate,
name=f"{name}_GridGmlpLayer",
)(u)
# BlockGMLPLayer
v = BlockGmlpLayer(
block_size=block_size,
factor=block_gmlp_factor,
use_bias=use_bias,
dropout_rate=dropout_rate,
name=f"{name}_BlockGmlpLayer",
)(v)
x = tf.concat([u, v], axis=-1)
x = layers.Dense(
num_channels,
use_bias=use_bias,
name=f"{name}_out_project",
)(x)
x = layers.Dropout(dropout_rate)(x)
x = x + shortcut
return x
return apply
def GetSpatialGatingWeights(
features: int,
block_size,
grid_size,
input_proj_factor: int = 2,
dropout_rate: float = 0.0,
use_bias: bool = True,
name: str = "spatial_gating",
):
"""Get gating weights for cross-gating MLP block."""
def apply(x):
n, h, w, num_channels = (
K.int_shape(x)[0],
K.int_shape(x)[1],
K.int_shape(x)[2],
K.int_shape(x)[3],
)
# input projection
x = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm_in")(x)
x = layers.Dense(
num_channels * input_proj_factor,
use_bias=use_bias,
name=f"{name}_in_project",
)(x)
x = tf.nn.gelu(x, approximate=True)
u, v = tf.split(x, 2, axis=-1)
# Get grid MLP weights
gh, gw = grid_size
fh, fw = h // gh, w // gw
u = BlockImages()(u, patch_size=(fh, fw))
dim_u = K.int_shape(u)[-3]
u = SwapAxes()(u, -1, -3)
u = layers.Dense(dim_u, use_bias=use_bias, name=f"{name}_Dense_0")(u)
u = SwapAxes()(u, -1, -3)
u = UnblockImages()(u, grid_size=(gh, gw), patch_size=(fh, fw))
# Get Block MLP weights
fh, fw = block_size
gh, gw = h // fh, w // fw
v = BlockImages()(v, patch_size=(fh, fw))
dim_v = K.int_shape(v)[-2]
v = SwapAxes()(v, -1, -2)
v = layers.Dense(dim_v, use_bias=use_bias, name=f"{name}_Dense_1")(v)
v = SwapAxes()(v, -1, -2)
v = UnblockImages()(v, grid_size=(gh, gw), patch_size=(fh, fw))
x = tf.concat([u, v], axis=-1)
x = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_out_project")(x)
x = layers.Dropout(dropout_rate)(x)
return x
return apply
def CrossGatingBlock(
features: int,
block_size,
grid_size,
dropout_rate: float = 0.0,
input_proj_factor: int = 2,
upsample_y: bool = True,
use_bias: bool = True,
name: str = "cross_gating",
):
"""Cross-gating MLP block."""
def apply(x, y):
# Upscale Y signal, y is the gating signal.
if upsample_y:
y = ConvT_up(
filters=features, use_bias=use_bias, name=f"{name}_ConvTranspose_0"
)(y)
x = Conv1x1(filters=features, use_bias=use_bias, name=f"{name}_Conv_0")(x)
n, h, w, num_channels = (
K.int_shape(x)[0],
K.int_shape(x)[1],
K.int_shape(x)[2],
K.int_shape(x)[3],
)
y = Conv1x1(filters=num_channels, use_bias=use_bias, name=f"{name}_Conv_1")(y)
shortcut_x = x
shortcut_y = y
# Get gating weights from X
x = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm_x")(x)
x = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_in_project_x")(x)
x = tf.nn.gelu(x, approximate=True)
gx = GetSpatialGatingWeights(
features=num_channels,
block_size=block_size,
grid_size=grid_size,
dropout_rate=dropout_rate,
use_bias=use_bias,
name=f"{name}_SplitHeadMultiAxisGating_x",
)(x)
# Get gating weights from Y
y = layers.LayerNormalization(epsilon=1e-06, name=f"{name}_LayerNorm_y")(y)
y = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_in_project_y")(y)
y = tf.nn.gelu(y, approximate=True)
gy = GetSpatialGatingWeights(
features=num_channels,
block_size=block_size,
grid_size=grid_size,
dropout_rate=dropout_rate,
use_bias=use_bias,
name=f"{name}_SplitHeadMultiAxisGating_y",
)(y)
# Apply cross gating: X = X * GY, Y = Y * GX
y = y * gx
y = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_out_project_y")(y)
y = layers.Dropout(dropout_rate)(y)
y = y + shortcut_y
x = x * gy # gating x using y
x = layers.Dense(num_channels, use_bias=use_bias, name=f"{name}_out_project_x")(x)
x = layers.Dropout(dropout_rate)(x)
x = x + y + shortcut_x # get all aggregated signals
return x, y
return apply