import dataclasses import warnings import numpy as np import torch def to_device(data, device=None, dtype=None, non_blocking=False, copy=False): """Change the device of object recursively""" if isinstance(data, dict): return { k: to_device(v, device, dtype, non_blocking, copy) for k, v in data.items() } elif dataclasses.is_dataclass(data) and not isinstance(data, type): return type(data)( *[ to_device(v, device, dtype, non_blocking, copy) for v in dataclasses.astuple(data) ] ) # maybe namedtuple. I don't know the correct way to judge namedtuple. elif isinstance(data, tuple) and type(data) is not tuple: return type(data)( *[to_device(o, device, dtype, non_blocking, copy) for o in data] ) elif isinstance(data, (list, tuple)): return type(data)(to_device(v, device, dtype, non_blocking, copy) for v in data) elif isinstance(data, np.ndarray): return to_device(torch.from_numpy(data), device, dtype, non_blocking, copy) elif isinstance(data, torch.Tensor): return data.to(device, dtype, non_blocking, copy) else: return data def force_gatherable(data, device): """Change object to gatherable in torch.nn.DataParallel recursively The difference from to_device() is changing to torch.Tensor if float or int value is found. The restriction to the returned value in DataParallel: The object must be - torch.cuda.Tensor - 1 or more dimension. 0-dimension-tensor sends warning. or a list, tuple, dict. """ if isinstance(data, dict): return {k: force_gatherable(v, device) for k, v in data.items()} # DataParallel can't handle NamedTuple well elif isinstance(data, tuple) and type(data) is not tuple: return type(data)(*[force_gatherable(o, device) for o in data]) elif isinstance(data, (list, tuple, set)): return type(data)(force_gatherable(v, device) for v in data) elif isinstance(data, np.ndarray): return force_gatherable(torch.from_numpy(data), device) elif isinstance(data, torch.Tensor): if data.dim() == 0: # To 1-dim array data = data[None] return data.to(device) elif isinstance(data, float): return torch.tensor([data], dtype=torch.float, device=device) elif isinstance(data, int): return torch.tensor([data], dtype=torch.long, device=device) elif data is None: return None else: warnings.warn(f"{type(data)} may not be gatherable by DataParallel") return data