File size: 6,199 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from ding.torch_utils.network import one_hot
class MultiLogitsLoss(nn.Module):
"""
Overview:
Base class for supervised learning on linklink, including basic processes.
Interfaces:
``__init__``, ``forward``.
"""
def __init__(self, criterion: str = None, smooth_ratio: float = 0.1) -> None:
"""
Overview:
Initialization method, use cross_entropy as default criterion.
Arguments:
- criterion (:obj:`str`): Criterion type, supports ['cross_entropy', 'label_smooth_ce'].
- smooth_ratio (:obs:`float`): Smoothing ratio for label smoothing.
"""
super(MultiLogitsLoss, self).__init__()
if criterion is None:
criterion = 'cross_entropy'
assert (criterion in ['cross_entropy', 'label_smooth_ce'])
self.criterion = criterion
if self.criterion == 'label_smooth_ce':
self.ratio = smooth_ratio
def _label_process(self, logits: torch.Tensor, labels: torch.LongTensor) -> torch.LongTensor:
"""
Overview:
Process the label according to the criterion.
Arguments:
- logits (:obj:`torch.Tensor`): Predicted logits.
- labels (:obj:`torch.LongTensor`): Ground truth.
Returns:
- ret (:obj:`torch.LongTensor`): Processed label.
"""
N = logits.shape[1]
if self.criterion == 'cross_entropy':
return one_hot(labels, num=N)
elif self.criterion == 'label_smooth_ce':
val = float(self.ratio) / (N - 1)
ret = torch.full_like(logits, val)
ret.scatter_(1, labels.unsqueeze(1), 1 - val)
return ret
def _nll_loss(self, nlls: torch.Tensor, labels: torch.LongTensor) -> torch.Tensor:
"""
Overview:
Calculate the negative log likelihood loss.
Arguments:
- nlls (:obj:`torch.Tensor`): Negative log likelihood loss.
- labels (:obj:`torch.LongTensor`): Ground truth.
Returns:
- ret (:obj:`torch.Tensor`): Calculated loss.
"""
ret = (-nlls * (labels.detach()))
return ret.sum(dim=1)
def _get_metric_matrix(self, logits: torch.Tensor, labels: torch.LongTensor) -> torch.Tensor:
"""
Overview:
Calculate the metric matrix.
Arguments:
- logits (:obj:`torch.Tensor`): Predicted logits.
- labels (:obj:`torch.LongTensor`): Ground truth.
Returns:
- metric (:obj:`torch.Tensor`): Calculated metric matrix.
"""
M, N = logits.shape
labels = self._label_process(logits, labels)
logits = F.log_softmax(logits, dim=1)
metric = []
for i in range(M):
logit = logits[i]
logit = logit.repeat(M).reshape(M, N)
metric.append(self._nll_loss(logit, labels))
return torch.stack(metric, dim=0)
def _match(self, matrix: torch.Tensor):
"""
Overview:
Match the metric matrix.
Arguments:
- matrix (:obj:`torch.Tensor`): Metric matrix.
Returns:
- index (:obj:`np.ndarray`): Matched index.
"""
mat = matrix.clone().detach().to('cpu').numpy()
mat = -mat # maximize
M = mat.shape[0]
index = np.full(M, -1, dtype=np.int32) # -1 note not find link
lx = mat.max(axis=1)
ly = np.zeros(M, dtype=np.float32)
visx = np.zeros(M, dtype=np.bool_)
visy = np.zeros(M, dtype=np.bool_)
def has_augmented_path(t, binary_distance_matrix):
# What is changed? visx, visy, distance_matrix, index
# What is changed within this function? visx, visy, index
visx[t] = True
for i in range(M):
if not visy[i] and binary_distance_matrix[t, i]:
visy[i] = True
if index[i] == -1 or has_augmented_path(index[i], binary_distance_matrix):
index[i] = t
return True
return False
for i in range(M):
while True:
visx.fill(False)
visy.fill(False)
distance_matrix = self._get_distance_matrix(lx, ly, mat, M)
binary_distance_matrix = np.abs(distance_matrix) < 1e-4
if has_augmented_path(i, binary_distance_matrix):
break
masked_distance_matrix = distance_matrix[:, ~visy][visx]
if 0 in masked_distance_matrix.shape: # empty matrix
raise RuntimeError("match error, matrix: {}".format(matrix))
else:
d = masked_distance_matrix.min()
lx[visx] -= d
ly[visy] += d
return index
@staticmethod
def _get_distance_matrix(lx: np.ndarray, ly: np.ndarray, mat: np.ndarray, M: int) -> np.ndarray:
"""
Overview:
Get distance matrix.
Arguments:
- lx (:obj:`np.ndarray`): lx.
- ly (:obj:`np.ndarray`): ly.
- mat (:obj:`np.ndarray`): mat.
- M (:obj:`int`): M.
"""
nlx = np.broadcast_to(lx, [M, M]).T
nly = np.broadcast_to(ly, [M, M])
nret = nlx + nly - mat
return nret
def forward(self, logits: torch.Tensor, labels: torch.LongTensor) -> torch.Tensor:
"""
Overview:
Calculate multiple logits loss.
Arguments:
- logits (:obj:`torch.Tensor`): Predicted logits, whose shape must be 2-dim, like (B, N).
- labels (:obj:`torch.LongTensor`): Ground truth.
Returns:
- loss (:obj:`torch.Tensor`): Calculated loss.
"""
assert (len(logits.shape) == 2)
metric_matrix = self._get_metric_matrix(logits, labels)
index = self._match(metric_matrix)
loss = []
for i in range(metric_matrix.shape[0]):
loss.append(metric_matrix[index[i], i])
return sum(loss) / len(loss)
|