|
import dataclasses |
|
import warnings |
|
|
|
import numpy as np |
|
import torch |
|
|
|
|
|
def to_device(data, device=None, dtype=None, non_blocking=False, copy=False): |
|
"""Change the device of object recursively""" |
|
if isinstance(data, dict): |
|
return { |
|
k: to_device(v, device, dtype, non_blocking, copy) for k, v in data.items() |
|
} |
|
elif dataclasses.is_dataclass(data) and not isinstance(data, type): |
|
return type(data)( |
|
*[ |
|
to_device(v, device, dtype, non_blocking, copy) |
|
for v in dataclasses.astuple(data) |
|
] |
|
) |
|
|
|
elif isinstance(data, tuple) and type(data) is not tuple: |
|
return type(data)( |
|
*[to_device(o, device, dtype, non_blocking, copy) for o in data] |
|
) |
|
elif isinstance(data, (list, tuple)): |
|
return type(data)(to_device(v, device, dtype, non_blocking, copy) for v in data) |
|
elif isinstance(data, np.ndarray): |
|
return to_device(torch.from_numpy(data), device, dtype, non_blocking, copy) |
|
elif isinstance(data, torch.Tensor): |
|
return data.to(device, dtype, non_blocking, copy) |
|
else: |
|
return data |
|
|
|
|
|
def force_gatherable(data, device): |
|
"""Change object to gatherable in torch.nn.DataParallel recursively |
|
|
|
The difference from to_device() is changing to torch.Tensor if float or int |
|
value is found. |
|
|
|
The restriction to the returned value in DataParallel: |
|
The object must be |
|
- torch.cuda.Tensor |
|
- 1 or more dimension. 0-dimension-tensor sends warning. |
|
or a list, tuple, dict. |
|
|
|
""" |
|
if isinstance(data, dict): |
|
return {k: force_gatherable(v, device) for k, v in data.items()} |
|
|
|
elif isinstance(data, tuple) and type(data) is not tuple: |
|
return type(data)(*[force_gatherable(o, device) for o in data]) |
|
elif isinstance(data, (list, tuple, set)): |
|
return type(data)(force_gatherable(v, device) for v in data) |
|
elif isinstance(data, np.ndarray): |
|
return force_gatherable(torch.from_numpy(data), device) |
|
elif isinstance(data, torch.Tensor): |
|
if data.dim() == 0: |
|
|
|
data = data[None] |
|
return data.to(device) |
|
elif isinstance(data, float): |
|
return torch.tensor([data], dtype=torch.float, device=device) |
|
elif isinstance(data, int): |
|
return torch.tensor([data], dtype=torch.long, device=device) |
|
elif data is None: |
|
return None |
|
else: |
|
warnings.warn(f"{type(data)} may not be gatherable by DataParallel") |
|
return data |
|
|