""" © Battelle Memorial Institute 2023 Made available under the GNU General Public License v 2.0 BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, REPAIR OR CORRECTION. """ import numpy as np import torch import torch.nn as nn class PositionalEncoding(nn.Module): """ A class that extends torch.nn.Module that applies positional encoding for use in the Transformer architecture. """ def __init__(self, d_model, dropout=0.1, max_len=5000): """ Initializes a PositionalEncoding object. Parameters ---------- d_model : int The size of the model's embedding dimension. dropout : float, optional The fractional dropout to apply to the embedding. The default is 0.1. max_len : int, optional The maximum potential input sequnce length. The default is 5000. Returns ------- None. """ super(PositionalEncoding, self).__init__() # Create the dropout self.dropout = nn.Dropout(p=dropout) # Create the encoding pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp( torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model) ) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) self.register_buffer("pe", pe) def forward(self, x): """ Perform a forward pass of the module. Parameters ---------- x : tensor The input tensor to apply the positional encoding to. Returns ------- tensor The resulting tensor after applying the positional encoding to the input. """ x = x + self.pe[:, : x.size(1)] return self.dropout(x)