Spaces:
Runtime error
Runtime error
from typing import Optional, List | |
import torch | |
from torch import Tensor | |
from torch.nn import ModuleList, Module, Upsample | |
from tha3.nn.common.conv_block_factory import ConvBlockFactory | |
from tha3.nn.nonlinearity_factory import ReLUFactory | |
from tha3.nn.normalization import InstanceNorm2dFactory | |
from tha3.nn.util import BlockArgs | |
class ResizeConvUNetArgs: | |
def __init__(self, | |
image_size: int, | |
input_channels: int, | |
start_channels: int, | |
bottleneck_image_size: int, | |
num_bottleneck_blocks: int, | |
max_channels: int, | |
upsample_mode: str = 'bilinear', | |
block_args: Optional[BlockArgs] = None, | |
use_separable_convolution: bool = False): | |
if block_args is None: | |
block_args = BlockArgs( | |
normalization_layer_factory=InstanceNorm2dFactory(), | |
nonlinearity_factory=ReLUFactory(inplace=False)) | |
self.use_separable_convolution = use_separable_convolution | |
self.block_args = block_args | |
self.upsample_mode = upsample_mode | |
self.max_channels = max_channels | |
self.num_bottleneck_blocks = num_bottleneck_blocks | |
self.bottleneck_image_size = bottleneck_image_size | |
self.input_channels = input_channels | |
self.start_channels = start_channels | |
self.image_size = image_size | |
class ResizeConvUNet(Module): | |
def __init__(self, args: ResizeConvUNetArgs): | |
super().__init__() | |
self.args = args | |
conv_block_factory = ConvBlockFactory(args.block_args, args.use_separable_convolution) | |
self.downsample_blocks = ModuleList() | |
self.downsample_blocks.append(conv_block_factory.create_conv3_block( | |
self.args.input_channels, | |
self.args.start_channels)) | |
current_channels = self.args.start_channels | |
current_size = self.args.image_size | |
size_to_channel = { | |
current_size: current_channels | |
} | |
while current_size > self.args.bottleneck_image_size: | |
next_size = current_size // 2 | |
next_channels = min(self.args.max_channels, current_channels * 2) | |
self.downsample_blocks.append(conv_block_factory.create_downsample_block( | |
current_channels, | |
next_channels, | |
is_output_1x1=False)) | |
current_size = next_size | |
current_channels = next_channels | |
size_to_channel[current_size] = current_channels | |
self.bottleneck_blocks = ModuleList() | |
for i in range(self.args.num_bottleneck_blocks): | |
self.bottleneck_blocks.append(conv_block_factory.create_resnet_block(current_channels, is_1x1=False)) | |
self.output_image_sizes = [current_size] | |
self.output_num_channels = [current_channels] | |
self.upsample_blocks = ModuleList() | |
while current_size < self.args.image_size: | |
next_size = current_size * 2 | |
next_channels = size_to_channel[next_size] | |
self.upsample_blocks.append(conv_block_factory.create_conv3_block( | |
current_channels + next_channels, | |
next_channels)) | |
current_size = next_size | |
current_channels = next_channels | |
self.output_image_sizes.append(current_size) | |
self.output_num_channels.append(current_channels) | |
if args.upsample_mode == 'nearest': | |
align_corners = None | |
else: | |
align_corners = False | |
self.double_resolution = Upsample(scale_factor=2, mode=args.upsample_mode, align_corners=align_corners) | |
def forward(self, feature: Tensor) -> List[Tensor]: | |
downsampled_features = [] | |
for block in self.downsample_blocks: | |
feature = block(feature) | |
downsampled_features.append(feature) | |
for block in self.bottleneck_blocks: | |
feature = block(feature) | |
outputs = [feature] | |
for i in range(0, len(self.upsample_blocks)): | |
feature = self.double_resolution(feature) | |
feature = torch.cat([feature, downsampled_features[-i - 2]], dim=1) | |
feature = self.upsample_blocks[i](feature) | |
outputs.append(feature) | |
return outputs | |
if __name__ == "__main__": | |
device = torch.device('cuda') | |
image_size = 512 | |
image_channels = 4 | |
num_pose_params = 6 | |
args = ResizeConvUNetArgs( | |
image_size=512, | |
input_channels=10, | |
start_channels=32, | |
bottleneck_image_size=32, | |
num_bottleneck_blocks=6, | |
max_channels=512, | |
upsample_mode='nearest', | |
use_separable_convolution=False, | |
block_args=BlockArgs( | |
initialization_method='he', | |
use_spectral_norm=False, | |
normalization_layer_factory=InstanceNorm2dFactory(), | |
nonlinearity_factory=ReLUFactory(inplace=False))) | |
module = ResizeConvUNet(args).to(device) | |
image_count = 8 | |
input = torch.zeros(image_count, 10, 512, 512, device=device) | |
outputs = module.forward(input) | |
for output in outputs: | |
print(output.shape) | |
if True: | |
repeat = 100 | |
acc = 0.0 | |
for i in range(repeat + 2): | |
start = torch.cuda.Event(enable_timing=True) | |
end = torch.cuda.Event(enable_timing=True) | |
start.record() | |
module.forward(input) | |
end.record() | |
torch.cuda.synchronize() | |
if i >= 2: | |
elapsed_time = start.elapsed_time(end) | |
print("%d:" % i, elapsed_time) | |
acc = acc + elapsed_time | |
print("average:", acc / repeat) |