import torch import torch.nn as nn import math class SpeechEmbedding(nn.Module): """ """ def __init__(self, args, _): super(SpeechEmbedding, self).__init__() self.conv = Conv1dModule(args) self.sinusoidalpos = False self.emb_size = args.emb_size if "sinusoidalpos" in args.embedding: self.sinusoidalpos = True def forward(self, src, _): """Embed inputs. Args: src (FloatTensor): Sequence of word vectors ``(batch_size, seq_len, self.dim)`` """ speech_emb = self.conv(src) if self.sinusoidalpos: return speech_emb * math.sqrt(self.emb_size) else: return speech_emb class Transpose_module(nn.Module): def __init__(self): super().__init__() def forward(self, x): return x.transpose(-2, -1) class Conv1dModule(nn.Module): """ Convolutional subsampler: a stack of 1D convolution (along temporal dimension) followed by non-linear activation via gated linear units (https://arxiv.org/abs/1911.08460) """ def __init__(self, args): super(Conv1dModule, self).__init__() self.embedding_dim = args.emb_size self.norm_mode = None self.feature_grad_mult = 1.0 self.conv_bias = True self.dropout_input = 0.0 self.use_glu = True if args.data_processor == "s2t" else False self.padding = True self.conv_channels = args.conv_channels self.audio_feature_size = args.audio_feature_size self.kernel_sizes = args.conv_kernel_sizes self.strides = [2 for _ in range(len(self.kernel_sizes))] self.conv_layers = nn.ModuleList() def conv_layer_block( in_channels, out_channels, kernel_size, stride, padding, norm_mode=None, conv_bias=False, ): def make_conv(in_channels, out_channels, kernel_size, stride, padding, conv_bias): conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=conv_bias) nn.init.kaiming_normal_(conv.weight) return conv if norm_mode == "layer": return nn.Sequential( make_conv(in_channels, out_channels, kernel_size, stride, padding, conv_bias), Transpose_module(), nn.LayerNorm(out_channels, eps=1e-5, elementwise_affine=True), Transpose_module(), nn.GELU(), ) elif norm_mode == "group": return nn.Sequential( make_conv(in_channels, out_channels, kernel_size, stride, padding, conv_bias), nn.GroupNorm(out_channels, out_channels, eps=1e-5, affine=True), nn.GELU(), ) elif self.use_glu: return nn.Sequential( make_conv(in_channels, out_channels, kernel_size, stride, padding, conv_bias), ) else: return nn.Sequential( make_conv(in_channels, out_channels, kernel_size, stride, padding, conv_bias), nn.GELU(), ) assert len(self.strides) == len(self.kernel_sizes), "strides and kernel_sizes are not matched" assert len(self.strides) == len(self.conv_channels), "strides and conv_channels are not matched" in_channel = self.conv_channels[0] // 2 for i, (k, s, c) in enumerate(zip(self.kernel_sizes, self.strides, self.conv_channels)): if self.audio_feature_size == 1: in_channel = c if self.norm_mode == "group" and i != 0: self.norm_mode = None if self.padding: padding = k // 2 else: padding = 0 self.conv_layers.append( conv_layer_block( self.audio_feature_size if i == 0 else in_channel, c, k, s, padding, norm_mode=self.norm_mode, conv_bias=self.conv_bias, ) ) def forward(self, input_features, mask_indices=None, mask_channel_indices=None): if len(input_features.size()) == 2: hidden_states = input_features.unsqueeze(1) # wav B x T -> B x (C x D) x T else: hidden_states = input_features.transpose(1, 2).contiguous() #acoustic feature B x T x (C x D) -> B x (C x D) x T for conv in self.conv_layers: hidden_states = conv(hidden_states) if self.use_glu: hidden_states = nn.functional.glu(hidden_states, dim=1) hidden_states = hidden_states.transpose(1, 2).contiguous() # -> B x T x (C x D) return hidden_states