mnist / mnist_classifier.py
carlfeynman's picture
name change to classifier
d9abc97
raw
history blame
4.26 kB
#!/usr/bin/env python
# coding: utf-8
import torch
from torch import nn
import torch.nn.functional as F
from datasets import load_dataset
import fastcore.all as fc
import matplotlib.pyplot as plt
import matplotlib as mpl
import torchvision.transforms.functional as TF
from torch.utils.data import default_collate, DataLoader
import torch.optim as optim
import pickle
get_ipython().run_line_magic('matplotlib', 'inline')
plt.rcParams['figure.figsize'] = [2, 2]
dataset_nm = 'mnist'
x,y = 'image', 'label'
ds = load_dataset(dataset_nm)
def transform_ds(b):
b[x] = [TF.to_tensor(ele) for ele in b[x]]
return b
dst = ds.with_transform(transform_ds)
plt.imshow(dst['train'][0]['image'].permute(1,2,0));
bs = 1024
class DataLoaders:
def __init__(self, train_ds, valid_ds, bs, collate_fn, **kwargs):
self.train = DataLoader(train_ds, batch_size=bs, shuffle=True, collate_fn=collate_fn, **kwargs)
self.valid = DataLoader(train_ds, batch_size=bs*2, shuffle=False, collate_fn=collate_fn, **kwargs)
def collate_fn(b):
collate = default_collate(b)
return (collate[x], collate[y])
dls = DataLoaders(dst['train'], dst['test'], bs=bs, collate_fn=collate_fn)
xb,yb = next(iter(dls.train))
xb.shape, yb.shape
class Reshape(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
return x.reshape(self.dim)
# model definition
def linear_classifier():
return nn.Sequential(
Reshape((-1, 784)),
nn.Linear(784, 50),
nn.ReLU(),
nn.Linear(50, 50),
nn.ReLU(),
nn.Linear(50, 10)
)
model = linear_classifier()
lr = 0.1
max_lr = 0.1
epochs = 5
opt = optim.AdamW(model.parameters(), lr=lr)
sched = optim.lr_scheduler.OneCycleLR(opt, max_lr, total_steps=len(dls.train), epochs=epochs)
for epoch in range(epochs):
for train in (True, False):
accuracy = 0
dl = dls.train if train else dls.valid
for xb,yb in dl:
preds = model(xb)
loss = F.cross_entropy(preds, yb)
if train:
loss.backward()
opt.step()
opt.zero_grad()
with torch.no_grad():
accuracy += (preds.argmax(1).detach().cpu() == yb).float().mean()
if train:
sched.step()
accuracy /= len(dl)
print(f"{'train' if train else 'eval'}, epoch:{epoch+1}, loss: {loss.item():.4f}, accuracy: {accuracy:.4f}")
def cnn_classifier():
ks,stride = 3,2
return nn.Sequential(
nn.Conv2d(1, 8, kernel_size=ks, stride=stride, padding=ks//2),
nn.BatchNorm2d(8),
nn.ReLU(),
nn.Conv2d(8, 16, kernel_size=ks, stride=stride, padding=ks//2),
nn.BatchNorm2d(16),
nn.ReLU(),
nn.Conv2d(16, 32, kernel_size=ks, stride=stride, padding=ks//2),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=ks, stride=stride, padding=ks//2),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=ks, stride=stride, padding=ks//2),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 10, kernel_size=ks, stride=stride, padding=ks//2),
nn.Flatten(),
)
def kaiming_init(m):
if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
nn.init.kaiming_normal_(m.weight)
model = cnn_classifier()
model.apply(kaiming_init)
lr = 0.1
max_lr = 0.3
epochs = 5
opt = optim.AdamW(model.parameters(), lr=lr)
sched = optim.lr_scheduler.OneCycleLR(opt, max_lr, total_steps=len(dls.train), epochs=epochs)
for epoch in range(epochs):
for train in (True, False):
accuracy = 0
dl = dls.train if train else dls.valid
for xb,yb in dl:
preds = model(xb)
loss = F.cross_entropy(preds, yb)
if train:
loss.backward()
opt.step()
opt.zero_grad()
with torch.no_grad():
accuracy += (preds.argmax(1).detach().cpu() == yb).float().mean()
if train:
sched.step()
accuracy /= len(dl)
print(f"{'train' if train else 'eval'}, epoch:{epoch+1}, loss: {loss.item():.4f}, accuracy: {accuracy:.4f}")