import pytest
import torch
from ding.data.buffer import DequeBuffer
from ding.data.buffer.middleware import clone_object, use_time_check, staleness_check, sample_range_view
from ding.data.buffer.middleware import PriorityExperienceReplay, group_sample
from ding.data.buffer.middleware.padding import padding


@pytest.mark.unittest
def test_clone_object():
    buffer = DequeBuffer(size=10).use(clone_object())

    # Store a dict, a list, a tensor
    arr = [{"key": "v1"}, ["a"], torch.Tensor([1, 2, 3])]
    for o in arr:
        buffer.push(o)

    # Modify it
    for item in buffer.sample(len(arr)):
        item = item.data
        if isinstance(item, dict):
            item["key"] = "v2"
        elif isinstance(item, list):
            item.append("b")
        elif isinstance(item, torch.Tensor):
            item[0] = 3
        else:
            raise Exception("Unexpected type")

    # Resample it, and check their values
    for item in buffer.sample(len(arr)):
        item = item.data
        if isinstance(item, dict):
            assert item["key"] == "v1"
        elif isinstance(item, list):
            assert len(item) == 1
        elif isinstance(item, torch.Tensor):
            assert item[0] == 1
        else:
            raise Exception("Unexpected type")


def get_data():
    return {'obs': torch.randn(4), 'reward': torch.randn(1), 'info': 'xxx'}


@pytest.mark.unittest
def test_use_time_check():
    N = 6
    buffer = DequeBuffer(size=10)
    buffer.use(use_time_check(buffer, max_use=2))

    for _ in range(N):
        buffer.push(get_data())

    for _ in range(2):
        data = buffer.sample(size=N, replace=False)
        assert len(data) == N
    with pytest.raises(ValueError):
        buffer.sample(size=1, replace=False)


@pytest.mark.unittest
def test_staleness_check():
    N = 6
    buffer = DequeBuffer(size=10)
    buffer.use(staleness_check(buffer, max_staleness=10))

    with pytest.raises(AssertionError):
        buffer.push(get_data())
    for _ in range(N):
        buffer.push(get_data(), meta={'train_iter_data_collected': 0})
    data = buffer.sample(size=N, replace=False, train_iter_sample_data=9)
    assert len(data) == N
    data = buffer.sample(size=N, replace=False, train_iter_sample_data=10)  # edge case
    assert len(data) == N
    for _ in range(2):
        buffer.push(get_data(), meta={'train_iter_data_collected': 5})
    assert buffer.count() == 8
    with pytest.raises(ValueError):
        data = buffer.sample(size=N, replace=False, train_iter_sample_data=11)
    assert buffer.count() == 2


@pytest.mark.unittest
def test_priority():
    N = 5
    buffer = DequeBuffer(size=10)
    buffer.use(PriorityExperienceReplay(buffer, IS_weight=True))
    for _ in range(N):
        buffer.push(get_data(), meta={'priority': 2.0})
    assert buffer.count() == N
    for _ in range(N):
        buffer.push(get_data(), meta={'priority': 2.0})
    assert buffer.count() == N + N
    data = buffer.sample(size=N + N, replace=False)
    assert len(data) == N + N
    for item in data:
        meta = item.meta
        assert set(meta.keys()).issuperset(set(['priority', 'priority_idx', 'priority_IS']))
        meta['priority'] = 3.0
    for item in data:
        data, index, meta = item.data, item.index, item.meta
        buffer.update(index, data, meta)
    data = buffer.sample(size=1)
    assert data[0].meta['priority'] == 3.0
    buffer.delete(data[0].index)
    assert buffer.count() == N + N - 1
    buffer.clear()
    assert buffer.count() == 0


@pytest.mark.unittest
def test_priority_from_collector():
    N = 5
    buffer = DequeBuffer(size=10)
    buffer.use(PriorityExperienceReplay(buffer, IS_weight=True))
    for _ in range(N):
        tmp_data = get_data()
        tmp_data['priority'] = 2.0
        buffer.push(get_data())
    assert buffer.count() == N
    for _ in range(N):
        tmp_data = get_data()
        tmp_data['priority'] = 2.0
        buffer.push(get_data())
    assert buffer.count() == N + N
    data = buffer.sample(size=N + N, replace=False)
    assert len(data) == N + N
    for item in data:
        meta = item.meta
        assert set(meta.keys()).issuperset(set(['priority', 'priority_idx', 'priority_IS']))
        meta['priority'] = 3.0
    for item in data:
        data, index, meta = item.data, item.index, item.meta
        buffer.update(index, data, meta)
    data = buffer.sample(size=1)
    assert data[0].meta['priority'] == 3.0
    buffer.delete(data[0].index)
    assert buffer.count() == N + N - 1
    buffer.clear()
    assert buffer.count() == 0


@pytest.mark.unittest
def test_padding():
    buffer = DequeBuffer(size=10)
    buffer.use(padding())
    for i in range(10):
        buffer.push(i, {"group": i & 5})  # [3,3,2,2]
    sampled_data = buffer.sample(4, groupby="group")
    assert len(sampled_data) == 4
    for grouped_data in sampled_data:
        assert len(grouped_data) == 3


@pytest.mark.unittest
def test_group_sample():
    buffer = DequeBuffer(size=10)
    buffer.use(padding(policy="none")).use(group_sample(size_in_group=5, ordered_in_group=True, max_use_in_group=True))
    for i in range(4):
        buffer.push(i, {"episode": 0})
    for i in range(6):
        buffer.push(i, {"episode": 1})
    sampled_data = buffer.sample(2, groupby="episode")
    assert len(sampled_data) == 2

    def check_group0(grouped_data):
        # In group0 should find only last record with data as None
        n_none = 0
        for item in grouped_data:
            if item.data is None:
                n_none += 1
        assert n_none == 1

    def check_group1(grouped_data):
        # In group1 every record should have data and meta
        for item in grouped_data:
            assert item.data is not None

    for grouped_data in sampled_data:
        assert len(grouped_data) == 5
        meta = grouped_data[0].meta
        if meta and "episode" in meta and meta["episode"] == 1:
            check_group1(grouped_data)
        else:
            check_group0(grouped_data)


@pytest.mark.unittest
def test_sample_range_view():
    buffer_ = DequeBuffer(size=10)
    for i in range(5):
        buffer_.push({'data': 'x'})
    for i in range(5, 5 + 3):
        buffer_.push({'data': 'y'})
    for i in range(8, 8 + 2):
        buffer_.push({'data': 'z'})

    buffer1 = buffer_.view()
    buffer1.use(sample_range_view(buffer1, start=-5, end=-2))
    for _ in range(10):
        sampled_data = buffer1.sample(1)
        assert sampled_data[0].data['data'] == 'y'

    buffer2 = buffer_.view()
    buffer2.use(sample_range_view(buffer1, start=-2))
    for _ in range(10):
        sampled_data = buffer2.sample(1)
        assert sampled_data[0].data['data'] == 'z'