7900c16
1
2
3
4
5
6
7
8
9
10
11
12
import torch def save_model(model, model_path): """ Save model weights to file. """ if hasattr(model, "module"): torch.save(model.module.state_dict(), model_path) else: torch.save(model.state_dict(), model_path)