szukevin's picture
upload
7900c16
raw
history blame
1.19 kB
import torch
import torch.nn as nn
class PatchEmbedding(nn.Module):
"""
Image to Patch Embedding for Vision Transformer.
"""
def __init__(self, args, _):
super(PatchEmbedding, self).__init__()
self.cls_emb = nn.Parameter(torch.zeros(1, 1, args.emb_size))
self.image_height = args.image_height
self.image_width = args.image_width
patch_size = (args.patch_size, args.patch_size)
channels_num = args.channels_num
self.projection = nn.Conv2d(channels_num, args.emb_size, kernel_size=patch_size, stride=patch_size, bias=False)
def forward(self, src, _):
# batch_size, channels_num, height, width
batch_size, _, height, width = src.shape
if height != self.image_height or width != self.image_width:
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model ({self.image_height}*{self.image_width})."
)
patch_emb = self.projection(src).flatten(2).transpose(1, 2)
cls_emb = self.cls_emb.expand(batch_size, -1, -1).to(patch_emb.device)
patch_emb = torch.cat((cls_emb, patch_emb), dim=1)
return patch_emb