|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from models import utils |
|
|
|
|
|
class CMDTop(nn.Module): |
|
def __init__(self, in_channel, out_channels, kernel_shapes, strides): |
|
super(CMDTop, self).__init__() |
|
self.in_channels = [in_channel] + list(out_channels[:-1]) |
|
self.out_channels = out_channels |
|
self.kernel_shapes = kernel_shapes |
|
self.strides = strides |
|
|
|
self.conv = nn.ModuleList([ |
|
nn.Sequential( |
|
utils.Conv2dSamePadding( |
|
in_channels=self.in_channels[i], |
|
out_channels=self.out_channels[i], |
|
kernel_size=self.kernel_shapes[i], |
|
stride=self.strides[i], |
|
), |
|
nn.GroupNorm(out_channels[i] // 16, out_channels[i]), |
|
nn.ReLU() |
|
) for i in range(len(out_channels)) |
|
]) |
|
|
|
def forward(self, x): |
|
""" |
|
x: (b, h, w, i, j) |
|
""" |
|
out1 = utils.einshape('bhwij->b(ij)hw', x) |
|
out2 = utils.einshape('bhwij->b(hw)ij', x) |
|
|
|
for i in range(len(self.out_channels)): |
|
out1 = self.conv[i](out1) |
|
|
|
for i in range(len(self.out_channels)): |
|
out2 = self.conv[i](out2) |
|
|
|
out1 = torch.mean(out1, dim=(2, 3)) |
|
out2 = torch.mean(out2, dim=(2, 3)) |
|
|
|
return torch.cat([out1, out2], dim=-1) |
|
|
|
|