File size: 1,737 Bytes
10b912d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import json
from dataclasses import dataclass, make_dataclass, asdict, field
from typing import List


@dataclass
class Config:
    # paths
    config: str = "config/default.json"
    loader: str = "loaders/newsroom.py"
    dataset: str = ""
    indices: str = ""
    model_dir: str = "default_model_dir"
    validation_datasets: List = field(default_factory=lambda: [])

    # training settings/hyperparams
    batch_size: int = 4
    learning_rate: float = 0.00001
    k_samples: int = 1
    sample_aggregation: str = "max"
    max_val_steps: int = None
    max_train_steps: int = None
    max_train_seconds: int = None
    print_every: int = 10
    save_every: int = 100
    eval_every: int = 100
    verbose: bool = True

    # pretrained models
    encoder_model_id: str = "distilroberta-base"
    # reward settings
    rewards: tuple = (
        "FluencyReward",
        "BiEncoderSimilarity",
        "GaussianLength",
    )


def validate_config(args):
    assert (args.sample_aggregation in ("max", "mean"))


def load_config(args):
    """
    Loads settings into a dataclass object, from the following sources:
    - defaults defined above by DefaultConfig
    - args.config (path to a JSON config file)
    - args (from using argparse in a script)

    Overlapping fields are overwritten in that order.

    Example usage:
    (...)
    args = load_config(parser.parse_args())
    args.batch_size
    """
    config = asdict(Config())
    if args.config:
        with open(args.config) as f:
            config.update(json.load(f))
    config.update(args.__dict__)
    Config_ = make_dataclass("Config", fields=config.items())
    config_object = Config_(**config)
    validate_config(config_object)
    return config_object