File size: 902 Bytes
f831146 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
import torch
import torch.nn.functional as F
from models.cross_entropy_model import FBankCrossEntropyNet
def get_cosine_distance(a, b):
a = torch.from_numpy(a)
b = torch.from_numpy(b)
return (1 - F.cosine_similarity(a, b)).numpy()
MODEL_PATH = 'weights/triplet_loss_trained_model.pth'
model_instance = FBankCrossEntropyNet()
model_instance.load_state_dict(torch.load(MODEL_PATH, map_location=lambda storage, loc: storage))
model_instance = model_instance.double()
model_instance.eval()
### I think the instance model was train in stage 2 (constrative learning) ###
def get_embeddings_instance(x):
x = torch.from_numpy(x)
with torch.no_grad():
embeddings = model_instance(x)
return embeddings.numpy()
def get_embeddings(x , model):
model.double()
x = torch.from_numpy(x)
with torch.no_grad():
embeddings = model(x)
return embeddings.numpy() |