JerryLiJinyi's picture
Upload 127 files
10b912d verified
raw
history blame
1.32 kB
import json
from dataclasses import dataclass, make_dataclass, asdict, field
from typing import List
@dataclass
class Config:
device: str = "cpu"
# paths
config: str = "config/default.json"
loader: str = "loaders/google_sc.py"
dataset: str = ""
indices: str = ""
model_dir: str = "default_model_dir"
validation_datasets: List = field(default_factory=lambda: [])
# training settings/hyperparams
batch_size: int = 4
verbose: bool = True
# pretrained models
encoder_model_id: str = "distilroberta-base"
# reward settings
rewards: tuple = (
"FluencyReward",
"CrossSimilarityReward",
)
def load_config(args):
"""
Loads settings into a dataclass object, from the following sources:
- defaults defined above by DefaultConfig
- args.config (path to a JSON config file)
- args (from using argparse in a script)
Overlapping fields are overwritten in that order.
Example usage:
(...)
args = load_config(parser.parse_args())
args.batch_size
"""
config = asdict(Config())
if args.config:
with open(args.config) as f:
config.update(json.load(f))
config.update(args.__dict__)
Config_ = make_dataclass("Config", fields=config.items())
return Config_(**config)