szukevin's picture
upload
7900c16
raw
history blame
250 Bytes
import torch
def save_model(model, model_path):
"""
Save model weights to file.
"""
if hasattr(model, "module"):
torch.save(model.module.state_dict(), model_path)
else:
torch.save(model.state_dict(), model_path)