File size: 4,244 Bytes
05c9ac2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import os
import unittest
import json
from enum import Enum
import time
from mlagents.trainers.training_status import (
    StatusType,
    StatusMetaData,
    GlobalTrainingStatus,
)
from mlagents.trainers.policy.checkpoint_manager import (
    ModelCheckpointManager,
    ModelCheckpoint,
)


def test_globaltrainingstatus(tmpdir):
    path_dir = os.path.join(tmpdir, "test.json")

    GlobalTrainingStatus.set_parameter_state("Category1", StatusType.LESSON_NUM, 3)
    GlobalTrainingStatus.save_state(path_dir)

    with open(path_dir) as fp:
        test_json = json.load(fp)

    assert "Category1" in test_json
    assert StatusType.LESSON_NUM.value in test_json["Category1"]
    assert test_json["Category1"][StatusType.LESSON_NUM.value] == 3
    assert "metadata" in test_json

    GlobalTrainingStatus.load_state(path_dir)
    restored_val = GlobalTrainingStatus.get_parameter_state(
        "Category1", StatusType.LESSON_NUM
    )
    assert restored_val == 3

    # Test unknown categories and status types (keys)
    unknown_category = GlobalTrainingStatus.get_parameter_state(
        "Category3", StatusType.LESSON_NUM
    )

    class FakeStatusType(Enum):
        NOTAREALKEY = "notarealkey"

    unknown_key = GlobalTrainingStatus.get_parameter_state(
        "Category1", FakeStatusType.NOTAREALKEY
    )
    assert unknown_category is None
    assert unknown_key is None


def test_model_management(tmpdir):

    results_path = os.path.join(tmpdir, "results")
    brain_name = "Mock_brain"
    final_model_path = os.path.join(results_path, brain_name)
    test_checkpoint_list = [
        {
            "steps": 1,
            "file_path": os.path.join(final_model_path, f"{brain_name}-1.nn"),
            "reward": 1.312,
            "creation_time": time.time(),
            "auxillary_file_paths": [],
        },
        {
            "steps": 2,
            "file_path": os.path.join(final_model_path, f"{brain_name}-2.nn"),
            "reward": 1.912,
            "creation_time": time.time(),
            "auxillary_file_paths": [],
        },
        {
            "steps": 3,
            "file_path": os.path.join(final_model_path, f"{brain_name}-3.nn"),
            "reward": 2.312,
            "creation_time": time.time(),
            "auxillary_file_paths": [],
        },
    ]
    GlobalTrainingStatus.set_parameter_state(
        brain_name, StatusType.CHECKPOINTS, test_checkpoint_list
    )

    new_checkpoint_4 = ModelCheckpoint(
        4, os.path.join(final_model_path, f"{brain_name}-4.nn"), 2.678, time.time()
    )
    ModelCheckpointManager.add_checkpoint(brain_name, new_checkpoint_4, 4)
    assert len(ModelCheckpointManager.get_checkpoints(brain_name)) == 4

    new_checkpoint_5 = ModelCheckpoint(
        5, os.path.join(final_model_path, f"{brain_name}-5.nn"), 3.122, time.time()
    )
    ModelCheckpointManager.add_checkpoint(brain_name, new_checkpoint_5, 4)
    assert len(ModelCheckpointManager.get_checkpoints(brain_name)) == 4

    final_model_path = f"{final_model_path}.nn"
    final_model_time = time.time()
    current_step = 6
    final_model = ModelCheckpoint(
        current_step, final_model_path, 3.294, final_model_time
    )

    ModelCheckpointManager.track_final_checkpoint(brain_name, final_model)
    assert len(ModelCheckpointManager.get_checkpoints(brain_name)) == 4

    check_checkpoints = GlobalTrainingStatus.saved_state[brain_name][
        StatusType.CHECKPOINTS.value
    ]
    assert check_checkpoints is not None

    final_model = GlobalTrainingStatus.saved_state[StatusType.FINAL_CHECKPOINT.value]
    assert final_model is not None


class StatsMetaDataTest(unittest.TestCase):
    def test_metadata_compare(self):
        # Test write_stats
        with self.assertLogs("mlagents.trainers", level="WARNING") as cm:
            default_metadata = StatusMetaData()
            version_statsmetadata = StatusMetaData(mlagents_version="test")
            default_metadata.check_compatibility(version_statsmetadata)

            torch_version_statsmetadata = StatusMetaData(torch_version="test")
            default_metadata.check_compatibility(torch_version_statsmetadata)

        # Assert that 2 warnings have been thrown
        assert len(cm.output) == 2