Spaces:
Build error
Build error
File size: 593 Bytes
0ab9a32 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
import torch
import torch.nn as nn
class ViTPatchMaskGenerator(nn.Module):
def __init__(self, patch_size) -> None:
super(ViTPatchMaskGenerator, self).__init__()
self.patch_size = patch_size
self.pool = nn.MaxPool2d(kernel_size=patch_size, stride=patch_size)
def forward(self, pixel_masks):
patch_mask = self.pool(pixel_masks)
patch_mask = patch_mask.bool().flatten(1)
cls_token_mask = patch_mask.new_ones([patch_mask.shape[0], 1]).bool()
patch_mask = torch.cat([cls_token_mask, patch_mask], dim=-1)
return patch_mask
|