|
import copy |
|
import os |
|
|
|
import numpy as np |
|
import pytorch_lightning as pl |
|
import torch |
|
import torch.nn as nn |
|
from einops import rearrange |
|
from torchmetrics.functional import accuracy |
|
from torchmetrics.functional.classification import multiclass_recall, multiclass_precision |
|
from x_transformers import Encoder, Decoder |
|
|
|
ON_EPOCH = True |
|
ON_STEP = False |
|
BATCH_SIZE = 64 |
|
TARGET_SIZE = (64, 64) |
|
SPLIT_RATE = 0.8 |
|
ROOT_DIR_DATA = "/kaggle/input/ant-data-new/data" |
|
|
|
|
|
class PatchEmbed(nn.Module): |
|
"""Image to Patch Embedding""" |
|
|
|
def __init__(self, img_size=TARGET_SIZE[0], patch_size=4, in_chans=3, embed_dim=64): |
|
super().__init__() |
|
if isinstance(img_size, int): |
|
img_size = img_size, img_size |
|
if isinstance(patch_size, int): |
|
patch_size = patch_size, patch_size |
|
|
|
|
|
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) |
|
|
|
|
|
self.conv = nn.Conv2d( |
|
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size |
|
) |
|
|
|
def forward(self, x): |
|
x = self.conv(x) |
|
|
|
x = rearrange(x, 'b e h w -> b (h w) e') |
|
return x |
|
|
|
|
|
class ViTIJEPA(nn.Module): |
|
def __init__(self, img_size, patch_size, in_chans, embed_dim, enc_depth, num_heads, |
|
num_classes, post_emb_norm=False, |
|
layer_dropout=0.): |
|
super().__init__() |
|
self.layer_dropout = layer_dropout |
|
self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) |
|
self.num_tokens = self.patch_embed.patch_shape[0] * self.patch_embed.patch_shape[1] |
|
self.pos_embedding = nn.Parameter(torch.randn(1, self.num_tokens, embed_dim)) |
|
self.post_emb_norm = nn.LayerNorm(embed_dim) if post_emb_norm else nn.Identity() |
|
self.student_encoder = Encoder( |
|
dim=embed_dim, |
|
heads=num_heads, |
|
depth=enc_depth, |
|
layer_dropout=self.layer_dropout, |
|
flash=True |
|
) |
|
|
|
self.average_pool = nn.AvgPool1d((embed_dim), stride=1) |
|
|
|
self.mlp_head = nn.Sequential( |
|
nn.LayerNorm(self.num_tokens), |
|
nn.Linear(self.num_tokens, num_classes), |
|
) |
|
|
|
def forward(self, x): |
|
x = self.patch_embed(x) |
|
b, n, e = x.shape |
|
|
|
x = x + self.pos_embedding |
|
|
|
x = self.post_emb_norm(x) |
|
|
|
x = self.student_encoder(x) |
|
|
|
x = self.average_pool(x) |
|
x = x.squeeze(-1) |
|
x = self.mlp_head(x) |
|
return x |
|
|
|
|