File size: 2,463 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import torch
import numpy as np
from typing import Any, List
class _FastCopy:
"""
Overview:
The idea of this class comes from this article \
https://newbedev.com/what-is-a-fast-pythonic-way-to-deepcopy-just-data-from-a-python-dict-or-list.
We use recursive calls to copy each object that needs to be copied, which will be 5x faster \
than copy.deepcopy.
Interfaces:
``__init__``, ``_copy_list``, ``_copy_dict``, ``_copy_tensor``, ``_copy_ndarray``, ``copy``.
"""
def __init__(self):
"""
Overview:
Initialize the _FastCopy object.
"""
dispatch = {}
dispatch[list] = self._copy_list
dispatch[dict] = self._copy_dict
dispatch[torch.Tensor] = self._copy_tensor
dispatch[np.ndarray] = self._copy_ndarray
self.dispatch = dispatch
def _copy_list(self, l: List) -> dict:
"""
Overview:
Copy the list.
Arguments:
- l (:obj:`List`): The list to be copied.
"""
ret = l.copy()
for idx, item in enumerate(ret):
cp = self.dispatch.get(type(item))
if cp is not None:
ret[idx] = cp(item)
return ret
def _copy_dict(self, d: dict) -> dict:
"""
Overview:
Copy the dict.
Arguments:
- d (:obj:`dict`): The dict to be copied.
"""
ret = d.copy()
for key, value in ret.items():
cp = self.dispatch.get(type(value))
if cp is not None:
ret[key] = cp(value)
return ret
def _copy_tensor(self, t: torch.Tensor) -> torch.Tensor:
"""
Overview:
Copy the tensor.
Arguments:
- t (:obj:`torch.Tensor`): The tensor to be copied.
"""
return t.clone()
def _copy_ndarray(self, a: np.ndarray) -> np.ndarray:
"""
Overview:
Copy the ndarray.
Arguments:
- a (:obj:`np.ndarray`): The ndarray to be copied.
"""
return np.copy(a)
def copy(self, sth: Any) -> Any:
"""
Overview:
Copy the object.
Arguments:
- sth (:obj:`Any`): The object to be copied.
"""
cp = self.dispatch.get(type(sth))
if cp is None:
return sth
else:
return cp(sth)
fastcopy = _FastCopy()
|