Spaces:
Runtime error
Runtime error
import json | |
from dataclasses import dataclass, make_dataclass, asdict, field | |
from typing import List | |
class Config: | |
# paths | |
config: str = "config/default.json" | |
loader: str = "loaders/newsroom.py" | |
dataset: str = "" | |
indices: str = "" | |
model_dir: str = "default_model_dir" | |
validation_datasets: List = field(default_factory=lambda: []) | |
# training settings/hyperparams | |
batch_size: int = 4 | |
learning_rate: float = 0.00001 | |
k_samples: int = 1 | |
sample_aggregation: str = "max" | |
max_val_steps: int = None | |
max_train_steps: int = None | |
max_train_seconds: int = None | |
print_every: int = 10 | |
save_every: int = 100 | |
eval_every: int = 100 | |
verbose: bool = True | |
# pretrained models | |
encoder_model_id: str = "distilroberta-base" | |
# reward settings | |
rewards: tuple = ( | |
"FluencyReward", | |
"BiEncoderSimilarity", | |
"GaussianLength", | |
) | |
def validate_config(args): | |
assert (args.sample_aggregation in ("max", "mean")) | |
def load_config(args): | |
""" | |
Loads settings into a dataclass object, from the following sources: | |
- defaults defined above by DefaultConfig | |
- args.config (path to a JSON config file) | |
- args (from using argparse in a script) | |
Overlapping fields are overwritten in that order. | |
Example usage: | |
(...) | |
args = load_config(parser.parse_args()) | |
args.batch_size | |
""" | |
config = asdict(Config()) | |
if args.config: | |
with open(args.config) as f: | |
config.update(json.load(f)) | |
config.update(args.__dict__) | |
Config_ = make_dataclass("Config", fields=config.items()) | |
config_object = Config_(**config) | |
validate_config(config_object) | |
return config_object | |