FupBERT / fup_bert_model.py
c-dunlap's picture
Upload FupBERT
aae4e29
"""
© Battelle Memorial Institute 2023
Made available under the GNU General Public License v 2.0
BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY
FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN
OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES
PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED
OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS
TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE
PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING,
REPAIR OR CORRECTION.
"""
import torch
import torch.nn as nn
from .positional_encoding import PositionalEncoding
class FupBERTModel(nn.Module):
"""
A class that extends torch.nn.Module that implements a custom Transformer
encoder model to create a single embedding for Fup prediction.
"""
def __init__(
self,
ntoken,
ninp,
nhead,
nhid,
nlayers,
token_reduction,
padding_idx,
cls_idx,
edge_idx,
num_out,
dropout=0.1,
):
"""
Initializes a FubBERT object.
Parameters
----------
ntoken : int
The maximum number of tokens the embedding layer should expect. This
is the same as the size of the vocabulary.
ninp : int
The hidden dimension that should be used for embedding and input
to the Transformer encoder.
nhead : int
The number of heads to use in the Transformer encoder.
nhid : int
The size of the hidden dimension to use throughout the Transformer
encoder.
nlayers : int
The number of layers to use in a single head of the Transformer
encoder.
token_reduction : str
The type of token reduction to use. This can be either 'mean' or
'cls'.
padding_idx : int
The index used as padding for the input sequences.
cls_idx : int
The index used as the cls token for the input sequences.
edge_idx : int
The index used as the edge token for the input sequences.
num_out : int
The number of outputs to predict with the model.
dropout : float, optional
The fractional dropout to apply to the model. The default is 0.1.
Returns
-------
None.
"""
super(FupBERTModel, self).__init__()
# Store the input parameters
self.ntoken = ntoken
self.ninp = ninp
self.nhead = nhead
self.nhid = nhid
self.nlayers = nlayers
self.token_reduction = token_reduction
self.padding_idx = padding_idx
self.cls_idx = cls_idx
self.edge_idx = edge_idx
self.num_out = num_out
self.dropout = dropout
# Set the model parameters
self.model_type = "Transformer Encoder"
self.embedding = nn.Embedding(
self.ntoken, self.ninp, padding_idx=self.padding_idx
)
self.pos_encoder = PositionalEncoding(self.ninp, self.dropout)
encoder_layers = nn.TransformerEncoderLayer(
self.ninp,
self.nhead,
self.nhid,
self.dropout,
activation="gelu",
batch_first=True,
)
self.transformer_encoder = nn.TransformerEncoder(encoder_layers, self.nlayers)
self.pred_head = nn.Linear(self.ninp, self.num_out)
def _generate_src_key_mask(self, src):
mask = src == self.padding_idx
mask = mask.type(torch.bool)
return mask
def forward(self, src):
"""
Perform a forward pass of the module.
Parameters
----------
src : tensor
The input tensor. The shape should be (batch size, sequence length).
Returns
-------
output : tensor
The output tensor. The shape will be (batch size, num_out).
"""
src = self.get_embeddings(src)
output = self.pred_head(src)
return output
def get_embeddings(self, src):
"""
Perform a forward pass of the module excluding the classification layers. This
will return the embeddings from the encoder.
Parameters
----------
src : tensor
The input tensor. The shape should be (batch size, sequence length).
Returns
-------
embeds : tensor
The output tensor of sequence embeddings. The shape should be
(batch size, self.ninp)
"""
src_mask = self._generate_src_key_mask(src)
x = self.embedding(src)
x = self.pos_encoder(x)
x = self.transformer_encoder(x, src_key_padding_mask=src_mask)
# Mask the data based on the token reduction strategy
if self.token_reduction == "mean":
pad_mask = src == self.padding_idx
cls_mask = src == self.cls_idx
edge_mask = src == self.edge_idx
mask = torch.logical_or(pad_mask, cls_mask)
mask = torch.logical_or(mask, edge_mask)
# Apply the mask
x[mask[:, : x.shape[1]]] = torch.nan
# Take the mean of the embeddings
embeds = torch.nanmean(x, dim=1)
elif self.token_reduction == "cls":
embeds = x[:, 0, :]
else:
raise ValueError(
"Token reduction must be mean or cls. "
"Recieved {}".format(self.token_reduction)
)
return embeds