from data_gen.tts.emotion.params_model import * from data_gen.tts.emotion.params_data import * from torch.nn.utils import clip_grad_norm_ from scipy.optimize import brentq from torch import nn import numpy as np import torch class EmotionEncoder(nn.Module): def __init__(self, device, loss_device): super().__init__() self.loss_device = loss_device # Network defition self.lstm = nn.LSTM(input_size=mel_n_channels, hidden_size=model_hidden_size, num_layers=model_num_layers, batch_first=True).to(device) self.linear = nn.Linear(in_features=model_hidden_size, out_features=model_embedding_size).to(device) self.relu = torch.nn.ReLU().to(device) # Cosine similarity scaling (with fixed initial parameter values) self.similarity_weight = nn.Parameter(torch.tensor([10.])).to(loss_device) self.similarity_bias = nn.Parameter(torch.tensor([-5.])).to(loss_device) # Loss self.loss_fn = nn.CrossEntropyLoss().to(loss_device) def do_gradient_ops(self): # Gradient scale self.similarity_weight.grad *= 0.01 self.similarity_bias.grad *= 0.01 # Gradient clipping clip_grad_norm_(self.parameters(), 3, norm_type=2) def forward(self, utterances, hidden_init=None): """ Computes the embeddings of a batch of utterance spectrograms. :param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape (batch_size, n_frames, n_channels) :param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers, batch_size, hidden_size). Will default to a tensor of zeros if None. :return: the embeddings as a tensor of shape (batch_size, embedding_size) """ # Pass the input through the LSTM layers and retrieve all outputs, the final hidden state # and the final cell state. out, (hidden, cell) = self.lstm(utterances, hidden_init) # We take only the hidden state of the last layer embeds_raw = self.relu(self.linear(hidden[-1])) # L2-normalize it embeds = embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True) return embeds def inference(self, utterances, hidden_init=None): """ Computes the embeddings of a batch of utterance spectrograms. :param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape (batch_size, n_frames, n_channels) :param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers, batch_size, hidden_size). Will default to a tensor of zeros if None. :return: the embeddings as a tensor of shape (batch_size, embedding_size) """ # Pass the input through the LSTM layers and retrieve all outputs, the final hidden state # and the final cell state. out, (hidden, cell) = self.lstm(utterances, hidden_init) return hidden[-1]