Spaces:
Sleeping
Sleeping
import torch | |
def Pdist2(x, y): | |
"""compute the paired distance between x and y.""" | |
x_norm = (x ** 2).sum(1).view(-1, 1) | |
if y is not None: | |
y_norm = (y ** 2).sum(1).view(1, -1) | |
else: | |
y = x | |
y_norm = x_norm.view(1, -1) | |
Pdist = x_norm + y_norm - 2.0 * torch.mm(x, torch.transpose(y, 0, 1)) | |
Pdist[Pdist<0]=0 | |
return Pdist | |
def MMD_batch2(Fea, len_s, Fea_org, sigma, sigma0=0.1, epsilon = 10**(-10), is_smooth=True, is_var_computed=True, use_1sample_U=True, coeff_xy=2): | |
X = Fea[0:len_s, :] | |
Y = Fea[len_s:, :] | |
if is_smooth: | |
X_org = Fea_org[0:len_s, :] | |
Y_org = Fea_org[len_s:, :] | |
L = 1 # generalized Gaussian (if L>1) | |
nx = X.shape[0] | |
ny = Y.shape[0] | |
Dxx = Pdist2(X, X) | |
Dyy = torch.zeros(Fea.shape[0] - len_s, 1).to(Dxx.device) | |
# Dyy = Pdist2(Y, Y) | |
Dxy = Pdist2(X, Y).transpose(0,1) | |
if is_smooth: | |
Dxx_org = Pdist2(X_org, X_org) | |
Dyy_org = torch.zeros(Fea.shape[0] - len_s, 1).to(Dxx.device) | |
# Dyy_org = Pdist2(Y_org, Y_org) # 1,1 0 | |
Dxy_org = Pdist2(X_org, Y_org).transpose(0,1) | |
if is_smooth: | |
Kx = (1-epsilon) * torch.exp(-(Dxx / sigma0)**L -Dxx_org / sigma) + epsilon * torch.exp(-Dxx_org / sigma) | |
Ky = (1-epsilon) * torch.exp(-(Dyy / sigma0)**L -Dyy_org / sigma) + epsilon * torch.exp(-Dyy_org / sigma) | |
Kxy = (1-epsilon) * torch.exp(-(Dxy / sigma0)**L -Dxy_org / sigma) + epsilon * torch.exp(-Dxy_org / sigma) | |
else: | |
Kx = torch.exp(-Dxx / sigma0) | |
Ky = torch.exp(-Dyy / sigma0) | |
Kxy = torch.exp(-Dxy / sigma0) | |
nx = Kx.shape[0] | |
is_unbiased = False | |
if 1: | |
xx = torch.div((torch.sum(Kx)), (nx * nx)) | |
yy = Ky.reshape(-1) | |
# one-sample U-statistic. | |
xy = torch.div(torch.sum(Kxy, dim = 1), (nx )) | |
mmd2 = xx - 2 * xy + yy | |
return mmd2 | |