import math import torch import torch.nn as nn from tencentpretrain.utils.constants import * class SinusoidalposEmbedding(nn.Module): """Sinusoidal positional encoding for non-recurrent neural networks. Implementation based on "Attention Is All You Need" :cite:`DBLP:journals/corr/VaswaniSPUJGKP17` Args: dropout (float): dropout parameter dim (int): embedding size """ def __init__(self, args, _): super(SinusoidalposEmbedding, self).__init__() if "speech" in args.embedding: self.max_seq_length = max(args.max_seq_length, args.max_audio_frames) self.arrange_sincos_cross = False else: self.max_seq_length = args.max_seq_length self.arrange_sincos_cross = True self.emb_size = args.emb_size half_dim = self.emb_size // 2 value = math.log(10000) / (half_dim - 1) half_exp = torch.exp(torch.arange(half_dim, dtype=torch.float) * -value) half_mat = torch.arange(self.max_seq_length, dtype=torch.float).unsqueeze( 1 ) * half_exp.unsqueeze(0) if not self.arrange_sincos_cross: #Same as the implementation of huggingface/transformers, tensor2tensor self.emb = torch.cat([torch.sin(half_mat), torch.cos(half_mat)], dim=1).view( self.max_seq_length, -1 ) else: #Implementation based on "Attention Is All You Need" self.emb = torch.zeros(self.max_seq_length, args.emb_size) self.emb[:, 0::2] = torch.sin(half_mat) self.emb[:, 1::2] = torch.cos(half_mat) if self.emb_size % 2 == 1: # zero pad self.emb = torch.cat([self.emb, torch.zeros(self.max_seq_length, 1)], dim=1) self.emb[args.tokenizer.vocab.get(PAD_TOKEN), :] = 0 def forward(self, src, seg): """Embed inputs. Args: emb (FloatTensor): Sequence of word vectors ``(batch_size, seq_len, self.dim)`` step (int or NoneType): If stepwise (``seq_len = 1``), use the encoding for this position. """ if seg is not None: batch_size, seq_length = seg.size() device = seg.device no_pad_num = seg.sum(dim=-1) else: batch_size, seq_length = src.size() device = src.device no_pad_num = (src != 0).sum(dim=-1) emb = torch.zeros(batch_size, seq_length, self.emb_size) for i in range(batch_size): emb[i, :no_pad_num[i], :] = self.emb[2: no_pad_num[i]+2] return emb.to(device)