Abdullah-Nazhat commited on
Commit
8f72cfc
·
verified ·
1 Parent(s): 43f5dd3

Update train_only_GEGLU.py

Browse files
Files changed (1) hide show
  1. train_only_GEGLU.py +3 -5
train_only_GEGLU.py CHANGED
@@ -162,7 +162,7 @@ def test(dataloader, model, loss_fn):
162
 
163
  # apply train and test
164
 
165
- logname = "/home/abdullah/Desktop/Proposals_experiments/Activator/Experiments_cifar10/logs_activator/logs_cifar10_only_geglu.csv"
166
  if not os.path.exists(logname):
167
  with open(logname, 'w') as logfile:
168
  logwriter = csv.writer(logfile, delimiter=',')
@@ -174,9 +174,7 @@ epochs = 100
174
  for epoch in range(epochs):
175
  print(f"Epoch {epoch+1}\n-----------------------------------")
176
  train_loss, train_acc = train(train_dataloader, model, loss_fn, optimizer)
177
- # learning rate scheduler
178
- #if scheduler is not None:
179
- # scheduler.step()
180
  test_loss, test_acc = test(test_dataloader, model, loss_fn)
181
  with open(logname, 'a') as logfile:
182
  logwriter = csv.writer(logfile, delimiter=',')
@@ -186,7 +184,7 @@ print("Done!")
186
 
187
  # saving trained model
188
 
189
- path = "/home/abdullah/Desktop/Proposals_experiments/Activator/Experiments_cifar10/weights_activator"
190
  model_name = "ACTIVATOR_only_GEGLUImageClassification_cifar10"
191
  torch.save(model.state_dict(), f"{path}/{model_name}.pth")
192
  print(f"Saved Model State to {path}/{model_name}.pth ")
 
162
 
163
  # apply train and test
164
 
165
+ logname = "/PATH/Activator/Experiments_cifar10/logs_activator/logs_cifar10_only_geglu.csv"
166
  if not os.path.exists(logname):
167
  with open(logname, 'w') as logfile:
168
  logwriter = csv.writer(logfile, delimiter=',')
 
174
  for epoch in range(epochs):
175
  print(f"Epoch {epoch+1}\n-----------------------------------")
176
  train_loss, train_acc = train(train_dataloader, model, loss_fn, optimizer)
177
+
 
 
178
  test_loss, test_acc = test(test_dataloader, model, loss_fn)
179
  with open(logname, 'a') as logfile:
180
  logwriter = csv.writer(logfile, delimiter=',')
 
184
 
185
  # saving trained model
186
 
187
+ path = "/PATH/Activator/Experiments_cifar10/weights_activator"
188
  model_name = "ACTIVATOR_only_GEGLUImageClassification_cifar10"
189
  torch.save(model.state_dict(), f"{path}/{model_name}.pth")
190
  print(f"Saved Model State to {path}/{model_name}.pth ")