grad-svc / grad /reversal.py
maxmax20160403's picture
Upload 39 files
3aa4060
# Adapted from https://github.com/ubisoft/ubisoft-laforge-daft-exprt Apache License Version 2.0
# Unsupervised Domain Adaptation by Backpropagation
import torch
import torch.nn as nn
from torch.autograd import Function
from torch.nn.utils import weight_norm
class GradientReversalFunction(Function):
@staticmethod
def forward(ctx, x, lambda_):
ctx.lambda_ = lambda_
return x.clone()
@staticmethod
def backward(ctx, grads):
lambda_ = ctx.lambda_
lambda_ = grads.new_tensor(lambda_)
dx = -lambda_ * grads
return dx, None
class GradientReversal(torch.nn.Module):
''' Gradient Reversal Layer
Y. Ganin, V. Lempitsky,
"Unsupervised Domain Adaptation by Backpropagation",
in ICML, 2015.
Forward pass is the identity function
In the backward pass, upstream gradients are multiplied by -lambda (i.e. gradient are reversed)
'''
def __init__(self, lambda_reversal=1):
super(GradientReversal, self).__init__()
self.lambda_ = lambda_reversal
def forward(self, x):
return GradientReversalFunction.apply(x, self.lambda_)
class SpeakerClassifier(nn.Module):
def __init__(self, idim, odim):
super(SpeakerClassifier, self).__init__()
self.classifier = nn.Sequential(
GradientReversal(lambda_reversal=1),
weight_norm(nn.Conv1d(idim, 1024, kernel_size=5, padding=2)),
nn.ReLU(),
weight_norm(nn.Conv1d(1024, 1024, kernel_size=5, padding=2)),
nn.ReLU(),
weight_norm(nn.Conv1d(1024, odim, kernel_size=5, padding=2))
)
def forward(self, x):
''' Forward function of Speaker Classifier:
x = (B, idim, len)
'''
# pass through classifier
outputs = self.classifier(x) # (B, nb_speakers)
outputs = torch.mean(outputs, dim=-1)
return outputs