Spaces:
Restarting
Restarting
File size: 1,324 Bytes
4d4dd90 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# Criterion to train CroCo
# --------------------------------------------------------
# References:
# MAE: https://github.com/facebookresearch/mae
# --------------------------------------------------------
import torch
class MaskedMSE(torch.nn.Module):
def __init__(self, norm_pix_loss=False, masked=True):
"""
norm_pix_loss: normalize each patch by their pixel mean and variance
masked: compute loss over the masked patches only
"""
super().__init__()
self.norm_pix_loss = norm_pix_loss
self.masked = masked
def forward(self, pred, mask, target):
if self.norm_pix_loss:
mean = target.mean(dim=-1, keepdim=True)
var = target.var(dim=-1, keepdim=True)
target = (target - mean) / (var + 1.e-6)**.5
loss = (pred - target) ** 2
loss = loss.mean(dim=-1) # [N, L], mean loss per patch
if self.masked:
loss = (loss * mask).sum() / mask.sum() # mean loss on masked patches
else:
loss = loss.mean() # mean loss
return loss
|