Spaces:
Runtime error
Runtime error
File size: 1,316 Bytes
10b912d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import json
from dataclasses import dataclass, make_dataclass, asdict, field
from typing import List
@dataclass
class Config:
device: str = "cpu"
# paths
config: str = "config/default.json"
loader: str = "loaders/google_sc.py"
dataset: str = ""
indices: str = ""
model_dir: str = "default_model_dir"
validation_datasets: List = field(default_factory=lambda: [])
# training settings/hyperparams
batch_size: int = 4
verbose: bool = True
# pretrained models
encoder_model_id: str = "distilroberta-base"
# reward settings
rewards: tuple = (
"FluencyReward",
"CrossSimilarityReward",
)
def load_config(args):
"""
Loads settings into a dataclass object, from the following sources:
- defaults defined above by DefaultConfig
- args.config (path to a JSON config file)
- args (from using argparse in a script)
Overlapping fields are overwritten in that order.
Example usage:
(...)
args = load_config(parser.parse_args())
args.batch_size
"""
config = asdict(Config())
if args.config:
with open(args.config) as f:
config.update(json.load(f))
config.update(args.__dict__)
Config_ = make_dataclass("Config", fields=config.items())
return Config_(**config)
|