File size: 1,245 Bytes
ec9a6bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch
import torch.nn as nn
import torch.nn.functional as F


class MLP(nn.Module):
    def __init__(self, dims, last_op=None):
        super(MLP, self).__init__()

        self.dims = dims
        self.skip_layer = [int(len(dims) / 2)]
        self.last_op = last_op

        self.layers = []
        for l in range(0, len(dims) - 1):
            if l in self.skip_layer:
                self.layers.append(nn.Conv1d(dims[l] + dims[0], dims[l + 1], 1))
            else:
                self.layers.append(nn.Conv1d(dims[l], dims[l + 1], 1))
            self.add_module("conv%d" % l, self.layers[l])

    def forward(self, latet_code, return_all=False):
        y = latet_code
        tmpy = latet_code
        y_list = []
        for l, f in enumerate(self.layers):
            if l in self.skip_layer:
                y = self._modules['conv' + str(l)](torch.cat([y, tmpy], 1))
            else:
                y = self._modules['conv' + str(l)](y)
            if l != len(self.layers) - 1:
                y = F.leaky_relu(y)
        if self.last_op:
            y = self.last_op(y)
            y_list.append(y)
        if return_all:
            return y_list
        else:
            return y