File size: 2,463 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
import numpy as np
from typing import Any, List


class _FastCopy:
    """
    Overview:
        The idea of this class comes from this article \
        https://newbedev.com/what-is-a-fast-pythonic-way-to-deepcopy-just-data-from-a-python-dict-or-list.
        We use recursive calls to copy each object that needs to be copied, which will be 5x faster \
        than copy.deepcopy.
    Interfaces:
        ``__init__``, ``_copy_list``, ``_copy_dict``, ``_copy_tensor``, ``_copy_ndarray``, ``copy``.
    """

    def __init__(self):
        """
        Overview:
            Initialize the _FastCopy object.
        """

        dispatch = {}
        dispatch[list] = self._copy_list
        dispatch[dict] = self._copy_dict
        dispatch[torch.Tensor] = self._copy_tensor
        dispatch[np.ndarray] = self._copy_ndarray
        self.dispatch = dispatch

    def _copy_list(self, l: List) -> dict:
        """
        Overview:
            Copy the list.
        Arguments:
            - l (:obj:`List`): The list to be copied.
        """

        ret = l.copy()
        for idx, item in enumerate(ret):
            cp = self.dispatch.get(type(item))
            if cp is not None:
                ret[idx] = cp(item)
        return ret

    def _copy_dict(self, d: dict) -> dict:
        """
        Overview:
            Copy the dict.
        Arguments:
            - d (:obj:`dict`): The dict to be copied.
        """

        ret = d.copy()
        for key, value in ret.items():
            cp = self.dispatch.get(type(value))
            if cp is not None:
                ret[key] = cp(value)

        return ret

    def _copy_tensor(self, t: torch.Tensor) -> torch.Tensor:
        """
        Overview:
            Copy the tensor.
        Arguments:
            - t (:obj:`torch.Tensor`): The tensor to be copied.
        """

        return t.clone()

    def _copy_ndarray(self, a: np.ndarray) -> np.ndarray:
        """
        Overview:
            Copy the ndarray.
        Arguments:
            - a (:obj:`np.ndarray`): The ndarray to be copied.
        """

        return np.copy(a)

    def copy(self, sth: Any) -> Any:
        """
        Overview:
            Copy the object.
        Arguments:
            - sth (:obj:`Any`): The object to be copied.
        """

        cp = self.dispatch.get(type(sth))
        if cp is None:
            return sth
        else:
            return cp(sth)


fastcopy = _FastCopy()