Spaces:
Running
Running
# Adapted from https://github.com/ubisoft/ubisoft-laforge-daft-exprt Apache License Version 2.0 | |
# Unsupervised Domain Adaptation by Backpropagation | |
import torch | |
import torch.nn as nn | |
from torch.autograd import Function | |
from torch.nn.utils import weight_norm | |
class GradientReversalFunction(Function): | |
def forward(ctx, x, lambda_): | |
ctx.lambda_ = lambda_ | |
return x.clone() | |
def backward(ctx, grads): | |
lambda_ = ctx.lambda_ | |
lambda_ = grads.new_tensor(lambda_) | |
dx = -lambda_ * grads | |
return dx, None | |
class GradientReversal(torch.nn.Module): | |
''' Gradient Reversal Layer | |
Y. Ganin, V. Lempitsky, | |
"Unsupervised Domain Adaptation by Backpropagation", | |
in ICML, 2015. | |
Forward pass is the identity function | |
In the backward pass, upstream gradients are multiplied by -lambda (i.e. gradient are reversed) | |
''' | |
def __init__(self, lambda_reversal=1): | |
super(GradientReversal, self).__init__() | |
self.lambda_ = lambda_reversal | |
def forward(self, x): | |
return GradientReversalFunction.apply(x, self.lambda_) | |
class SpeakerClassifier(nn.Module): | |
def __init__(self, idim, odim): | |
super(SpeakerClassifier, self).__init__() | |
self.classifier = nn.Sequential( | |
GradientReversal(lambda_reversal=1), | |
weight_norm(nn.Conv1d(idim, 1024, kernel_size=5, padding=2)), | |
nn.ReLU(), | |
weight_norm(nn.Conv1d(1024, 1024, kernel_size=5, padding=2)), | |
nn.ReLU(), | |
weight_norm(nn.Conv1d(1024, odim, kernel_size=5, padding=2)) | |
) | |
def forward(self, x): | |
''' Forward function of Speaker Classifier: | |
x = (B, idim, len) | |
''' | |
# pass through classifier | |
outputs = self.classifier(x) # (B, nb_speakers) | |
outputs = torch.mean(outputs, dim=-1) | |
return outputs | |