File size: 5,740 Bytes
aae4e29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
"""
© Battelle Memorial Institute 2023
Made available under the GNU General Public License v 2.0

BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY
FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW.  EXCEPT WHEN
OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES
PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED
OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE.  THE ENTIRE RISK AS
TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU.  SHOULD THE
PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING,
REPAIR OR CORRECTION.
"""

import torch
import torch.nn as nn

from .positional_encoding import PositionalEncoding


class FupBERTModel(nn.Module):
    """
    A class that extends torch.nn.Module that implements a custom Transformer
    encoder model to create a single embedding for Fup prediction.
    """

    def __init__(
        self,
        ntoken,
        ninp,
        nhead,
        nhid,
        nlayers,
        token_reduction,
        padding_idx,
        cls_idx,
        edge_idx,
        num_out,
        dropout=0.1,
    ):
        """
        Initializes a FubBERT object.

        Parameters
        ----------
        ntoken : int
            The maximum number of tokens the embedding layer should expect. This
            is the same as the size of the vocabulary.
        ninp : int
            The hidden dimension that should be used for embedding and input
            to the Transformer encoder.
        nhead : int
            The number of heads to use in the Transformer encoder.
        nhid : int
            The size of the hidden dimension to use throughout the Transformer
            encoder.
        nlayers : int
            The number of layers to use in a single head of the Transformer
            encoder.
        token_reduction : str
            The type of token reduction to use. This can be either 'mean' or
            'cls'.
        padding_idx : int
            The index used as padding for the input sequences.
        cls_idx : int
            The index used as the cls token for the input sequences.
        edge_idx : int
            The index used as the edge token for the input sequences.
        num_out : int
            The number of outputs to predict with the model.
        dropout : float, optional
            The fractional dropout to apply to the model. The default is 0.1.

        Returns
        -------
        None.

        """
        super(FupBERTModel, self).__init__()
        # Store the input parameters
        self.ntoken = ntoken
        self.ninp = ninp
        self.nhead = nhead
        self.nhid = nhid
        self.nlayers = nlayers
        self.token_reduction = token_reduction
        self.padding_idx = padding_idx
        self.cls_idx = cls_idx
        self.edge_idx = edge_idx
        self.num_out = num_out
        self.dropout = dropout
        # Set the model parameters
        self.model_type = "Transformer Encoder"
        self.embedding = nn.Embedding(
            self.ntoken, self.ninp, padding_idx=self.padding_idx
        )
        self.pos_encoder = PositionalEncoding(self.ninp, self.dropout)
        encoder_layers = nn.TransformerEncoderLayer(
            self.ninp,
            self.nhead,
            self.nhid,
            self.dropout,
            activation="gelu",
            batch_first=True,
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, self.nlayers)
        self.pred_head = nn.Linear(self.ninp, self.num_out)

    def _generate_src_key_mask(self, src):
        mask = src == self.padding_idx
        mask = mask.type(torch.bool)

        return mask

    def forward(self, src):
        """
        Perform a forward pass of the module.

        Parameters
        ----------
        src : tensor
            The input tensor. The shape should be (batch size, sequence length).

        Returns
        -------
        output : tensor
            The output tensor. The shape will be (batch size, num_out).

        """
        src = self.get_embeddings(src)
        output = self.pred_head(src)

        return output

    def get_embeddings(self, src):
        """
        Perform a forward pass of the module excluding the classification layers. This
        will return the embeddings from the encoder.

        Parameters
        ----------
        src : tensor
            The input tensor. The shape should be (batch size, sequence length).

        Returns
        -------
        embeds : tensor
            The output tensor of sequence embeddings. The shape should be
            (batch size, self.ninp)
        """
        src_mask = self._generate_src_key_mask(src)
        x = self.embedding(src)
        x = self.pos_encoder(x)
        x = self.transformer_encoder(x, src_key_padding_mask=src_mask)
        # Mask the data based on the token reduction strategy
        if self.token_reduction == "mean":
            pad_mask = src == self.padding_idx
            cls_mask = src == self.cls_idx
            edge_mask = src == self.edge_idx
            mask = torch.logical_or(pad_mask, cls_mask)
            mask = torch.logical_or(mask, edge_mask)
            # Apply the mask
            x[mask[:, : x.shape[1]]] = torch.nan
            # Take the mean of the embeddings
            embeds = torch.nanmean(x, dim=1)
        elif self.token_reduction == "cls":
            embeds = x[:, 0, :]
        else:
            raise ValueError(
                "Token reduction must be mean or cls. "
                "Recieved {}".format(self.token_reduction)
            )

        return embeds