Spaces:
Running
Running
File size: 1,233 Bytes
4d4dd90 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
"""
Author: Paul-Edouard Sarlin (skydes)
"""
import collections.abc as collections
import numpy as np
import torch
string_classes = (str, bytes)
def map_tensor(input_, func):
if isinstance(input_, string_classes):
return input_
elif isinstance(input_, collections.Mapping):
return {k: map_tensor(sample, func) for k, sample in input_.items()}
elif isinstance(input_, collections.Sequence):
return [map_tensor(sample, func) for sample in input_]
elif input_ is None:
return None
else:
return func(input_)
def batch_to_numpy(batch):
return map_tensor(batch, lambda tensor: tensor.cpu().numpy())
def batch_to_device(batch, device, non_blocking=True):
def _func(tensor):
return tensor.to(device=device, non_blocking=non_blocking)
return map_tensor(batch, _func)
def rbd(data: dict) -> dict:
"""Remove batch dimension from elements in data"""
return {
k: v[0] if isinstance(v, (torch.Tensor, np.ndarray, list)) else v
for k, v in data.items()
}
def index_batch(tensor_dict):
batch_size = len(next(iter(tensor_dict.values())))
for i in range(batch_size):
yield map_tensor(tensor_dict, lambda t: t[i])
|