|
""" |
|
© Battelle Memorial Institute 2023 |
|
Made available under the GNU General Public License v 2.0 |
|
|
|
BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY |
|
FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN |
|
OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES |
|
PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED |
|
OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF |
|
MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS |
|
TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE |
|
PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, |
|
REPAIR OR CORRECTION. |
|
""" |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from .positional_encoding import PositionalEncoding |
|
|
|
|
|
class FupBERTModel(nn.Module): |
|
""" |
|
A class that extends torch.nn.Module that implements a custom Transformer |
|
encoder model to create a single embedding for Fup prediction. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
ntoken, |
|
ninp, |
|
nhead, |
|
nhid, |
|
nlayers, |
|
token_reduction, |
|
padding_idx, |
|
cls_idx, |
|
edge_idx, |
|
num_out, |
|
dropout=0.1, |
|
): |
|
""" |
|
Initializes a FubBERT object. |
|
|
|
Parameters |
|
---------- |
|
ntoken : int |
|
The maximum number of tokens the embedding layer should expect. This |
|
is the same as the size of the vocabulary. |
|
ninp : int |
|
The hidden dimension that should be used for embedding and input |
|
to the Transformer encoder. |
|
nhead : int |
|
The number of heads to use in the Transformer encoder. |
|
nhid : int |
|
The size of the hidden dimension to use throughout the Transformer |
|
encoder. |
|
nlayers : int |
|
The number of layers to use in a single head of the Transformer |
|
encoder. |
|
token_reduction : str |
|
The type of token reduction to use. This can be either 'mean' or |
|
'cls'. |
|
padding_idx : int |
|
The index used as padding for the input sequences. |
|
cls_idx : int |
|
The index used as the cls token for the input sequences. |
|
edge_idx : int |
|
The index used as the edge token for the input sequences. |
|
num_out : int |
|
The number of outputs to predict with the model. |
|
dropout : float, optional |
|
The fractional dropout to apply to the model. The default is 0.1. |
|
|
|
Returns |
|
------- |
|
None. |
|
|
|
""" |
|
super(FupBERTModel, self).__init__() |
|
|
|
self.ntoken = ntoken |
|
self.ninp = ninp |
|
self.nhead = nhead |
|
self.nhid = nhid |
|
self.nlayers = nlayers |
|
self.token_reduction = token_reduction |
|
self.padding_idx = padding_idx |
|
self.cls_idx = cls_idx |
|
self.edge_idx = edge_idx |
|
self.num_out = num_out |
|
self.dropout = dropout |
|
|
|
self.model_type = "Transformer Encoder" |
|
self.embedding = nn.Embedding( |
|
self.ntoken, self.ninp, padding_idx=self.padding_idx |
|
) |
|
self.pos_encoder = PositionalEncoding(self.ninp, self.dropout) |
|
encoder_layers = nn.TransformerEncoderLayer( |
|
self.ninp, |
|
self.nhead, |
|
self.nhid, |
|
self.dropout, |
|
activation="gelu", |
|
batch_first=True, |
|
) |
|
self.transformer_encoder = nn.TransformerEncoder(encoder_layers, self.nlayers) |
|
self.pred_head = nn.Linear(self.ninp, self.num_out) |
|
|
|
def _generate_src_key_mask(self, src): |
|
mask = src == self.padding_idx |
|
mask = mask.type(torch.bool) |
|
|
|
return mask |
|
|
|
def forward(self, src): |
|
""" |
|
Perform a forward pass of the module. |
|
|
|
Parameters |
|
---------- |
|
src : tensor |
|
The input tensor. The shape should be (batch size, sequence length). |
|
|
|
Returns |
|
------- |
|
output : tensor |
|
The output tensor. The shape will be (batch size, num_out). |
|
|
|
""" |
|
src = self.get_embeddings(src) |
|
output = self.pred_head(src) |
|
|
|
return output |
|
|
|
def get_embeddings(self, src): |
|
""" |
|
Perform a forward pass of the module excluding the classification layers. This |
|
will return the embeddings from the encoder. |
|
|
|
Parameters |
|
---------- |
|
src : tensor |
|
The input tensor. The shape should be (batch size, sequence length). |
|
|
|
Returns |
|
------- |
|
embeds : tensor |
|
The output tensor of sequence embeddings. The shape should be |
|
(batch size, self.ninp) |
|
""" |
|
src_mask = self._generate_src_key_mask(src) |
|
x = self.embedding(src) |
|
x = self.pos_encoder(x) |
|
x = self.transformer_encoder(x, src_key_padding_mask=src_mask) |
|
|
|
if self.token_reduction == "mean": |
|
pad_mask = src == self.padding_idx |
|
cls_mask = src == self.cls_idx |
|
edge_mask = src == self.edge_idx |
|
mask = torch.logical_or(pad_mask, cls_mask) |
|
mask = torch.logical_or(mask, edge_mask) |
|
|
|
x[mask[:, : x.shape[1]]] = torch.nan |
|
|
|
embeds = torch.nanmean(x, dim=1) |
|
elif self.token_reduction == "cls": |
|
embeds = x[:, 0, :] |
|
else: |
|
raise ValueError( |
|
"Token reduction must be mean or cls. " |
|
"Recieved {}".format(self.token_reduction) |
|
) |
|
|
|
return embeds |
|
|