|
""" |
|
This file provides an implementation of several different neural network modules that are used for merging and |
|
transforming input data in various ways. The following components can be used when we are dealing with |
|
data from multiple modes, or when we need to merge multiple intermediate embedded representations in |
|
the forward process of a model. |
|
|
|
The main classes defined in this code are: |
|
|
|
- BilinearGeneral: This class implements a bilinear transformation layer that applies a bilinear transformation to |
|
incoming data, as described in the "Multiplicative Interactions and Where to Find Them", published at ICLR 2020, |
|
https://openreview.net/forum?id=rylnK6VtDH. The transformation involves two input features and an output |
|
feature, and also includes an optional bias term. |
|
|
|
- TorchBilinearCustomized: This class implements a bilinear layer similar to the one provided by PyTorch |
|
(torch.nn.Bilinear), but with additional customizations. This class can be used as an alternative to the |
|
BilinearGeneral class. |
|
|
|
- TorchBilinear: This class is a simple wrapper around the PyTorch's built-in nn.Bilinear module. It provides the |
|
same functionality as PyTorch's nn.Bilinear but within the structure of the current module. |
|
|
|
- FiLM: This class implements a Feature-wise Linear Modulation (FiLM) layer. FiLM layers apply an affine |
|
transformation to the input data, conditioned on some additional context information. |
|
|
|
- GatingType: This is an enumeration class that defines different types of gating mechanisms that can be used in |
|
the modules. |
|
|
|
- SumMerge: This class provides a simple summing mechanism to merge input streams. |
|
|
|
- VectorMerge: This class implements a more complex merging mechanism for vector streams. |
|
The streams are first transformed using layer normalization, a ReLU activation, and a linear layer. |
|
Then they are merged either by simple summing or by using a gating mechanism. |
|
|
|
The implementation of these classes involves PyTorch and Numpy libraries, and the classes use PyTorch's nn.Module as |
|
the base class, making them compatible with PyTorch's neural network modules and functionalities. |
|
These modules can be useful building blocks in more complex deep learning architectures. |
|
""" |
|
|
|
import enum |
|
import math |
|
from collections import OrderedDict |
|
from typing import List, Dict, Tuple |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch import Tensor |
|
|
|
|
|
class BilinearGeneral(nn.Module): |
|
""" |
|
Overview: |
|
Bilinear implementation as in: Multiplicative Interactions and Where to Find Them, |
|
ICLR 2020, https://openreview.net/forum?id=rylnK6VtDH. |
|
Interfaces: |
|
``__init__``, ``forward`` |
|
""" |
|
|
|
def __init__(self, in1_features: int, in2_features: int, out_features: int): |
|
""" |
|
Overview: |
|
Initialize the Bilinear layer. |
|
Arguments: |
|
- in1_features (:obj:`int`): The size of each first input sample. |
|
- in2_features (:obj:`int`): The size of each second input sample. |
|
- out_features (:obj:`int`): The size of each output sample. |
|
""" |
|
|
|
super(BilinearGeneral, self).__init__() |
|
|
|
self.W = nn.Parameter(torch.Tensor(out_features, in1_features, in2_features)) |
|
self.U = nn.Parameter(torch.Tensor(out_features, in2_features)) |
|
self.V = nn.Parameter(torch.Tensor(out_features, in1_features)) |
|
self.b = nn.Parameter(torch.Tensor(out_features)) |
|
self.in1_features = in1_features |
|
self.in2_features = in2_features |
|
self.out_features = out_features |
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
""" |
|
Overview: |
|
Initialize the parameters of the Bilinear layer. |
|
""" |
|
|
|
stdv = 1. / np.sqrt(self.in1_features) |
|
self.W.data.uniform_(-stdv, stdv) |
|
self.U.data.uniform_(-stdv, stdv) |
|
self.V.data.uniform_(-stdv, stdv) |
|
self.b.data.uniform_(-stdv, stdv) |
|
|
|
def forward(self, x: torch.Tensor, z: torch.Tensor): |
|
""" |
|
Overview: |
|
compute the bilinear function. |
|
Arguments: |
|
- x (:obj:`torch.Tensor`): The first input tensor. |
|
- z (:obj:`torch.Tensor`): The second input tensor. |
|
""" |
|
|
|
|
|
|
|
out_W = torch.einsum('bi,kij,bj->bk', x, self.W, z) |
|
|
|
out_U = z.matmul(self.U.t()) |
|
|
|
out_V = x.matmul(self.V.t()) |
|
|
|
out = out_W + out_U + out_V + self.b |
|
return out |
|
|
|
|
|
class TorchBilinearCustomized(nn.Module): |
|
""" |
|
Overview: |
|
Customized Torch Bilinear implementation. |
|
Interfaces: |
|
``__init__``, ``forward`` |
|
""" |
|
|
|
def __init__(self, in1_features: int, in2_features: int, out_features: int): |
|
""" |
|
Overview: |
|
Initialize the Bilinear layer. |
|
Arguments: |
|
- in1_features (:obj:`int`): The size of each first input sample. |
|
- in2_features (:obj:`int`): The size of each second input sample. |
|
- out_features (:obj:`int`): The size of each output sample. |
|
""" |
|
|
|
super(TorchBilinearCustomized, self).__init__() |
|
self.in1_features = in1_features |
|
self.in2_features = in2_features |
|
self.out_features = out_features |
|
self.weight = nn.Parameter(torch.Tensor(out_features, in1_features, in2_features)) |
|
self.bias = nn.Parameter(torch.Tensor(out_features)) |
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
""" |
|
Overview: |
|
Initialize the parameters of the Bilinear layer. |
|
""" |
|
|
|
bound = 1 / math.sqrt(self.in1_features) |
|
nn.init.uniform_(self.weight, -bound, bound) |
|
nn.init.uniform_(self.bias, -bound, bound) |
|
|
|
def forward(self, x, z): |
|
""" |
|
Overview: |
|
Compute the bilinear function. |
|
Arguments: |
|
- x (:obj:`torch.Tensor`): The first input tensor. |
|
- z (:obj:`torch.Tensor`): The second input tensor. |
|
""" |
|
|
|
|
|
out = torch.einsum('bi,oij,bj->bo', x, self.weight, z) + self.bias |
|
return out.squeeze(-1) |
|
|
|
|
|
""" |
|
Overview: |
|
Implementation of the Bilinear layer as in PyTorch: |
|
https://pytorch.org/docs/stable/generated/torch.nn.Bilinear.html#torch.nn.Bilinear |
|
Arguments: |
|
- in1_features (:obj:`int`): The size of each first input sample. |
|
- in2_features (:obj:`int`): The size of each second input sample. |
|
- out_features (:obj:`int`): The size of each output sample. |
|
- bias (:obj:`bool`): If set to False, the layer will not learn an additive bias. Default: ``True``. |
|
""" |
|
TorchBilinear = nn.Bilinear |
|
|
|
|
|
class FiLM(nn.Module): |
|
""" |
|
Overview: |
|
Feature-wise Linear Modulation (FiLM) Layer. |
|
This layer applies feature-wise affine transformation based on context. |
|
Interfaces: |
|
``__init__``, ``forward`` |
|
""" |
|
|
|
def __init__(self, feature_dim: int, context_dim: int): |
|
""" |
|
Overview: |
|
Initialize the FiLM layer. |
|
Arguments: |
|
- feature_dim (:obj:`int`). The dimension of the input feature vector. |
|
- context_dim (:obj:`int`). The dimension of the input context vector. |
|
""" |
|
|
|
super(FiLM, self).__init__() |
|
|
|
|
|
self.context_layer = nn.Linear(context_dim, 2 * feature_dim) |
|
|
|
def forward(self, feature: torch.Tensor, context: torch.Tensor): |
|
""" |
|
Overview: |
|
Forward propagation. |
|
Arguments: |
|
- feature (:obj:`torch.Tensor`). The input feature, shape (batch_size, feature_dim). |
|
- context (:obj:`torch.Tensor`). The input context, shape (batch_size, context_dim). |
|
Returns: |
|
- conditioned_feature : torch.Tensor. The output feature after FiLM, shape (batch_size, feature_dim). |
|
""" |
|
|
|
|
|
out = self.context_layer(context) |
|
|
|
|
|
gamma, beta = torch.split(out, out.shape[1] // 2, dim=1) |
|
|
|
conditioned_feature = gamma * feature + beta |
|
return conditioned_feature |
|
|
|
|
|
class GatingType(enum.Enum): |
|
""" |
|
Overview: |
|
Enum class defining different types of tensor gating and aggregation in modules. |
|
""" |
|
NONE = 'none' |
|
GLOBAL = 'global' |
|
POINTWISE = 'pointwise' |
|
|
|
|
|
class SumMerge(nn.Module): |
|
""" |
|
Overview: |
|
A PyTorch module that merges a list of tensors by computing their sum. All input tensors must have the same |
|
size. This module can work with any type of tensor (vector, units or visual). |
|
Interfaces: |
|
``__init__``, ``forward`` |
|
""" |
|
|
|
def forward(self, tensors: List[Tensor]) -> Tensor: |
|
""" |
|
Overview: |
|
Forward pass of the SumMerge module, which sums the input tensors. |
|
Arguments: |
|
- tensors (:obj:`List[Tensor]`): List of input tensors to be summed. All tensors must have the same size. |
|
Returns: |
|
- summed (:obj:`Tensor`): Tensor resulting from the sum of all input tensors. |
|
""" |
|
|
|
stacked = torch.stack(tensors, dim=0) |
|
|
|
|
|
summed = torch.sum(stacked, dim=0) |
|
|
|
return summed |
|
|
|
|
|
class VectorMerge(nn.Module): |
|
""" |
|
Overview: |
|
Merges multiple vector streams. Streams are first transformed through layer normalization, relu, and linear |
|
layers, then summed. They don't need to have the same size. Gating can also be used before the sum. |
|
Interfaces: |
|
``__init__``, ``encode``, ``_compute_gate``, ``forward`` |
|
|
|
.. note:: |
|
For more details about the gating types, please refer to the GatingType enum class. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
input_sizes: Dict[str, int], |
|
output_size: int, |
|
gating_type: GatingType = GatingType.NONE, |
|
use_layer_norm: bool = True, |
|
): |
|
""" |
|
Overview: |
|
Initialize the `VectorMerge` module. |
|
Arguments: |
|
- input_sizes (:obj:`Dict[str, int]`): A dictionary mapping input names to their sizes. \ |
|
The size is a single integer for 1D inputs, or `None` for 0D inputs. \ |
|
If an input size is `None`, we assume it's `()`. |
|
- output_size (:obj:`int`): The size of the output vector. |
|
- gating_type (:obj:`GatingType`): The type of gating mechanism to use. Default is `GatingType.NONE`. |
|
- use_layer_norm (:obj:`bool`): Whether to use layer normalization. Default is `True`. |
|
""" |
|
super().__init__() |
|
self._input_sizes = OrderedDict(input_sizes) |
|
self._output_size = output_size |
|
self._gating_type = gating_type |
|
self._use_layer_norm = use_layer_norm |
|
|
|
if self._use_layer_norm: |
|
self._layer_norms = nn.ModuleDict() |
|
else: |
|
self._layer_norms = None |
|
|
|
self._linears = nn.ModuleDict() |
|
for name, size in self._input_sizes.items(): |
|
linear_input_size = size if size > 0 else 1 |
|
if self._use_layer_norm: |
|
self._layer_norms[name] = nn.LayerNorm(linear_input_size) |
|
self._linears[name] = nn.Linear(linear_input_size, self._output_size) |
|
|
|
self._gating_linears = nn.ModuleDict() |
|
if self._gating_type is GatingType.GLOBAL: |
|
self.gate_size = 1 |
|
elif self._gating_type is GatingType.POINTWISE: |
|
self.gate_size = self._output_size |
|
elif self._gating_type is GatingType.NONE: |
|
self._gating_linears = None |
|
else: |
|
raise ValueError(f'Gating type {self._gating_type} is not supported') |
|
|
|
if self._gating_linears is not None: |
|
if len(self._input_sizes) == 2: |
|
|
|
for name, size in self._input_sizes.items(): |
|
gate_input_size = size if size > 0 else 1 |
|
gating_layer = nn.Linear(gate_input_size, self.gate_size) |
|
torch.nn.init.normal_(gating_layer.weight, std=0.005) |
|
torch.nn.init.constant_(gating_layer.bias, 0.0) |
|
self._gating_linears[name] = gating_layer |
|
else: |
|
for name, size in self._input_sizes.items(): |
|
gate_input_size = size if size > 0 else 1 |
|
gating_layer = nn.Linear(gate_input_size, len(self._input_sizes) * self.gate_size) |
|
torch.nn.init.normal_(gating_layer.weight, std=0.005) |
|
torch.nn.init.constant_(gating_layer.bias, 0.0) |
|
self._gating_linears[name] = gating_layer |
|
|
|
def encode(self, inputs: Dict[str, Tensor]) -> Tuple[List[Tensor], List[Tensor]]: |
|
""" |
|
Overview: |
|
Encode the input tensors using layer normalization, relu, and linear transformations. |
|
Arguments: |
|
- inputs (:obj:`Dict[str, Tensor]`): The input tensors. |
|
Returns: |
|
- gates (:obj:`List[Tensor]`): The gate tensors after transformations. |
|
- outputs (:obj:`List[Tensor]`): The output tensors after transformations. |
|
""" |
|
gates, outputs = [], [] |
|
for name, size in self._input_sizes.items(): |
|
feature = inputs[name] |
|
if size <= 0 and feature.dim() == 1: |
|
feature = feature.unsqueeze(-1) |
|
feature = feature.to(torch.float32) |
|
if self._use_layer_norm and name in self._layer_norms: |
|
feature = self._layer_norms[name](feature) |
|
feature = F.relu(feature) |
|
gates.append(feature) |
|
outputs.append(self._linears[name](feature)) |
|
return gates, outputs |
|
|
|
def _compute_gate( |
|
self, |
|
init_gate: List[Tensor], |
|
) -> List[Tensor]: |
|
""" |
|
Overview: |
|
Compute the gate values based on the initial gate values. |
|
Arguments: |
|
- init_gate (:obj:`List[Tensor]`): The initial gate values. |
|
Returns: |
|
- gate (:obj:`List[Tensor]`): The computed gate values. |
|
""" |
|
if len(self._input_sizes) == 2: |
|
gate = [self._gating_linears[name](y) for name, y in zip(self._input_sizes.keys(), init_gate)] |
|
gate = sum(gate) |
|
sigmoid = torch.sigmoid(gate) |
|
gate = [sigmoid, 1.0 - sigmoid] |
|
else: |
|
gate = [self._gating_linears[name](y) for name, y in zip(self._input_sizes.keys(), init_gate)] |
|
gate = sum(gate) |
|
gate = gate.reshape([-1, len(self._input_sizes), self.gate_size]) |
|
gate = F.softmax(gate, dim=1) |
|
assert gate.shape[1] == len(self._input_sizes) |
|
gate = [gate[:, i] for i in range(len(self._input_sizes))] |
|
return gate |
|
|
|
def forward(self, inputs: Dict[str, Tensor]) -> Tensor: |
|
""" |
|
Overview: |
|
Forward pass through the VectorMerge module. |
|
Arguments: |
|
- inputs (:obj:`Dict[str, Tensor]`): The input tensors. |
|
Returns: |
|
- output (:obj:`Tensor`): The output tensor after passing through the module. |
|
""" |
|
gates, outputs = self.encode(inputs) |
|
if len(outputs) == 1: |
|
|
|
output = outputs[0] |
|
elif self._gating_type is GatingType.NONE: |
|
output = sum(outputs) |
|
else: |
|
gate = self._compute_gate(gates) |
|
data = [g * d for g, d in zip(gate, outputs)] |
|
output = sum(data) |
|
return output |
|
|