import pytest import copy from collections import deque import numpy as np import torch from ding.rl_utils import get_gae, get_gae_with_default_last_value, get_nstep_return_data, get_train_sample @pytest.mark.unittest class TestAdder: def get_transition(self): return { 'value': torch.randn(1), 'reward': torch.rand(1), 'action': torch.rand(3), 'other': np.random.randint(0, 10, size=(4, )), 'obs': torch.randn(3), 'done': False } def get_transition_multi_agent(self): return { 'value': torch.randn(1, 8), 'reward': torch.rand(1, 1), 'action': torch.rand(3), 'other': np.random.randint(0, 10, size=(4, )), 'obs': torch.randn(3), 'done': False } def test_get_gae(self): transitions = deque([self.get_transition() for _ in range(10)]) last_value = torch.randn(1) output = get_gae(transitions, last_value, gamma=0.99, gae_lambda=0.97, cuda=False) for i in range(len(output)): o = output[i] assert 'adv' in o.keys() for k, v in o.items(): if k == 'adv': assert isinstance(v, torch.Tensor) assert v.shape == (1, ) else: if k == 'done': assert v == transitions[i][k] else: assert (v == transitions[i][k]).all() output1 = get_gae_with_default_last_value( copy.deepcopy(transitions), True, gamma=0.99, gae_lambda=0.97, cuda=False ) for i in range(len(output)): assert output[i]['adv'].ne(output1[i]['adv']) data = copy.deepcopy(transitions) data.append({'value': last_value}) output2 = get_gae_with_default_last_value(data, False, gamma=0.99, gae_lambda=0.97, cuda=False) for i in range(len(output)): assert output[i]['adv'].eq(output2[i]['adv']) def test_get_gae_multi_agent(self): transitions = deque([self.get_transition_multi_agent() for _ in range(10)]) last_value = torch.randn(1, 8) output = get_gae(transitions, last_value, gamma=0.99, gae_lambda=0.97, cuda=False) for i in range(len(output)): o = output[i] assert 'adv' in o.keys() for k, v in o.items(): if k == 'adv': assert isinstance(v, torch.Tensor) assert v.shape == ( 1, 8, ) else: if k == 'done': assert v == transitions[i][k] else: assert (v == transitions[i][k]).all() output1 = get_gae_with_default_last_value( copy.deepcopy(transitions), True, gamma=0.99, gae_lambda=0.97, cuda=False ) for i in range(len(output)): for j in range(output[i]['adv'].shape[1]): assert output[i]['adv'][0][j].ne(output1[i]['adv'][0][j]) data = copy.deepcopy(transitions) data.append({'value': last_value}) output2 = get_gae_with_default_last_value(data, False, gamma=0.99, gae_lambda=0.97, cuda=False) for i in range(len(output)): for j in range(output[i]['adv'].shape[1]): assert output[i]['adv'][0][j].eq(output2[i]['adv'][0][j]) def test_get_nstep_return_data(self): nstep = 3 data = deque([self.get_transition() for _ in range(10)]) output_data = get_nstep_return_data(data, nstep=nstep) assert len(output_data) == 10 for i, o in enumerate(output_data): assert o['reward'].shape == (nstep, ) if i >= 10 - nstep + 1: assert o['done'] is data[-1]['done'] assert o['reward'][-(i - 10 + nstep):].sum() == 0 data = deque([self.get_transition() for _ in range(12)]) output_data = get_nstep_return_data(data, nstep=nstep) assert len(output_data) == 12 def test_get_train_sample(self): data = [self.get_transition() for _ in range(10)] output = get_train_sample(data, unroll_len=1, last_fn_type='drop') assert len(output) == 10 output = get_train_sample(data, unroll_len=4, last_fn_type='drop') assert len(output) == 2 for o in output: for v in o.values(): assert len(v) == 4 output = get_train_sample(data, unroll_len=4, last_fn_type='null_padding') assert len(output) == 3 for o in output: for v in o.values(): assert len(v) == 4 assert output[-1]['done'] == [False, False, True, True] for i in range(1, 10 % 4 + 1): assert id(output[-1]['obs'][-i]) != id(output[-1]['obs'][0]) output = get_train_sample(data, unroll_len=4, last_fn_type='last') assert len(output) == 3 for o in output: for v in o.values(): assert len(v) == 4 miss_num = 4 - 10 % 4 for i in range(10 % 4): assert id(output[-1]['obs'][i]) != id(output[-2]['obs'][miss_num + i]) output = get_train_sample(data, unroll_len=11, last_fn_type='last') assert len(output) == 1 assert len(output[0]['obs']) == 11 assert output[-1]['done'][-1] is True assert output[-1]['done'][0] is False assert id(output[-1]['obs'][-1]) != id(output[-1]['obs'][0]) test = TestAdder() test.test_get_gae_multi_agent()