Abdullah-Nazhat
commited on
Update train.py
Browse files
train.py
CHANGED
@@ -162,7 +162,7 @@ def test(dataloader, model, loss_fn):
|
|
162 |
|
163 |
# apply train and test
|
164 |
|
165 |
-
logname = "/
|
166 |
if not os.path.exists(logname):
|
167 |
with open(logname, 'w') as logfile:
|
168 |
logwriter = csv.writer(logfile, delimiter=',')
|
@@ -186,7 +186,7 @@ print("Done!")
|
|
186 |
|
187 |
# saving trained model
|
188 |
|
189 |
-
path = "/
|
190 |
model_name = "AverageformerImageClassification_cifar10"
|
191 |
torch.save(model.state_dict(), f"{path}/{model_name}.pth")
|
192 |
print(f"Saved Model State to {path}/{model_name}.pth ")
|
|
|
162 |
|
163 |
# apply train and test
|
164 |
|
165 |
+
logname = "/PATH/Averageformer/Experiments_cifar10/logs_averageformer/logs_cifar10.csv"
|
166 |
if not os.path.exists(logname):
|
167 |
with open(logname, 'w') as logfile:
|
168 |
logwriter = csv.writer(logfile, delimiter=',')
|
|
|
186 |
|
187 |
# saving trained model
|
188 |
|
189 |
+
path = "/PATH/Averageformer/Experiments_cifar10/weights_averageformer"
|
190 |
model_name = "AverageformerImageClassification_cifar10"
|
191 |
torch.save(model.state_dict(), f"{path}/{model_name}.pth")
|
192 |
print(f"Saved Model State to {path}/{model_name}.pth ")
|