import json
from dataclasses import dataclass, make_dataclass, asdict, field
from typing import List


@dataclass
class Config:
    device: str = "cpu"
    # paths
    config: str = "config/default.json"
    loader: str = "loaders/google_sc.py"
    dataset: str = ""
    indices: str = ""
    model_dir: str = "default_model_dir"
    validation_datasets: List = field(default_factory=lambda: [])

    # training settings/hyperparams
    batch_size: int = 4
    verbose: bool = True

    # pretrained models
    encoder_model_id: str = "distilroberta-base"
    # reward settings
    rewards: tuple = (
        "FluencyReward",
        "CrossSimilarityReward",
    )


def load_config(args):
    """
    Loads settings into a dataclass object, from the following sources:
    - defaults defined above by DefaultConfig
    - args.config (path to a JSON config file)
    - args (from using argparse in a script)

    Overlapping fields are overwritten in that order.

    Example usage:
    (...)
    args = load_config(parser.parse_args())
    args.batch_size
    """
    config = asdict(Config())
    if args.config:
        with open(args.config) as f:
            config.update(json.load(f))
    config.update(args.__dict__)
    Config_ = make_dataclass("Config", fields=config.items())
    return Config_(**config)