File size: 1,481 Bytes
f1586f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch
import torch.nn as nn
import torch.nn.functional as F
from models import utils


class CMDTop(nn.Module):
    def __init__(self, in_channel, out_channels, kernel_shapes, strides):
        super(CMDTop, self).__init__()
        self.in_channels = [in_channel] + list(out_channels[:-1])
        self.out_channels = out_channels
        self.kernel_shapes = kernel_shapes
        self.strides = strides

        self.conv = nn.ModuleList([
            nn.Sequential(
                utils.Conv2dSamePadding(
                    in_channels=self.in_channels[i],
                    out_channels=self.out_channels[i],
                    kernel_size=self.kernel_shapes[i],
                    stride=self.strides[i],
                ),
                nn.GroupNorm(out_channels[i] // 16, out_channels[i]),
                nn.ReLU()
            ) for i in range(len(out_channels))
        ])

    def forward(self, x):
        """
        x: (b, h, w, i, j)
        """
        out1 = utils.einshape('bhwij->b(ij)hw', x)
        out2 = utils.einshape('bhwij->b(hw)ij', x)
        
        for i in range(len(self.out_channels)):
            out1 = self.conv[i](out1)
        
        for i in range(len(self.out_channels)):
            out2 = self.conv[i](out2)

        out1 = torch.mean(out1, dim=(2, 3)) # (b, out_channels[-1])
        out2 = torch.mean(out2, dim=(2, 3)) # (b, out_channels[-1])

        return torch.cat([out1, out2], dim=-1) # (b, 2*out_channels[-1])