vshirasuna
commited on
Commit
•
f6401dc
1
Parent(s):
8abbc76
Restore best_vloss in finetune
Browse files- smi-ted/finetune/trainers.py +10 -8
smi-ted/finetune/trainers.py
CHANGED
@@ -47,6 +47,8 @@ class Trainer:
|
|
47 |
self.save_every_epoch = save_every_epoch
|
48 |
self.save_ckpt = save_ckpt
|
49 |
self.device = device
|
|
|
|
|
50 |
self._set_seed(seed)
|
51 |
|
52 |
def _prepare_data(self):
|
@@ -89,8 +91,6 @@ class Trainer:
|
|
89 |
print('Checkpoint restored!')
|
90 |
|
91 |
def fit(self, max_epochs=500):
|
92 |
-
best_vloss = float('inf')
|
93 |
-
|
94 |
for epoch in range(self.start_epoch, max_epochs+1):
|
95 |
print(f'\n=====Epoch [{epoch}/{max_epochs}]=====')
|
96 |
|
@@ -106,22 +106,22 @@ class Trainer:
|
|
106 |
print(f"[VALID] Evaluation {m.upper()}: {round(val_metrics[m], 4)}")
|
107 |
|
108 |
############################### Save Finetune checkpoint #######################################
|
109 |
-
if ((val_loss < best_vloss) or self.save_every_epoch) and self.save_ckpt:
|
110 |
# remove old checkpoint
|
111 |
-
if
|
112 |
os.remove(os.path.join(self.checkpoints_folder, self.last_filename))
|
113 |
|
114 |
# filename
|
115 |
model_name = f'{str(self.model)}-Finetune'
|
116 |
-
self.last_filename = f"{model_name}
|
|
|
|
|
|
|
117 |
|
118 |
# save checkpoint
|
119 |
print('Saving checkpoint...')
|
120 |
self._save_checkpoint(epoch, self.last_filename)
|
121 |
|
122 |
-
# update best loss
|
123 |
-
best_vloss = val_loss
|
124 |
-
|
125 |
def evaluate(self, verbose=True):
|
126 |
if verbose:
|
127 |
print("\n=====Test Evaluation=====")
|
@@ -189,6 +189,7 @@ class Trainer:
|
|
189 |
ckpt_dict = torch.load(ckpt_path, map_location='cpu')
|
190 |
self.model.load_state_dict(ckpt_dict['MODEL_STATE'])
|
191 |
self.start_epoch = ckpt_dict['EPOCHS_RUN'] + 1
|
|
|
192 |
|
193 |
def _save_checkpoint(self, current_epoch, filename):
|
194 |
if not os.path.exists(self.checkpoints_folder):
|
@@ -209,6 +210,7 @@ class Trainer:
|
|
209 |
'train_size': self.df_train.shape[0],
|
210 |
'valid_size': self.df_valid.shape[0],
|
211 |
'test_size': self.df_test.shape[0],
|
|
|
212 |
},
|
213 |
'seed': self.seed,
|
214 |
}
|
|
|
47 |
self.save_every_epoch = save_every_epoch
|
48 |
self.save_ckpt = save_ckpt
|
49 |
self.device = device
|
50 |
+
self.best_vloss = float('inf')
|
51 |
+
self.last_filename = None
|
52 |
self._set_seed(seed)
|
53 |
|
54 |
def _prepare_data(self):
|
|
|
91 |
print('Checkpoint restored!')
|
92 |
|
93 |
def fit(self, max_epochs=500):
|
|
|
|
|
94 |
for epoch in range(self.start_epoch, max_epochs+1):
|
95 |
print(f'\n=====Epoch [{epoch}/{max_epochs}]=====')
|
96 |
|
|
|
106 |
print(f"[VALID] Evaluation {m.upper()}: {round(val_metrics[m], 4)}")
|
107 |
|
108 |
############################### Save Finetune checkpoint #######################################
|
109 |
+
if ((val_loss < self.best_vloss) or self.save_every_epoch) and self.save_ckpt:
|
110 |
# remove old checkpoint
|
111 |
+
if (self.last_filename != None) and (not self.save_every_epoch):
|
112 |
os.remove(os.path.join(self.checkpoints_folder, self.last_filename))
|
113 |
|
114 |
# filename
|
115 |
model_name = f'{str(self.model)}-Finetune'
|
116 |
+
self.last_filename = f"{model_name}_seed{self.seed}_{self.dataset_name}_epoch={epoch}_valloss={round(val_loss, 4)}.pt"
|
117 |
+
|
118 |
+
# update best loss
|
119 |
+
self.best_vloss = val_loss
|
120 |
|
121 |
# save checkpoint
|
122 |
print('Saving checkpoint...')
|
123 |
self._save_checkpoint(epoch, self.last_filename)
|
124 |
|
|
|
|
|
|
|
125 |
def evaluate(self, verbose=True):
|
126 |
if verbose:
|
127 |
print("\n=====Test Evaluation=====")
|
|
|
189 |
ckpt_dict = torch.load(ckpt_path, map_location='cpu')
|
190 |
self.model.load_state_dict(ckpt_dict['MODEL_STATE'])
|
191 |
self.start_epoch = ckpt_dict['EPOCHS_RUN'] + 1
|
192 |
+
self.best_vloss = ckpt_dict['finetune_info']['best_vloss']
|
193 |
|
194 |
def _save_checkpoint(self, current_epoch, filename):
|
195 |
if not os.path.exists(self.checkpoints_folder):
|
|
|
210 |
'train_size': self.df_train.shape[0],
|
211 |
'valid_size': self.df_valid.shape[0],
|
212 |
'test_size': self.df_test.shape[0],
|
213 |
+
'best_vloss': self.best_vloss,
|
214 |
},
|
215 |
'seed': self.seed,
|
216 |
}
|