File size: 2,006 Bytes
05c9ac2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import os

# Opt-in checking mode to ensure that we always create numpy arrays using float32
if os.getenv("TEST_ENFORCE_NUMPY_FLOAT32"):
    # This file is importer by pytest multiple times, but this breaks the patching
    # Removing the env variable seems the easiest way to prevent this.
    del os.environ["TEST_ENFORCE_NUMPY_FLOAT32"]
    import numpy as np
    import traceback

    __old_np_array = np.array
    __old_np_zeros = np.zeros
    __old_np_ones = np.ones

    def _check_no_float64(arr, kwargs_dtype):
        if arr.dtype == np.float64:
            tb = traceback.extract_stack()
            # tb[-1] in the stack is this function.
            # tb[-2] is the wrapper function, e.g. np_array_no_float64
            # we want the calling function, so use tb[-3]
            filename = tb[-3].filename
            # Only raise if this came from mlagents code
            if (
                "ml-agents/mlagents" in filename
                or "ml-agents-envs/mlagents" in filename
            ):
                raise ValueError(
                    f"float64 array created. Set dtype=np.float32 instead of current dtype={kwargs_dtype}. "
                    f"Run pytest with TEST_ENFORCE_NUMPY_FLOAT32=1 to confirm fix."
                )

    def np_array_no_float64(*args, **kwargs):
        res = __old_np_array(*args, **kwargs)
        _check_no_float64(res, kwargs.get("dtype"))
        return res

    def np_zeros_no_float64(*args, **kwargs):
        res = __old_np_zeros(*args, **kwargs)
        _check_no_float64(res, kwargs.get("dtype"))
        return res

    def np_ones_no_float64(*args, **kwargs):
        res = __old_np_ones(*args, **kwargs)
        _check_no_float64(res, kwargs.get("dtype"))
        return res

    np.array = np_array_no_float64
    np.zeros = np_zeros_no_float64
    np.ones = np_ones_no_float64


if os.getenv("TEST_ENFORCE_BUFFER_KEY_TYPES"):
    from mlagents.trainers.buffer import AgentBuffer

    AgentBuffer.CHECK_KEY_TYPES_AT_RUNTIME = True