Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,549 Bytes
7262fda |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
# -*- coding: utf-8 -*-
import torch
import numpy as np
def worker_init_fn(_):
worker_info = torch.utils.data.get_worker_info()
worker_id = worker_info.id
# dataset = worker_info.dataset
# split_size = dataset.num_records // worker_info.num_workers
# # reset num_records to the true number to retain reliable length information
# dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size]
# current_id = np.random.choice(len(np.random.get_state()[1]), 1)
# return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
return np.random.seed(np.random.get_state()[1][0] + worker_id)
def collation_fn(samples, combine_tensors=True, combine_scalars=True):
"""
Args:
samples (list[dict]):
combine_tensors:
combine_scalars:
Returns:
"""
result = {}
keys = samples[0].keys()
for key in keys:
result[key] = []
for sample in samples:
for key in keys:
val = sample[key]
result[key].append(val)
for key in keys:
val_list = result[key]
if isinstance(val_list[0], (int, float)):
if combine_scalars:
result[key] = np.array(result[key])
elif isinstance(val_list[0], torch.Tensor):
if combine_tensors:
result[key] = torch.stack(val_list)
elif isinstance(val_list[0], np.ndarray):
if combine_tensors:
result[key] = np.stack(val_list)
return result
|