Realcat
add: mast3r
fb98d2a
raw
history blame
1.32 kB
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# Criterion to train CroCo
# --------------------------------------------------------
# References:
# MAE: https://github.com/facebookresearch/mae
# --------------------------------------------------------
import torch
class MaskedMSE(torch.nn.Module):
def __init__(self, norm_pix_loss=False, masked=True):
"""
norm_pix_loss: normalize each patch by their pixel mean and variance
masked: compute loss over the masked patches only
"""
super().__init__()
self.norm_pix_loss = norm_pix_loss
self.masked = masked
def forward(self, pred, mask, target):
if self.norm_pix_loss:
mean = target.mean(dim=-1, keepdim=True)
var = target.var(dim=-1, keepdim=True)
target = (target - mean) / (var + 1.e-6)**.5
loss = (pred - target) ** 2
loss = loss.mean(dim=-1) # [N, L], mean loss per patch
if self.masked:
loss = (loss * mask).sum() / mask.sum() # mean loss on masked patches
else:
loss = loss.mean() # mean loss
return loss