# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import hashlib import json import subprocess import tempfile from typing import Hashable try: import pyarrow.plasma as plasma PYARROW_AVAILABLE = True except ImportError: plasma = None PYARROW_AVAILABLE = False class PlasmaArray: """ Wrapper around numpy arrays that automatically moves the data to shared memory upon serialization. This is particularly helpful when passing numpy arrays through multiprocessing, so that data is not unnecessarily duplicated or pickled. """ def __init__(self, array): super().__init__() self.array = array self.disable = array.nbytes < 134217728 # disable for arrays <128MB self.object_id = None self.path = None # variables with underscores shouldn't be pickled self._client = None self._server = None self._server_tmp = None self._plasma = None @property def plasma(self): if self._plasma is None and not self.disable: self._plasma = plasma return self._plasma def start_server(self): if self.plasma is None or self._server is not None: return assert self.object_id is None assert self.path is None self._server_tmp = tempfile.NamedTemporaryFile() self.path = self._server_tmp.name self._server = subprocess.Popen( ["plasma_store", "-m", str(int(1.05 * self.array.nbytes)), "-s", self.path] ) @property def client(self): if self._client is None: assert self.path is not None self._client = self.plasma.connect(self.path, num_retries=200) return self._client def __getstate__(self): """Called on pickle load""" if self.plasma is None: return self.__dict__ if self.object_id is None: self.start_server() self.object_id = self.client.put(self.array) state = self.__dict__.copy() del state["array"] state["_client"] = None state["_server"] = None state["_server_tmp"] = None state["_plasma"] = None return state def __setstate__(self, state): """Called on pickle save""" self.__dict__.update(state) if self.plasma is None: return self.array = self.client.get(self.object_id) def __del__(self): if self._server is not None: self._server.kill() self._server = None self._server_tmp.close() self._server_tmp = None DEFAULT_PLASMA_PATH = "/tmp/plasma" class PlasmaView: """Interface to write and read from shared memory. Whereas PlasmaArray writes to plasma on serialization, PlasmaView writes to shared memory on instantiation.""" def __init__(self, array, split_path: str, hash_data: Hashable, plasma_path=None): """ Args: array: numpy array to store. This can be read with ``PlasmaView().array`` split_path: the path whence the data was read, used for hashing hash_data: other metadata about the array that can be used to create a unique key. as of writing, the 3 callers in ``TokenBlockDataset`` use:: hash_data = ((block_size, document_sep_len, str(break_mode), len(dataset)), 0|1|2) """ assert PYARROW_AVAILABLE assert split_path is not None if plasma_path is None: plasma_path = DEFAULT_PLASMA_PATH self.path = plasma_path self.split_path = split_path self._client = None # Initialize lazily for pickle. plasma clients should not be deep copied or serialized. self._n = None self.object_id = self.get_object_id(self.split_path, hash_data) try: self.client.put(array, object_id=self.object_id) except plasma.PlasmaObjectExists: pass @property def client(self): if self._client is None: self._client = plasma.connect(self.path, num_retries=200) return self._client @property def array(self): """Fetch a read only view of an np.array, stored in plasma.""" ret = self.client.get(self.object_id) return ret @staticmethod def get_object_id(split_path: str, hash_data: Hashable): """Returns plasma.ObjectID from hashing split_path and object_num.""" hash = hashlib.blake2b(bytes(split_path, "utf-8"), digest_size=20) harg = json.dumps(hash_data).encode("utf-8") hash.update(harg) return plasma.ObjectID(hash.digest()) def __getstate__(self): """Called on pickle save""" self.disconnect() state = self.__dict__.copy() assert state["_client"] is None assert "object_id" in state return state def __setstate__(self, state): """Called on pickle load""" self.__dict__.update(state) def __del__(self): self.disconnect() def disconnect(self): if self._client is not None: self._client.disconnect() self._client = None def __len__(self): """Save reads by caching len""" if self._n is None: self._n = len(self.array) return self._n GB100 = (1024**3) * 100 class PlasmaStore: def __init__(self, path=DEFAULT_PLASMA_PATH, nbytes: int = GB100): self.server = self.start(path, nbytes) def __del__(self): self.server.kill() @staticmethod def start(path=DEFAULT_PLASMA_PATH, nbytes: int = GB100) -> subprocess.Popen: if not PYARROW_AVAILABLE: raise ImportError("please run pip install pyarrow to use --use_plasma_view") # best practice is to allocate more space than we need. The limitation seems to be the size of /dev/shm _server = subprocess.Popen(["plasma_store", "-m", str(nbytes), "-s", path]) plasma.connect(path, num_retries=200) # If we can't connect we fail immediately return _server