AnnaMats's picture
Second Push
05c9ac2
from mlagents.torch_utils import torch
from typing import List
import math
from mlagents.trainers.torch_entities.layers import (
linear_layer,
Swish,
Initialization,
LayerNorm,
)
class ConditionalEncoder(torch.nn.Module):
def __init__(
self,
input_size: int,
goal_size: int,
hidden_size: int,
num_layers: int,
num_conditional_layers: int,
kernel_init: Initialization = Initialization.KaimingHeNormal,
kernel_gain: float = 1.0,
):
"""
ConditionalEncoder module. A fully connected network of which some of the
weights are generated by a goal conditioning. Uses the HyperNetwork module to
generate the weights of the network. Only the weights of the last
"num_conditional_layers" layers will be generated by HyperNetworks, the others
will use regular parameters.
:param input_size: The size of the input of the encoder
:param goal_size: The size of the goal tensor that will condition the encoder
:param hidden_size: The number of hidden units in the encoder
:param num_layers: The total number of layers of the encoder (both regular and
generated by HyperNetwork)
:param num_conditional_layers: The number of layers generated with hypernetworks
:param kernel_init: The Initialization to use for the weights of the layer
:param kernel_gain: The multiplier for the weights of the kernel.
"""
super().__init__()
layers: List[torch.nn.Module] = []
prev_size = input_size
for i in range(num_layers):
if num_layers - i <= num_conditional_layers:
# This means layer i is a conditional layer since the conditional
# leyers are the last num_conditional_layers
layers.append(
HyperNetwork(prev_size, hidden_size, goal_size, hidden_size, 2)
)
else:
layers.append(
linear_layer(
prev_size,
hidden_size,
kernel_init=kernel_init,
kernel_gain=kernel_gain,
)
)
layers.append(Swish())
prev_size = hidden_size
self.layers = torch.nn.ModuleList(layers)
def forward(
self, input_tensor: torch.Tensor, goal_tensor: torch.Tensor
) -> torch.Tensor: # type: ignore
activation = input_tensor
for layer in self.layers:
if isinstance(layer, HyperNetwork):
activation = layer(activation, goal_tensor)
else:
activation = layer(activation)
return activation
class HyperNetwork(torch.nn.Module):
def __init__(
self, input_size, output_size, hyper_input_size, layer_size, num_layers
):
"""
Hyper Network module. This module will use the hyper_input tensor to generate
the weights of the main network. The main network is a single fully connected
layer.
:param input_size: The size of the input of the main network
:param output_size: The size of the output of the main network
:param hyper_input_size: The size of the input of the hypernetwork that will
generate the main network.
:param layer_size: The number of hidden units in the layers of the hypernetwork
:param num_layers: The number of layers of the hypernetwork
"""
super().__init__()
self.input_size = input_size
self.output_size = output_size
layer_in_size = hyper_input_size
layers = []
for _ in range(num_layers):
layers.append(
linear_layer(
layer_in_size,
layer_size,
kernel_init=Initialization.KaimingHeNormal,
kernel_gain=1.0,
bias_init=Initialization.Zero,
)
)
layers.append(Swish())
layer_in_size = layer_size
flat_output = linear_layer(
layer_size,
input_size * output_size,
kernel_init=Initialization.KaimingHeNormal,
kernel_gain=0.1,
bias_init=Initialization.Zero,
)
# Re-initializing the weights of the last layer of the hypernetwork
bound = math.sqrt(1 / (layer_size * self.input_size))
flat_output.weight.data.uniform_(-bound, bound)
self.hypernet = torch.nn.Sequential(*layers, LayerNorm(), flat_output)
# The hypernetwork will not generate the bias of the main network layer
self.bias = torch.nn.Parameter(torch.zeros(output_size))
def forward(self, input_activation, hyper_input):
output_weights = self.hypernet(hyper_input)
output_weights = output_weights.view(-1, self.input_size, self.output_size)
result = (
torch.bmm(input_activation.unsqueeze(1), output_weights).squeeze(1)
+ self.bias
)
return result