File size: 5,134 Bytes
05c9ac2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from mlagents.torch_utils import torch
from typing import List
import math

from mlagents.trainers.torch_entities.layers import (
    linear_layer,
    Swish,
    Initialization,
    LayerNorm,
)


class ConditionalEncoder(torch.nn.Module):
    def __init__(
        self,
        input_size: int,
        goal_size: int,
        hidden_size: int,
        num_layers: int,
        num_conditional_layers: int,
        kernel_init: Initialization = Initialization.KaimingHeNormal,
        kernel_gain: float = 1.0,
    ):
        """
        ConditionalEncoder module. A fully connected network of which some of the
        weights are generated by a goal conditioning. Uses the HyperNetwork module to
        generate the weights of the network. Only the weights of the last
        "num_conditional_layers" layers will be generated by HyperNetworks, the others
        will use regular parameters.
        :param input_size: The size of the input of the encoder
        :param goal_size: The size of the goal tensor that will condition the encoder
        :param hidden_size: The number of hidden units in the encoder
        :param num_layers: The total number of layers of the encoder (both regular and
        generated by HyperNetwork)
        :param num_conditional_layers: The number of layers generated with hypernetworks
        :param kernel_init: The Initialization to use for the weights of the layer
        :param kernel_gain: The multiplier for the weights of the kernel.
        """
        super().__init__()
        layers: List[torch.nn.Module] = []
        prev_size = input_size
        for i in range(num_layers):
            if num_layers - i <= num_conditional_layers:
                # This means layer i is a conditional layer since the conditional
                # leyers are the last num_conditional_layers
                layers.append(
                    HyperNetwork(prev_size, hidden_size, goal_size, hidden_size, 2)
                )
            else:
                layers.append(
                    linear_layer(
                        prev_size,
                        hidden_size,
                        kernel_init=kernel_init,
                        kernel_gain=kernel_gain,
                    )
                )
            layers.append(Swish())
            prev_size = hidden_size
        self.layers = torch.nn.ModuleList(layers)

    def forward(
        self, input_tensor: torch.Tensor, goal_tensor: torch.Tensor
    ) -> torch.Tensor:  # type: ignore
        activation = input_tensor
        for layer in self.layers:
            if isinstance(layer, HyperNetwork):
                activation = layer(activation, goal_tensor)
            else:
                activation = layer(activation)
        return activation


class HyperNetwork(torch.nn.Module):
    def __init__(
        self, input_size, output_size, hyper_input_size, layer_size, num_layers
    ):
        """
        Hyper Network module. This module will use the hyper_input tensor to generate
        the weights of the main network. The main network is a single fully connected
        layer.
        :param input_size: The size of the input of the main network
        :param output_size: The size of the output of the main network
        :param hyper_input_size: The size of the input of the hypernetwork that will
        generate the main network.
        :param layer_size: The number of hidden units in the layers of the hypernetwork
        :param num_layers: The number of layers of the hypernetwork
        """
        super().__init__()
        self.input_size = input_size
        self.output_size = output_size

        layer_in_size = hyper_input_size
        layers = []
        for _ in range(num_layers):
            layers.append(
                linear_layer(
                    layer_in_size,
                    layer_size,
                    kernel_init=Initialization.KaimingHeNormal,
                    kernel_gain=1.0,
                    bias_init=Initialization.Zero,
                )
            )
            layers.append(Swish())
            layer_in_size = layer_size
        flat_output = linear_layer(
            layer_size,
            input_size * output_size,
            kernel_init=Initialization.KaimingHeNormal,
            kernel_gain=0.1,
            bias_init=Initialization.Zero,
        )

        # Re-initializing the weights of the last layer of the hypernetwork
        bound = math.sqrt(1 / (layer_size * self.input_size))
        flat_output.weight.data.uniform_(-bound, bound)

        self.hypernet = torch.nn.Sequential(*layers, LayerNorm(), flat_output)

        # The hypernetwork will not generate the bias of the main network layer
        self.bias = torch.nn.Parameter(torch.zeros(output_size))

    def forward(self, input_activation, hyper_input):
        output_weights = self.hypernet(hyper_input)

        output_weights = output_weights.view(-1, self.input_size, self.output_size)

        result = (
            torch.bmm(input_activation.unsqueeze(1), output_weights).squeeze(1)
            + self.bias
        )
        return result