gomoku / DI-engine /ding /data /buffer /tests /test_buffer_benchmark.py
zjowowen's picture
init space
079c32c
raw
history blame
3.43 kB
import sys
import timeit
import torch
import random
import pytest
import numpy as np
from ding.data.buffer import DequeBuffer
from ding.data.buffer.middleware import clone_object, PriorityExperienceReplay
# test different buffer size, eg: 1000, 10000, 100000;
size_list = [1000, 10000]
# test different tensor dim, eg: 32*32, 128*128, 512*512;
data_dim_list = [32, 128]
# repeat times.
repeats = 100
class BufferBenchmark:
def __init__(self, buffer_size, data_dim, buffer_type='base') -> None:
self._buffer = DequeBuffer(size=buffer_size)
self._meta = dict()
if buffer_type == "clone":
self._buffer.use(clone_object())
if buffer_type == "priority":
self._buffer.use(PriorityExperienceReplay(self._buffer))
self._meta["priority"] = 2.0
self._data = {"obs": torch.rand(data_dim, data_dim)}
def data_storage(self) -> float:
return sys.getsizeof(self._data["obs"].storage()) / 1024
def count(self) -> int:
return self._buffer.count()
def push_op(self) -> None:
self._buffer.push(self._data, meta=self._meta)
def push_with_group_info(self, num_keys=256) -> None:
meta = self._meta.copy()
rand = random.random()
value = int(rand * num_keys)
meta['group'] = value
self._buffer.push(self._data, meta=meta)
def sample_op(self) -> None:
self._buffer.sample(128, replace=False)
def replace_sample_op(self) -> None:
self._buffer.sample(128, replace=True)
def groupby_sample_op(self) -> None:
self._buffer.sample(128, groupby="group")
def get_mean_std(res):
# return the total time per 1000 ops
return np.mean(res) * 1000.0 / repeats, np.std(res) * 1000.0 / repeats
@pytest.mark.benchmark
@pytest.mark.parametrize('buffer_type', ['base', 'clone', 'priority'])
def test_benchmark(buffer_type):
for size in size_list:
for dim in data_dim_list:
assert size >= 128, "size is too small, please set an int no less than 128!"
buffer_test = BufferBenchmark(size, dim, buffer_type)
print("exp-buffer_{}_{}-data_{:.2f}_KB".format(buffer_type, size, buffer_test.data_storage()))
# test pushing
mean, std = get_mean_std(timeit.repeat(buffer_test.push_op, number=repeats))
print("Empty Push Test: mean {:.4f} s, std {:.4f} s".format(mean, std))
# fill the buffer before sampling tests
for _ in range(size):
buffer_test.push_with_group_info()
assert buffer_test.count() == size, "buffer is not full when testing sampling!"
# test sampling without replace
mean, std = get_mean_std(timeit.repeat(buffer_test.sample_op, number=repeats))
print("No-Replace Sample Test: mean {:.4f} s, std {:.4f} s".format(mean, std))
# test sampling with replace
mean, std = get_mean_std(timeit.repeat(buffer_test.replace_sample_op, number=repeats))
print("Replace Sample Test: mean {:.4f} s, std {:.4f} s".format(mean, std))
# test groupby sampling
if buffer_type != 'priority':
mean, std = get_mean_std(timeit.repeat(buffer_test.groupby_sample_op, number=repeats))
print("Groupby Sample Test: mean {:.4f} s, std {:.4f} s".format(mean, std))
print("=" * 100)