File size: 2,116 Bytes
a80d6bb
 
 
 
c74a070
 
a80d6bb
 
 
 
 
 
 
c74a070
a80d6bb
 
 
 
 
c74a070
 
 
 
 
 
 
a80d6bb
c74a070
 
 
a80d6bb
 
 
 
 
c74a070
a80d6bb
c74a070
 
 
 
a80d6bb
 
c74a070
 
 
 
 
 
 
a80d6bb
 
c74a070
a80d6bb
 
 
 
 
 
 
 
 
 
 
 
 
c74a070
a80d6bb
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch.nn as nn


class PixelShuffleDecoder(nn.Module):
    """Pixel shuffle decoder."""

    def __init__(self, input_feat_dim=128, num_upsample=2, output_channel=2):
        super(PixelShuffleDecoder, self).__init__()
        # Get channel parameters
        self.channel_conf = self.get_channel_conf(num_upsample)

        # Define the pixel shuffle
        self.pixshuffle = nn.PixelShuffle(2)

        # Process the feature
        self.conv_block_lst = []
        # The input block
        self.conv_block_lst.append(
            nn.Sequential(
                nn.Conv2d(
                    input_feat_dim,
                    self.channel_conf[0],
                    kernel_size=3,
                    stride=1,
                    padding=1,
                ),
                nn.BatchNorm2d(self.channel_conf[0]),
                nn.ReLU(inplace=True),
            )
        )

        # Intermediate block
        for channel in self.channel_conf[1:-1]:
            self.conv_block_lst.append(
                nn.Sequential(
                    nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1),
                    nn.BatchNorm2d(channel),
                    nn.ReLU(inplace=True),
                )
            )

        # Output block
        self.conv_block_lst.append(
            nn.Conv2d(
                self.channel_conf[-1],
                output_channel,
                kernel_size=1,
                stride=1,
                padding=0,
            )
        )
        self.conv_block_lst = nn.ModuleList(self.conv_block_lst)

    # Get num of channels based on number of upsampling.
    def get_channel_conf(self, num_upsample):
        if num_upsample == 2:
            return [256, 64, 16]
        elif num_upsample == 3:
            return [256, 64, 16, 4]

    def forward(self, input_features):
        # Iterate til output block
        out = input_features
        for block in self.conv_block_lst[:-1]:
            out = block(out)
            out = self.pixshuffle(out)

        # Output layer
        out = self.conv_block_lst[-1](out)

        return out