import math import torch.nn as nn import torch import torch.nn.functional as F import pflow.models.components.vits_modules as modules import pflow.models.components.commons as commons class Mish(nn.Module): def forward(self, x): return x * torch.tanh(F.softplus(x)) class SinusoidalPosEmb(nn.Module): def __init__(self, dim): super(SinusoidalPosEmb, self).__init__() self.dim = dim def forward(self, x, scale=1000): if x.ndim < 1: x = x.unsqueeze(0) device = x.device half_dim = self.dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) emb = torch.cat((emb.sin(), emb.cos()), dim=-1) return emb class VitsWNDecoder(nn.Module): def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, pe_scale=1000 ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.hidden_channels = hidden_channels self.kernel_size = kernel_size self.dilation_rate = dilation_rate self.n_layers = n_layers self.gin_channels = gin_channels self.pe_scale = pe_scale self.time_pos_emb = SinusoidalPosEmb(hidden_channels * 2) dim = hidden_channels * 2 self.mlp = nn.Sequential( nn.Linear(dim, dim * 4), Mish(), nn.Linear(dim * 4, dim) ) self.pre = nn.Conv1d(in_channels, hidden_channels, 1) self.enc = modules.WN(hidden_channels * 2, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels) self.proj = nn.Conv1d(hidden_channels * 2, out_channels, 1) def forward(self, x, x_mask, mu, t, *args, **kwargs): # x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), # 1).to(x.dtype) t = self.time_pos_emb(t, scale=self.pe_scale) t = self.mlp(t) x = self.pre(x) * x_mask mu = self.pre(mu) x = torch.cat((x, mu), dim=1) x = self.enc(x, x_mask, g=t) stats = self.proj(x) * x_mask return stats