Spaces:
Runtime error
Runtime error
""" | |
encoder | |
""" | |
#import functions | |
import numpy as np | |
from torch import nn | |
import torch | |
import torchvision | |
from einops import rearrange, reduce | |
from argparse import ArgumentParser | |
from pytorch_lightning import LightningModule, Trainer, Callback | |
from pytorch_lightning.loggers import WandbLogger | |
from torch.optim import Adam | |
from torch.optim.lr_scheduler import CosineAnnealingLR | |
class Encoder(nn.Module): | |
def __init__(self, n_features=3, kernel_size=3, n_filters=16, feature_dim=1024): | |
super().__init__() | |
self.conv1 = nn.Conv2d(n_features, n_filters, kernel_size=kernel_size, stride=2) | |
self.conv2 = nn.Conv2d(n_filters, n_filters*2, kernel_size=kernel_size, stride=2) | |
self.conv3 = nn.Conv2d(n_filters*2, n_filters*4, kernel_size=kernel_size, stride=2) | |
self.fc1 = nn.Linear(576, feature_dim) | |
def forward(self, x): | |
y = nn.ReLU()(self.conv1(x)) | |
y = nn.ReLU()(self.conv2(y)) | |
y = nn.ReLU()(self.conv3(y)) | |
y = rearrange(y, 'b c h w -> b (c h w)') | |
y = self.fc1(y) | |
return y | |