|
from mlagents.torch_utils import torch |
|
from typing import List |
|
import math |
|
|
|
from mlagents.trainers.torch_entities.layers import ( |
|
linear_layer, |
|
Swish, |
|
Initialization, |
|
LayerNorm, |
|
) |
|
|
|
|
|
class ConditionalEncoder(torch.nn.Module): |
|
def __init__( |
|
self, |
|
input_size: int, |
|
goal_size: int, |
|
hidden_size: int, |
|
num_layers: int, |
|
num_conditional_layers: int, |
|
kernel_init: Initialization = Initialization.KaimingHeNormal, |
|
kernel_gain: float = 1.0, |
|
): |
|
""" |
|
ConditionalEncoder module. A fully connected network of which some of the |
|
weights are generated by a goal conditioning. Uses the HyperNetwork module to |
|
generate the weights of the network. Only the weights of the last |
|
"num_conditional_layers" layers will be generated by HyperNetworks, the others |
|
will use regular parameters. |
|
:param input_size: The size of the input of the encoder |
|
:param goal_size: The size of the goal tensor that will condition the encoder |
|
:param hidden_size: The number of hidden units in the encoder |
|
:param num_layers: The total number of layers of the encoder (both regular and |
|
generated by HyperNetwork) |
|
:param num_conditional_layers: The number of layers generated with hypernetworks |
|
:param kernel_init: The Initialization to use for the weights of the layer |
|
:param kernel_gain: The multiplier for the weights of the kernel. |
|
""" |
|
super().__init__() |
|
layers: List[torch.nn.Module] = [] |
|
prev_size = input_size |
|
for i in range(num_layers): |
|
if num_layers - i <= num_conditional_layers: |
|
|
|
|
|
layers.append( |
|
HyperNetwork(prev_size, hidden_size, goal_size, hidden_size, 2) |
|
) |
|
else: |
|
layers.append( |
|
linear_layer( |
|
prev_size, |
|
hidden_size, |
|
kernel_init=kernel_init, |
|
kernel_gain=kernel_gain, |
|
) |
|
) |
|
layers.append(Swish()) |
|
prev_size = hidden_size |
|
self.layers = torch.nn.ModuleList(layers) |
|
|
|
def forward( |
|
self, input_tensor: torch.Tensor, goal_tensor: torch.Tensor |
|
) -> torch.Tensor: |
|
activation = input_tensor |
|
for layer in self.layers: |
|
if isinstance(layer, HyperNetwork): |
|
activation = layer(activation, goal_tensor) |
|
else: |
|
activation = layer(activation) |
|
return activation |
|
|
|
|
|
class HyperNetwork(torch.nn.Module): |
|
def __init__( |
|
self, input_size, output_size, hyper_input_size, layer_size, num_layers |
|
): |
|
""" |
|
Hyper Network module. This module will use the hyper_input tensor to generate |
|
the weights of the main network. The main network is a single fully connected |
|
layer. |
|
:param input_size: The size of the input of the main network |
|
:param output_size: The size of the output of the main network |
|
:param hyper_input_size: The size of the input of the hypernetwork that will |
|
generate the main network. |
|
:param layer_size: The number of hidden units in the layers of the hypernetwork |
|
:param num_layers: The number of layers of the hypernetwork |
|
""" |
|
super().__init__() |
|
self.input_size = input_size |
|
self.output_size = output_size |
|
|
|
layer_in_size = hyper_input_size |
|
layers = [] |
|
for _ in range(num_layers): |
|
layers.append( |
|
linear_layer( |
|
layer_in_size, |
|
layer_size, |
|
kernel_init=Initialization.KaimingHeNormal, |
|
kernel_gain=1.0, |
|
bias_init=Initialization.Zero, |
|
) |
|
) |
|
layers.append(Swish()) |
|
layer_in_size = layer_size |
|
flat_output = linear_layer( |
|
layer_size, |
|
input_size * output_size, |
|
kernel_init=Initialization.KaimingHeNormal, |
|
kernel_gain=0.1, |
|
bias_init=Initialization.Zero, |
|
) |
|
|
|
|
|
bound = math.sqrt(1 / (layer_size * self.input_size)) |
|
flat_output.weight.data.uniform_(-bound, bound) |
|
|
|
self.hypernet = torch.nn.Sequential(*layers, LayerNorm(), flat_output) |
|
|
|
|
|
self.bias = torch.nn.Parameter(torch.zeros(output_size)) |
|
|
|
def forward(self, input_activation, hyper_input): |
|
output_weights = self.hypernet(hyper_input) |
|
|
|
output_weights = output_weights.view(-1, self.input_size, self.output_size) |
|
|
|
result = ( |
|
torch.bmm(input_activation.unsqueeze(1), output_weights).squeeze(1) |
|
+ self.bias |
|
) |
|
return result |
|
|