File size: 1,708 Bytes
d0e1f8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch

def Pdist2(x, y):
	"""compute the paired distance between x and y."""
	x_norm = (x ** 2).sum(1).view(-1, 1)
	if y is not None:
		y_norm = (y ** 2).sum(1).view(1, -1)
	else:
		y = x
		y_norm = x_norm.view(1, -1)
	Pdist = x_norm + y_norm - 2.0 * torch.mm(x, torch.transpose(y, 0, 1))
	Pdist[Pdist<0]=0
	return Pdist

def MMD_batch2(Fea, len_s, Fea_org, sigma, sigma0=0.1, epsilon = 10**(-10), is_smooth=True, is_var_computed=True, use_1sample_U=True, coeff_xy=2):
	X = Fea[0:len_s, :] 
	Y = Fea[len_s:, :] 
	if is_smooth:
		X_org = Fea_org[0:len_s, :] 
		Y_org = Fea_org[len_s:, :]  
	L = 1 # generalized Gaussian (if L>1)

	nx = X.shape[0] 
	ny = Y.shape[0] 
	Dxx = Pdist2(X, X) 
	Dyy = torch.zeros(Fea.shape[0] - len_s, 1).to(Dxx.device) 
	# Dyy = Pdist2(Y, Y) 
	Dxy = Pdist2(X, Y).transpose(0,1) 
	if is_smooth:
		Dxx_org = Pdist2(X_org, X_org)
		Dyy_org = torch.zeros(Fea.shape[0] - len_s, 1).to(Dxx.device)
		# Dyy_org = Pdist2(Y_org, Y_org) # 1,1  0
		Dxy_org = Pdist2(X_org, Y_org).transpose(0,1)  


	if is_smooth:
		Kx = (1-epsilon) * torch.exp(-(Dxx / sigma0)**L -Dxx_org / sigma) + epsilon * torch.exp(-Dxx_org / sigma) 
		Ky = (1-epsilon) * torch.exp(-(Dyy / sigma0)**L -Dyy_org / sigma) + epsilon * torch.exp(-Dyy_org / sigma)
		Kxy = (1-epsilon) * torch.exp(-(Dxy / sigma0)**L -Dxy_org / sigma) + epsilon * torch.exp(-Dxy_org / sigma)
	else:
		Kx = torch.exp(-Dxx / sigma0)
		Ky = torch.exp(-Dyy / sigma0)
		Kxy = torch.exp(-Dxy / sigma0)

	nx = Kx.shape[0] 

	is_unbiased = False
	if 1:
		xx = torch.div((torch.sum(Kx)), (nx * nx)) 
		yy = Ky.reshape(-1) 

		# one-sample U-statistic.

		xy = torch.div(torch.sum(Kxy, dim = 1), (nx )) 
		
		mmd2 = xx - 2 * xy + yy
	return mmd2