File size: 7,278 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 |
from typing import Any, ByteString, Callable
import pickle
import cloudpickle
import zlib
import numpy as np
class CloudPickleWrapper:
"""
Overview:
CloudPickleWrapper can be able to pickle more python object(e.g: an object with lambda expression).
Interfaces:
``__init__``, ``__getstate__``, ``__setstate__``.
"""
def __init__(self, data: Any) -> None:
"""
Overview:
Initialize the CloudPickleWrapper using the given arguments.
Arguments:
- data (:obj:`Any`): The object to be dumped.
"""
self.data = data
def __getstate__(self) -> bytes:
"""
Overview:
Get the state of the CloudPickleWrapper.
Returns:
- data (:obj:`bytes`): The dumped byte-like result.
"""
return cloudpickle.dumps(self.data)
def __setstate__(self, data: bytes) -> None:
"""
Overview:
Set the state of the CloudPickleWrapper.
Arguments:
- data (:obj:`bytes`): The dumped byte-like result.
"""
if isinstance(data, (tuple, list, np.ndarray)): # pickle is faster
self.data = pickle.loads(data)
else:
self.data = cloudpickle.loads(data)
def dummy_compressor(data: Any) -> Any:
"""
Overview:
Return the raw input data.
Arguments:
- data (:obj:`Any`): The input data of the compressor.
Returns:
- output (:obj:`Any`): This compressor will exactly return the input data.
"""
return data
def zlib_data_compressor(data: Any) -> bytes:
"""
Overview:
Takes the input compressed data and return the compressed original data (zlib compressor) in binary format.
Arguments:
- data (:obj:`Any`): The input data of the compressor.
Returns:
- output (:obj:`bytes`): The compressed byte-like result.
Examples:
>>> zlib_data_compressor("Hello")
b'x\x9ck`\x99\xca\xc9\x00\x01=\xac\x1e\xa999\xf9S\xf4\x00%L\x04j'
"""
return zlib.compress(pickle.dumps(data))
def lz4_data_compressor(data: Any) -> bytes:
"""
Overview:
Return the compressed original data (lz4 compressor).The compressor outputs in binary format.
Arguments:
- data (:obj:`Any`): The input data of the compressor.
Returns:
- output (:obj:`bytes`): The compressed byte-like result.
Examples:
>>> lz4.block.compress(pickle.dumps("Hello"))
b'\x14\x00\x00\x00R\x80\x04\x95\t\x00\x01\x00\x90\x8c\x05Hello\x94.'
"""
try:
import lz4.block
except ImportError:
from ditk import logging
import sys
logging.warning("Please install lz4 first, such as `pip3 install lz4`")
sys.exit(1)
return lz4.block.compress(pickle.dumps(data))
def jpeg_data_compressor(data: np.ndarray) -> bytes:
"""
Overview:
To reduce memory usage, we can choose to store the jpeg strings of image instead of the numpy array in \
the buffer. This function encodes the observation numpy arr to the jpeg strings.
Arguments:
- data (:obj:`np.array`): the observation numpy arr.
Returns:
- img_str (:obj:`bytes`): The compressed byte-like result.
"""
try:
import cv2
except ImportError:
from ditk import logging
import sys
logging.warning("Please install opencv-python first.")
sys.exit(1)
img_str = cv2.imencode('.jpg', data)[1].tobytes()
return img_str
_COMPRESSORS_MAP = {
'lz4': lz4_data_compressor,
'zlib': zlib_data_compressor,
'jpeg': jpeg_data_compressor,
'none': dummy_compressor,
}
def get_data_compressor(name: str):
"""
Overview:
Get the data compressor according to the input name.
Arguments:
- name(:obj:`str`): Name of the compressor, support ``['lz4', 'zlib', 'jpeg', 'none']``
Return:
- compressor (:obj:`Callable`): Corresponding data_compressor, taking input data returning compressed data.
Example:
>>> compress_fn = get_data_compressor('lz4')
>>> compressed_data = compressed(input_data)
"""
return _COMPRESSORS_MAP[name]
def dummy_decompressor(data: Any) -> Any:
"""
Overview:
Return the input data.
Arguments:
- data (:obj:`Any`): The input data of the decompressor.
Returns:
- output (:obj:`bytes`): The decompressed result, which is exactly the input.
"""
return data
def lz4_data_decompressor(compressed_data: bytes) -> Any:
"""
Overview:
Return the decompressed original data (lz4 compressor).
Arguments:
- data (:obj:`bytes`): The input data of the decompressor.
Returns:
- output (:obj:`Any`): The decompressed object.
"""
try:
import lz4.block
except ImportError:
from ditk import logging
import sys
logging.warning("Please install lz4 first, such as `pip3 install lz4`")
sys.exit(1)
return pickle.loads(lz4.block.decompress(compressed_data))
def zlib_data_decompressor(compressed_data: bytes) -> Any:
"""
Overview:
Return the decompressed original data (zlib compressor).
Arguments:
- data (:obj:`bytes`): The input data of the decompressor.
Returns:
- output (:obj:`Any`): The decompressed object.
"""
return pickle.loads(zlib.decompress(compressed_data))
def jpeg_data_decompressor(compressed_data: bytes, gray_scale=False) -> np.ndarray:
"""
Overview:
To reduce memory usage, we can choose to store the jpeg strings of image instead of the numpy array in the \
buffer. This function decodes the observation numpy arr from the jpeg strings.
Arguments:
- compressed_data (:obj:`bytes`): The jpeg strings.
- gray_scale (:obj:`bool`): If the observation is gray, ``gray_scale=True``,
if the observation is RGB, ``gray_scale=False``.
Returns:
- arr (:obj:`np.ndarray`): The decompressed numpy array.
"""
try:
import cv2
except ImportError:
from ditk import logging
import sys
logging.warning("Please install opencv-python first.")
sys.exit(1)
nparr = np.frombuffer(compressed_data, np.uint8)
if gray_scale:
arr = cv2.imdecode(nparr, cv2.IMREAD_GRAYSCALE)
arr = np.expand_dims(arr, -1)
else:
arr = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
return arr
_DECOMPRESSORS_MAP = {
'lz4': lz4_data_decompressor,
'zlib': zlib_data_decompressor,
'jpeg': jpeg_data_decompressor,
'none': dummy_decompressor,
}
def get_data_decompressor(name: str) -> Callable:
"""
Overview:
Get the data decompressor according to the input name.
Arguments:
- name(:obj:`str`): Name of the decompressor, support ``['lz4', 'zlib', 'none']``
.. note::
For all the decompressors, the input of a bytes-like object is required.
Returns:
- decompressor (:obj:`Callable`): Corresponding data decompressor.
Examples:
>>> decompress_fn = get_data_decompressor('lz4')
>>> origin_data = compressed(compressed_data)
"""
return _DECOMPRESSORS_MAP[name]
|