|
import copy |
|
import random |
|
from typing import List |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
class LossFunction(nn.Module): |
|
def __init__(self,main_config): |
|
super(LossFunction, self).__init__() |
|
self.train_config = main_config["train_config"] |
|
self.model_config = main_config["model_config"] |
|
self.batch_per_epoch = self.train_config.batch_per_epoch |
|
self.warm_up_annealing = ( |
|
self.train_config.warmup_epoch * self.batch_per_epoch |
|
) |
|
self.num_embed_hidden = self.model_config.num_embed_hidden |
|
self.batch_idx = 0 |
|
self.loss_anealing_term = 0 |
|
|
|
|
|
class_weights = self.model_config.class_weights |
|
|
|
class_weights = torch.FloatTensor([float(x) for x in class_weights]) |
|
|
|
if self.model_config.num_classes > 2: |
|
self.clf_loss_fn = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=self.train_config.label_smoothing_clf,reduction='none') |
|
else: |
|
self.clf_loss_fn = self.focal_loss |
|
|
|
|
|
|
|
def cosine_similarity_matrix( |
|
self, gene_embedd: torch.Tensor, second_input_embedd: torch.Tensor, annealing=True |
|
) -> torch.Tensor: |
|
|
|
|
|
|
|
|
|
|
|
|
|
assert gene_embedd.size(0) == second_input_embedd.size(0) |
|
|
|
cosine_similarity = torch.matmul(gene_embedd, second_input_embedd.T) |
|
|
|
if annealing: |
|
if self.batch_idx < self.warm_up_annealing: |
|
self.loss_anealing_term = 1 + ( |
|
self.batch_idx / self.warm_up_annealing |
|
) * torch.sqrt(torch.tensor(self.num_embed_hidden)) |
|
|
|
cosine_similarity *= self.loss_anealing_term |
|
|
|
return cosine_similarity |
|
def get_similar_labels(self,y:torch.Tensor): |
|
''' |
|
This function recieves y, the labels tensor |
|
It creates a list of lists containing at every index a list(min_len = 2) of the indices of the labels that are similar |
|
''' |
|
|
|
labels_y = y[:,0].cpu().detach().numpy() |
|
|
|
|
|
idx_sort = np.argsort(labels_y) |
|
|
|
|
|
sorted_records_array = labels_y[idx_sort] |
|
|
|
|
|
vals, idx_start, count = np.unique(sorted_records_array, return_counts=True, return_index=True) |
|
|
|
|
|
res = np.split(idx_sort, idx_start[1:]) |
|
|
|
vals = vals[count > 1] |
|
res = filter(lambda x: x.size > 1, res) |
|
|
|
indices_similar_labels = [] |
|
similar_labels = [] |
|
for r in res: |
|
indices_similar_labels.append(list(r)) |
|
similar_labels.append(list(labels_y[r])) |
|
|
|
return indices_similar_labels,similar_labels |
|
|
|
def get_triplet_samples(self,indices_similar_labels,similar_labels): |
|
''' |
|
This function creates three lists, positives, anchors and negatives |
|
Each index in the three lists correpond to a single triplet |
|
''' |
|
positives,anchors,negatives = [],[],[] |
|
for idx_similar_labels in indices_similar_labels: |
|
random_indices = random.sample(range(len(idx_similar_labels)), 2) |
|
positives.append(idx_similar_labels[random_indices[0]]) |
|
anchors.append(idx_similar_labels[random_indices[1]]) |
|
|
|
negatives = copy.deepcopy(positives) |
|
random.shuffle(negatives) |
|
while (np.array(positives) == np.array(negatives)).any(): |
|
random.shuffle(negatives) |
|
|
|
return positives,anchors,negatives |
|
def get_triplet_loss(self,y,gene_embedd,second_input_embedd): |
|
''' |
|
This function computes triplet loss by creating triplet samples of positives, negatives and anchors |
|
The objective is to decrease the distance of the embeddings between the anchors and the positives |
|
while increasing the distance between the anchor and the negatives. |
|
This is done seperately for both the embeddings, gene embedds 0 and ss embedds 1 |
|
''' |
|
|
|
indices_similar_labels,similar_labels = self.get_similar_labels(y) |
|
|
|
if len(indices_similar_labels)>1: |
|
|
|
positives,anchors,negatives = self.get_triplet_samples(indices_similar_labels,similar_labels) |
|
|
|
gene_embedd_triplet_loss = self.triplet_loss(gene_embedd[positives,:], |
|
gene_embedd[anchors,:], |
|
gene_embedd[negatives,:]) |
|
|
|
second_input_embedd_triplet_loss = self.triplet_loss(second_input_embedd[positives,:], |
|
second_input_embedd[anchors,:], |
|
second_input_embedd[negatives,:]) |
|
return gene_embedd_triplet_loss + second_input_embedd_triplet_loss |
|
else: |
|
return 0 |
|
|
|
def focal_loss(self,predicted_labels,y): |
|
y = y.unsqueeze(dim=1) |
|
y_new = torch.zeros(y.shape[0], 2).type(torch.cuda.FloatTensor) |
|
y_new[range(y.shape[0]), y[:,0]]=1 |
|
BCE_loss = F.binary_cross_entropy_with_logits(predicted_labels.float(), y_new.float(), reduction='none') |
|
pt = torch.exp(-BCE_loss) |
|
F_loss = (1-pt)**2 * BCE_loss |
|
loss = 10*F_loss.mean() |
|
return loss |
|
|
|
def contrastive_loss(self,cosine_similarity,batch_size): |
|
j = -torch.sum(torch.diagonal(cosine_similarity)) |
|
|
|
cosine_similarity.diagonal().copy_(torch.zeros(cosine_similarity.size(0))) |
|
|
|
j = (1 - self.train_config.label_smoothing_sim) * j + ( |
|
self.train_config.label_smoothing_sim / (cosine_similarity.size(0) * (cosine_similarity.size(0) - 1)) |
|
) * torch.sum(cosine_similarity) |
|
|
|
j += torch.sum(torch.logsumexp(cosine_similarity, dim=0)) |
|
|
|
if j < 0: |
|
j = j-j |
|
return j/batch_size |
|
|
|
def forward(self, embedds: List[torch.Tensor], y=None) -> torch.Tensor: |
|
self.batch_idx += 1 |
|
gene_embedd, second_input_embedd, predicted_labels,curr_epoch = embedds |
|
|
|
|
|
loss = self.clf_loss_fn(predicted_labels,y.squeeze()) |
|
|
|
|
|
|
|
return loss |
|
|