import pytest import torch from ding.rl_utils import q_nstep_td_data, q_nstep_td_error, q_1step_td_data, q_1step_td_error, td_lambda_data,\ td_lambda_error, q_nstep_td_error_with_rescale, dist_1step_td_data, dist_1step_td_error, dist_nstep_td_data,\ dqfd_nstep_td_data, dqfd_nstep_td_error, dist_nstep_td_error, v_1step_td_data, v_1step_td_error, v_nstep_td_data,\ v_nstep_td_error, q_nstep_sql_td_error, iqn_nstep_td_data, iqn_nstep_td_error,\ fqf_nstep_td_data, fqf_nstep_td_error, qrdqn_nstep_td_data, qrdqn_nstep_td_error, bdq_nstep_td_error,\ m_q_1step_td_data, m_q_1step_td_error from ding.rl_utils.td import shape_fn_dntd, shape_fn_qntd, shape_fn_td_lambda, shape_fn_qntd_rescale @pytest.mark.unittest def test_q_nstep_td(): batch_size = 4 action_dim = 3 next_q = torch.randn(batch_size, action_dim) done = torch.randn(batch_size) action = torch.randint(0, action_dim, size=(batch_size, )) next_action = torch.randint(0, action_dim, size=(batch_size, )) for nstep in range(1, 10): q = torch.randn(batch_size, action_dim).requires_grad_(True) reward = torch.rand(nstep, batch_size) data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) loss, td_error_per_sample = q_nstep_td_error(data, 0.95, nstep=nstep) assert td_error_per_sample.shape == (batch_size, ) assert loss.shape == () assert q.grad is None loss.backward() assert isinstance(q.grad, torch.Tensor) data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) loss, td_error_per_sample = q_nstep_td_error(data, 0.95, nstep=nstep, cum_reward=True) value_gamma = torch.tensor(0.9) data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) loss, td_error_per_sample = q_nstep_td_error(data, 0.95, nstep=nstep, cum_reward=True, value_gamma=value_gamma) loss.backward() assert isinstance(q.grad, torch.Tensor) @pytest.mark.unittest def test_bdq_nstep_td(): batch_size = 8 branch_num = 6 action_per_branch = 3 next_q = torch.randn(batch_size, branch_num, action_per_branch) done = torch.randn(batch_size) action = torch.randint(0, action_per_branch, size=(batch_size, branch_num)) next_action = torch.randint(0, action_per_branch, size=(batch_size, branch_num)) for nstep in range(1, 10): q = torch.randn(batch_size, branch_num, action_per_branch).requires_grad_(True) reward = torch.rand(nstep, batch_size) data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) loss, td_error_per_sample = bdq_nstep_td_error(data, 0.95, nstep=nstep) assert td_error_per_sample.shape == (batch_size, ) assert loss.shape == () assert q.grad is None loss.backward() assert isinstance(q.grad, torch.Tensor) data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) loss, td_error_per_sample = bdq_nstep_td_error(data, 0.95, nstep=nstep, cum_reward=True) value_gamma = torch.tensor(0.9) data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) loss, td_error_per_sample = bdq_nstep_td_error( data, 0.95, nstep=nstep, cum_reward=True, value_gamma=value_gamma ) loss.backward() assert isinstance(q.grad, torch.Tensor) @pytest.mark.unittest def test_q_nstep_td_ngu(): batch_size = 4 action_dim = 3 next_q = torch.randn(batch_size, action_dim) done = torch.randn(batch_size) action = torch.randint(0, action_dim, size=(batch_size, )) next_action = torch.randint(0, action_dim, size=(batch_size, )) gamma = [torch.tensor(0.95) for i in range(batch_size)] for nstep in range(1, 10): q = torch.randn(batch_size, action_dim).requires_grad_(True) reward = torch.rand(nstep, batch_size) data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) loss, td_error_per_sample = q_nstep_td_error(data, gamma, nstep=nstep) assert td_error_per_sample.shape == (batch_size, ) assert loss.shape == () assert q.grad is None loss.backward() assert isinstance(q.grad, torch.Tensor) @pytest.mark.unittest def test_dist_1step_td(): batch_size = 4 action_dim = 3 n_atom = 51 v_min = -10.0 v_max = 10.0 dist = torch.randn(batch_size, action_dim, n_atom).abs().requires_grad_(True) next_dist = torch.randn(batch_size, action_dim, n_atom).abs() done = torch.randn(batch_size) action = torch.randint(0, action_dim, size=(batch_size, )) next_action = torch.randint(0, action_dim, size=(batch_size, )) reward = torch.randn(batch_size) data = dist_1step_td_data(dist, next_dist, action, next_action, reward, done, None) loss = dist_1step_td_error(data, 0.95, v_min, v_max, n_atom) assert loss.shape == () assert dist.grad is None loss.backward() assert isinstance(dist.grad, torch.Tensor) @pytest.mark.unittest def test_q_1step_compatible(): batch_size = 4 action_dim = 3 next_q = torch.randn(batch_size, action_dim) done = torch.randn(batch_size) action = torch.randint(0, action_dim, size=(batch_size, )) next_action = torch.randint(0, action_dim, size=(batch_size, )) q = torch.randn(batch_size, action_dim).requires_grad_(True) reward = torch.rand(batch_size) nstep_data = q_nstep_td_data(q, next_q, action, next_action, reward.unsqueeze(0), done, None) onestep_data = q_1step_td_data(q, next_q, action, next_action, reward, done, None) nstep_loss, _ = q_nstep_td_error(nstep_data, 0.99, nstep=1) onestep_loss = q_1step_td_error(onestep_data, 0.99) assert pytest.approx(nstep_loss.item()) == onestep_loss.item() @pytest.mark.unittest def test_dist_nstep_td(): batch_size = 4 action_dim = 3 n_atom = 51 v_min = -10.0 v_max = 10.0 nstep = 5 dist = torch.randn(batch_size, action_dim, n_atom).abs().requires_grad_(True) next_n_dist = torch.randn(batch_size, action_dim, n_atom).abs() done = torch.randn(batch_size) action = torch.randint(0, action_dim, size=(batch_size, )) next_action = torch.randint(0, action_dim, size=(batch_size, )) reward = torch.randn(nstep, batch_size) data = dist_nstep_td_data(dist, next_n_dist, action, next_action, reward, done, None) loss, _ = dist_nstep_td_error(data, 0.95, v_min, v_max, n_atom, nstep) assert loss.shape == () assert dist.grad is None loss.backward() assert isinstance(dist.grad, torch.Tensor) weight = torch.tensor([0.9]) value_gamma = torch.tensor(0.9) data = dist_nstep_td_data(dist, next_n_dist, action, next_action, reward, done, weight) loss, _ = dist_nstep_td_error(data, 0.95, v_min, v_max, n_atom, nstep, value_gamma) assert loss.shape == () loss.backward() assert isinstance(dist.grad, torch.Tensor) @pytest.mark.unittest def test_dist_nstep_multi_agent_td(): batch_size = 4 action_dim = 3 agent_num = 2 n_atom = 51 v_min = -10.0 v_max = 10.0 nstep = 5 dist = torch.randn(batch_size, agent_num, action_dim, n_atom).abs().requires_grad_(True) next_n_dist = torch.randn(batch_size, agent_num, action_dim, n_atom).abs() done = torch.randint(0, 2, (batch_size, )) action = torch.randint( 0, action_dim, size=( batch_size, agent_num, ) ) next_action = torch.randint( 0, action_dim, size=( batch_size, agent_num, ) ) reward = torch.randn(nstep, batch_size) data = dist_nstep_td_data(dist, next_n_dist, action, next_action, reward, done, None) loss, _ = dist_nstep_td_error(data, 0.95, v_min, v_max, n_atom, nstep) assert loss.shape == () assert dist.grad is None loss.backward() assert isinstance(dist.grad, torch.Tensor) weight = 0.9 value_gamma = 0.9 data = dist_nstep_td_data(dist, next_n_dist, action, next_action, reward, done, weight) loss, _ = dist_nstep_td_error(data, 0.95, v_min, v_max, n_atom, nstep, value_gamma) assert loss.shape == () loss.backward() assert isinstance(dist.grad, torch.Tensor) agent_total_loss = 0 for i in range(agent_num): data = dist_nstep_td_data( dist[:, i, ], next_n_dist[:, i, ], action[:, i, ], next_action[:, i, ], reward, done, weight ) agent_loss, _ = dist_nstep_td_error(data, 0.95, v_min, v_max, n_atom, nstep, value_gamma) agent_total_loss = agent_total_loss + agent_loss agent_average_loss = agent_total_loss / agent_num assert abs(agent_average_loss.item() - loss.item()) < 1e-5 @pytest.mark.unittest def test_q_nstep_td_with_rescale(): batch_size = 4 action_dim = 3 next_q = torch.randn(batch_size, action_dim) done = torch.randn(batch_size) action = torch.randint(0, action_dim, size=(batch_size, )) next_action = torch.randint(0, action_dim, size=(batch_size, )) for nstep in range(1, 10): q = torch.randn(batch_size, action_dim).requires_grad_(True) reward = torch.rand(nstep, batch_size) data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) loss, _ = q_nstep_td_error_with_rescale(data, 0.95, nstep=nstep) assert loss.shape == () assert q.grad is None loss.backward() assert isinstance(q.grad, torch.Tensor) print(loss) @pytest.mark.unittest def test_q_nstep_td_with_rescale_ngu(): batch_size = 4 action_dim = 3 next_q = torch.randn(batch_size, action_dim) done = torch.randn(batch_size) action = torch.randint(0, action_dim, size=(batch_size, )) next_action = torch.randint(0, action_dim, size=(batch_size, )) gamma = [torch.tensor(0.95) for i in range(batch_size)] for nstep in range(1, 10): q = torch.randn(batch_size, action_dim).requires_grad_(True) reward = torch.rand(nstep, batch_size) data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) loss, _ = q_nstep_td_error_with_rescale(data, gamma, nstep=nstep) assert loss.shape == () assert q.grad is None loss.backward() assert isinstance(q.grad, torch.Tensor) print(loss) @pytest.mark.unittest def test_qrdqn_nstep_td(): batch_size = 4 action_dim = 3 tau = 3 next_q = torch.randn(batch_size, action_dim, tau) done = torch.randn(batch_size) action = torch.randint(0, action_dim, size=(batch_size, )) next_action = torch.randint(0, action_dim, size=(batch_size, )) for nstep in range(1, 10): q = torch.randn(batch_size, action_dim, tau).requires_grad_(True) reward = torch.rand(nstep, batch_size) data = qrdqn_nstep_td_data(q, next_q, action, next_action, reward, done, tau, None) loss, td_error_per_sample = qrdqn_nstep_td_error(data, 0.95, nstep=nstep) assert td_error_per_sample.shape == (batch_size, ) assert loss.shape == () assert q.grad is None loss.backward() assert isinstance(q.grad, torch.Tensor) loss, td_error_per_sample = qrdqn_nstep_td_error(data, 0.95, nstep=nstep, value_gamma=torch.tensor(0.9)) assert td_error_per_sample.shape == (batch_size, ) @pytest.mark.unittest def test_dist_1step_compatible(): batch_size = 4 action_dim = 3 n_atom = 51 v_min = -10.0 v_max = 10.0 dist = torch.randn(batch_size, action_dim, n_atom).abs().requires_grad_(True) next_dist = torch.randn(batch_size, action_dim, n_atom).abs() done = torch.randn(batch_size) action = torch.randint(0, action_dim, size=(batch_size, )) next_action = torch.randint(0, action_dim, size=(batch_size, )) reward = torch.randn(batch_size) onestep_data = dist_1step_td_data(dist, next_dist, action, next_action, reward, done, None) nstep_data = dist_nstep_td_data(dist, next_dist, action, next_action, reward.unsqueeze(0), done, None) onestep_loss = dist_1step_td_error(onestep_data, 0.95, v_min, v_max, n_atom) nstep_loss, _ = dist_nstep_td_error(nstep_data, 0.95, v_min, v_max, n_atom, nstep=1) assert pytest.approx(nstep_loss.item()) == onestep_loss.item() @pytest.mark.unittest def test_dist_1step_multi_agent_td(): batch_size = 4 action_dim = 3 agent_num = 2 n_atom = 51 v_min = -10.0 v_max = 10.0 dist = torch.randn(batch_size, agent_num, action_dim, n_atom).abs().requires_grad_(True) next_dist = torch.randn(batch_size, agent_num, action_dim, n_atom).abs() done = torch.randint(0, 2, (batch_size, )) action = torch.randint( 0, action_dim, size=( batch_size, agent_num, ) ) next_action = torch.randint( 0, action_dim, size=( batch_size, agent_num, ) ) reward = torch.randn(batch_size) data = dist_1step_td_data(dist, next_dist, action, next_action, reward, done, None) loss = dist_1step_td_error(data, 0.95, v_min, v_max, n_atom) assert loss.shape == () assert dist.grad is None loss.backward() assert isinstance(dist.grad, torch.Tensor) agent_total_loss = 0 for i in range(agent_num): data = dist_1step_td_data( dist[:, i, ], next_dist[:, i, ], action[:, i, ], next_action[:, i, ], reward, done, None ) agent_loss = dist_1step_td_error(data, 0.95, v_min, v_max, n_atom) agent_total_loss = agent_total_loss + agent_loss agent_average_loss = agent_total_loss / agent_num assert abs(agent_average_loss.item() - loss.item()) < 1e-5 @pytest.mark.unittest def test_td_lambda(): T, B = 8, 4 value = torch.randn(T + 1, B).requires_grad_(True) reward = torch.rand(T, B) loss = td_lambda_error(td_lambda_data(value, reward, None)) assert loss.shape == () assert value.grad is None loss.backward() assert isinstance(value.grad, torch.Tensor) @pytest.mark.unittest def test_v_1step_td(): batch_size = 5 v = torch.randn(batch_size).requires_grad_(True) next_v = torch.randn(batch_size) reward = torch.rand(batch_size) done = torch.zeros(batch_size) data = v_1step_td_data(v, next_v, reward, done, None) loss, td_error_per_sample = v_1step_td_error(data, 0.99) assert loss.shape == () assert v.grad is None loss.backward() assert isinstance(v.grad, torch.Tensor) data = v_1step_td_data(v, next_v, reward, None, None) loss, td_error_per_sample = v_1step_td_error(data, 0.99) loss.backward() assert isinstance(v.grad, torch.Tensor) @pytest.mark.unittest def test_v_1step_multi_agent_td(): batch_size = 5 agent_num = 2 v = torch.randn(batch_size, agent_num).requires_grad_(True) next_v = torch.randn(batch_size, agent_num) reward = torch.rand(batch_size) done = torch.zeros(batch_size) data = v_1step_td_data(v, next_v, reward, done, None) loss, td_error_per_sample = v_1step_td_error(data, 0.99) assert loss.shape == () assert v.grad is None loss.backward() assert isinstance(v.grad, torch.Tensor) data = v_1step_td_data(v, next_v, reward, None, None) loss, td_error_per_sample = v_1step_td_error(data, 0.99) loss.backward() assert isinstance(v.grad, torch.Tensor) @pytest.mark.unittest def test_v_nstep_td(): batch_size = 5 v = torch.randn(batch_size).requires_grad_(True) next_v = torch.randn(batch_size) reward = torch.rand(5, batch_size) done = torch.zeros(batch_size) data = v_nstep_td_data(v, next_v, reward, done, 0.9, 0.99) loss, td_error_per_sample = v_nstep_td_error(data, 0.99, 5) assert loss.shape == () assert v.grad is None loss.backward() assert isinstance(v.grad, torch.Tensor) data = v_nstep_td_data(v, next_v, reward, done, None, 0.99) loss, td_error_per_sample = v_nstep_td_error(data, 0.99, 5) loss.backward() assert isinstance(v.grad, torch.Tensor) @pytest.mark.unittest def test_dqfd_nstep_td(): batch_size = 4 action_dim = 3 next_q = torch.randn(batch_size, action_dim) done = torch.randn(batch_size) done_1 = torch.randn(batch_size) next_q_one_step = torch.randn(batch_size, action_dim) action = torch.randint(0, action_dim, size=(batch_size, )) next_action = torch.randint(0, action_dim, size=(batch_size, )) next_action_one_step = torch.randint(0, action_dim, size=(batch_size, )) is_expert = torch.ones((batch_size)) for nstep in range(1, 10): q = torch.randn(batch_size, action_dim).requires_grad_(True) reward = torch.rand(nstep, batch_size) data = dqfd_nstep_td_data( q, next_q, action, next_action, reward, done, done_1, None, next_q_one_step, next_action_one_step, is_expert ) loss, td_error_per_sample, loss_statistics = dqfd_nstep_td_error( data, 0.95, lambda_n_step_td=1, lambda_supervised_loss=1, margin_function=0.8, nstep=nstep ) assert td_error_per_sample.shape == (batch_size, ) assert loss.shape == () assert q.grad is None loss.backward() assert isinstance(q.grad, torch.Tensor) print(loss) @pytest.mark.unittest def test_q_nstep_sql_td(): batch_size = 4 action_dim = 3 next_q = torch.randn(batch_size, action_dim) done = torch.randn(batch_size) action = torch.randint(0, action_dim, size=(batch_size, )) next_action = torch.randint(0, action_dim, size=(batch_size, )) for nstep in range(1, 10): q = torch.randn(batch_size, action_dim).requires_grad_(True) reward = torch.rand(nstep, batch_size) data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) loss, td_error_per_sample, record_target_v = q_nstep_sql_td_error(data, 0.95, 1.0, nstep=nstep) assert td_error_per_sample.shape == (batch_size, ) assert loss.shape == () assert q.grad is None loss.backward() assert isinstance(q.grad, torch.Tensor) data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) loss, td_error_per_sample, record_target_v = q_nstep_sql_td_error(data, 0.95, 0.5, nstep=nstep, cum_reward=True) value_gamma = torch.tensor(0.9) data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) loss, td_error_per_sample, record_target_v = q_nstep_sql_td_error( data, 0.95, 0.5, nstep=nstep, cum_reward=True, value_gamma=value_gamma ) loss.backward() assert isinstance(q.grad, torch.Tensor) @pytest.mark.unittest def test_iqn_nstep_td(): batch_size = 4 action_dim = 3 tau = 3 next_q = torch.randn(tau, batch_size, action_dim) done = torch.randn(batch_size) action = torch.randint(0, action_dim, size=(batch_size, )) next_action = torch.randint(0, action_dim, size=(batch_size, )) for nstep in range(1, 10): q = torch.randn(tau, batch_size, action_dim).requires_grad_(True) replay_quantile = torch.randn([tau, batch_size, 1]) reward = torch.rand(nstep, batch_size) data = iqn_nstep_td_data(q, next_q, action, next_action, reward, done, replay_quantile, None) loss, td_error_per_sample = iqn_nstep_td_error(data, 0.95, nstep=nstep) assert td_error_per_sample.shape == (batch_size, ) assert loss.shape == () assert q.grad is None loss.backward() assert isinstance(q.grad, torch.Tensor) loss, td_error_per_sample = iqn_nstep_td_error(data, 0.95, nstep=nstep, value_gamma=torch.tensor(0.9)) assert td_error_per_sample.shape == (batch_size, ) @pytest.mark.unittest def test_fqf_nstep_td(): batch_size = 4 action_dim = 3 tau = 3 next_q = torch.randn(batch_size, tau, action_dim) done = torch.randn(batch_size) action = torch.randint(0, action_dim, size=(batch_size, )) next_action = torch.randint(0, action_dim, size=(batch_size, )) for nstep in range(1, 10): q = torch.randn(batch_size, tau, action_dim).requires_grad_(True) quantiles_hats = torch.randn([batch_size, tau]) reward = torch.rand(nstep, batch_size) data = fqf_nstep_td_data(q, next_q, action, next_action, reward, done, quantiles_hats, None) loss, td_error_per_sample = fqf_nstep_td_error(data, 0.95, nstep=nstep) assert td_error_per_sample.shape == (batch_size, ) assert loss.shape == () assert q.grad is None loss.backward() assert isinstance(q.grad, torch.Tensor) loss, td_error_per_sample = fqf_nstep_td_error(data, 0.95, nstep=nstep, value_gamma=torch.tensor(0.9)) assert td_error_per_sample.shape == (batch_size, ) @pytest.mark.unittest def test_shape_fn_qntd(): batch_size = 4 action_dim = 3 next_q = torch.randn(batch_size, action_dim) done = torch.randn(batch_size) action = torch.randint(0, action_dim, size=(batch_size, )) next_action = torch.randint(0, action_dim, size=(batch_size, )) for nstep in range(1, 10): q = torch.randn(batch_size, action_dim).requires_grad_(True) reward = torch.rand(nstep, batch_size) data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) tmp = shape_fn_qntd([data, 0.95, 1], {}) assert tmp[0] == reward.shape[0] assert tmp[1] == q.shape[0] assert tmp[2] == q.shape[1] tmp = shape_fn_qntd([], {'gamma': 0.95, 'nstep': 1, 'data': data}) assert tmp[0] == reward.shape[0] assert tmp[1] == q.shape[0] assert tmp[2] == q.shape[1] @pytest.mark.unittest def test_shape_fn_dntd(): batch_size = 4 action_dim = 3 n_atom = 51 v_min = -10.0 v_max = 10.0 nstep = 5 dist = torch.randn(batch_size, action_dim, n_atom).abs().requires_grad_(True) next_n_dist = torch.randn(batch_size, action_dim, n_atom).abs() done = torch.randn(batch_size) action = torch.randint(0, action_dim, size=(batch_size, )) next_action = torch.randint(0, action_dim, size=(batch_size, )) reward = torch.randn(nstep, batch_size) data = dist_nstep_td_data(dist, next_n_dist, action, next_action, reward, done, None) tmp = shape_fn_dntd([data, 0.9, v_min, v_max, n_atom, nstep], {}) assert tmp[0] == reward.shape[0] assert tmp[1] == dist.shape[0] assert tmp[2] == dist.shape[1] assert tmp[3] == n_atom tmp = shape_fn_dntd([], {'data': data, 'gamma': 0.9, 'v_min': v_min, 'v_max': v_max, 'n_atom': n_atom, 'nstep': 5}) assert tmp[0] == reward.shape[0] assert tmp[1] == dist.shape[0] assert tmp[2] == dist.shape[1] assert tmp[3] == n_atom @pytest.mark.unittest def test_shape_fn_qntd_rescale(): batch_size = 4 action_dim = 3 next_q = torch.randn(batch_size, action_dim) done = torch.randn(batch_size) action = torch.randint(0, action_dim, size=(batch_size, )) next_action = torch.randint(0, action_dim, size=(batch_size, )) for nstep in range(1, 10): q = torch.randn(batch_size, action_dim).requires_grad_(True) reward = torch.rand(nstep, batch_size) data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) tmp = shape_fn_qntd_rescale([data, 0.95, 1], {}) assert tmp[0] == reward.shape[0] assert tmp[1] == q.shape[0] assert tmp[2] == q.shape[1] tmp = shape_fn_qntd_rescale([], {'gamma': 0.95, 'nstep': 1, 'data': data}) assert tmp[0] == reward.shape[0] assert tmp[1] == q.shape[0] assert tmp[2] == q.shape[1] @pytest.mark.unittest def test_fn_td_lambda(): T, B = 8, 4 value = torch.randn(T + 1, B).requires_grad_(True) reward = torch.rand(T, B) data = td_lambda_data(value, reward, None) tmp = shape_fn_td_lambda([], {'data': data}) assert tmp == reward.shape[0] tmp = shape_fn_td_lambda([data], {}) assert tmp == reward.shape @pytest.mark.unittest def test_fn_m_q_1step_td_error(): batch_size = 128 action_dim = 9 q = torch.randn(batch_size, action_dim).requires_grad_(True) target_q_current = torch.randn(batch_size, action_dim).requires_grad_(False) target_q_next = torch.randn(batch_size, action_dim).requires_grad_(False) done = torch.randn(batch_size) action = torch.randint(0, action_dim, size=(batch_size, )) reward = torch.randn(batch_size) data = m_q_1step_td_data(q, target_q_current, target_q_next, action, reward, done, None) loss, td_error_per_sample, action_gap, clip_frac = m_q_1step_td_error(data, 0.99, 0.03, 0.6) assert loss.shape == () assert q.grad is None loss.backward() assert isinstance(q.grad, torch.Tensor) assert clip_frac.mean().item() <= 1 assert action_gap.item() > 0 assert td_error_per_sample.shape == (batch_size, )