""" encoder """ #import functions import numpy as np from torch import nn import torch import torchvision from einops import rearrange, reduce from argparse import ArgumentParser from pytorch_lightning import LightningModule, Trainer, Callback from pytorch_lightning.loggers import WandbLogger from torch.optim import Adam from torch.optim.lr_scheduler import CosineAnnealingLR class Encoder(nn.Module): def __init__(self, n_features=3, kernel_size=3, n_filters=16, feature_dim=1024): super().__init__() self.conv1 = nn.Conv2d(n_features, n_filters, kernel_size=kernel_size, stride=2) self.conv2 = nn.Conv2d(n_filters, n_filters*2, kernel_size=kernel_size, stride=2) self.conv3 = nn.Conv2d(n_filters*2, n_filters*4, kernel_size=kernel_size, stride=2) self.fc1 = nn.Linear(576, feature_dim) def forward(self, x): y = nn.ReLU()(self.conv1(x)) y = nn.ReLU()(self.conv2(y)) y = nn.ReLU()(self.conv3(y)) y = rearrange(y, 'b c h w -> b (c h w)') y = self.fc1(y) return y