Denoising_CIFAR100 / encoder.py
abbylagar's picture
Upload encoder.py
6a98841
"""
encoder
"""
#import functions
import numpy as np
from torch import nn
import torch
import torchvision
from einops import rearrange, reduce
from argparse import ArgumentParser
from pytorch_lightning import LightningModule, Trainer, Callback
from pytorch_lightning.loggers import WandbLogger
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
class Encoder(nn.Module):
def __init__(self, n_features=3, kernel_size=3, n_filters=16, feature_dim=1024):
super().__init__()
self.conv1 = nn.Conv2d(n_features, n_filters, kernel_size=kernel_size, stride=2)
self.conv2 = nn.Conv2d(n_filters, n_filters*2, kernel_size=kernel_size, stride=2)
self.conv3 = nn.Conv2d(n_filters*2, n_filters*4, kernel_size=kernel_size, stride=2)
self.fc1 = nn.Linear(576, feature_dim)
def forward(self, x):
y = nn.ReLU()(self.conv1(x))
y = nn.ReLU()(self.conv2(y))
y = nn.ReLU()(self.conv3(y))
y = rearrange(y, 'b c h w -> b (c h w)')
y = self.fc1(y)
return y