TomatoCocotree
上传
6a62ffb
raw
history blame
5.69 kB
from typing import Optional, List
import torch
from torch import Tensor
from torch.nn import ModuleList, Module, Upsample
from tha3.nn.common.conv_block_factory import ConvBlockFactory
from tha3.nn.nonlinearity_factory import ReLUFactory
from tha3.nn.normalization import InstanceNorm2dFactory
from tha3.nn.util import BlockArgs
class ResizeConvUNetArgs:
def __init__(self,
image_size: int,
input_channels: int,
start_channels: int,
bottleneck_image_size: int,
num_bottleneck_blocks: int,
max_channels: int,
upsample_mode: str = 'bilinear',
block_args: Optional[BlockArgs] = None,
use_separable_convolution: bool = False):
if block_args is None:
block_args = BlockArgs(
normalization_layer_factory=InstanceNorm2dFactory(),
nonlinearity_factory=ReLUFactory(inplace=False))
self.use_separable_convolution = use_separable_convolution
self.block_args = block_args
self.upsample_mode = upsample_mode
self.max_channels = max_channels
self.num_bottleneck_blocks = num_bottleneck_blocks
self.bottleneck_image_size = bottleneck_image_size
self.input_channels = input_channels
self.start_channels = start_channels
self.image_size = image_size
class ResizeConvUNet(Module):
def __init__(self, args: ResizeConvUNetArgs):
super().__init__()
self.args = args
conv_block_factory = ConvBlockFactory(args.block_args, args.use_separable_convolution)
self.downsample_blocks = ModuleList()
self.downsample_blocks.append(conv_block_factory.create_conv3_block(
self.args.input_channels,
self.args.start_channels))
current_channels = self.args.start_channels
current_size = self.args.image_size
size_to_channel = {
current_size: current_channels
}
while current_size > self.args.bottleneck_image_size:
next_size = current_size // 2
next_channels = min(self.args.max_channels, current_channels * 2)
self.downsample_blocks.append(conv_block_factory.create_downsample_block(
current_channels,
next_channels,
is_output_1x1=False))
current_size = next_size
current_channels = next_channels
size_to_channel[current_size] = current_channels
self.bottleneck_blocks = ModuleList()
for i in range(self.args.num_bottleneck_blocks):
self.bottleneck_blocks.append(conv_block_factory.create_resnet_block(current_channels, is_1x1=False))
self.output_image_sizes = [current_size]
self.output_num_channels = [current_channels]
self.upsample_blocks = ModuleList()
while current_size < self.args.image_size:
next_size = current_size * 2
next_channels = size_to_channel[next_size]
self.upsample_blocks.append(conv_block_factory.create_conv3_block(
current_channels + next_channels,
next_channels))
current_size = next_size
current_channels = next_channels
self.output_image_sizes.append(current_size)
self.output_num_channels.append(current_channels)
if args.upsample_mode == 'nearest':
align_corners = None
else:
align_corners = False
self.double_resolution = Upsample(scale_factor=2, mode=args.upsample_mode, align_corners=align_corners)
def forward(self, feature: Tensor) -> List[Tensor]:
downsampled_features = []
for block in self.downsample_blocks:
feature = block(feature)
downsampled_features.append(feature)
for block in self.bottleneck_blocks:
feature = block(feature)
outputs = [feature]
for i in range(0, len(self.upsample_blocks)):
feature = self.double_resolution(feature)
feature = torch.cat([feature, downsampled_features[-i - 2]], dim=1)
feature = self.upsample_blocks[i](feature)
outputs.append(feature)
return outputs
if __name__ == "__main__":
device = torch.device('cuda')
image_size = 512
image_channels = 4
num_pose_params = 6
args = ResizeConvUNetArgs(
image_size=512,
input_channels=10,
start_channels=32,
bottleneck_image_size=32,
num_bottleneck_blocks=6,
max_channels=512,
upsample_mode='nearest',
use_separable_convolution=False,
block_args=BlockArgs(
initialization_method='he',
use_spectral_norm=False,
normalization_layer_factory=InstanceNorm2dFactory(),
nonlinearity_factory=ReLUFactory(inplace=False)))
module = ResizeConvUNet(args).to(device)
image_count = 8
input = torch.zeros(image_count, 10, 512, 512, device=device)
outputs = module.forward(input)
for output in outputs:
print(output.shape)
if True:
repeat = 100
acc = 0.0
for i in range(repeat + 2):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
module.forward(input)
end.record()
torch.cuda.synchronize()
if i >= 2:
elapsed_time = start.elapsed_time(end)
print("%d:" % i, elapsed_time)
acc = acc + elapsed_time
print("average:", acc / repeat)