File size: 3,027 Bytes
05c9ac2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import os
from mlagents.trainers.exception import UnityTrainerException
from mlagents.trainers.settings import TrainerSettings
from mlagents.trainers.model_saver.torch_model_saver import DEFAULT_CHECKPOINT_NAME


def validate_existing_directories(
    output_path: str, resume: bool, force: bool, init_path: str = None
) -> None:
    """
    Validates that if the run_id model exists, we do not overwrite it unless --force is specified.
    Throws an exception if resume isn't specified and run_id exists. Throws an exception
    if --resume is specified and run-id was not found.
    :param model_path: The model path specified.
    :param summary_path: The summary path to be used.
    :param resume: Whether or not the --resume flag was passed.
    :param force: Whether or not the --force flag was passed.
    :param init_path: Path to run-id dir to initialize from
    """

    output_path_exists = os.path.isdir(output_path)

    if output_path_exists:
        if not resume and not force:
            raise UnityTrainerException(
                "Previous data from this run ID was found. "
                "Either specify a new run ID, use --resume to resume this run, "
                "or use the --force parameter to overwrite existing data."
            )
    else:
        if resume:
            raise UnityTrainerException(
                "Previous data from this run ID was not found. "
                "Train a new run by removing the --resume flag."
            )

    # Verify init path if specified.
    if init_path is not None:
        if not os.path.isdir(init_path):
            raise UnityTrainerException(
                "Could not initialize from {}. "
                "Make sure models have already been saved with that run ID.".format(
                    init_path
                )
            )


def setup_init_path(
    behaviors: TrainerSettings.DefaultTrainerDict, init_dir: str
) -> None:
    """
    For each behavior, setup full init_path to checkpoint file to initialize policy from
    :param behaviors: mapping from behavior_name to TrainerSettings
    :param init_dir: Path to run-id dir to initialize from
    """
    for behavior_name, ts in behaviors.items():
        if ts.init_path is None:
            # set default if None
            ts.init_path = os.path.join(
                init_dir, behavior_name, DEFAULT_CHECKPOINT_NAME
            )
        elif not os.path.dirname(ts.init_path):
            # update to full path if just the file name
            ts.init_path = os.path.join(init_dir, behavior_name, ts.init_path)
        _validate_init_full_path(ts.init_path)


def _validate_init_full_path(init_file: str) -> None:
    """
    Validate initialization path to be a .pt file
    :param init_file: full path to initialization checkpoint file
    """
    if not (os.path.isfile(init_file) and init_file.endswith(".pt")):
        raise UnityTrainerException(
            f"Could not initialize from {init_file}. file does not exists or is not a `.pt` file"
        )