File size: 1,115 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import pytest
from itertools import product
import numpy as np
import torch
from ding.rl_utils import coma_data, coma_error

random_weight = torch.rand(128, 4, 8) + 1
weight_args = [None, random_weight]


@pytest.mark.unittest
@pytest.mark.parametrize('weight, ', weight_args)
def test_coma(weight):
    T, B, A, N = 128, 4, 8, 32
    logit = torch.randn(
        T,
        B,
        A,
        N,
    ).requires_grad_(True)
    action = torch.randint(
        0, N, size=(
            T,
            B,
            A,
        )
    )
    reward = torch.rand(T, B)
    q_value = torch.randn(T, B, A, N).requires_grad_(True)
    target_q_value = torch.randn(T, B, A, N).requires_grad_(True)
    mask = torch.randint(0, 2, (T, B, A))
    data = coma_data(logit, action, q_value, target_q_value, reward, weight)
    loss = coma_error(data, 0.99, 0.95)
    assert all([l.shape == tuple() for l in loss])
    assert logit.grad is None
    assert q_value.grad is None
    total_loss = sum(loss)
    total_loss.backward()
    assert isinstance(logit.grad, torch.Tensor)
    assert isinstance(q_value.grad, torch.Tensor)