conex / espnet2 /torch_utils /device_funcs.py
tobiasc's picture
Initial commit
ad16788
raw
history blame
2.68 kB
import dataclasses
import warnings
import numpy as np
import torch
def to_device(data, device=None, dtype=None, non_blocking=False, copy=False):
"""Change the device of object recursively"""
if isinstance(data, dict):
return {
k: to_device(v, device, dtype, non_blocking, copy) for k, v in data.items()
}
elif dataclasses.is_dataclass(data) and not isinstance(data, type):
return type(data)(
*[
to_device(v, device, dtype, non_blocking, copy)
for v in dataclasses.astuple(data)
]
)
# maybe namedtuple. I don't know the correct way to judge namedtuple.
elif isinstance(data, tuple) and type(data) is not tuple:
return type(data)(
*[to_device(o, device, dtype, non_blocking, copy) for o in data]
)
elif isinstance(data, (list, tuple)):
return type(data)(to_device(v, device, dtype, non_blocking, copy) for v in data)
elif isinstance(data, np.ndarray):
return to_device(torch.from_numpy(data), device, dtype, non_blocking, copy)
elif isinstance(data, torch.Tensor):
return data.to(device, dtype, non_blocking, copy)
else:
return data
def force_gatherable(data, device):
"""Change object to gatherable in torch.nn.DataParallel recursively
The difference from to_device() is changing to torch.Tensor if float or int
value is found.
The restriction to the returned value in DataParallel:
The object must be
- torch.cuda.Tensor
- 1 or more dimension. 0-dimension-tensor sends warning.
or a list, tuple, dict.
"""
if isinstance(data, dict):
return {k: force_gatherable(v, device) for k, v in data.items()}
# DataParallel can't handle NamedTuple well
elif isinstance(data, tuple) and type(data) is not tuple:
return type(data)(*[force_gatherable(o, device) for o in data])
elif isinstance(data, (list, tuple, set)):
return type(data)(force_gatherable(v, device) for v in data)
elif isinstance(data, np.ndarray):
return force_gatherable(torch.from_numpy(data), device)
elif isinstance(data, torch.Tensor):
if data.dim() == 0:
# To 1-dim array
data = data[None]
return data.to(device)
elif isinstance(data, float):
return torch.tensor([data], dtype=torch.float, device=device)
elif isinstance(data, int):
return torch.tensor([data], dtype=torch.long, device=device)
elif data is None:
return None
else:
warnings.warn(f"{type(data)} may not be gatherable by DataParallel")
return data