Abdullah-Nazhat
commited on
Update train.py
Browse files
train.py
CHANGED
@@ -6,7 +6,7 @@ from torch import nn
|
|
6 |
from torch.utils.data import DataLoader
|
7 |
from torchvision import datasets
|
8 |
from torchvision.transforms import ToTensor, Normalize, RandomCrop, RandomHorizontalFlip, RandomRotation, Compose
|
9 |
-
from
|
10 |
|
11 |
transform = Compose([
|
12 |
RandomCrop(32, padding=4),
|
@@ -152,7 +152,7 @@ def test(dataloader, model, loss_fn):
|
|
152 |
|
153 |
# apply train and test
|
154 |
|
155 |
-
logname = "/
|
156 |
if not os.path.exists(logname):
|
157 |
with open(logname, 'w') as logfile:
|
158 |
logwriter = csv.writer(logfile, delimiter=',')
|
@@ -172,7 +172,7 @@ for epoch in range(epochs):
|
|
172 |
print("Done!")
|
173 |
|
174 |
# saving trained model
|
175 |
-
path = "/
|
176 |
model_name = "NiNformerImageClassification_cifar10"
|
177 |
torch.save(model.state_dict(), f"{path}/{model_name}.pth")
|
178 |
print(f"Saved Model State to {path}/{model_name}.pth ")
|
|
|
6 |
from torch.utils.data import DataLoader
|
7 |
from torchvision import datasets
|
8 |
from torchvision.transforms import ToTensor, Normalize, RandomCrop, RandomHorizontalFlip, RandomRotation, Compose
|
9 |
+
from ninformer import NiNformer
|
10 |
|
11 |
transform = Compose([
|
12 |
RandomCrop(32, padding=4),
|
|
|
152 |
|
153 |
# apply train and test
|
154 |
|
155 |
+
logname = "/PATH/NiNformer/Experiments_cifar10/logs_ninformer/logs_cifar10.csv"
|
156 |
if not os.path.exists(logname):
|
157 |
with open(logname, 'w') as logfile:
|
158 |
logwriter = csv.writer(logfile, delimiter=',')
|
|
|
172 |
print("Done!")
|
173 |
|
174 |
# saving trained model
|
175 |
+
path = "/PATH/NiNformer/Experiments_cifar10/weights_ninformer"
|
176 |
model_name = "NiNformerImageClassification_cifar10"
|
177 |
torch.save(model.state_dict(), f"{path}/{model_name}.pth")
|
178 |
print(f"Saved Model State to {path}/{model_name}.pth ")
|