File size: 2,623 Bytes
db26c81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import sys
import time
import torch
import unittest
import basetest

from greedrl import Solver
from greedrl.function import *

device = Solver().device


class TestFunction(basetest.TestCase):

    def test_task_group_split(self):
        group = torch.ones((8, 8), dtype=torch.int32)
        group[:, 0:4] = 0
        value = torch.zeros((8, 8), dtype=torch.bool)
        value[:, 0:4] = True
        result = task_group_split(group, value)
        assert not torch.any(result)

        value[:, 0:2] = False
        result = task_group_split(group, value)
        assert torch.all(result)

    def test_task_group_split2(self):
        group = torch.randint(48, (1024, 1000), dtype=torch.int32)
        value = torch.randint(2, (1024, 1000), dtype=torch.int8) <= 0
        self.do_test(task_group_split, group, value)

    def test_task_group_priority(self):
        group = torch.ones((8, 8), dtype=torch.int32)
        group[:, 0:4] = 0
        priority = torch.tensor([0, 1, 2, 3, 0, 1, 2, 3], dtype=torch.int32)
        priority = priority[None, :].expand(8, -1).clone()
        value = torch.zeros((8, 8), dtype=torch.bool)
        value[:, 4:6] = True

        result = task_group_priority(group, priority, value)
        expected = torch.tensor([False, True, True, True, True, True, False, True])
        expected = expected[None, :].expand(8, -1)
        assert torch.all(result == expected)

    def test_task_group_priority2(self):
        group = torch.randint(48, (1024, 1000), dtype=torch.int32)
        value = torch.randint(2, (1024, 1000), dtype=torch.int8) < 1
        priority = torch.randint(2, (1024, 1000), dtype=torch.int32)
        self.do_test(task_group_priority, group, priority, value)

    def do_test(self, function, *args):
        print("\ntest {} ...".format(function.__name__))
        start = time.time()
        result1 = function(*args)
        print("time: {:.6f}s, device: {}".format(time.time() - start, args[0].device))

        args = [arg.to(device) for arg in args]
        result1 = result1.to(device)

        function(*args)
        self.sync_device(device)

        start = time.time()
        result2 = function(*args)
        self.sync_device(device)
        print("time: {:.6f}s, device: {} ".format(time.time() - start, args[0].device))

        if result1.is_floating_point():
            assert torch.all(torch.abs(result1 - result2) < 1e-6)
        else:
            assert torch.all(result1 == result2)

    def sync_device(self, device):
        if device.type == 'cuda':
            torch.cuda.synchronize(device)


if __name__ == '__main__':
    unittest.main()