dbf8b7e 358ab8f dbf8b7e 358ab8f dbf8b7e 358ab8f
1
2
3
4
5
6
7
8
9
10
import torch def kde(x, std=0.1): # use a gaussian kernel to estimate density x = x.half() # Do it in half precision scores = (-torch.cdist(x, x) ** 2 / (2 * std**2)).exp() density = scores.sum(dim=-1) return density