Spaces:
Runtime error
Runtime error
import json | |
from dataclasses import dataclass, make_dataclass, asdict, field | |
from typing import List | |
class Config: | |
device: str = "cpu" | |
# paths | |
config: str = "config/default.json" | |
loader: str = "loaders/google_sc.py" | |
dataset: str = "" | |
indices: str = "" | |
model_dir: str = "default_model_dir" | |
validation_datasets: List = field(default_factory=lambda: []) | |
# training settings/hyperparams | |
batch_size: int = 4 | |
verbose: bool = True | |
# pretrained models | |
encoder_model_id: str = "distilroberta-base" | |
# reward settings | |
rewards: tuple = ( | |
"FluencyReward", | |
"CrossSimilarityReward", | |
) | |
def load_config(args): | |
""" | |
Loads settings into a dataclass object, from the following sources: | |
- defaults defined above by DefaultConfig | |
- args.config (path to a JSON config file) | |
- args (from using argparse in a script) | |
Overlapping fields are overwritten in that order. | |
Example usage: | |
(...) | |
args = load_config(parser.parse_args()) | |
args.batch_size | |
""" | |
config = asdict(Config()) | |
if args.config: | |
with open(args.config) as f: | |
config.update(json.load(f)) | |
config.update(args.__dict__) | |
Config_ = make_dataclass("Config", fields=config.items()) | |
return Config_(**config) | |