vshirasuna commited on
Commit
f6401dc
1 Parent(s): 8abbc76

Restore best_vloss in finetune

Browse files
Files changed (1) hide show
  1. smi-ted/finetune/trainers.py +10 -8
smi-ted/finetune/trainers.py CHANGED
@@ -47,6 +47,8 @@ class Trainer:
47
  self.save_every_epoch = save_every_epoch
48
  self.save_ckpt = save_ckpt
49
  self.device = device
 
 
50
  self._set_seed(seed)
51
 
52
  def _prepare_data(self):
@@ -89,8 +91,6 @@ class Trainer:
89
  print('Checkpoint restored!')
90
 
91
  def fit(self, max_epochs=500):
92
- best_vloss = float('inf')
93
-
94
  for epoch in range(self.start_epoch, max_epochs+1):
95
  print(f'\n=====Epoch [{epoch}/{max_epochs}]=====')
96
 
@@ -106,22 +106,22 @@ class Trainer:
106
  print(f"[VALID] Evaluation {m.upper()}: {round(val_metrics[m], 4)}")
107
 
108
  ############################### Save Finetune checkpoint #######################################
109
- if ((val_loss < best_vloss) or self.save_every_epoch) and self.save_ckpt:
110
  # remove old checkpoint
111
- if best_vloss != float('inf') and not self.save_every_epoch:
112
  os.remove(os.path.join(self.checkpoints_folder, self.last_filename))
113
 
114
  # filename
115
  model_name = f'{str(self.model)}-Finetune'
116
- self.last_filename = f"{model_name}_epoch={epoch}_{self.dataset_name}_seed{self.seed}_valloss={round(val_loss, 4)}.pt"
 
 
 
117
 
118
  # save checkpoint
119
  print('Saving checkpoint...')
120
  self._save_checkpoint(epoch, self.last_filename)
121
 
122
- # update best loss
123
- best_vloss = val_loss
124
-
125
  def evaluate(self, verbose=True):
126
  if verbose:
127
  print("\n=====Test Evaluation=====")
@@ -189,6 +189,7 @@ class Trainer:
189
  ckpt_dict = torch.load(ckpt_path, map_location='cpu')
190
  self.model.load_state_dict(ckpt_dict['MODEL_STATE'])
191
  self.start_epoch = ckpt_dict['EPOCHS_RUN'] + 1
 
192
 
193
  def _save_checkpoint(self, current_epoch, filename):
194
  if not os.path.exists(self.checkpoints_folder):
@@ -209,6 +210,7 @@ class Trainer:
209
  'train_size': self.df_train.shape[0],
210
  'valid_size': self.df_valid.shape[0],
211
  'test_size': self.df_test.shape[0],
 
212
  },
213
  'seed': self.seed,
214
  }
 
47
  self.save_every_epoch = save_every_epoch
48
  self.save_ckpt = save_ckpt
49
  self.device = device
50
+ self.best_vloss = float('inf')
51
+ self.last_filename = None
52
  self._set_seed(seed)
53
 
54
  def _prepare_data(self):
 
91
  print('Checkpoint restored!')
92
 
93
  def fit(self, max_epochs=500):
 
 
94
  for epoch in range(self.start_epoch, max_epochs+1):
95
  print(f'\n=====Epoch [{epoch}/{max_epochs}]=====')
96
 
 
106
  print(f"[VALID] Evaluation {m.upper()}: {round(val_metrics[m], 4)}")
107
 
108
  ############################### Save Finetune checkpoint #######################################
109
+ if ((val_loss < self.best_vloss) or self.save_every_epoch) and self.save_ckpt:
110
  # remove old checkpoint
111
+ if (self.last_filename != None) and (not self.save_every_epoch):
112
  os.remove(os.path.join(self.checkpoints_folder, self.last_filename))
113
 
114
  # filename
115
  model_name = f'{str(self.model)}-Finetune'
116
+ self.last_filename = f"{model_name}_seed{self.seed}_{self.dataset_name}_epoch={epoch}_valloss={round(val_loss, 4)}.pt"
117
+
118
+ # update best loss
119
+ self.best_vloss = val_loss
120
 
121
  # save checkpoint
122
  print('Saving checkpoint...')
123
  self._save_checkpoint(epoch, self.last_filename)
124
 
 
 
 
125
  def evaluate(self, verbose=True):
126
  if verbose:
127
  print("\n=====Test Evaluation=====")
 
189
  ckpt_dict = torch.load(ckpt_path, map_location='cpu')
190
  self.model.load_state_dict(ckpt_dict['MODEL_STATE'])
191
  self.start_epoch = ckpt_dict['EPOCHS_RUN'] + 1
192
+ self.best_vloss = ckpt_dict['finetune_info']['best_vloss']
193
 
194
  def _save_checkpoint(self, current_epoch, filename):
195
  if not os.path.exists(self.checkpoints_folder):
 
210
  'train_size': self.df_train.shape[0],
211
  'valid_size': self.df_valid.shape[0],
212
  'test_size': self.df_test.shape[0],
213
+ 'best_vloss': self.best_vloss,
214
  },
215
  'seed': self.seed,
216
  }