Spaces:
Runtime error
Runtime error
File size: 1,737 Bytes
10b912d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
import json
from dataclasses import dataclass, make_dataclass, asdict, field
from typing import List
@dataclass
class Config:
# paths
config: str = "config/default.json"
loader: str = "loaders/newsroom.py"
dataset: str = ""
indices: str = ""
model_dir: str = "default_model_dir"
validation_datasets: List = field(default_factory=lambda: [])
# training settings/hyperparams
batch_size: int = 4
learning_rate: float = 0.00001
k_samples: int = 1
sample_aggregation: str = "max"
max_val_steps: int = None
max_train_steps: int = None
max_train_seconds: int = None
print_every: int = 10
save_every: int = 100
eval_every: int = 100
verbose: bool = True
# pretrained models
encoder_model_id: str = "distilroberta-base"
# reward settings
rewards: tuple = (
"FluencyReward",
"BiEncoderSimilarity",
"GaussianLength",
)
def validate_config(args):
assert (args.sample_aggregation in ("max", "mean"))
def load_config(args):
"""
Loads settings into a dataclass object, from the following sources:
- defaults defined above by DefaultConfig
- args.config (path to a JSON config file)
- args (from using argparse in a script)
Overlapping fields are overwritten in that order.
Example usage:
(...)
args = load_config(parser.parse_args())
args.batch_size
"""
config = asdict(Config())
if args.config:
with open(args.config) as f:
config.update(json.load(f))
config.update(args.__dict__)
Config_ = make_dataclass("Config", fields=config.items())
config_object = Config_(**config)
validate_config(config_object)
return config_object
|