Abdullah-Nazhat commited on
Commit
732c71e
·
verified ·
1 Parent(s): 77ed3dd

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +3 -3
train.py CHANGED
@@ -6,7 +6,7 @@ from torch import nn
6
  from torch.utils.data import DataLoader
7
  from torchvision import datasets
8
  from torchvision.transforms import ToTensor, Normalize, RandomCrop, RandomHorizontalFlip, RandomRotation, Compose
9
- from core import NiNformer
10
 
11
  transform = Compose([
12
  RandomCrop(32, padding=4),
@@ -152,7 +152,7 @@ def test(dataloader, model, loss_fn):
152
 
153
  # apply train and test
154
 
155
- logname = "/home/abdullah/Proposals_experiments/NiNformer/Experiments_cifar10/logs_ninformer/logs_cifar10.csv"
156
  if not os.path.exists(logname):
157
  with open(logname, 'w') as logfile:
158
  logwriter = csv.writer(logfile, delimiter=',')
@@ -172,7 +172,7 @@ for epoch in range(epochs):
172
  print("Done!")
173
 
174
  # saving trained model
175
- path = "/home/abdullah/Desktop/Proposals_experiments/NiNformer/Experiments_cifar10/weights_ninformer"
176
  model_name = "NiNformerImageClassification_cifar10"
177
  torch.save(model.state_dict(), f"{path}/{model_name}.pth")
178
  print(f"Saved Model State to {path}/{model_name}.pth ")
 
6
  from torch.utils.data import DataLoader
7
  from torchvision import datasets
8
  from torchvision.transforms import ToTensor, Normalize, RandomCrop, RandomHorizontalFlip, RandomRotation, Compose
9
+ from ninformer import NiNformer
10
 
11
  transform = Compose([
12
  RandomCrop(32, padding=4),
 
152
 
153
  # apply train and test
154
 
155
+ logname = "/PATH/NiNformer/Experiments_cifar10/logs_ninformer/logs_cifar10.csv"
156
  if not os.path.exists(logname):
157
  with open(logname, 'w') as logfile:
158
  logwriter = csv.writer(logfile, delimiter=',')
 
172
  print("Done!")
173
 
174
  # saving trained model
175
+ path = "/PATH/NiNformer/Experiments_cifar10/weights_ninformer"
176
  model_name = "NiNformerImageClassification_cifar10"
177
  torch.save(model.state_dict(), f"{path}/{model_name}.pth")
178
  print(f"Saved Model State to {path}/{model_name}.pth ")