File size: 741 Bytes
f53b39e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
# --------------------------------------------------------
# Masking utils
# --------------------------------------------------------
import torch
import torch.nn as nn
class RandomMask(nn.Module):
"""
random masking
"""
def __init__(self, num_patches, mask_ratio):
super().__init__()
self.num_patches = num_patches
self.num_mask = int(mask_ratio * self.num_patches)
def __call__(self, x):
noise = torch.rand(x.size(0), self.num_patches, device=x.device)
argsort = torch.argsort(noise, dim=1)
return argsort < self.num_mask
|