Seokju Cho
initial commit
f1586f7
# Copyright 2024 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Pytorch neural network definitions."""
from typing import Sequence, Union
import torch
from torch import nn
import torch.nn.functional as F
from models.utils import Conv2dSamePadding
class ExtraConvBlock(nn.Module):
"""Additional convolution block."""
def __init__(
self,
channel_dim,
channel_multiplier,
):
super().__init__()
self.channel_dim = channel_dim
self.channel_multiplier = channel_multiplier
self.layer_norm = nn.LayerNorm(
normalized_shape=channel_dim, elementwise_affine=True, bias=True
)
self.conv = nn.Conv2d(
self.channel_dim,
self.channel_dim * self.channel_multiplier,
kernel_size=3,
stride=1,
padding=1,
)
self.conv_1 = nn.Conv2d(
self.channel_dim * self.channel_multiplier,
self.channel_dim,
kernel_size=3,
stride=1,
padding=1,
)
def forward(self, x):
x = self.layer_norm(x)
x = x.permute(0, 3, 1, 2)
res = self.conv(x)
res = F.gelu(res, approximate='tanh')
x = x + self.conv_1(res)
x = x.permute(0, 2, 3, 1)
return x
class ExtraConvs(nn.Module):
"""Additional CNN."""
def __init__(
self,
num_layers=5,
channel_dim=256,
channel_multiplier=4,
):
super().__init__()
self.num_layers = num_layers
self.channel_dim = channel_dim
self.channel_multiplier = channel_multiplier
self.blocks = nn.ModuleList()
for _ in range(self.num_layers):
self.blocks.append(
ExtraConvBlock(self.channel_dim, self.channel_multiplier)
)
def forward(self, x):
for block in self.blocks:
x = block(x)
return x
class ConvChannelsMixer(nn.Module):
"""Linear activation block for PIPs's MLP Mixer."""
def __init__(self, in_channels):
super().__init__()
self.mlp2_up = nn.Linear(in_channels, in_channels * 4)
self.mlp2_down = nn.Linear(in_channels * 4, in_channels)
def forward(self, x):
x = self.mlp2_up(x)
x = F.gelu(x, approximate='tanh')
x = self.mlp2_down(x)
return x
class PIPsConvBlock(nn.Module):
"""Convolutional block for PIPs's MLP Mixer."""
def __init__(
self, in_channels, kernel_shape=3, use_causal_conv=False, block_idx=None
):
super().__init__()
self.use_causal_conv = use_causal_conv
self.block_name = f'block_{block_idx}'
self.kernel_shape = kernel_shape
self.layer_norm = nn.LayerNorm(
normalized_shape=in_channels, elementwise_affine=True, bias=False
)
self.mlp1_up = nn.Conv1d(
in_channels,
in_channels * 4,
kernel_shape,
stride=1,
padding=0 if self.use_causal_conv else 1,
groups=in_channels,
)
self.mlp1_up_1 = nn.Conv1d(
in_channels * 4,
in_channels * 4,
kernel_shape,
stride=1,
padding=0 if self.use_causal_conv else 1,
groups=in_channels * 4,
)
self.layer_norm_1 = nn.LayerNorm(
normalized_shape=in_channels, elementwise_affine=True, bias=False
)
self.conv_channels_mixer = ConvChannelsMixer(in_channels)
def forward(self, x, causal_context=None, get_causal_context=False):
to_skip = x
x = self.layer_norm(x)
new_causal_context = {}
num_extra = 0
if causal_context is not None:
name1 = self.block_name + '_causal_1'
x = torch.cat([causal_context[name1], x], dim=-2)
num_extra = causal_context[name1].shape[-2]
new_causal_context[name1] = x[..., -(self.kernel_shape - 1) :, :]
x = x.permute(0, 2, 1)
if self.use_causal_conv:
x = F.pad(x, (2, 0))
x = self.mlp1_up(x)
x = F.gelu(x, approximate='tanh')
if causal_context is not None:
x = x.permute(0, 2, 1)
name2 = self.block_name + '_causal_2'
num_extra = causal_context[name2].shape[-2]
x = torch.cat([causal_context[name2], x[..., num_extra:, :]], dim=-2)
new_causal_context[name2] = x[..., -(self.kernel_shape - 1) :, :]
x = x.permute(0, 2, 1)
if self.use_causal_conv:
x = F.pad(x, (2, 0))
x = self.mlp1_up_1(x)
x = x.permute(0, 2, 1)
if causal_context is not None:
x = x[..., num_extra:, :]
x = x[..., 0::4] + x[..., 1::4] + x[..., 2::4] + x[..., 3::4]
x = x + to_skip
to_skip = x
x = self.layer_norm_1(x)
x = self.conv_channels_mixer(x)
x = x + to_skip
return x, new_causal_context
class PIPSMLPMixer(nn.Module):
"""Depthwise-conv version of PIPs's MLP Mixer."""
def __init__(
self,
input_channels: int,
output_channels: int,
hidden_dim: int = 512,
num_blocks: int = 12,
kernel_shape: int = 3,
use_causal_conv: bool = False,
):
"""Inits Mixer module.
A depthwise-convolutional version of a MLP Mixer for processing images.
Args:
input_channels (int): The number of input channels.
output_channels (int): The number of output channels.
hidden_dim (int, optional): The dimension of the hidden layer. Defaults
to 512.
num_blocks (int, optional): The number of convolution blocks in the
mixer. Defaults to 12.
kernel_shape (int, optional): The size of the kernel in the convolution
blocks. Defaults to 3.
use_causal_conv (bool, optional): Whether to use causal convolutions.
Defaults to False.
"""
super().__init__()
self.hidden_dim = hidden_dim
self.num_blocks = num_blocks
self.use_causal_conv = use_causal_conv
self.linear = nn.Linear(input_channels, self.hidden_dim)
self.layer_norm = nn.LayerNorm(
normalized_shape=hidden_dim, elementwise_affine=True, bias=False
)
self.linear_1 = nn.Linear(hidden_dim, output_channels)
self.blocks = nn.ModuleList([
PIPsConvBlock(
hidden_dim, kernel_shape, self.use_causal_conv, block_idx=i
)
for i in range(num_blocks)
])
def forward(self, x, causal_context=None, get_causal_context=False):
x = self.linear(x)
all_causal_context = {}
for block in self.blocks:
x, new_causal_context = block(x, causal_context, get_causal_context)
if get_causal_context:
all_causal_context.update(new_causal_context)
x = self.layer_norm(x)
x = self.linear_1(x)
return x, all_causal_context
class BlockV2(nn.Module):
"""ResNet V2 block."""
def __init__(
self,
channels_in: int,
channels_out: int,
stride: Union[int, Sequence[int]],
use_projection: bool,
):
super().__init__()
self.padding = (1, 1, 1, 1)
# Handle assymetric padding created by padding="SAME" in JAX/LAX
if stride == 1:
self.padding = (1, 1, 1, 1)
elif stride == 2:
self.padding = (0, 2, 0, 2)
else:
raise ValueError(
'Check correct padding using padtype_to_padsin jax._src.lax.lax'
)
self.use_projection = use_projection
if self.use_projection:
self.proj_conv = Conv2dSamePadding(
in_channels=channels_in,
out_channels=channels_out,
kernel_size=1,
stride=stride,
padding=0,
bias=False,
)
self.bn_0 = nn.InstanceNorm2d(
num_features=channels_in,
eps=1e-05,
momentum=0.1,
affine=True,
track_running_stats=False,
)
self.conv_0 = Conv2dSamePadding(
in_channels=channels_in,
out_channels=channels_out,
kernel_size=3,
stride=stride,
padding=0,
bias=False,
)
self.conv_1 = Conv2dSamePadding(
in_channels=channels_out,
out_channels=channels_out,
kernel_size=3,
stride=1,
padding=1,
bias=False,
)
self.bn_1 = nn.InstanceNorm2d(
num_features=channels_out,
eps=1e-05,
momentum=0.1,
affine=True,
track_running_stats=False,
)
def forward(self, inputs):
x = shortcut = inputs
x = self.bn_0(x)
x = torch.relu(x)
if self.use_projection:
shortcut = self.proj_conv(x)
x = self.conv_0(x)
x = self.bn_1(x)
x = torch.relu(x)
# no issues with padding here as this layer always has stride 1
x = self.conv_1(x)
return x + shortcut
class BlockGroup(nn.Module):
"""Higher level block for ResNet implementation."""
def __init__(
self,
channels_in: int,
channels_out: int,
num_blocks: int,
stride: Union[int, Sequence[int]],
use_projection: bool,
):
super().__init__()
blocks = []
for i in range(num_blocks):
blocks.append(
BlockV2(
channels_in=channels_in if i == 0 else channels_out,
channels_out=channels_out,
stride=(1 if i else stride),
use_projection=(i == 0 and use_projection),
)
)
self.blocks = nn.ModuleList(blocks)
def forward(self, inputs):
out = inputs
for block in self.blocks:
out = block(out)
return out
class ResNet(nn.Module):
"""ResNet model."""
def __init__(
self,
blocks_per_group: Sequence[int],
channels_per_group: Sequence[int] = (64, 128, 256, 512),
use_projection: Sequence[bool] = (True, True, True, True),
strides: Sequence[int] = (1, 2, 2, 2),
):
"""Initializes a ResNet model with customizable layers and configurations.
This constructor allows defining the architecture of a ResNet model by
setting the number of blocks, channels, projection usage, and strides for
each group of blocks within the network. It provides flexibility in
creating various ResNet configurations.
Args:
blocks_per_group: A sequence of 4 integers, each indicating the number
of residual blocks in each group.
channels_per_group: A sequence of 4 integers, each specifying the number
of output channels for the blocks in each group. Defaults to (64, 128,
256, 512).
use_projection: A sequence of 4 booleans, each indicating whether to use
a projection shortcut (True) or an identity shortcut (False) in each
group. Defaults to (True, True, True, True).
strides: A sequence of 4 integers, each specifying the stride size for
the convolutions in each group. Defaults to (1, 2, 2, 2).
The ResNet model created will have 4 groups, with each group's
architecture defined by the corresponding elements in these sequences.
"""
super().__init__()
self.initial_conv = Conv2dSamePadding(
in_channels=3,
out_channels=channels_per_group[0],
kernel_size=(7, 7),
stride=2,
padding=0,
bias=False,
)
block_groups = []
for i, _ in enumerate(strides):
block_groups.append(
BlockGroup(
channels_in=channels_per_group[i - 1] if i > 0 else 64,
channels_out=channels_per_group[i],
num_blocks=blocks_per_group[i],
stride=strides[i],
use_projection=use_projection[i],
)
)
self.block_groups = nn.ModuleList(block_groups)
def forward(self, inputs):
result = {}
out = inputs
out = self.initial_conv(out)
result['initial_conv'] = out
for block_id, block_group in enumerate(self.block_groups):
out = block_group(out)
result[f'resnet_unit_{block_id}'] = out
return result