File size: 6,160 Bytes
a64b7d4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
import torch
from torch import nn as nn
from torch.nn import functional as F
from basicsr.utils.registry import ARCH_REGISTRY
from .arch_util import flow_warp
class BasicModule(nn.Module):
"""Basic module of SPyNet.
Note that unlike the architecture in spynet_arch.py, the basic module
here contains batch normalization.
"""
def __init__(self):
super(BasicModule, self).__init__()
self.basic_module = nn.Sequential(
nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3, bias=False),
nn.BatchNorm2d(32), nn.ReLU(inplace=True),
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3, bias=False),
nn.BatchNorm2d(64), nn.ReLU(inplace=True),
nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3, bias=False),
nn.BatchNorm2d(32), nn.ReLU(inplace=True),
nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3, bias=False),
nn.BatchNorm2d(16), nn.ReLU(inplace=True),
nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3))
def forward(self, tensor_input):
"""
Args:
tensor_input (Tensor): Input tensor with shape (b, 8, h, w).
8 channels contain:
[reference image (3), neighbor image (3), initial flow (2)].
Returns:
Tensor: Estimated flow with shape (b, 2, h, w)
"""
return self.basic_module(tensor_input)
class SPyNetTOF(nn.Module):
"""SPyNet architecture for TOF.
Note that this implementation is specifically for TOFlow. Please use :file:`spynet_arch.py` for general use.
They differ in the following aspects:
1. The basic modules here contain BatchNorm.
2. Normalization and denormalization are not done here, as they are done in TOFlow.
``Paper: Optical Flow Estimation using a Spatial Pyramid Network``
Reference: https://github.com/Coldog2333/pytoflow
Args:
load_path (str): Path for pretrained SPyNet. Default: None.
"""
def __init__(self, load_path=None):
super(SPyNetTOF, self).__init__()
self.basic_module = nn.ModuleList([BasicModule() for _ in range(4)])
if load_path:
self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params'])
def forward(self, ref, supp):
"""
Args:
ref (Tensor): Reference image with shape of (b, 3, h, w).
supp: The supporting image to be warped: (b, 3, h, w).
Returns:
Tensor: Estimated optical flow: (b, 2, h, w).
"""
num_batches, _, h, w = ref.size()
ref = [ref]
supp = [supp]
# generate downsampled frames
for _ in range(3):
ref.insert(0, F.avg_pool2d(input=ref[0], kernel_size=2, stride=2, count_include_pad=False))
supp.insert(0, F.avg_pool2d(input=supp[0], kernel_size=2, stride=2, count_include_pad=False))
# flow computation
flow = ref[0].new_zeros(num_batches, 2, h // 16, w // 16)
for i in range(4):
flow_up = F.interpolate(input=flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0
flow = flow_up + self.basic_module[i](
torch.cat([ref[i], flow_warp(supp[i], flow_up.permute(0, 2, 3, 1)), flow_up], 1))
return flow
@ARCH_REGISTRY.register()
class TOFlow(nn.Module):
"""PyTorch implementation of TOFlow.
In TOFlow, the LR frames are pre-upsampled and have the same size with the GT frames.
``Paper: Video Enhancement with Task-Oriented Flow``
Reference: https://github.com/anchen1011/toflow
Reference: https://github.com/Coldog2333/pytoflow
Args:
adapt_official_weights (bool): Whether to adapt the weights translated
from the official implementation. Set to false if you want to
train from scratch. Default: False
"""
def __init__(self, adapt_official_weights=False):
super(TOFlow, self).__init__()
self.adapt_official_weights = adapt_official_weights
self.ref_idx = 0 if adapt_official_weights else 3
self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
# flow estimation module
self.spynet = SPyNetTOF()
# reconstruction module
self.conv_1 = nn.Conv2d(3 * 7, 64, 9, 1, 4)
self.conv_2 = nn.Conv2d(64, 64, 9, 1, 4)
self.conv_3 = nn.Conv2d(64, 64, 1)
self.conv_4 = nn.Conv2d(64, 3, 1)
# activation function
self.relu = nn.ReLU(inplace=True)
def normalize(self, img):
return (img - self.mean) / self.std
def denormalize(self, img):
return img * self.std + self.mean
def forward(self, lrs):
"""
Args:
lrs: Input lr frames: (b, 7, 3, h, w).
Returns:
Tensor: SR frame: (b, 3, h, w).
"""
# In the official implementation, the 0-th frame is the reference frame
if self.adapt_official_weights:
lrs = lrs[:, [3, 0, 1, 2, 4, 5, 6], :, :, :]
num_batches, num_lrs, _, h, w = lrs.size()
lrs = self.normalize(lrs.view(-1, 3, h, w))
lrs = lrs.view(num_batches, num_lrs, 3, h, w)
lr_ref = lrs[:, self.ref_idx, :, :, :]
lr_aligned = []
for i in range(7): # 7 frames
if i == self.ref_idx:
lr_aligned.append(lr_ref)
else:
lr_supp = lrs[:, i, :, :, :]
flow = self.spynet(lr_ref, lr_supp)
lr_aligned.append(flow_warp(lr_supp, flow.permute(0, 2, 3, 1)))
# reconstruction
hr = torch.stack(lr_aligned, dim=1)
hr = hr.view(num_batches, -1, h, w)
hr = self.relu(self.conv_1(hr))
hr = self.relu(self.conv_2(hr))
hr = self.relu(self.conv_3(hr))
hr = self.conv_4(hr) + lr_ref
return self.denormalize(hr)
|