File size: 5,008 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
from typing import Any, Optional, Union, Tuple, Dict
from multiprocessing import Array
import ctypes
import numpy as np
import torch
_NTYPE_TO_CTYPE = {
np.bool_: ctypes.c_bool,
np.uint8: ctypes.c_uint8,
np.uint16: ctypes.c_uint16,
np.uint32: ctypes.c_uint32,
np.uint64: ctypes.c_uint64,
np.int8: ctypes.c_int8,
np.int16: ctypes.c_int16,
np.int32: ctypes.c_int32,
np.int64: ctypes.c_int64,
np.float32: ctypes.c_float,
np.float64: ctypes.c_double,
}
class ShmBuffer():
"""
Overview:
Shared memory buffer to store numpy array.
"""
def __init__(
self,
dtype: Union[type, np.dtype],
shape: Tuple[int],
copy_on_get: bool = True,
ctype: Optional[type] = None
) -> None:
"""
Overview:
Initialize the buffer.
Arguments:
- dtype (:obj:`Union[type, np.dtype]`): The dtype of the data to limit the size of the buffer.
- shape (:obj:`Tuple[int]`): The shape of the data to limit the size of the buffer.
- copy_on_get (:obj:`bool`): Whether to copy data when calling get method.
- ctype (:obj:`Optional[type]`): Origin class type, e.g. np.ndarray, torch.Tensor.
"""
if isinstance(dtype, np.dtype): # it is type of gym.spaces.dtype
dtype = dtype.type
self.buffer = Array(_NTYPE_TO_CTYPE[dtype], int(np.prod(shape)))
self.dtype = dtype
self.shape = shape
self.copy_on_get = copy_on_get
self.ctype = ctype
def fill(self, src_arr: np.ndarray) -> None:
"""
Overview:
Fill the shared memory buffer with a numpy array. (Replace the original one.)
Arguments:
- src_arr (:obj:`np.ndarray`): array to fill the buffer.
"""
assert isinstance(src_arr, np.ndarray), type(src_arr)
# for np.array with shape (4, 84, 84) and float32 dtype, reshape is 15~20x faster than flatten
# for np.array with shape (4, 84, 84) and uint8 dtype, reshape is 5~7x faster than flatten
# so we reshape dst_arr rather than flatten src_arr
dst_arr = np.frombuffer(self.buffer.get_obj(), dtype=self.dtype).reshape(self.shape)
np.copyto(dst_arr, src_arr)
def get(self) -> np.ndarray:
"""
Overview:
Get the array stored in the buffer.
Return:
- data (:obj:`np.ndarray`): A copy of the data stored in the buffer.
"""
data = np.frombuffer(self.buffer.get_obj(), dtype=self.dtype).reshape(self.shape)
if self.copy_on_get:
data = data.copy() # must use np.copy, torch.from_numpy and torch.as_tensor still use the same memory
if self.ctype is torch.Tensor:
data = torch.from_numpy(data)
return data
class ShmBufferContainer(object):
"""
Overview:
Support multiple shared memory buffers. Each key-value is name-buffer.
"""
def __init__(
self,
dtype: Union[Dict[Any, type], type, np.dtype],
shape: Union[Dict[Any, tuple], tuple],
copy_on_get: bool = True
) -> None:
"""
Overview:
Initialize the buffer container.
Arguments:
- dtype (:obj:`Union[type, np.dtype]`): The dtype of the data to limit the size of the buffer.
- shape (:obj:`Union[Dict[Any, tuple], tuple]`): If `Dict[Any, tuple]`, use a dict to manage \
multiple buffers; If `tuple`, use single buffer.
- copy_on_get (:obj:`bool`): Whether to copy data when calling get method.
"""
if isinstance(shape, dict):
self._data = {k: ShmBufferContainer(dtype[k], v, copy_on_get) for k, v in shape.items()}
elif isinstance(shape, (tuple, list)):
self._data = ShmBuffer(dtype, shape, copy_on_get)
else:
raise RuntimeError("not support shape: {}".format(shape))
self._shape = shape
def fill(self, src_arr: Union[Dict[Any, np.ndarray], np.ndarray]) -> None:
"""
Overview:
Fill the one or many shared memory buffer.
Arguments:
- src_arr (:obj:`Union[Dict[Any, np.ndarray], np.ndarray]`): array to fill the buffer.
"""
if isinstance(self._shape, dict):
for k in self._shape.keys():
self._data[k].fill(src_arr[k])
elif isinstance(self._shape, (tuple, list)):
self._data.fill(src_arr)
def get(self) -> Union[Dict[Any, np.ndarray], np.ndarray]:
"""
Overview:
Get the one or many arrays stored in the buffer.
Return:
- data (:obj:`np.ndarray`): The array(s) stored in the buffer.
"""
if isinstance(self._shape, dict):
return {k: self._data[k].get() for k in self._shape.keys()}
elif isinstance(self._shape, (tuple, list)):
return self._data.get()
|