Abdullah-Nazhat commited on
Commit
2897170
·
verified ·
1 Parent(s): 6a18f40

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +3 -5
train.py CHANGED
@@ -166,7 +166,7 @@ def test(dataloader, model, loss_fn):
166
 
167
  # apply train and test
168
 
169
- logname = "/home/abdullah/Desktop/Proposal_experiments/Linearizer/Experiments_cifar10/logs_linearizer/logs_cifar10.csv"
170
  if not os.path.exists(logname):
171
  with open(logname, 'w') as logfile:
172
  logwriter = csv.writer(logfile, delimiter=',')
@@ -178,9 +178,7 @@ epochs = 100
178
  for epoch in range(epochs):
179
  print(f"Epoch {epoch+1}\n-----------------------------------")
180
  train_loss, train_acc = train(train_dataloader, model, loss_fn, optimizer)
181
- # learning rate scheduler
182
- #if scheduler is not None:
183
- # scheduler.step()
184
  test_loss, test_acc = test(test_dataloader, model, loss_fn)
185
  with open(logname, 'a') as logfile:
186
  logwriter = csv.writer(logfile, delimiter=',')
@@ -190,7 +188,7 @@ print("Done!")
190
 
191
  # saving trained model
192
 
193
- path = "/home/abdullah/Desktop/Proposal_experiments/Linearizer/Experiments_cifar10/weights_linearizer"
194
  model_name = "linearizerImageClassification_cifar10"
195
  torch.save(model.state_dict(), f"{path}/{model_name}.pth")
196
  print(f"Saved Model State to {path}/{model_name}.pth ")
 
166
 
167
  # apply train and test
168
 
169
+ logname = "/PATH/Experiments_cifar10/logs_linearizer/logs_cifar10.csv"
170
  if not os.path.exists(logname):
171
  with open(logname, 'w') as logfile:
172
  logwriter = csv.writer(logfile, delimiter=',')
 
178
  for epoch in range(epochs):
179
  print(f"Epoch {epoch+1}\n-----------------------------------")
180
  train_loss, train_acc = train(train_dataloader, model, loss_fn, optimizer)
181
+
 
 
182
  test_loss, test_acc = test(test_dataloader, model, loss_fn)
183
  with open(logname, 'a') as logfile:
184
  logwriter = csv.writer(logfile, delimiter=',')
 
188
 
189
  # saving trained model
190
 
191
+ path = "/PATH/Linearizer/Experiments_cifar10/weights_linearizer"
192
  model_name = "linearizerImageClassification_cifar10"
193
  torch.save(model.state_dict(), f"{path}/{model_name}.pth")
194
  print(f"Saved Model State to {path}/{model_name}.pth ")