import torch import numpy as np from typing import Any, List class _FastCopy: """ Overview: The idea of this class comes from this article \ https://newbedev.com/what-is-a-fast-pythonic-way-to-deepcopy-just-data-from-a-python-dict-or-list. We use recursive calls to copy each object that needs to be copied, which will be 5x faster \ than copy.deepcopy. Interfaces: ``__init__``, ``_copy_list``, ``_copy_dict``, ``_copy_tensor``, ``_copy_ndarray``, ``copy``. """ def __init__(self): """ Overview: Initialize the _FastCopy object. """ dispatch = {} dispatch[list] = self._copy_list dispatch[dict] = self._copy_dict dispatch[torch.Tensor] = self._copy_tensor dispatch[np.ndarray] = self._copy_ndarray self.dispatch = dispatch def _copy_list(self, l: List) -> dict: """ Overview: Copy the list. Arguments: - l (:obj:`List`): The list to be copied. """ ret = l.copy() for idx, item in enumerate(ret): cp = self.dispatch.get(type(item)) if cp is not None: ret[idx] = cp(item) return ret def _copy_dict(self, d: dict) -> dict: """ Overview: Copy the dict. Arguments: - d (:obj:`dict`): The dict to be copied. """ ret = d.copy() for key, value in ret.items(): cp = self.dispatch.get(type(value)) if cp is not None: ret[key] = cp(value) return ret def _copy_tensor(self, t: torch.Tensor) -> torch.Tensor: """ Overview: Copy the tensor. Arguments: - t (:obj:`torch.Tensor`): The tensor to be copied. """ return t.clone() def _copy_ndarray(self, a: np.ndarray) -> np.ndarray: """ Overview: Copy the ndarray. Arguments: - a (:obj:`np.ndarray`): The ndarray to be copied. """ return np.copy(a) def copy(self, sth: Any) -> Any: """ Overview: Copy the object. Arguments: - sth (:obj:`Any`): The object to be copied. """ cp = self.dispatch.get(type(sth)) if cp is None: return sth else: return cp(sth) fastcopy = _FastCopy()