gomoku / DI-engine /ding /model /wrapper /test_model_wrappers.py
zjowowen's picture
init space
079c32c
import copy
from copy import deepcopy
from collections import OrderedDict
import pytest
import torch
import torch.nn as nn
from ditk import logging
from ding.torch_utils import get_lstm
from ding.torch_utils.network.gtrxl import GTrXL
from ding.model import model_wrap, register_wrapper, IModelWrapper
from ding.model.wrapper.model_wrappers import BaseModelWrapper
class TempMLP(torch.nn.Module):
def __init__(self):
super(TempMLP, self).__init__()
self.fc1 = nn.Linear(3, 4)
self.bn1 = nn.BatchNorm1d(4)
self.fc2 = nn.Linear(4, 6)
self.act = nn.ReLU()
def forward(self, x):
x = self.fc1(x)
x = self.bn1(x)
x = self.act(x)
x = self.fc2(x)
x = self.act(x)
return x
class ActorMLP(torch.nn.Module):
def __init__(self):
super(ActorMLP, self).__init__()
self.fc1 = nn.Linear(3, 4)
self.bn1 = nn.BatchNorm1d(4)
self.fc2 = nn.Linear(4, 6)
self.act = nn.ReLU()
self.out = nn.Softmax(dim=-1)
def forward(self, inputs, tmp=0):
x = self.fc1(inputs['obs'])
x = self.bn1(x)
x = self.act(x)
x = self.fc2(x)
x = self.act(x)
x = self.out(x)
ret = {'logit': x, 'tmp': tmp, 'action': x + torch.rand_like(x)}
if 'mask' in inputs:
ret['action_mask'] = inputs['mask']
return ret
class HybridActorMLP(torch.nn.Module):
def __init__(self):
super(HybridActorMLP, self).__init__()
self.fc1 = nn.Linear(3, 4)
self.bn1 = nn.BatchNorm1d(4)
self.fc2 = nn.Linear(4, 6)
self.act = nn.ReLU()
self.out = nn.Softmax(dim=-1)
self.fc2_cont = nn.Linear(4, 6)
self.act_cont = nn.ReLU()
def forward(self, inputs, tmp=0):
x = self.fc1(inputs['obs'])
x = self.bn1(x)
x_ = self.act(x)
x = self.fc2(x_)
x = self.act(x)
x_disc = self.out(x)
x = self.fc2_cont(x_)
x_cont = self.act_cont(x)
ret = {'logit': x_disc, 'action_args': x_cont, 'tmp': tmp}
if 'mask' in inputs:
ret['action_mask'] = inputs['mask']
return ret
class HybridReparamActorMLP(torch.nn.Module):
def __init__(self):
super(HybridReparamActorMLP, self).__init__()
self.fc1 = nn.Linear(3, 4)
self.bn1 = nn.BatchNorm1d(4)
self.fc2 = nn.Linear(4, 6)
self.act = nn.ReLU()
self.out = nn.Softmax(dim=-1)
self.fc2_cont_mu = nn.Linear(4, 6)
self.act_cont_mu = nn.ReLU()
self.fc2_cont_sigma = nn.Linear(4, 6)
self.act_cont_sigma = nn.ReLU()
def forward(self, inputs, tmp=0):
x = self.fc1(inputs['obs'])
x = self.bn1(x)
x_ = self.act(x)
x = self.fc2(x_)
x = self.act(x)
x_disc = self.out(x)
x = self.fc2_cont_mu(x_)
x_cont_mu = self.act_cont_mu(x)
x = self.fc2_cont_sigma(x_)
x_cont_sigma = self.act_cont_sigma(x) + 1e-8
ret = {'logit': {'action_type': x_disc, 'action_args': {'mu': x_cont_mu, 'sigma': x_cont_sigma}}, 'tmp': tmp}
if 'mask' in inputs:
ret['action_mask'] = inputs['mask']
return ret
class ReparamActorMLP(torch.nn.Module):
def __init__(self):
super(ReparamActorMLP, self).__init__()
self.fc1 = nn.Linear(3, 4)
self.bn1 = nn.BatchNorm1d(4)
self.fc2 = nn.Linear(4, 6)
self.act = nn.ReLU()
self.fc2_cont_mu = nn.Linear(4, 6)
self.fc2_cont_sigma = nn.Linear(4, 6)
def forward(self, inputs, tmp=0):
x = self.fc1(inputs['obs'])
x = self.bn1(x)
x_ = self.act(x)
x = self.fc2_cont_mu(x_)
x_cont_mu = self.act(x)
x = self.fc2_cont_sigma(x_)
x_cont_sigma = self.act(x) + 1e-8
ret = {'logit': {'mu': x_cont_mu, 'sigma': x_cont_sigma}, 'tmp': tmp}
if 'mask' in inputs:
ret['action_mask'] = inputs['mask']
return ret
class DeterministicActorMLP(torch.nn.Module):
def __init__(self):
super(DeterministicActorMLP, self).__init__()
self.fc1 = nn.Linear(3, 4)
self.bn1 = nn.BatchNorm1d(4)
self.act = nn.ReLU()
self.fc2_cont_mu = nn.Linear(4, 6)
self.act_cont_mu = nn.ReLU()
def forward(self, inputs):
x = self.fc1(inputs['obs'])
x = self.bn1(x)
x_ = self.act(x)
x = self.fc2_cont_mu(x_)
x_cont_mu = self.act_cont_mu(x)
ret = {
'logit': {
'mu': x_cont_mu,
}
}
if 'mask' in inputs:
ret['action_mask'] = inputs['mask']
return ret
class TempLSTM(torch.nn.Module):
def __init__(self):
super(TempLSTM, self).__init__()
self.model = get_lstm(lstm_type='pytorch', input_size=36, hidden_size=32, num_layers=2, norm_type=None)
def forward(self, data):
output, next_state = self.model(data['f'], data['prev_state'], list_next_state=True)
return {'output': output, 'next_state': next_state}
@pytest.fixture(scope='function')
def setup_model():
return torch.nn.Linear(3, 6)
@pytest.mark.unittest
class TestModelWrappers:
def test_hidden_state_wrapper(self):
model = TempLSTM()
state_num = 4
model = model_wrap(model, wrapper_name='hidden_state', state_num=state_num, save_prev_state=True)
model.reset()
data = {'f': torch.randn(2, 4, 36)}
output = model.forward(data)
assert output['output'].shape == (2, state_num, 32)
assert len(output['prev_state']) == 4
assert output['prev_state'][0]['h'].shape == (2, 1, 32)
for item in model._state.values():
assert isinstance(item, dict) and len(item) == 2
assert all(t.shape == (2, 1, 32) for t in item.values())
data = {'f': torch.randn(2, 3, 36)}
data_id = [0, 1, 3]
output = model.forward(data, data_id=data_id)
assert output['output'].shape == (2, 3, 32)
assert all([len(s) == 2 for s in output['prev_state']])
for item in model._state.values():
assert isinstance(item, dict) and len(item) == 2
assert all(t.shape == (2, 1, 32) for t in item.values())
data = {'f': torch.randn(2, 2, 36)}
data_id = [0, 1]
output = model.forward(data, data_id=data_id)
assert output['output'].shape == (2, 2, 32)
assert all([isinstance(s, dict) and len(s) == 2 for s in model._state.values()])
model.reset()
assert all([isinstance(s, type(None)) for s in model._state.values()])
def test_target_network_wrapper(self):
model = TempMLP()
target_model = deepcopy(model)
target_model2 = deepcopy(model)
target_model = model_wrap(target_model, wrapper_name='target', update_type='assign', update_kwargs={'freq': 2})
model = model_wrap(model, wrapper_name='base')
register_wrapper('abstract', IModelWrapper)
assert all([hasattr(target_model, n) for n in ['reset', 'forward', 'update']])
assert model.fc1.weight.eq(target_model.fc1.weight).sum() == 12
model.fc1.weight.data = torch.randn_like(model.fc1.weight)
assert model.fc1.weight.ne(target_model.fc1.weight).sum() == 12
target_model.update(model.state_dict(), direct=True)
assert model.fc1.weight.eq(target_model.fc1.weight).sum() == 12
model.reset()
target_model.reset()
inputs = torch.randn(2, 3)
model.train()
target_model.train()
output = model.forward(inputs)
with torch.no_grad():
output_target = target_model.forward(inputs)
assert output.eq(output_target).sum() == 2 * 6
model.fc1.weight.data = torch.randn_like(model.fc1.weight)
assert model.fc1.weight.ne(target_model.fc1.weight).sum() == 12
target_model.update(model.state_dict())
assert model.fc1.weight.ne(target_model.fc1.weight).sum() == 12
target_model.update(model.state_dict())
assert model.fc1.weight.eq(target_model.fc1.weight).sum() == 12
# test real reset update_count
assert target_model._update_count != 0
target_model.reset()
assert target_model._update_count != 0
target_model.reset(target_update_count=0)
assert target_model._update_count == 0
target_model2 = model_wrap(
target_model2, wrapper_name='target', update_type='momentum', update_kwargs={'theta': 0.01}
)
target_model2.update(model.state_dict(), direct=True)
assert model.fc1.weight.eq(target_model2.fc1.weight).sum() == 12
model.fc1.weight.data = torch.randn_like(model.fc1.weight)
old_state_dict = target_model2.state_dict()
target_model2.update(model.state_dict())
assert target_model2.fc1.weight.data.eq(
old_state_dict['fc1.weight'] * (1 - 0.01) + model.fc1.weight.data * 0.01
).all()
def test_eps_greedy_wrapper(self):
model = ActorMLP()
model = model_wrap(model, wrapper_name='eps_greedy_sample')
model.eval()
eps_threshold = 0.5
data = {'obs': torch.randn(4, 3), 'mask': torch.randint(0, 2, size=(4, 6))}
with torch.no_grad():
output = model.forward(data, eps=eps_threshold)
assert output['tmp'] == 0
for i in range(10):
if i == 5:
data.pop('mask')
with torch.no_grad():
output = model.forward(data, eps=eps_threshold, tmp=1)
assert isinstance(output, dict)
assert output['tmp'] == 1
def test_multinomial_sample_wrapper(self):
model = model_wrap(ActorMLP(), wrapper_name='multinomial_sample')
data = {'obs': torch.randn(4, 3)}
output = model.forward(data)
assert output['action'].shape == (4, )
data = {'obs': torch.randn(4, 3), 'mask': torch.randint(0, 2, size=(4, 6))}
output = model.forward(data)
assert output['action'].shape == (4, )
def test_eps_greedy_multinomial_wrapper(self):
model = ActorMLP()
model = model_wrap(model, wrapper_name='eps_greedy_multinomial_sample')
model.eval()
eps_threshold = 0.5
data = {'obs': torch.randn(4, 3), 'mask': torch.randint(0, 2, size=(4, 6))}
with torch.no_grad():
output = model.forward(data, eps=eps_threshold, alpha=0.2)
assert output['tmp'] == 0
for i in range(10):
if i == 5:
data.pop('mask')
with torch.no_grad():
output = model.forward(data, eps=eps_threshold, tmp=1, alpha=0.2)
assert isinstance(output, dict)
assert output['tmp'] == 1
def test_hybrid_eps_greedy_wrapper(self):
model = HybridActorMLP()
model = model_wrap(model, wrapper_name='hybrid_eps_greedy_sample')
model.eval()
eps_threshold = 0.5
data = {'obs': torch.randn(4, 3), 'mask': torch.randint(0, 2, size=(4, 6))}
with torch.no_grad():
output = model.forward(data, eps=eps_threshold)
# logit = output['logit']
# assert output['action']['action_type'].eq(logit.argmax(dim=-1)).all()
assert isinstance(output['action']['action_args'],
torch.Tensor) and output['action']['action_args'].shape == (4, 6)
for i in range(10):
if i == 5:
data.pop('mask')
with torch.no_grad():
output = model.forward(data, eps=eps_threshold, tmp=1)
assert isinstance(output, dict)
def test_hybrid_eps_greedy_multinomial_wrapper(self):
model = HybridActorMLP()
model = model_wrap(model, wrapper_name='hybrid_eps_greedy_multinomial_sample')
model.eval()
eps_threshold = 0.5
data = {'obs': torch.randn(4, 3), 'mask': torch.randint(0, 2, size=(4, 6))}
with torch.no_grad():
output = model.forward(data, eps=eps_threshold)
assert isinstance(output['logit'], torch.Tensor) and output['logit'].shape == (4, 6)
assert isinstance(output['action']['action_type'],
torch.Tensor) and output['action']['action_type'].shape == (4, )
assert isinstance(output['action']['action_args'],
torch.Tensor) and output['action']['action_args'].shape == (4, 6)
for i in range(10):
if i == 5:
data.pop('mask')
with torch.no_grad():
output = model.forward(data, eps=eps_threshold, tmp=1)
assert isinstance(output, dict)
def test_hybrid_reparam_multinomial_wrapper(self):
model = HybridReparamActorMLP()
model = model_wrap(model, wrapper_name='hybrid_reparam_multinomial_sample')
model.eval()
data = {'obs': torch.randn(4, 3), 'mask': torch.randint(0, 2, size=(4, 6))}
with torch.no_grad():
output = model.forward(data)
assert isinstance(output['logit'], dict) and output['logit']['action_type'].shape == (4, 6)
assert isinstance(output['logit']['action_args'], dict) and output['logit']['action_args']['mu'].shape == (
4, 6
) and output['logit']['action_args']['sigma'].shape == (4, 6)
assert isinstance(output['action']['action_type'],
torch.Tensor) and output['action']['action_type'].shape == (4, )
assert isinstance(output['action']['action_args'],
torch.Tensor) and output['action']['action_args'].shape == (4, 6)
for i in range(10):
if i == 5:
data.pop('mask')
with torch.no_grad():
output = model.forward(data, tmp=1)
assert isinstance(output, dict)
def test_argmax_sample_wrapper(self):
model = model_wrap(ActorMLP(), wrapper_name='argmax_sample')
data = {'obs': torch.randn(4, 3)}
output = model.forward(data)
logit = output['logit']
assert output['action'].eq(logit.argmax(dim=-1)).all()
data = {'obs': torch.randn(4, 3), 'mask': torch.randint(0, 2, size=(4, 6))}
output = model.forward(data)
logit = output['logit'].sub(1e8 * (1 - data['mask']))
assert output['action'].eq(logit.argmax(dim=-1)).all()
def test_hybrid_argmax_sample_wrapper(self):
model = model_wrap(HybridActorMLP(), wrapper_name='hybrid_argmax_sample')
data = {'obs': torch.randn(4, 3)}
output = model.forward(data)
logit = output['logit']
assert output['action']['action_type'].eq(logit.argmax(dim=-1)).all()
assert isinstance(output['action']['action_args'],
torch.Tensor) and output['action']['action_args'].shape == (4, 6)
data = {'obs': torch.randn(4, 3), 'mask': torch.randint(0, 2, size=(4, 6))}
output = model.forward(data)
logit = output['logit'].sub(1e8 * (1 - data['mask']))
assert output['action']['action_type'].eq(logit.argmax(dim=-1)).all()
assert output['action']['action_args'].shape == (4, 6)
def test_hybrid_deterministic_argmax_sample_wrapper(self):
model = model_wrap(HybridReparamActorMLP(), wrapper_name='hybrid_deterministic_argmax_sample')
data = {'obs': torch.randn(4, 3)}
output = model.forward(data)
assert output['action']['action_type'].eq(output['logit']['action_type'].argmax(dim=-1)).all()
assert isinstance(output['action']['action_args'],
torch.Tensor) and output['action']['action_args'].shape == (4, 6)
assert output['action']['action_args'].eq(output['logit']['action_args']['mu']).all
def test_deterministic_sample_wrapper(self):
model = model_wrap(DeterministicActorMLP(), wrapper_name='deterministic_sample')
data = {'obs': torch.randn(4, 3)}
output = model.forward(data)
assert output['action'].eq(output['logit']['mu']).all()
assert isinstance(output['action'], torch.Tensor) and output['action'].shape == (4, 6)
def test_reparam_wrapper(self):
model = ReparamActorMLP()
model = model_wrap(model, wrapper_name='reparam_sample')
model.eval()
data = {'obs': torch.randn(4, 3), 'mask': torch.randint(0, 2, size=(4, 6))}
with torch.no_grad():
output = model.forward(data)
assert isinstance(output['logit'],
dict) and output['logit']['mu'].shape == (4, 6) and output['logit']['sigma'].shape == (4, 6)
for i in range(10):
if i == 5:
data.pop('mask')
with torch.no_grad():
output = model.forward(data, tmp=1)
assert isinstance(output, dict)
def test_eps_greedy_wrapper_with_list_eps(self):
model = ActorMLP()
model = model_wrap(model, wrapper_name='eps_greedy_sample')
model.eval()
eps_threshold = {i: 0.5 for i in range(4)} # for NGU
data = {'obs': torch.randn(4, 3), 'mask': torch.randint(0, 2, size=(4, 6))}
with torch.no_grad():
output = model.forward(data, eps=eps_threshold)
assert output['tmp'] == 0
for i in range(10):
if i == 5:
data.pop('mask')
with torch.no_grad():
output = model.forward(data, eps=eps_threshold, tmp=1)
assert isinstance(output, dict)
assert output['tmp'] == 1
def test_action_noise_wrapper(self):
model = model_wrap(
ActorMLP(),
wrapper_name='action_noise',
noise_type='gauss',
noise_range={
'min': -0.1,
'max': 0.1
},
action_range={
'min': -0.05,
'max': 0.05
}
)
data = {'obs': torch.randn(4, 3)}
output = model.forward(data)
action = output['action']
assert action.shape == (4, 6)
assert action.eq(action.clamp(-0.05, 0.05)).all()
def test_transformer_input_wrapper(self):
seq_len, bs, obs_shape = 8, 8, 32
emb_dim = 64
model = GTrXL(input_dim=obs_shape, embedding_dim=emb_dim)
model = model_wrap(model, wrapper_name='transformer_input', seq_len=seq_len)
obs = []
for i in range(seq_len + 1):
obs.append(torch.randn((bs, obs_shape)))
out = model.forward(obs[0], only_last_logit=False)
assert out['logit'].shape == (seq_len, bs, emb_dim)
assert out['input_seq'].shape == (seq_len, bs, obs_shape)
assert sum(out['input_seq'][1:].flatten()) == 0
for i in range(1, seq_len - 1):
out = model.forward(obs[i])
assert out['logit'].shape == (bs, emb_dim)
assert out['input_seq'].shape == (seq_len, bs, obs_shape)
assert sum(out['input_seq'][seq_len - 1:].flatten()) == 0
assert sum(out['input_seq'][:seq_len - 1].flatten()) != 0
out = model.forward(obs[seq_len - 1])
prev_memory = torch.clone(out['input_seq'])
out = model.forward(obs[seq_len])
assert torch.all(torch.eq(out['input_seq'][seq_len - 2], prev_memory[seq_len - 1]))
# test update of single batches in the memory
model.reset(data_id=[0, 5]) # reset memory batch in position 0 and 5
assert sum(model.obs_memory[:, 0].flatten()) == 0 and sum(model.obs_memory[:, 5].flatten()) == 0
assert sum(model.obs_memory[:, 1].flatten()) != 0
assert model.memory_idx[0] == 0 and model.memory_idx[5] == 0 and model.memory_idx[1] == seq_len
# test reset
model.reset()
assert model.obs_memory is None
def test_transformer_segment_wrapper(self):
seq_len, bs, obs_shape = 12, 8, 32
layer_num, memory_len, emb_dim = 3, 4, 4
model = GTrXL(input_dim=obs_shape, embedding_dim=emb_dim, memory_len=memory_len, layer_num=layer_num)
model = model_wrap(model, wrapper_name='transformer_segment', seq_len=seq_len)
inputs1 = torch.randn((seq_len, bs, obs_shape))
out = model.forward(inputs1)
info = model.info('info')
info = model.info('x')
def test_transformer_memory_wrapper(self):
seq_len, bs, obs_shape = 12, 8, 32
layer_num, memory_len, emb_dim = 3, 4, 4
model = GTrXL(input_dim=obs_shape, embedding_dim=emb_dim, memory_len=memory_len, layer_num=layer_num)
model1 = model_wrap(model, wrapper_name='transformer_memory', batch_size=bs)
model2 = model_wrap(model, wrapper_name='transformer_memory', batch_size=bs)
model1.show_memory_occupancy()
inputs1 = torch.randn((seq_len, bs, obs_shape))
out = model1.forward(inputs1)
new_memory1 = model1.memory
inputs2 = torch.randn((seq_len, bs, obs_shape))
out = model2.forward(inputs2)
new_memory2 = model2.memory
assert not torch.all(torch.eq(new_memory1, new_memory2))
model1.reset(data_id=[0, 5])
assert sum(model1.memory[:, :, 0].flatten()) == 0 and sum(model1.memory[:, :, 5].flatten()) == 0
assert sum(model1.memory[:, :, 1].flatten()) != 0
model1.reset()
assert sum(model1.memory.flatten()) == 0
seq_len, bs, obs_shape = 8, 8, 32
layer_num, memory_len, emb_dim = 3, 20, 4
model = GTrXL(input_dim=obs_shape, embedding_dim=emb_dim, memory_len=memory_len, layer_num=layer_num)
model = model_wrap(model, wrapper_name='transformer_memory', batch_size=bs)
inputs1 = torch.randn((seq_len, bs, obs_shape))
out = model.forward(inputs1)
new_memory1 = model.memory
inputs2 = torch.randn((seq_len, bs, obs_shape))
out = model.forward(inputs2)
new_memory2 = model.memory
print(new_memory1.shape, inputs1.shape)
assert sum(new_memory1[:, -8:].flatten()) != 0
assert sum(new_memory1[:, :-8].flatten()) == 0
assert sum(new_memory2[:, -16:].flatten()) != 0
assert sum(new_memory2[:, :-16].flatten()) == 0
assert torch.all(torch.eq(new_memory1[:, -8:], new_memory2[:, -16:-8]))
def test_combination_argmax_sample_wrapper(self):
model = model_wrap(ActorMLP(), wrapper_name='combination_argmax_sample')
data = {'obs': torch.randn(4, 3)}
shot_number = 2
output = model.forward(shot_number=shot_number, inputs=data)
assert output['action'].shape == (4, shot_number)
assert (output['action'] >= 0).all() and (output['action'] < 64).all()
def test_combination_multinomial_sample_wrapper(self):
model = model_wrap(ActorMLP(), wrapper_name='combination_multinomial_sample')
data = {'obs': torch.randn(4, 3)}
shot_number = 2
output = model.forward(shot_number=shot_number, inputs=data)
assert output['action'].shape == (4, shot_number)
assert (output['action'] >= 0).all() and (output['action'] < 64).all()