File size: 5,134 Bytes
05c9ac2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
from mlagents.torch_utils import torch
from typing import List
import math
from mlagents.trainers.torch_entities.layers import (
linear_layer,
Swish,
Initialization,
LayerNorm,
)
class ConditionalEncoder(torch.nn.Module):
def __init__(
self,
input_size: int,
goal_size: int,
hidden_size: int,
num_layers: int,
num_conditional_layers: int,
kernel_init: Initialization = Initialization.KaimingHeNormal,
kernel_gain: float = 1.0,
):
"""
ConditionalEncoder module. A fully connected network of which some of the
weights are generated by a goal conditioning. Uses the HyperNetwork module to
generate the weights of the network. Only the weights of the last
"num_conditional_layers" layers will be generated by HyperNetworks, the others
will use regular parameters.
:param input_size: The size of the input of the encoder
:param goal_size: The size of the goal tensor that will condition the encoder
:param hidden_size: The number of hidden units in the encoder
:param num_layers: The total number of layers of the encoder (both regular and
generated by HyperNetwork)
:param num_conditional_layers: The number of layers generated with hypernetworks
:param kernel_init: The Initialization to use for the weights of the layer
:param kernel_gain: The multiplier for the weights of the kernel.
"""
super().__init__()
layers: List[torch.nn.Module] = []
prev_size = input_size
for i in range(num_layers):
if num_layers - i <= num_conditional_layers:
# This means layer i is a conditional layer since the conditional
# leyers are the last num_conditional_layers
layers.append(
HyperNetwork(prev_size, hidden_size, goal_size, hidden_size, 2)
)
else:
layers.append(
linear_layer(
prev_size,
hidden_size,
kernel_init=kernel_init,
kernel_gain=kernel_gain,
)
)
layers.append(Swish())
prev_size = hidden_size
self.layers = torch.nn.ModuleList(layers)
def forward(
self, input_tensor: torch.Tensor, goal_tensor: torch.Tensor
) -> torch.Tensor: # type: ignore
activation = input_tensor
for layer in self.layers:
if isinstance(layer, HyperNetwork):
activation = layer(activation, goal_tensor)
else:
activation = layer(activation)
return activation
class HyperNetwork(torch.nn.Module):
def __init__(
self, input_size, output_size, hyper_input_size, layer_size, num_layers
):
"""
Hyper Network module. This module will use the hyper_input tensor to generate
the weights of the main network. The main network is a single fully connected
layer.
:param input_size: The size of the input of the main network
:param output_size: The size of the output of the main network
:param hyper_input_size: The size of the input of the hypernetwork that will
generate the main network.
:param layer_size: The number of hidden units in the layers of the hypernetwork
:param num_layers: The number of layers of the hypernetwork
"""
super().__init__()
self.input_size = input_size
self.output_size = output_size
layer_in_size = hyper_input_size
layers = []
for _ in range(num_layers):
layers.append(
linear_layer(
layer_in_size,
layer_size,
kernel_init=Initialization.KaimingHeNormal,
kernel_gain=1.0,
bias_init=Initialization.Zero,
)
)
layers.append(Swish())
layer_in_size = layer_size
flat_output = linear_layer(
layer_size,
input_size * output_size,
kernel_init=Initialization.KaimingHeNormal,
kernel_gain=0.1,
bias_init=Initialization.Zero,
)
# Re-initializing the weights of the last layer of the hypernetwork
bound = math.sqrt(1 / (layer_size * self.input_size))
flat_output.weight.data.uniform_(-bound, bound)
self.hypernet = torch.nn.Sequential(*layers, LayerNorm(), flat_output)
# The hypernetwork will not generate the bias of the main network layer
self.bias = torch.nn.Parameter(torch.zeros(output_size))
def forward(self, input_activation, hyper_input):
output_weights = self.hypernet(hyper_input)
output_weights = output_weights.view(-1, self.input_size, self.output_size)
result = (
torch.bmm(input_activation.unsqueeze(1), output_weights).squeeze(1)
+ self.bias
)
return result
|