|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
class MaskedMSE(torch.nn.Module): |
|
|
|
def __init__(self, norm_pix_loss=False, masked=True): |
|
""" |
|
norm_pix_loss: normalize each patch by their pixel mean and variance |
|
masked: compute loss over the masked patches only |
|
""" |
|
super().__init__() |
|
self.norm_pix_loss = norm_pix_loss |
|
self.masked = masked |
|
|
|
def forward(self, pred, mask, target): |
|
|
|
if self.norm_pix_loss: |
|
mean = target.mean(dim=-1, keepdim=True) |
|
var = target.var(dim=-1, keepdim=True) |
|
target = (target - mean) / (var + 1.e-6)**.5 |
|
|
|
loss = (pred - target) ** 2 |
|
loss = loss.mean(dim=-1) |
|
if self.masked: |
|
loss = (loss * mask).sum() / mask.sum() |
|
else: |
|
loss = loss.mean() |
|
return loss |
|
|