Spaces:
Running
Running
File size: 5,576 Bytes
02259d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
# This is Multi-reference timbre encoder
import torch
from torch import nn
from torch.nn.utils import remove_weight_norm, weight_norm
from module.attentions import MultiHeadAttention
class MRTE(nn.Module):
def __init__(self,
content_enc_channels=192,
hidden_size=512,
out_channels=192,
kernel_size=5,
n_heads=4,
ge_layer = 2
):
super(MRTE, self).__init__()
self.cross_attention = MultiHeadAttention(hidden_size,hidden_size,n_heads)
self.c_pre = nn.Conv1d(content_enc_channels,hidden_size, 1)
self.text_pre = nn.Conv1d(content_enc_channels,hidden_size, 1)
self.c_post = nn.Conv1d(hidden_size,out_channels, 1)
def forward(self, ssl_enc, ssl_mask, text, text_mask, ge, test=None):
if(ge==None):ge=0
attn_mask = text_mask.unsqueeze(2) * ssl_mask.unsqueeze(-1)
ssl_enc = self.c_pre(ssl_enc * ssl_mask)
text_enc = self.text_pre(text * text_mask)
if test != None:
if test == 0:
x = self.cross_attention(ssl_enc * ssl_mask, text_enc * text_mask, attn_mask) + ssl_enc + ge
elif test == 1:
x = ssl_enc + ge
elif test ==2:
x = self.cross_attention(ssl_enc*0 * ssl_mask, text_enc * text_mask, attn_mask) + ge
else:
raise ValueError("test should be 0,1,2")
else:
x = self.cross_attention(ssl_enc * ssl_mask, text_enc * text_mask, attn_mask) + ssl_enc + ge
x = self.c_post(x * ssl_mask)
return x
class SpeakerEncoder(torch.nn.Module):
def __init__(self, mel_n_channels=80, model_num_layers=2, model_hidden_size=256, model_embedding_size=256):
super(SpeakerEncoder, self).__init__()
self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True)
self.linear = nn.Linear(model_hidden_size, model_embedding_size)
self.relu = nn.ReLU()
def forward(self, mels):
self.lstm.flatten_parameters()
_, (hidden, _) = self.lstm(mels.transpose(-1, -2))
embeds_raw = self.relu(self.linear(hidden[-1]))
return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)
class MELEncoder(nn.Module):
def __init__(self,
in_channels,
out_channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers)
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
def forward(self, x):
# print(x.shape,x_lengths.shape)
x = self.pre(x)
x = self.enc(x)
x = self.proj(x)
return x
class WN(torch.nn.Module):
def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers):
super(WN, self).__init__()
assert(kernel_size % 2 == 1)
self.hidden_channels =hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.in_layers = torch.nn.ModuleList()
self.res_skip_layers = torch.nn.ModuleList()
for i in range(n_layers):
dilation = dilation_rate ** i
padding = int((kernel_size * dilation - dilation) / 2)
in_layer = nn.Conv1d(hidden_channels, 2*hidden_channels, kernel_size,
dilation=dilation, padding=padding)
in_layer = weight_norm(in_layer)
self.in_layers.append(in_layer)
# last one is not necessary
if i < n_layers - 1:
res_skip_channels = 2 * hidden_channels
else:
res_skip_channels = hidden_channels
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
res_skip_layer = weight_norm(res_skip_layer, name='weight')
self.res_skip_layers.append(res_skip_layer)
def forward(self, x):
output = torch.zeros_like(x)
n_channels_tensor = torch.IntTensor([self.hidden_channels])
for i in range(self.n_layers):
x_in = self.in_layers[i](x)
acts = fused_add_tanh_sigmoid_multiply(
x_in,
n_channels_tensor)
res_skip_acts = self.res_skip_layers[i](acts)
if i < self.n_layers - 1:
res_acts = res_skip_acts[:,:self.hidden_channels,:]
x = (x + res_acts)
output = output + res_skip_acts[:,self.hidden_channels:,:]
else:
output = output + res_skip_acts
return output
def remove_weight_norm(self):
for l in self.in_layers:
remove_weight_norm(l)
for l in self.res_skip_layers:
remove_weight_norm(l)
@torch.jit.script
def fused_add_tanh_sigmoid_multiply(input, n_channels):
n_channels_int = n_channels[0]
t_act = torch.tanh(input[:, :n_channels_int, :])
s_act = torch.sigmoid(input[:, n_channels_int:, :])
acts = t_act * s_act
return acts
if __name__ == '__main__':
content_enc = torch.randn(3,192,100)
content_mask = torch.ones(3,1,100)
ref_mel = torch.randn(3,128,30)
ref_mask = torch.ones(3,1,30)
model = MRTE()
out = model(content_enc,content_mask,ref_mel,ref_mask)
print(out.shape) |