|
import io |
|
import os |
|
from unittest import mock |
|
import numpy as np |
|
import pytest |
|
import tempfile |
|
|
|
from mlagents_envs.communicator_objects.demonstration_meta_pb2 import ( |
|
DemonstrationMetaProto, |
|
) |
|
from mlagents.trainers.tests.mock_brain import ( |
|
create_mock_3dball_behavior_specs, |
|
setup_test_behavior_specs, |
|
) |
|
from mlagents.trainers.demo_loader import ( |
|
load_demonstration, |
|
demo_to_buffer, |
|
get_demo_files, |
|
write_delimited, |
|
) |
|
from mlagents.trainers.buffer import BufferKey |
|
|
|
|
|
BEHAVIOR_SPEC = create_mock_3dball_behavior_specs() |
|
|
|
|
|
def test_load_demo(): |
|
path_prefix = os.path.dirname(os.path.abspath(__file__)) |
|
behavior_spec, pair_infos, total_expected = load_demonstration( |
|
path_prefix + "/test.demo" |
|
) |
|
assert np.sum(behavior_spec.observation_specs[0].shape) == 8 |
|
assert len(pair_infos) == total_expected |
|
|
|
_, demo_buffer = demo_to_buffer(path_prefix + "/test.demo", 1, BEHAVIOR_SPEC) |
|
assert ( |
|
len(demo_buffer[BufferKey.CONTINUOUS_ACTION]) == total_expected - 1 |
|
or len(demo_buffer[BufferKey.DISCRETE_ACTION]) == total_expected - 1 |
|
) |
|
|
|
|
|
def test_load_demo_dir(): |
|
path_prefix = os.path.dirname(os.path.abspath(__file__)) |
|
behavior_spec, pair_infos, total_expected = load_demonstration( |
|
path_prefix + "/test_demo_dir" |
|
) |
|
assert np.sum(behavior_spec.observation_specs[0].shape) == 8 |
|
assert len(pair_infos) == total_expected |
|
|
|
_, demo_buffer = demo_to_buffer(path_prefix + "/test_demo_dir", 1, BEHAVIOR_SPEC) |
|
assert ( |
|
len(demo_buffer[BufferKey.CONTINUOUS_ACTION]) == total_expected - 1 |
|
or len(demo_buffer[BufferKey.DISCRETE_ACTION]) == total_expected - 1 |
|
) |
|
|
|
|
|
def test_demo_mismatch(): |
|
path_prefix = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
with pytest.raises(RuntimeError): |
|
mismatch_obs = setup_test_behavior_specs( |
|
False, False, vector_action_space=2, vector_obs_space=9 |
|
) |
|
_, demo_buffer = demo_to_buffer(path_prefix + "/test.demo", 1, mismatch_obs) |
|
|
|
with pytest.raises(RuntimeError): |
|
mismatch_act = setup_test_behavior_specs( |
|
False, False, vector_action_space=3, vector_obs_space=9 |
|
) |
|
_, demo_buffer = demo_to_buffer(path_prefix + "/test.demo", 1, mismatch_act) |
|
|
|
with pytest.raises(RuntimeError): |
|
mismatch_act_type = setup_test_behavior_specs( |
|
True, False, vector_action_space=[2], vector_obs_space=9 |
|
) |
|
_, demo_buffer = demo_to_buffer( |
|
path_prefix + "/test.demo", 1, mismatch_act_type |
|
) |
|
|
|
with pytest.raises(RuntimeError): |
|
mismatch_obs_number = setup_test_behavior_specs( |
|
False, True, vector_action_space=2, vector_obs_space=9 |
|
) |
|
_, demo_buffer = demo_to_buffer( |
|
path_prefix + "/test.demo", 1, mismatch_obs_number |
|
) |
|
|
|
|
|
def test_edge_cases(): |
|
path_prefix = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
with pytest.raises(FileNotFoundError): |
|
get_demo_files(os.path.join(path_prefix, "nonexistent_file.demo")) |
|
with pytest.raises(FileNotFoundError): |
|
get_demo_files(os.path.join(path_prefix, "nonexistent_directory")) |
|
with tempfile.TemporaryDirectory() as tmpdirname: |
|
|
|
with pytest.raises(ValueError): |
|
get_demo_files(tmpdirname) |
|
|
|
invalid_fname = os.path.join(tmpdirname, "mydemo.notademo") |
|
with open(invalid_fname, "w") as f: |
|
f.write("I'm not a demo") |
|
with pytest.raises(ValueError): |
|
get_demo_files(invalid_fname) |
|
|
|
with pytest.raises(ValueError): |
|
get_demo_files(tmpdirname) |
|
|
|
valid_fname = os.path.join(tmpdirname, "mydemo.demo") |
|
with open(valid_fname, "w") as f: |
|
f.write("I'm a demo file") |
|
assert get_demo_files(valid_fname) == [valid_fname] |
|
|
|
assert get_demo_files(tmpdirname) == [valid_fname] |
|
|
|
|
|
@mock.patch("mlagents.trainers.demo_loader.get_demo_files", return_value=["foo.demo"]) |
|
def test_unsupported_version_raises_error(mock_get_demo_files): |
|
|
|
bad_metadata = DemonstrationMetaProto() |
|
bad_metadata.api_version = 1337 |
|
|
|
|
|
buffer = io.BytesIO() |
|
write_delimited(buffer, bad_metadata) |
|
m = mock.mock_open(read_data=buffer.getvalue()) |
|
|
|
|
|
with mock.patch("builtins.open", m): |
|
with pytest.raises(RuntimeError): |
|
load_demonstration("foo") |
|
|