|
from abc import ABC, abstractmethod |
|
import torch |
|
from torch import nn |
|
from singleVis.backend import compute_cross_entropy_tf, convert_distance_to_probability, compute_cross_entropy |
|
|
|
import torch |
|
torch.manual_seed(0) |
|
torch.cuda.manual_seed_all(0) |
|
|
|
|
|
"""Losses modules for preserving four propertes""" |
|
|
|
|
|
class Loss(nn.Module): |
|
def __init__(self) -> None: |
|
super().__init__() |
|
|
|
@abstractmethod |
|
def forward(self, *args, **kwargs): |
|
pass |
|
|
|
class UmapLoss(nn.Module): |
|
def __init__(self, negative_sample_rate, device, _a=1.0, _b=1.0, repulsion_strength=1.0): |
|
super(UmapLoss, self).__init__() |
|
|
|
self._negative_sample_rate = negative_sample_rate |
|
self._a = _a, |
|
self._b = _b, |
|
self._repulsion_strength = repulsion_strength |
|
self.DEVICE = torch.device(device) |
|
|
|
@property |
|
def a(self): |
|
return self._a[0] |
|
|
|
@property |
|
def b(self): |
|
return self._b[0] |
|
|
|
def forward(self, embedding_to, embedding_from): |
|
batch_size = embedding_to.shape[0] |
|
|
|
embedding_neg_to = torch.repeat_interleave(embedding_to, self._negative_sample_rate, dim=0) |
|
repeat_neg = torch.repeat_interleave(embedding_from, self._negative_sample_rate, dim=0) |
|
randperm = torch.randperm(repeat_neg.shape[0]) |
|
embedding_neg_from = repeat_neg[randperm] |
|
|
|
|
|
|
|
distance_embedding = torch.cat( |
|
( |
|
torch.norm(embedding_to - embedding_from, dim=1), |
|
torch.norm(embedding_neg_to - embedding_neg_from, dim=1), |
|
), |
|
dim=0, |
|
) |
|
probabilities_distance = convert_distance_to_probability( |
|
distance_embedding, self.a, self.b |
|
) |
|
probabilities_distance = probabilities_distance.to(self.DEVICE) |
|
|
|
|
|
probabilities_graph = torch.cat( |
|
(torch.ones(batch_size), torch.zeros(batch_size * self._negative_sample_rate)), dim=0, |
|
) |
|
probabilities_graph = probabilities_graph.to(device=self.DEVICE) |
|
|
|
|
|
(_, _, ce_loss) = compute_cross_entropy( |
|
probabilities_graph, |
|
probabilities_distance, |
|
repulsion_strength=self._repulsion_strength, |
|
) |
|
|
|
return torch.mean(ce_loss) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ReconstructionLoss(nn.Module): |
|
def __init__(self, beta=1.0,alpha=0.5): |
|
super(ReconstructionLoss, self).__init__() |
|
self._beta = beta |
|
self._alpha = alpha |
|
|
|
def forward(self, edge_to, edge_from, recon_to, recon_from, a_to, a_from): |
|
loss1 = torch.mean(torch.mean(torch.multiply(torch.pow((1+a_to), self._beta), torch.pow(edge_to - recon_to, 2)), 1)) |
|
loss2 = torch.mean(torch.mean(torch.multiply(torch.pow((1+a_from), self._beta), torch.pow(edge_from - recon_from, 2)), 1)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return (loss1 + loss2)/2 |
|
|
|
|
|
|
|
class SmoothnessLoss(nn.Module): |
|
def __init__(self, margin=0.0): |
|
super(SmoothnessLoss, self).__init__() |
|
self._margin = margin |
|
|
|
def forward(self, embedding, target, Coefficient): |
|
loss = torch.mean(Coefficient * torch.clamp(torch.norm(embedding-target, dim=1)-self._margin, min=0)) |
|
return loss |
|
|
|
|
|
class SingleVisLoss(nn.Module): |
|
def __init__(self, umap_loss, recon_loss, lambd): |
|
super(SingleVisLoss, self).__init__() |
|
self.umap_loss = umap_loss |
|
self.recon_loss = recon_loss |
|
self.lambd = lambd |
|
|
|
def forward(self, edge_to, edge_from, a_to, a_from, outputs): |
|
embedding_to, embedding_from = outputs["umap"] |
|
recon_to, recon_from = outputs["recon"] |
|
|
|
recon_l = self.recon_loss(edge_to, edge_from, recon_to, recon_from, a_to, a_from) |
|
|
|
umap_l = self.umap_loss(embedding_to, embedding_from) |
|
|
|
loss = umap_l + self.lambd * recon_l |
|
|
|
return umap_l, recon_l, loss |
|
|
|
class HybridLoss(nn.Module): |
|
def __init__(self, umap_loss, recon_loss, smooth_loss, lambd1, lambd2): |
|
super(HybridLoss, self).__init__() |
|
self.umap_loss = umap_loss |
|
self.recon_loss = recon_loss |
|
self.smooth_loss = smooth_loss |
|
self.lambd1 = lambd1 |
|
self.lambd2 = lambd2 |
|
|
|
def forward(self, edge_to, edge_from, a_to, a_from, embeded_to, coeff, outputs): |
|
embedding_to, embedding_from = outputs["umap"] |
|
recon_to, recon_from = outputs["recon"] |
|
|
|
recon_l = self.recon_loss(edge_to, edge_from, recon_to, recon_from, a_to, a_from) |
|
umap_l = self.umap_loss(embedding_to, embedding_from) |
|
smooth_l = self.smooth_loss(embedding_to, embeded_to, coeff) |
|
|
|
loss = umap_l + self.lambd1 * recon_l + self.lambd2 * smooth_l |
|
|
|
return umap_l, recon_l, smooth_l, loss |
|
|
|
|
|
class TemporalLoss(nn.Module): |
|
def __init__(self, prev_w, device) -> None: |
|
super(TemporalLoss, self).__init__() |
|
self.prev_w = prev_w |
|
self.device = device |
|
for param_name in self.prev_w.keys(): |
|
self.prev_w[param_name] = self.prev_w[param_name].to(device=self.device, dtype=torch.float32) |
|
|
|
def forward(self, curr_module): |
|
loss = torch.tensor(0., requires_grad=True).to(self.device) |
|
|
|
for name, curr_param in curr_module.named_parameters(): |
|
|
|
prev_param = self.prev_w[name] |
|
|
|
loss = loss + torch.sum(torch.square(curr_param-prev_param)) |
|
|
|
|
|
|
|
return loss |
|
|
|
|
|
class DummyTemporalLoss(nn.Module): |
|
def __init__(self, device) -> None: |
|
super(DummyTemporalLoss, self).__init__() |
|
self.device = device |
|
|
|
def forward(self, curr_module): |
|
loss = torch.tensor(0., requires_grad=True).to(self.device) |
|
return loss |
|
|
|
|
|
class PositionRecoverLoss(nn.Module): |
|
def __init__(self, device) -> None: |
|
super(PositionRecoverLoss, self).__init__() |
|
self.device = device |
|
def forward(self, position, recover_position): |
|
mse_loss = nn.MSELoss().to(self.device) |
|
loss = mse_loss(position, recover_position) |
|
return loss |
|
|
|
|
|
class DVILoss(nn.Module): |
|
def __init__(self, umap_loss, recon_loss, temporal_loss, lambd1, lambd2, device, umap_weight=1): |
|
super(DVILoss, self).__init__() |
|
self.umap_loss = umap_loss |
|
self.recon_loss = recon_loss |
|
self.temporal_loss = temporal_loss |
|
self.lambd1 = lambd1 |
|
self.lambd2 = lambd2 |
|
self.device = device |
|
self.umap_weight = umap_weight |
|
|
|
def forward(self, edge_to, edge_from, a_to, a_from, curr_model, outputs): |
|
|
|
embedding_to, embedding_from = outputs["umap"] |
|
recon_to, recon_from = outputs["recon"] |
|
|
|
|
|
recon_l = self.recon_loss(edge_to, edge_from, recon_to, recon_from, a_to, a_from).to(self.device) |
|
umap_l = self.umap_loss(embedding_to, embedding_from).to(self.device) |
|
temporal_l = self.temporal_loss(curr_model).to(self.device) |
|
|
|
loss = self.umap_weight * umap_l + self.lambd1 * recon_l + self.lambd2 * temporal_l |
|
|
|
return self.umap_weight * umap_l, self.lambd1 *recon_l, self.lambd2 *temporal_l, loss |
|
|
|
class MINE(nn.Module): |
|
def __init__(self): |
|
super(MINE, self).__init__() |
|
|
|
self.network = nn.Sequential( |
|
nn.Linear(2, 100), |
|
nn.ReLU(), |
|
nn.Linear(100, 1), |
|
) |
|
|
|
def forward(self, x, y): |
|
joint = torch.cat((x, y), dim=1) |
|
marginal = torch.cat((x, y[torch.randperm(x.size(0))]), dim=1) |
|
t_joint = self.network(joint) |
|
t_marginal = self.network(marginal) |
|
|
|
mi = torch.mean(t_joint) - torch.log(torch.mean(torch.exp(t_marginal))) |
|
return -mi |
|
|
|
|
|
class TVILoss(nn.Module): |
|
def __init__(self, umap_loss, recon_loss, temporal_loss, MI_loss, lambd1, lambd2, lambd3, device): |
|
super(TVILoss, self).__init__() |
|
self.umap_loss = umap_loss |
|
self.recon_loss = recon_loss |
|
self.temporal_loss = temporal_loss |
|
self.MI_loss = MI_loss |
|
self.lambd1 = lambd1 |
|
self.lambd2 = lambd2 |
|
self.lambd3 = lambd3 |
|
self.device = device |
|
|
|
def forward(self, edge_to, edge_from, a_to, a_from, curr_model, outputs): |
|
embedding_to, embedding_from = outputs["umap"] |
|
recon_to, recon_from = outputs["recon"] |
|
recon_l = self.recon_loss(edge_to, edge_from, recon_to, recon_from, a_to, a_from).to(self.device) |
|
umap_l = self.umap_loss(embedding_to, embedding_from).to(self.device) |
|
temporal_l = self.temporal_loss(curr_model).to(self.device) |
|
|
|
|
|
|
|
MI_l_embedding = self.MI_loss(embedding_to, embedding_from).to(self.device) |
|
MI_l_edge = self.MI_loss(edge_to, edge_from).to(self.device) |
|
|
|
MI_l = (MI_l_embedding + MI_l_edge) / 2 |
|
loss = umap_l + self.lambd1 * recon_l + self.lambd2 * temporal_l + self.lambd3 * MI_l |
|
|
|
return umap_l, self.lambd1 * recon_l, self.lambd2 * temporal_l, loss |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import tensorflow as tf |
|
def umap_loss( |
|
batch_size, |
|
negative_sample_rate, |
|
_a, |
|
_b, |
|
repulsion_strength=1.0, |
|
): |
|
""" |
|
Generate a keras-ccompatible loss function for UMAP loss |
|
|
|
Parameters |
|
---------- |
|
batch_size : int |
|
size of mini-batches |
|
negative_sample_rate : int |
|
number of negative samples per positive samples to train on |
|
_a : float |
|
distance parameter in embedding space |
|
_b : float float |
|
distance parameter in embedding space |
|
repulsion_strength : float, optional |
|
strength of repulsion vs attraction for cross-entropy, by default 1.0 |
|
|
|
Returns |
|
------- |
|
loss : function |
|
loss function that takes in a placeholder (0) and the output of the keras network |
|
""" |
|
|
|
@tf.function |
|
def loss(placeholder_y, embed_to_from): |
|
|
|
embedding_to, embedding_from, weights = tf.split( |
|
embed_to_from, num_or_size_splits=[2, 2, 1], axis=1 |
|
) |
|
|
|
|
|
|
|
embedding_neg_to = tf.repeat(embedding_to, negative_sample_rate, axis=0) |
|
repeat_neg = tf.repeat(embedding_from, negative_sample_rate, axis=0) |
|
embedding_neg_from = tf.gather( |
|
repeat_neg, tf.random.shuffle(tf.range(tf.shape(repeat_neg)[0])) |
|
) |
|
|
|
|
|
distance_embedding = tf.concat( |
|
( |
|
tf.norm(embedding_to - embedding_from, axis=1), |
|
tf.norm(embedding_neg_to - embedding_neg_from, axis=1), |
|
), |
|
axis=0, |
|
) |
|
|
|
|
|
probabilities_distance = 1.0 / (1.0 + _a * tf.math.pow(distance_embedding, 2 * _b)) |
|
|
|
|
|
probabilities_graph = tf.concat( |
|
(tf.ones(batch_size), tf.zeros(batch_size * negative_sample_rate)), axis=0, |
|
) |
|
probabilities = tf.concat( |
|
(tf.squeeze(weights), tf.zeros(batch_size * negative_sample_rate)), axis=0, |
|
) |
|
|
|
|
|
(attraction_loss, repellant_loss, ce_loss) = compute_cross_entropy_tf( |
|
probabilities_graph, |
|
probabilities_distance, |
|
repulsion_strength=repulsion_strength, |
|
) |
|
|
|
return tf.reduce_mean(ce_loss) |
|
|
|
return loss |
|
|
|
|
|
def regularize_loss(): |
|
''' |
|
Add temporal regularization L2 loss on weights |
|
''' |
|
|
|
@tf.function |
|
def loss(w_prev, w_current, to_alpha): |
|
assert len(w_prev) == len(w_current) |
|
|
|
for j in range(len(w_prev)): |
|
diff = tf.reduce_sum(tf.math.square(w_current[j] - w_prev[j])) |
|
diff = tf.math.multiply(to_alpha, diff) |
|
if j == 0: |
|
alldiff = tf.reduce_mean(diff) |
|
else: |
|
alldiff += tf.reduce_mean(diff) |
|
return alldiff |
|
|
|
return loss |
|
|
|
def reconstruction_loss( |
|
beta=1 |
|
): |
|
""" |
|
Generate a keras-ccompatible loss function for customize reconstruction loss |
|
|
|
Parameters |
|
---------- |
|
beta: hyperparameter |
|
Returns |
|
------- |
|
loss : function |
|
""" |
|
|
|
@tf.function |
|
def loss(edge_to, edge_from, recon_to, recon_from, alpha_to, alpha_from): |
|
loss1 = tf.reduce_mean(tf.reduce_mean(tf.math.multiply(tf.math.pow((1+alpha_to), beta), tf.math.pow(edge_to - recon_to, 2)), 1)) |
|
loss2 = tf.reduce_mean(tf.reduce_mean(tf.math.multiply(tf.math.pow((1+alpha_from), beta), tf.math.pow(edge_from - recon_from, 2)), 1)) |
|
return (loss1 + loss2)/2 |
|
|
|
return loss |