Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
# Save and Load Functions. | |
def save_checkpoint(save_path, model, valid_loss): | |
if save_path == None: | |
return | |
state_dict = {'model_state_dict': model.state_dict(), | |
'valid_loss': valid_loss} | |
torch.save(state_dict, save_path) | |
print('[SAVE] Model has been saved successfully to \'{}\''.format(save_path)) | |
def load_checkpoint(load_path, model, device): | |
if load_path == None: | |
return | |
state_dict = torch.load(load_path, map_location=device) | |
print('DICT:', state_dict) | |
print('[LOAD] Model has been loaded successfully from \'{}\''.format(load_path)) | |
model.load_state_dict(state_dict['model_state_dict']) | |
return state_dict['valid_loss'] | |
def save_metrics(save_path, train_loss_list, valid_loss_list, global_steps_list): | |
if save_path == None: | |
return | |
state_dict = {'train_loss_list': train_loss_list, | |
'valid_loss_list': valid_loss_list, | |
'global_steps_list': global_steps_list} | |
torch.save(state_dict, save_path) | |
print('[SAVE] Model with matrics has been saved successfully to \'{}\''.format(save_path)) | |
def load_metrics(load_path, device): | |
if load_path == None: | |
return | |
state_dict = torch.load(load_path, map_location=device) | |
print('[LOAD] Model with matrics has been loaded successfully from \'{}\''.format(load_path)) | |
return state_dict['train_loss_list'], state_dict['valid_loss_list'], state_dict['global_steps_list'] |