File size: 619 Bytes
9ae46f4 bacf856 9ae46f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
import torch
from model import GLiNER
def save_model(current_model, path):
config = current_model.config
dict_save = {"model_weights": current_model.state_dict(), "config": config}
torch.save(dict_save, path)
def load_model(path, model_name=None, device=None):
dict_load = torch.load(path, map_location=torch.device('cpu'))
config = dict_load["config"]
if model_name is not None:
config.model_name = model_name
loaded_model = GLiNER(config)
loaded_model.load_state_dict(dict_load["model_weights"])
return loaded_model.to(device) if device is not None else loaded_model
|