File size: 572 Bytes
3ef28b3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
import json
import torch
import transformers
from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
from transformers.utils.hub import cached_file
def load_config_hf(model_name):
resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False)
return json.load(open(resolved_archive_file))
def load_state_dict_hf(model_name, device="cpu"):
resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
return torch.load(resolved_archive_file, map_location=device)
|