Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
# coding: utf-8 | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
from datasets import load_dataset | |
import fastcore.all as fc | |
import matplotlib.pyplot as plt | |
import matplotlib as mpl | |
import torchvision.transforms.functional as TF | |
from torch.utils.data import default_collate, DataLoader | |
import torch.optim as optim | |
import pickle | |
get_ipython().run_line_magic('matplotlib', 'inline') | |
plt.rcParams['figure.figsize'] = [2, 2] | |
dataset_nm = 'mnist' | |
x,y = 'image', 'label' | |
ds = load_dataset(dataset_nm) | |
def transform_ds(b): | |
b[x] = [TF.to_tensor(ele) for ele in b[x]] | |
return b | |
dst = ds.with_transform(transform_ds) | |
plt.imshow(dst['train'][0]['image'].permute(1,2,0)); | |
bs = 1024 | |
class DataLoaders: | |
def __init__(self, train_ds, valid_ds, bs, collate_fn, **kwargs): | |
self.train = DataLoader(train_ds, batch_size=bs, shuffle=True, collate_fn=collate_fn, **kwargs) | |
self.valid = DataLoader(train_ds, batch_size=bs*2, shuffle=False, collate_fn=collate_fn, **kwargs) | |
def collate_fn(b): | |
collate = default_collate(b) | |
return (collate[x], collate[y]) | |
dls = DataLoaders(dst['train'], dst['test'], bs=bs, collate_fn=collate_fn) | |
xb,yb = next(iter(dls.train)) | |
xb.shape, yb.shape | |
class Reshape(nn.Module): | |
def __init__(self, dim): | |
super().__init__() | |
self.dim = dim | |
def forward(self, x): | |
return x.reshape(self.dim) | |
# model definition | |
def linear_classifier(): | |
return nn.Sequential( | |
Reshape((-1, 784)), | |
nn.Linear(784, 50), | |
nn.ReLU(), | |
nn.Linear(50, 50), | |
nn.ReLU(), | |
nn.Linear(50, 10) | |
) | |
model = linear_classifier() | |
lr = 0.1 | |
max_lr = 0.1 | |
epochs = 5 | |
opt = optim.AdamW(model.parameters(), lr=lr) | |
sched = optim.lr_scheduler.OneCycleLR(opt, max_lr, total_steps=len(dls.train), epochs=epochs) | |
for epoch in range(epochs): | |
for train in (True, False): | |
accuracy = 0 | |
dl = dls.train if train else dls.valid | |
for xb,yb in dl: | |
preds = model(xb) | |
loss = F.cross_entropy(preds, yb) | |
if train: | |
loss.backward() | |
opt.step() | |
opt.zero_grad() | |
with torch.no_grad(): | |
accuracy += (preds.argmax(1).detach().cpu() == yb).float().mean() | |
if train: | |
sched.step() | |
accuracy /= len(dl) | |
print(f"{'train' if train else 'eval'}, epoch:{epoch+1}, loss: {loss.item():.4f}, accuracy: {accuracy:.4f}") | |
def cnn_classifier(): | |
ks,stride = 3,2 | |
return nn.Sequential( | |
nn.Conv2d(1, 8, kernel_size=ks, stride=stride, padding=ks//2), | |
nn.BatchNorm2d(8), | |
nn.ReLU(), | |
nn.Conv2d(8, 16, kernel_size=ks, stride=stride, padding=ks//2), | |
nn.BatchNorm2d(16), | |
nn.ReLU(), | |
nn.Conv2d(16, 32, kernel_size=ks, stride=stride, padding=ks//2), | |
nn.BatchNorm2d(32), | |
nn.ReLU(), | |
nn.Conv2d(32, 64, kernel_size=ks, stride=stride, padding=ks//2), | |
nn.BatchNorm2d(64), | |
nn.ReLU(), | |
nn.Conv2d(64, 64, kernel_size=ks, stride=stride, padding=ks//2), | |
nn.BatchNorm2d(64), | |
nn.ReLU(), | |
nn.Conv2d(64, 10, kernel_size=ks, stride=stride, padding=ks//2), | |
nn.Flatten(), | |
) | |
def kaiming_init(m): | |
if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): | |
nn.init.kaiming_normal_(m.weight) | |
model = cnn_classifier() | |
model.apply(kaiming_init) | |
lr = 0.1 | |
max_lr = 0.3 | |
epochs = 5 | |
opt = optim.AdamW(model.parameters(), lr=lr) | |
sched = optim.lr_scheduler.OneCycleLR(opt, max_lr, total_steps=len(dls.train), epochs=epochs) | |
for epoch in range(epochs): | |
for train in (True, False): | |
accuracy = 0 | |
dl = dls.train if train else dls.valid | |
for xb,yb in dl: | |
preds = model(xb) | |
loss = F.cross_entropy(preds, yb) | |
if train: | |
loss.backward() | |
opt.step() | |
opt.zero_grad() | |
with torch.no_grad(): | |
accuracy += (preds.argmax(1).detach().cpu() == yb).float().mean() | |
if train: | |
sched.step() | |
accuracy /= len(dl) | |
print(f"{'train' if train else 'eval'}, epoch:{epoch+1}, loss: {loss.item():.4f}, accuracy: {accuracy:.4f}") | |