|
|
|
import unittest |
|
import torch |
|
|
|
from detectron2.structures import Instances |
|
|
|
|
|
class TestInstancesIndexing(unittest.TestCase): |
|
def test_int_indexing(self): |
|
attr1 = torch.tensor([[0.0, 0.0, 1.0], [0.0, 0.0, 0.5], [0.0, 0.0, 1.0], [0.0, 0.5, 0.5]]) |
|
attr2 = torch.tensor([0.1, 0.2, 0.3, 0.4]) |
|
instances = Instances((100, 100)) |
|
instances.attr1 = attr1 |
|
instances.attr2 = attr2 |
|
for i in range(-len(instances), len(instances)): |
|
inst = instances[i] |
|
self.assertEqual((inst.attr1 == attr1[i]).all(), True) |
|
self.assertEqual((inst.attr2 == attr2[i]).all(), True) |
|
|
|
self.assertRaises(IndexError, lambda: instances[len(instances)]) |
|
self.assertRaises(IndexError, lambda: instances[-len(instances) - 1]) |
|
|
|
|
|
if __name__ == "__main__": |
|
unittest.main() |
|
|