Realcat
add: GIM (https://github.com/xuelunshen/gim)
c0283b3
raw
history blame
1.23 kB
"""
Author: Paul-Edouard Sarlin (skydes)
"""
import collections.abc as collections
import numpy as np
import torch
string_classes = (str, bytes)
def map_tensor(input_, func):
if isinstance(input_, string_classes):
return input_
elif isinstance(input_, collections.Mapping):
return {k: map_tensor(sample, func) for k, sample in input_.items()}
elif isinstance(input_, collections.Sequence):
return [map_tensor(sample, func) for sample in input_]
elif input_ is None:
return None
else:
return func(input_)
def batch_to_numpy(batch):
return map_tensor(batch, lambda tensor: tensor.cpu().numpy())
def batch_to_device(batch, device, non_blocking=True):
def _func(tensor):
return tensor.to(device=device, non_blocking=non_blocking)
return map_tensor(batch, _func)
def rbd(data: dict) -> dict:
"""Remove batch dimension from elements in data"""
return {
k: v[0] if isinstance(v, (torch.Tensor, np.ndarray, list)) else v
for k, v in data.items()
}
def index_batch(tensor_dict):
batch_size = len(next(iter(tensor_dict.values())))
for i in range(batch_size):
yield map_tensor(tensor_dict, lambda t: t[i])