import sys import time import torch import unittest import basetest from greedrl import Solver from greedrl.function import * device = Solver().device class TestFunction(basetest.TestCase): def test_task_group_split(self): group = torch.ones((8, 8), dtype=torch.int32) group[:, 0:4] = 0 value = torch.zeros((8, 8), dtype=torch.bool) value[:, 0:4] = True result = task_group_split(group, value) assert not torch.any(result) value[:, 0:2] = False result = task_group_split(group, value) assert torch.all(result) def test_task_group_split2(self): group = torch.randint(48, (1024, 1000), dtype=torch.int32) value = torch.randint(2, (1024, 1000), dtype=torch.int8) <= 0 self.do_test(task_group_split, group, value) def test_task_group_priority(self): group = torch.ones((8, 8), dtype=torch.int32) group[:, 0:4] = 0 priority = torch.tensor([0, 1, 2, 3, 0, 1, 2, 3], dtype=torch.int32) priority = priority[None, :].expand(8, -1).clone() value = torch.zeros((8, 8), dtype=torch.bool) value[:, 4:6] = True result = task_group_priority(group, priority, value) expected = torch.tensor([False, True, True, True, True, True, False, True]) expected = expected[None, :].expand(8, -1) assert torch.all(result == expected) def test_task_group_priority2(self): group = torch.randint(48, (1024, 1000), dtype=torch.int32) value = torch.randint(2, (1024, 1000), dtype=torch.int8) < 1 priority = torch.randint(2, (1024, 1000), dtype=torch.int32) self.do_test(task_group_priority, group, priority, value) def do_test(self, function, *args): print("\ntest {} ...".format(function.__name__)) start = time.time() result1 = function(*args) print("time: {:.6f}s, device: {}".format(time.time() - start, args[0].device)) args = [arg.to(device) for arg in args] result1 = result1.to(device) function(*args) self.sync_device(device) start = time.time() result2 = function(*args) self.sync_device(device) print("time: {:.6f}s, device: {} ".format(time.time() - start, args[0].device)) if result1.is_floating_point(): assert torch.all(torch.abs(result1 - result2) < 1e-6) else: assert torch.all(result1 == result2) def sync_device(self, device): if device.type == 'cuda': torch.cuda.synchronize(device) if __name__ == '__main__': unittest.main()