# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import pickle import random import unittest from typing import List, Tuple, Union import torch import torch.nn.functional as F from pytorch3d.io import save_obj from pytorch3d.ops.iou_box3d import _box_planes, _box_triangles, box3d_overlap from pytorch3d.transforms.rotation_conversions import random_rotation from .common_testing import get_random_cuda_device, get_tests_dir, TestCaseMixin OBJECTRON_TO_PYTORCH3D_FACE_IDX = [0, 4, 6, 2, 1, 5, 7, 3] DATA_DIR = get_tests_dir() / "data" DEBUG = False DOT_EPS = 1e-3 AREA_EPS = 1e-4 UNIT_BOX = [ [0, 0, 0], [1, 0, 0], [1, 1, 0], [0, 1, 0], [0, 0, 1], [1, 0, 1], [1, 1, 1], [0, 1, 1], ] class TestIoU3D(TestCaseMixin, unittest.TestCase): def setUp(self) -> None: super().setUp() torch.manual_seed(1) @staticmethod def create_box(xyz, whl): x, y, z = xyz w, h, le = whl verts = torch.tensor( [ [x - w / 2.0, y - h / 2.0, z - le / 2.0], [x + w / 2.0, y - h / 2.0, z - le / 2.0], [x + w / 2.0, y + h / 2.0, z - le / 2.0], [x - w / 2.0, y + h / 2.0, z - le / 2.0], [x - w / 2.0, y - h / 2.0, z + le / 2.0], [x + w / 2.0, y - h / 2.0, z + le / 2.0], [x + w / 2.0, y + h / 2.0, z + le / 2.0], [x - w / 2.0, y + h / 2.0, z + le / 2.0], ], device=xyz.device, dtype=torch.float32, ) return verts @staticmethod def _box3d_overlap_naive_batched(boxes1, boxes2): """ Wrapper around box3d_overlap_naive to support batched input """ N = boxes1.shape[0] M = boxes2.shape[0] vols = torch.zeros((N, M), dtype=torch.float32, device=boxes1.device) ious = torch.zeros((N, M), dtype=torch.float32, device=boxes1.device) for n in range(N): for m in range(M): vol, iou = box3d_overlap_naive(boxes1[n], boxes2[m]) vols[n, m] = vol ious[n, m] = iou return vols, ious @staticmethod def _box3d_overlap_sampling_batched(boxes1, boxes2, num_samples: int): """ Wrapper around box3d_overlap_sampling to support batched input """ N = boxes1.shape[0] M = boxes2.shape[0] ious = torch.zeros((N, M), dtype=torch.float32, device=boxes1.device) for n in range(N): for m in range(M): iou = box3d_overlap_sampling(boxes1[n], boxes2[m]) ious[n, m] = iou return ious def _test_iou(self, overlap_fn, device): box1 = torch.tensor( UNIT_BOX, dtype=torch.float32, device=device, ) # 1st test: same box, iou = 1.0 vol, iou = overlap_fn(box1[None], box1[None]) self.assertClose(vol, torch.tensor([[1.0]], device=vol.device, dtype=vol.dtype)) self.assertClose(iou, torch.tensor([[1.0]], device=vol.device, dtype=vol.dtype)) # 2nd test dd = random.random() box2 = box1 + torch.tensor([[0.0, dd, 0.0]], device=device) vol, iou = overlap_fn(box1[None], box2[None]) self.assertClose( vol, torch.tensor([[1 - dd]], device=vol.device, dtype=vol.dtype) ) # symmetry vol, iou = overlap_fn(box2[None], box1[None]) self.assertClose( vol, torch.tensor([[1 - dd]], device=vol.device, dtype=vol.dtype) ) # 3rd test dd = random.random() box2 = box1 + torch.tensor([[dd, 0.0, 0.0]], device=device) vol, _ = overlap_fn(box1[None], box2[None]) self.assertClose( vol, torch.tensor([[1 - dd]], device=vol.device, dtype=vol.dtype) ) # symmetry vol, _ = overlap_fn(box2[None], box1[None]) self.assertClose( vol, torch.tensor([[1 - dd]], device=vol.device, dtype=vol.dtype) ) # 4th test ddx, ddy, ddz = random.random(), random.random(), random.random() box2 = box1 + torch.tensor([[ddx, ddy, ddz]], device=device) vol, _ = overlap_fn(box1[None], box2[None]) self.assertClose( vol, torch.tensor( [[(1 - ddx) * (1 - ddy) * (1 - ddz)]], device=vol.device, dtype=vol.dtype, ), ) # symmetry vol, _ = overlap_fn(box2[None], box1[None]) self.assertClose( vol, torch.tensor( [[(1 - ddx) * (1 - ddy) * (1 - ddz)]], device=vol.device, dtype=vol.dtype, ), ) # Also check IoU is 1 when computing overlap with the same shifted box vol, iou = overlap_fn(box2[None], box2[None]) self.assertClose(iou, torch.tensor([[1.0]], device=vol.device, dtype=vol.dtype)) # 5th test ddx, ddy, ddz = random.random(), random.random(), random.random() box2 = box1 + torch.tensor([[ddx, ddy, ddz]], device=device) RR = random_rotation(dtype=torch.float32, device=device) box1r = box1 @ RR.transpose(0, 1) box2r = box2 @ RR.transpose(0, 1) vol, _ = overlap_fn(box1r[None], box2r[None]) self.assertClose( vol, torch.tensor( [[(1 - ddx) * (1 - ddy) * (1 - ddz)]], device=vol.device, dtype=vol.dtype, ), ) # symmetry vol, _ = overlap_fn(box2r[None], box1r[None]) self.assertClose( vol, torch.tensor( [[(1 - ddx) * (1 - ddy) * (1 - ddz)]], device=vol.device, dtype=vol.dtype, ), ) # 6th test ddx, ddy, ddz = random.random(), random.random(), random.random() box2 = box1 + torch.tensor([[ddx, ddy, ddz]], device=device) RR = random_rotation(dtype=torch.float32, device=device) TT = torch.rand((1, 3), dtype=torch.float32, device=device) box1r = box1 @ RR.transpose(0, 1) + TT box2r = box2 @ RR.transpose(0, 1) + TT vol, _ = overlap_fn(box1r[None], box2r[None]) self.assertClose( vol, torch.tensor( [[(1 - ddx) * (1 - ddy) * (1 - ddz)]], device=vol.device, dtype=vol.dtype, ), atol=1e-7, ) # symmetry vol, _ = overlap_fn(box2r[None], box1r[None]) self.assertClose( vol, torch.tensor( [[(1 - ddx) * (1 - ddy) * (1 - ddz)]], device=vol.device, dtype=vol.dtype, ), atol=1e-7, ) # 7th test: hand coded example and test with meshlab output # Meshlab procedure to compute volumes of shapes # 1. Load a shape, then Filters # -> Remeshing, Simplification, Reconstruction -> Convex Hull # 2. Select the convex hull shape (This is important!) # 3. Then Filters -> Quality Measure and Computation -> Compute Geometric Measures # 3. Check for "Mesh Volume" in the stdout box1r = torch.tensor( [ [3.1673, -2.2574, 0.4817], [4.6470, 0.2223, 2.4197], [5.2200, 1.1844, 0.7510], [3.7403, -1.2953, -1.1869], [-4.9316, 2.5724, 0.4856], [-3.4519, 5.0521, 2.4235], [-2.8789, 6.0142, 0.7549], [-4.3586, 3.5345, -1.1831], ], device=device, ) box2r = torch.tensor( [ [0.5623, 4.0647, 3.4334], [3.3584, 4.3191, 1.1791], [3.0724, -5.9235, -0.3315], [0.2763, -6.1779, 1.9229], [-2.0773, 4.6121, 0.2213], [0.7188, 4.8665, -2.0331], [0.4328, -5.3761, -3.5436], [-2.3633, -5.6305, -1.2893], ], device=device, ) # from Meshlab: vol_inters = 33.558529 vol_box1 = 65.899010 vol_box2 = 156.386719 iou_mesh = vol_inters / (vol_box1 + vol_box2 - vol_inters) vol, iou = overlap_fn(box1r[None], box2r[None]) self.assertClose(vol, torch.tensor([[vol_inters]], device=device), atol=1e-1) self.assertClose(iou, torch.tensor([[iou_mesh]], device=device), atol=1e-1) # symmetry vol, iou = overlap_fn(box2r[None], box1r[None]) self.assertClose(vol, torch.tensor([[vol_inters]], device=device), atol=1e-1) self.assertClose(iou, torch.tensor([[iou_mesh]], device=device), atol=1e-1) # 8th test: compare with sampling # create box1 ctrs = torch.rand((2, 3), device=device) whl = torch.rand((2, 3), device=device) * 10.0 + 1.0 # box8a & box8b box8a = self.create_box(ctrs[0], whl[0]) box8b = self.create_box(ctrs[1], whl[1]) RR1 = random_rotation(dtype=torch.float32, device=device) TT1 = torch.rand((1, 3), dtype=torch.float32, device=device) RR2 = random_rotation(dtype=torch.float32, device=device) TT2 = torch.rand((1, 3), dtype=torch.float32, device=device) box1r = box8a @ RR1.transpose(0, 1) + TT1 box2r = box8b @ RR2.transpose(0, 1) + TT2 vol, iou = overlap_fn(box1r[None], box2r[None]) iou_sampling = self._box3d_overlap_sampling_batched( box1r[None], box2r[None], num_samples=10000 ) self.assertClose(iou, iou_sampling, atol=1e-2) # symmetry vol, iou = overlap_fn(box2r[None], box1r[None]) self.assertClose(iou, iou_sampling, atol=1e-2) # 9th test: non overlapping boxes, iou = 0.0 box2 = box1 + torch.tensor([[0.0, 100.0, 0.0]], device=device) vol, iou = overlap_fn(box1[None], box2[None]) self.assertClose(vol, torch.tensor([[0.0]], device=vol.device, dtype=vol.dtype)) self.assertClose(iou, torch.tensor([[0.0]], device=vol.device, dtype=vol.dtype)) # symmetry vol, iou = overlap_fn(box2[None], box1[None]) self.assertClose(vol, torch.tensor([[0.0]], device=vol.device, dtype=vol.dtype)) self.assertClose(iou, torch.tensor([[0.0]], device=vol.device, dtype=vol.dtype)) # 10th test: Non coplanar verts in a plane box10 = box1 + torch.rand((8, 3), dtype=torch.float32, device=device) msg = "Plane vertices are not coplanar" with self.assertRaisesRegex(ValueError, msg): overlap_fn(box10[None], box10[None]) # 11th test: Skewed bounding boxes but all verts are coplanar box_skew_1 = torch.tensor( [ [0, 0, 0], [1, 0, 0], [1, 1, 0], [0, 1, 0], [-2, -2, 2], [2, -2, 2], [2, 2, 2], [-2, 2, 2], ], dtype=torch.float32, device=device, ) box_skew_2 = torch.tensor( [ [2.015995, 0.695233, 2.152806], [2.832533, 0.663448, 1.576389], [2.675445, -0.309592, 1.407520], [1.858907, -0.277806, 1.983936], [-0.413922, 3.161758, 2.044343], [2.852230, 3.034615, -0.261321], [2.223878, -0.857545, -0.936800], [-1.042273, -0.730402, 1.368864], ], dtype=torch.float32, device=device, ) vol1 = 14.000 vol2 = 14.000005 vol_inters = 5.431122 iou = vol_inters / (vol1 + vol2 - vol_inters) vols, ious = overlap_fn(box_skew_1[None], box_skew_2[None]) self.assertClose(vols, torch.tensor([[vol_inters]], device=device), atol=1e-1) self.assertClose(ious, torch.tensor([[iou]], device=device), atol=1e-1) # symmetry vols, ious = overlap_fn(box_skew_2[None], box_skew_1[None]) self.assertClose(vols, torch.tensor([[vol_inters]], device=device), atol=1e-1) self.assertClose(ious, torch.tensor([[iou]], device=device), atol=1e-1) # 12th test: Zero area bounding box (from GH issue #992) box12a = torch.tensor( [ [-1.0000, -1.0000, -0.5000], [1.0000, -1.0000, -0.5000], [1.0000, 1.0000, -0.5000], [-1.0000, 1.0000, -0.5000], [-1.0000, -1.0000, 0.5000], [1.0000, -1.0000, 0.5000], [1.0000, 1.0000, 0.5000], [-1.0000, 1.0000, 0.5000], ], device=device, dtype=torch.float32, ) box12b = torch.tensor( [ [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], ], device=device, dtype=torch.float32, ) msg = "Planes have zero areas" with self.assertRaisesRegex(ValueError, msg): overlap_fn(box12a[None], box12b[None]) # symmetry with self.assertRaisesRegex(ValueError, msg): overlap_fn(box12b[None], box12a[None]) # 13th test: From GH issue #992 # Zero area coplanar face after intersection ctrs = torch.tensor([[0.0, 0.0, 0.0], [-1.0, 1.0, 0.0]]) whl = torch.tensor([[2.0, 2.0, 2.0], [2.0, 2, 2]]) box13a = TestIoU3D.create_box(ctrs[0], whl[0]) box13b = TestIoU3D.create_box(ctrs[1], whl[1]) vol, iou = overlap_fn(box13a[None], box13b[None]) self.assertClose(vol, torch.tensor([[2.0]], device=vol.device, dtype=vol.dtype)) # 14th test: From GH issue #992 # Random rotation, same boxes, iou should be 1.0 corners = ( torch.tensor( [ [-1.0, -1.0, -1.0], [1.0, -1.0, -1.0], [1.0, 1.0, -1.0], [-1.0, 1.0, -1.0], [-1.0, -1.0, 1.0], [1.0, -1.0, 1.0], [1.0, 1.0, 1.0], [-1.0, 1.0, 1.0], ], device=device, dtype=torch.float32, ) * 0.5 ) yaw = torch.tensor(0.185) Rot = torch.tensor( [ [torch.cos(yaw), 0.0, torch.sin(yaw)], [0.0, 1.0, 0.0], [-torch.sin(yaw), 0.0, torch.cos(yaw)], ], dtype=torch.float32, device=device, ) corners = (Rot.mm(corners.t())).t() vol, iou = overlap_fn(corners[None], corners[None]) self.assertClose( iou, torch.tensor([[1.0]], device=vol.device, dtype=vol.dtype), atol=1e-2 ) # 15th test: From GH issue #1082 box15a = torch.tensor( [ [-2.5629019, 4.13995749, -1.76344576], [1.92329434, 4.28127117, -1.86155124], [1.86994571, 5.97489644, -1.86155124], [-2.61625053, 5.83358276, -1.76344576], [-2.53123587, 4.14095496, -0.31397536], [1.95496037, 4.28226864, -0.41208084], [1.90161174, 5.97589391, -0.41208084], [-2.5845845, 5.83458023, -0.31397536], ], device=device, dtype=torch.float32, ) box15b = torch.tensor( [ [-2.6256125, 4.13036357, -1.82893437], [1.87201008, 4.25296695, -1.82893437], [1.82562476, 5.95458116, -1.82893437], [-2.67199782, 5.83197777, -1.82893437], [-2.6256125, 4.13036357, -0.40095884], [1.87201008, 4.25296695, -0.40095884], [1.82562476, 5.95458116, -0.40095884], [-2.67199782, 5.83197777, -0.40095884], ], device=device, dtype=torch.float32, ) vol, iou = overlap_fn(box15a[None], box15b[None]) self.assertClose( iou, torch.tensor([[0.91]], device=vol.device, dtype=vol.dtype), atol=1e-2 ) # symmetry vol, iou = overlap_fn(box15b[None], box15a[None]) self.assertClose( iou, torch.tensor([[0.91]], device=vol.device, dtype=vol.dtype), atol=1e-2 ) # 16th test: From GH issue 1287 box16a = torch.tensor( [ [-167.5847, -70.6167, -2.7927], [-166.7333, -72.4264, -2.7927], [-166.7333, -72.4264, -4.5927], [-167.5847, -70.6167, -4.5927], [-163.0605, -68.4880, -2.7927], [-162.2090, -70.2977, -2.7927], [-162.2090, -70.2977, -4.5927], [-163.0605, -68.4880, -4.5927], ], device=device, dtype=torch.float32, ) box16b = torch.tensor( [ [-167.5847, -70.6167, -2.7927], [-166.7333, -72.4264, -2.7927], [-166.7333, -72.4264, -4.5927], [-167.5847, -70.6167, -4.5927], [-163.0605, -68.4880, -2.7927], [-162.2090, -70.2977, -2.7927], [-162.2090, -70.2977, -4.5927], [-163.0605, -68.4880, -4.5927], ], device=device, dtype=torch.float32, ) vol, iou = overlap_fn(box16a[None], box16b[None]) self.assertClose( iou, torch.tensor([[1.0]], device=vol.device, dtype=vol.dtype), atol=1e-2 ) # symmetry vol, iou = overlap_fn(box16b[None], box16a[None]) self.assertClose( iou, torch.tensor([[1.0]], device=vol.device, dtype=vol.dtype), atol=1e-2 ) # 17th test: From GH issue 1287 box17a = torch.tensor( [ [-33.94158, -4.51639, 0.96941], [-34.67156, -2.65437, 0.96941], [-34.67156, -2.65437, -0.95367], [-33.94158, -4.51639, -0.95367], [-38.75954, -6.40521, 0.96941], [-39.48952, -4.54319, 0.96941], [-39.48952, -4.54319, -0.95367], [-38.75954, -6.40521, -0.95367], ], device=device, dtype=torch.float32, ) box17b = torch.tensor( [ [-33.94159, -4.51638, 0.96939], [-34.67158, -2.65437, 0.96939], [-34.67158, -2.65437, -0.95368], [-33.94159, -4.51638, -0.95368], [-38.75954, -6.40523, 0.96939], [-39.48953, -4.54321, 0.96939], [-39.48953, -4.54321, -0.95368], [-38.75954, -6.40523, -0.95368], ], device=device, dtype=torch.float32, ) vol, iou = overlap_fn(box17a[None], box17b[None]) self.assertClose( iou, torch.tensor([[1.0]], device=vol.device, dtype=vol.dtype), atol=1e-2 ) # symmetry vol, iou = overlap_fn(box17b[None], box17a[None]) self.assertClose( iou, torch.tensor([[1.0]], device=vol.device, dtype=vol.dtype), atol=1e-2 ) # 18th test: From GH issue 1287 box18a = torch.tensor( [ [-105.6248, -32.7026, -1.2279], [-106.4690, -30.8895, -1.2279], [-106.4690, -30.8895, -3.0279], [-105.6248, -32.7026, -3.0279], [-110.1575, -34.8132, -1.2279], [-111.0017, -33.0001, -1.2279], [-111.0017, -33.0001, -3.0279], [-110.1575, -34.8132, -3.0279], ], device=device, dtype=torch.float32, ) box18b = torch.tensor( [ [-105.5094, -32.9504, -1.0641], [-106.4272, -30.9793, -1.0641], [-106.4272, -30.9793, -3.1916], [-105.5094, -32.9504, -3.1916], [-110.0421, -35.0609, -1.0641], [-110.9599, -33.0899, -1.0641], [-110.9599, -33.0899, -3.1916], [-110.0421, -35.0609, -3.1916], ], device=device, dtype=torch.float32, ) # from Meshlab vol_inters = 17.108501 vol_box1 = 18.000067 vol_box2 = 23.128527 iou_mesh = vol_inters / (vol_box1 + vol_box2 - vol_inters) vol, iou = overlap_fn(box18a[None], box18b[None]) self.assertClose( iou, torch.tensor([[iou_mesh]], device=vol.device, dtype=vol.dtype), atol=1e-2, ) self.assertClose( vol, torch.tensor([[vol_inters]], device=vol.device, dtype=vol.dtype), atol=1e-2, ) # symmetry vol, iou = overlap_fn(box18b[None], box18a[None]) self.assertClose( iou, torch.tensor([[iou_mesh]], device=vol.device, dtype=vol.dtype), atol=1e-2, ) self.assertClose( vol, torch.tensor([[vol_inters]], device=vol.device, dtype=vol.dtype), atol=1e-2, ) # 19th example: From GH issue 1287 box19a = torch.tensor( [ [-59.4785, -15.6003, 0.4398], [-60.2263, -13.6928, 0.4398], [-60.2263, -13.6928, -1.3909], [-59.4785, -15.6003, -1.3909], [-64.1743, -17.4412, 0.4398], [-64.9221, -15.5337, 0.4398], [-64.9221, -15.5337, -1.3909], [-64.1743, -17.4412, -1.3909], ], device=device, dtype=torch.float32, ) box19b = torch.tensor( [ [-59.4874, -15.5775, -0.1512], [-60.2174, -13.7155, -0.1512], [-60.2174, -13.7155, -1.9820], [-59.4874, -15.5775, -1.9820], [-64.1832, -17.4185, -0.1512], [-64.9132, -15.5564, -0.1512], [-64.9132, -15.5564, -1.9820], [-64.1832, -17.4185, -1.9820], ], device=device, dtype=torch.float32, ) # from Meshlab vol_inters = 12.505723 vol_box1 = 18.918238 vol_box2 = 18.468531 iou_mesh = vol_inters / (vol_box1 + vol_box2 - vol_inters) vol, iou = overlap_fn(box19a[None], box19b[None]) self.assertClose( iou, torch.tensor([[iou_mesh]], device=vol.device, dtype=vol.dtype), atol=1e-2, ) self.assertClose( vol, torch.tensor([[vol_inters]], device=vol.device, dtype=vol.dtype), atol=1e-2, ) # symmetry vol, iou = overlap_fn(box19b[None], box19a[None]) self.assertClose( iou, torch.tensor([[iou_mesh]], device=vol.device, dtype=vol.dtype), atol=1e-2, ) self.assertClose( vol, torch.tensor([[vol_inters]], device=vol.device, dtype=vol.dtype), atol=1e-2, ) def _test_real_boxes(self, overlap_fn, device): data_filename = "./real_boxes.pkl" with open(DATA_DIR / data_filename, "rb") as f: example = pickle.load(f) verts1 = torch.FloatTensor(example["verts1"]) verts2 = torch.FloatTensor(example["verts2"]) boxes = torch.stack((verts1, verts2)).to(device) iou_expected = torch.eye(2).to(device) vol, iou = overlap_fn(boxes, boxes) self.assertClose(iou, iou_expected) def test_iou_naive(self): device = get_random_cuda_device() self._test_iou(self._box3d_overlap_naive_batched, device) self._test_compare_objectron(self._box3d_overlap_naive_batched, device) self._test_real_boxes(self._box3d_overlap_naive_batched, device) def test_iou_cpu(self): device = torch.device("cpu") self._test_iou(box3d_overlap, device) self._test_compare_objectron(box3d_overlap, device) self._test_real_boxes(box3d_overlap, device) def test_iou_cuda(self): device = torch.device("cuda:0") self._test_iou(box3d_overlap, device) self._test_compare_objectron(box3d_overlap, device) self._test_real_boxes(box3d_overlap, device) def _test_compare_objectron(self, overlap_fn, device): # Load saved objectron data data_filename = "./objectron_vols_ious.pt" objectron_vals = torch.load(DATA_DIR / data_filename) boxes1 = objectron_vals["boxes1"] boxes2 = objectron_vals["boxes2"] vols_objectron = objectron_vals["vols"] ious_objectron = objectron_vals["ious"] boxes1 = boxes1.to(device=device, dtype=torch.float32) boxes2 = boxes2.to(device=device, dtype=torch.float32) # Convert vertex orderings from Objectron to PyTorch3D convention idx = torch.tensor( OBJECTRON_TO_PYTORCH3D_FACE_IDX, dtype=torch.int64, device=device ) boxes1 = boxes1.index_select(index=idx, dim=1) boxes2 = boxes2.index_select(index=idx, dim=1) # Run PyTorch3D version vols, ious = overlap_fn(boxes1, boxes2) # Check values match self.assertClose(vols_objectron, vols.cpu()) self.assertClose(ious_objectron, ious.cpu()) def test_batched_errors(self): N, M = 5, 10 boxes1 = torch.randn((N, 8, 3)) boxes2 = torch.randn((M, 10, 3)) with self.assertRaisesRegex(ValueError, "(8, 3)"): box3d_overlap(boxes1, boxes2) def test_box_volume(self): device = torch.device("cuda:0") box1 = torch.tensor( [ [3.1673, -2.2574, 0.4817], [4.6470, 0.2223, 2.4197], [5.2200, 1.1844, 0.7510], [3.7403, -1.2953, -1.1869], [-4.9316, 2.5724, 0.4856], [-3.4519, 5.0521, 2.4235], [-2.8789, 6.0142, 0.7549], [-4.3586, 3.5345, -1.1831], ], dtype=torch.float32, device=device, ) box2 = torch.tensor( [ [0.5623, 4.0647, 3.4334], [3.3584, 4.3191, 1.1791], [3.0724, -5.9235, -0.3315], [0.2763, -6.1779, 1.9229], [-2.0773, 4.6121, 0.2213], [0.7188, 4.8665, -2.0331], [0.4328, -5.3761, -3.5436], [-2.3633, -5.6305, -1.2893], ], dtype=torch.float32, device=device, ) box3 = torch.tensor( [ [0, 0, 0], [1, 0, 0], [1, 1, 0], [0, 1, 0], [0, 0, 1], [1, 0, 1], [1, 1, 1], [0, 1, 1], ], dtype=torch.float32, device=device, ) RR = random_rotation(dtype=torch.float32, device=device) TT = torch.rand((1, 3), dtype=torch.float32, device=device) box4 = box3 @ RR.transpose(0, 1) + TT self.assertClose(box_volume(box1).cpu(), torch.tensor(65.899010), atol=1e-3) self.assertClose(box_volume(box2).cpu(), torch.tensor(156.386719), atol=1e-3) self.assertClose(box_volume(box3).cpu(), torch.tensor(1.0), atol=1e-3) self.assertClose(box_volume(box4).cpu(), torch.tensor(1.0), atol=1e-3) def test_box_planar_dir(self): device = torch.device("cuda:0") box1 = torch.tensor( UNIT_BOX, dtype=torch.float32, device=device, ) n1 = torch.tensor( [ [0.0, 0.0, 1.0], [0.0, -1.0, 0.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0], [-1.0, 0.0, 0.0], [0.0, 0.0, -1.0], ], device=device, dtype=torch.float32, ) RR = random_rotation(dtype=torch.float32, device=device) TT = torch.rand((1, 3), dtype=torch.float32, device=device) box2 = box1 @ RR.transpose(0, 1) + TT n2 = n1 @ RR.transpose(0, 1) self.assertClose(box_planar_dir(box1), n1) self.assertClose(box_planar_dir(box2), n2) @staticmethod def iou_naive(N: int, M: int, device="cpu"): box = torch.tensor( [UNIT_BOX], dtype=torch.float32, device=device, ) boxes1 = box + torch.randn((N, 1, 3), device=device) boxes2 = box + torch.randn((M, 1, 3), device=device) def output(): vol, iou = TestIoU3D._box3d_overlap_naive_batched(boxes1, boxes2) return output @staticmethod def iou(N: int, M: int, device="cpu"): box = torch.tensor( [UNIT_BOX], dtype=torch.float32, device=device, ) boxes1 = box + torch.randn((N, 1, 3), device=device) boxes2 = box + torch.randn((M, 1, 3), device=device) def output(): vol, iou = box3d_overlap(boxes1, boxes2) return output @staticmethod def iou_sampling(N: int, M: int, num_samples: int, device="cpu"): box = torch.tensor( [UNIT_BOX], dtype=torch.float32, device=device, ) boxes1 = box + torch.randn((N, 1, 3), device=device) boxes2 = box + torch.randn((M, 1, 3), device=device) def output(): _ = TestIoU3D._box3d_overlap_sampling_batched(boxes1, boxes2, num_samples) return output # -------------------------------------------------- # # NAIVE IMPLEMENTATION # # -------------------------------------------------- # """ The main functions below are: * box3d_overlap_naive: which computes the exact IoU of box1 and box2 * box3d_overlap_sampling: which computes an approximate IoU of box1 and box2 by sampling points within the boxes Note that both implementations currently do not support batching. """ # -------------------------------------------------- # # Throughout this implementation, we assume that boxes # are defined by their 8 corners in the following order # # (4) +---------+. (5) # | ` . | ` . # | (0) +---+-----+ (1) # | | | | # (7) +-----+---+. (6)| # ` . | ` . | # (3) ` +---------+ (2) # # -------------------------------------------------- # # -------------------------------------------------- # # HELPER FUNCTIONS FOR EXACT SOLUTION # # -------------------------------------------------- # def get_tri_verts(box: torch.Tensor) -> torch.Tensor: """ Return the vertex coordinates forming the triangles of the box. The computation here resembles the Meshes data structure. But since we only want this tiny functionality, we abstract it out. Args: box: tensor of shape (8, 3) Returns: tri_verts: tensor of shape (12, 3, 3) """ device = box.device faces = torch.tensor(_box_triangles, device=device, dtype=torch.int64) # (12, 3) tri_verts = box[faces] # (12, 3, 3) return tri_verts def get_plane_verts(box: torch.Tensor) -> torch.Tensor: """ Return the vertex coordinates forming the planes of the box. The computation here resembles the Meshes data structure. But since we only want this tiny functionality, we abstract it out. Args: box: tensor of shape (8, 3) Returns: plane_verts: tensor of shape (6, 4, 3) """ device = box.device faces = torch.tensor(_box_planes, device=device, dtype=torch.int64) # (6, 4) plane_verts = box[faces] # (6, 4, 3) return plane_verts def get_tri_center_normal(tris: torch.Tensor) -> torch.Tensor: """ Returns the center and normal of triangles Args: tris: tensor of shape (T, 3, 3) Returns: center: tensor of shape (T, 3) normal: tensor of shape (T, 3) """ add_dim0 = False if tris.ndim == 2: tris = tris.unsqueeze(0) add_dim0 = True ctr = tris.mean(1) # (T, 3) normals = torch.zeros_like(ctr) v0, v1, v2 = tris.unbind(1) # 3 x (T, 3) # unvectorized solution T = tris.shape[0] for t in range(T): ns = torch.zeros((3, 3), device=tris.device) ns[0] = torch.cross(v0[t] - ctr[t], v1[t] - ctr[t], dim=-1) ns[1] = torch.cross(v0[t] - ctr[t], v2[t] - ctr[t], dim=-1) ns[2] = torch.cross(v1[t] - ctr[t], v2[t] - ctr[t], dim=-1) i = torch.norm(ns, dim=-1).argmax() normals[t] = ns[i] if add_dim0: ctr = ctr[0] normals = normals[0] normals = F.normalize(normals, dim=-1) return ctr, normals def get_plane_center_normal(planes: torch.Tensor) -> torch.Tensor: """ Returns the center and normal of planes Args: planes: tensor of shape (P, 4, 3) Returns: center: tensor of shape (P, 3) normal: tensor of shape (P, 3) """ add_dim0 = False if planes.ndim == 2: planes = planes.unsqueeze(0) add_dim0 = True ctr = planes.mean(1) # (P, 3) normals = torch.zeros_like(ctr) v0, v1, v2, v3 = planes.unbind(1) # 4 x (P, 3) # unvectorized solution P = planes.shape[0] for t in range(P): ns = torch.zeros((6, 3), device=planes.device) ns[0] = torch.cross(v0[t] - ctr[t], v1[t] - ctr[t], dim=-1) ns[1] = torch.cross(v0[t] - ctr[t], v2[t] - ctr[t], dim=-1) ns[2] = torch.cross(v0[t] - ctr[t], v3[t] - ctr[t], dim=-1) ns[3] = torch.cross(v1[t] - ctr[t], v2[t] - ctr[t], dim=-1) ns[4] = torch.cross(v1[t] - ctr[t], v3[t] - ctr[t], dim=-1) ns[5] = torch.cross(v2[t] - ctr[t], v3[t] - ctr[t], dim=-1) i = torch.norm(ns, dim=-1).argmax() normals[t] = ns[i] if add_dim0: ctr = ctr[0] normals = normals[0] normals = F.normalize(normals, dim=-1) return ctr, normals def box_planar_dir( box: torch.Tensor, dot_eps: float = DOT_EPS, area_eps: float = AREA_EPS ) -> torch.Tensor: """ Finds the unit vector n which is perpendicular to each plane in the box and points towards the inside of the box. The planes are defined by `_box_planes`. Since the shape is convex, we define the interior to be the direction pointing to the center of the shape. Args: box: tensor of shape (8, 3) of the vertices of the 3D box Returns: n: tensor of shape (6,) of the unit vector orthogonal to the face pointing towards the interior of the shape """ assert box.shape[0] == 8 and box.shape[1] == 3 # center point of each box box_ctr = box.mean(0).view(1, 3) # box planes plane_verts = get_plane_verts(box) # (6, 4, 3) v0, v1, v2, v3 = plane_verts.unbind(1) plane_ctr, n = get_plane_center_normal(plane_verts) # Check all verts are coplanar if ( not ( F.normalize(v3 - v0, dim=-1).unsqueeze(1).bmm(n.unsqueeze(2)).abs() < dot_eps ) .all() .item() ): msg = "Plane vertices are not coplanar" raise ValueError(msg) # Check all faces have non zero area area1 = torch.cross(v1 - v0, v2 - v0, dim=-1).norm(dim=-1) / 2 area2 = torch.cross(v3 - v0, v2 - v0, dim=-1).norm(dim=-1) / 2 if (area1 < area_eps).any().item() or (area2 < area_eps).any().item(): msg = "Planes have zero areas" raise ValueError(msg) # We can write: `box_ctr = plane_ctr + a * e0 + b * e1 + c * n`, (1). # With = 0 and = 0, where <.,.> refers to the dot product, # since that e0 is orthogonal to n. Same for e1. """ # Below is how one would solve for (a, b, c) # Solving for (a, b) numF = verts.shape[0] A = torch.ones((numF, 2, 2), dtype=torch.float32, device=device) B = torch.ones((numF, 2), dtype=torch.float32, device=device) A[:, 0, 1] = (e0 * e1).sum(-1) A[:, 1, 0] = (e0 * e1).sum(-1) B[:, 0] = ((box_ctr - plane_ctr) * e0).sum(-1) B[:, 1] = ((box_ctr - plane_ctr) * e1).sum(-1) ab = torch.linalg.solve(A, B) # (numF, 2) a, b = ab.unbind(1) # solving for c c = ((box_ctr - plane_ctr - a.view(numF, 1) * e0 - b.view(numF, 1) * e1) * n).sum(-1) """ # Since we know that = 0 and = 0 (e0 and e1 are orthogonal to n), # the above solution is equivalent to direc = F.normalize(box_ctr - plane_ctr, dim=-1) # (6, 3) c = (direc * n).sum(-1) # If c is negative, then we revert the direction of n such that n points "inside" negc = c < 0.0 n[negc] *= -1.0 # c[negc] *= -1.0 # Now (a, b, c) is the solution to (1) return n def tri_verts_area(tri_verts: torch.Tensor) -> torch.Tensor: """ Computes the area of the triangle faces in tri_verts Args: tri_verts: tensor of shape (T, 3, 3) Returns: areas: the area of the triangles (T, 1) """ add_dim = False if tri_verts.ndim == 2: tri_verts = tri_verts.unsqueeze(0) add_dim = True v0, v1, v2 = tri_verts.unbind(1) areas = torch.cross(v1 - v0, v2 - v0, dim=-1).norm(dim=-1) / 2.0 if add_dim: areas = areas[0] return areas def box_volume(box: torch.Tensor) -> torch.Tensor: """ Computes the volume of each box in boxes. The volume of each box is the sum of all the tetrahedrons formed by the faces of the box. The face of the box is the base of that tetrahedron and the center point of the box is the apex. In other words, vol(box) = sum_i A_i * d_i / 3, where A_i is the area of the i-th face and d_i is the distance of the apex from the face. We use the equivalent dot/cross product formulation. Read https://en.wikipedia.org/wiki/Tetrahedron#Volume Args: box: tensor of shape (8, 3) containing the vertices of the 3D box Returns: vols: the volume of the box """ assert box.shape[0] == 8 and box.shape[1] == 3 # Compute the center point of each box ctr = box.mean(0).view(1, 1, 3) # Extract the coordinates of the faces for each box tri_verts = get_tri_verts(box) # Set the origin of the coordinate system to coincide # with the apex of the tetrahedron to simplify the volume calculation # See https://en.wikipedia.org/wiki/Tetrahedron#Volume tri_verts = tri_verts - ctr # Compute the volume of each box using the dot/cross product formula vols = torch.sum( tri_verts[:, 0] * torch.cross(tri_verts[:, 1], tri_verts[:, 2], dim=-1), dim=-1, ) vols = (vols.abs() / 6.0).sum() return vols def coplanar_tri_faces(tri1: torch.Tensor, tri2: torch.Tensor, eps: float = DOT_EPS): """ Determines whether two triangle faces in 3D are coplanar Args: tri1: tensor of shape (3, 3) of the vertices of the 1st triangle tri2: tensor of shape (3, 3) of the vertices of the 2nd triangle Returns: is_coplanar: bool """ tri1_ctr, tri1_n = get_tri_center_normal(tri1) tri2_ctr, tri2_n = get_tri_center_normal(tri2) check1 = tri1_n.dot(tri2_n).abs() > 1 - eps # checks if parallel dist12 = torch.norm(tri1.unsqueeze(1) - tri2.unsqueeze(0), dim=-1) dist12_argmax = dist12.argmax() i1 = dist12_argmax // 3 i2 = dist12_argmax % 3 assert dist12[i1, i2] == dist12.max() check2 = ( F.normalize(tri1[i1] - tri2[i2], dim=0).dot(tri1_n).abs() < eps ) or F.normalize(tri1[i1] - tri2[i2], dim=0).dot(tri2_n).abs() < eps return check1 and check2 def coplanar_tri_plane( tri: torch.Tensor, plane: torch.Tensor, n: torch.Tensor, eps: float = DOT_EPS ): """ Determines whether two triangle faces in 3D are coplanar Args: tri: tensor of shape (3, 3) of the vertices of the triangle plane: tensor of shape (4, 3) of the vertices of the plane n: tensor of shape (3,) of the unit "inside" direction on the plane Returns: is_coplanar: bool """ tri_ctr, tri_n = get_tri_center_normal(tri) check1 = tri_n.dot(n).abs() > 1 - eps # checks if parallel dist12 = torch.norm(tri.unsqueeze(1) - plane.unsqueeze(0), dim=-1) dist12_argmax = dist12.argmax() i1 = dist12_argmax // 4 i2 = dist12_argmax % 4 assert dist12[i1, i2] == dist12.max() check2 = F.normalize(tri[i1] - plane[i2], dim=0).dot(n).abs() < eps return check1 and check2 def is_inside( plane: torch.Tensor, n: torch.Tensor, points: torch.Tensor, return_proj: bool = True, ): """ Computes whether point is "inside" the plane. The definition of "inside" means that the point has a positive component in the direction of the plane normal defined by n. For example, plane | | . (A) |--> n | .(B) | Point (A) is "inside" the plane, while point (B) is "outside" the plane. Args: plane: tensor of shape (4,3) of vertices of a box plane n: tensor of shape (3,) of the unit "inside" direction on the plane points: tensor of shape (P, 3) of coordinates of a point return_proj: bool whether to return the projected point on the plane Returns: is_inside: bool of shape (P,) of whether point is inside p_proj: tensor of shape (P, 3) of the projected point on plane """ device = plane.device v0, v1, v2, v3 = plane.unbind(0) plane_ctr = plane.mean(0) e0 = F.normalize(v0 - plane_ctr, dim=0) e1 = F.normalize(v1 - plane_ctr, dim=0) if not torch.allclose(e0.dot(n), torch.zeros((1,), device=device), atol=1e-2): raise ValueError("Input n is not perpendicular to the plane") if not torch.allclose(e1.dot(n), torch.zeros((1,), device=device), atol=1e-2): raise ValueError("Input n is not perpendicular to the plane") add_dim = False if points.ndim == 1: points = points.unsqueeze(0) add_dim = True assert points.shape[1] == 3 # Every point p can be written as p = ctr + a e0 + b e1 + c n # If return_proj is True, we need to solve for (a, b) p_proj = None if return_proj: # solving for (a, b) A = torch.tensor( [[1.0, e0.dot(e1)], [e0.dot(e1), 1.0]], dtype=torch.float32, device=device ) B = torch.zeros((2, points.shape[0]), dtype=torch.float32, device=device) B[0, :] = torch.sum((points - plane_ctr.view(1, 3)) * e0.view(1, 3), dim=-1) B[1, :] = torch.sum((points - plane_ctr.view(1, 3)) * e1.view(1, 3), dim=-1) ab = A.inverse() @ B # (2, P) p_proj = plane_ctr.view(1, 3) + ab.transpose(0, 1) @ torch.stack( (e0, e1), dim=0 ) # solving for c # c = (point - ctr - a * e0 - b * e1).dot(n) direc = torch.sum((points - plane_ctr.view(1, 3)) * n.view(1, 3), dim=-1) ins = direc >= 0.0 if add_dim: assert p_proj.shape[0] == 1 p_proj = p_proj[0] return ins, p_proj def plane_edge_point_of_intersection(plane, n, p0, p1, eps: float = DOT_EPS): """ Finds the point of intersection between a box plane and a line segment connecting (p0, p1). The plane is assumed to be infinite long. Args: plane: tensor of shape (4, 3) of the coordinates of the vertices defining the plane n: tensor of shape (3,) of the unit direction perpendicular on the plane (Note that we could compute n but since it's computed in the main body of the function, we save time by feeding it in. For the purpose of this function, it's not important that n points "inside" the shape.) p0, p1: tensors of shape (3,), (3,) Returns: p: tensor of shape (3,) of the coordinates of the point of intersection a: scalar such that p = p0 + a*(p1-p0) """ # The point of intersection can be parametrized # p = p0 + a (p1 - p0) where a in [0, 1] # We want to find a such that p is on plane #

= 0 # if segment (p0, p1) is parallel to plane (it can only be on it) direc = F.normalize(p1 - p0, dim=0) if direc.dot(n).abs() < eps: return (p1 + p0) / 2.0, 0.5 else: ctr = plane.mean(0) a = -(p0 - ctr).dot(n) / ((p1 - p0).dot(n)) p = p0 + a * (p1 - p0) return p, a """ The three following functions support clipping a triangle face by a plane. They contain the following cases: (a) the triangle has one point "outside" the plane and (b) the triangle has two points "outside" the plane. This logic follows the logic of clipping triangles when they intersect the image plane while rendering. """ def clip_tri_by_plane_oneout( plane: torch.Tensor, n: torch.Tensor, vout: torch.Tensor, vin1: torch.Tensor, vin2: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Case (a). Clips triangle by plane when vout is outside plane, and vin1, vin2, is inside In this case, only one vertex of the triangle is outside the plane. Clip the triangle into a quadrilateral, and then split into two triangles Args: plane: tensor of shape (4, 3) of the coordinates of the vertices forming the plane n: tensor of shape (3,) of the unit "inside" direction of the plane vout, vin1, vin2: tensors of shape (3,) of the points forming the triangle, where vout is "outside" the plane and vin1, vin2 are "inside" Returns: verts: tensor of shape (4, 3) containing the new vertices formed after clipping the original intersecting triangle (vout, vin1, vin2) faces: tensor of shape (2, 3) defining the vertex indices forming the two new triangles which are "inside" the plane formed after clipping """ device = plane.device # point of intersection between plane and (vin1, vout) pint1, a1 = plane_edge_point_of_intersection(plane, n, vin1, vout) assert a1 >= -0.0001 and a1 <= 1.0001, a1 # point of intersection between plane and (vin2, vout) pint2, a2 = plane_edge_point_of_intersection(plane, n, vin2, vout) assert a2 >= -0.0001 and a2 <= 1.0001, a2 verts = torch.stack((vin1, pint1, pint2, vin2), dim=0) # 4x3 faces = torch.tensor( [[0, 1, 2], [0, 2, 3]], dtype=torch.int64, device=device ) # 2x3 return verts, faces def clip_tri_by_plane_twoout( plane: torch.Tensor, n: torch.Tensor, vout1: torch.Tensor, vout2: torch.Tensor, vin: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Case (b). Clips face by plane when vout1, vout2 are outside plane, and vin1 is inside In this case, only one vertex of the triangle is inside the plane. Args: plane: tensor of shape (4, 3) of the coordinates of the vertices forming the plane n: tensor of shape (3,) of the unit "inside" direction of the plane vout1, vout2, vin: tensors of shape (3,) of the points forming the triangle, where vin is "inside" the plane and vout1, vout2 are "outside" Returns: verts: tensor of shape (3, 3) containing the new vertices formed after clipping the original intersectiong triangle (vout, vin1, vin2) faces: tensor of shape (1, 3) defining the vertex indices forming the single new triangle which is "inside" the plane formed after clipping """ device = plane.device # point of intersection between plane and (vin, vout1) pint1, a1 = plane_edge_point_of_intersection(plane, n, vin, vout1) assert a1 >= -0.0001 and a1 <= 1.0001, a1 # point of intersection between plane and (vin, vout2) pint2, a2 = plane_edge_point_of_intersection(plane, n, vin, vout2) assert a2 >= -0.0001 and a2 <= 1.0001, a2 verts = torch.stack((vin, pint1, pint2), dim=0) # 3x3 faces = torch.tensor( [ [0, 1, 2], ], dtype=torch.int64, device=device, ) # 1x3 return verts, faces def clip_tri_by_plane(plane, n, tri_verts) -> Union[List, torch.Tensor]: """ Clip a trianglular face defined by tri_verts with a plane of inside "direction" n. This function computes whether the triangle has one or two or none points "outside" the plane. Args: plane: tensor of shape (4, 3) of the vertex coordinates of the plane n: tensor of shape (3,) of the unit "inside" direction of the plane tri_verts: tensor of shape (3, 3) of the vertex coordiantes of the the triangle faces Returns: tri_verts: tensor of shape (K, 3, 3) of the vertex coordinates of the triangles formed after clipping. All K triangles are now "inside" the plane. """ if coplanar_tri_plane(tri_verts, plane, n): return tri_verts.view(1, 3, 3) v0, v1, v2 = tri_verts.unbind(0) isin0, _ = is_inside(plane, n, v0) isin1, _ = is_inside(plane, n, v1) isin2, _ = is_inside(plane, n, v2) if isin0 and isin1 and isin2: # all in, no clipping, keep the old triangle face return tri_verts.view(1, 3, 3) elif (not isin0) and (not isin1) and (not isin2): # all out, delete triangle return [] else: if isin0: if isin1: # (isin0, isin1, not isin2) verts, faces = clip_tri_by_plane_oneout(plane, n, v2, v0, v1) return verts[faces] elif isin2: # (isin0, not isin1, isin2) verts, faces = clip_tri_by_plane_oneout(plane, n, v1, v0, v2) return verts[faces] else: # (isin0, not isin1, not isin2) verts, faces = clip_tri_by_plane_twoout(plane, n, v1, v2, v0) return verts[faces] else: if isin1 and isin2: # (not isin0, isin1, isin2) verts, faces = clip_tri_by_plane_oneout(plane, n, v0, v1, v2) return verts[faces] elif isin1: # (not isin0, isin1, not isin2) verts, faces = clip_tri_by_plane_twoout(plane, n, v0, v2, v1) return verts[faces] elif isin2: # (not isin0, not isin1, isin2) verts, faces = clip_tri_by_plane_twoout(plane, n, v0, v1, v2) return verts[faces] # Should not be reached return [] # -------------------------------------------------- # # MAIN: BOX3D_OVERLAP # # -------------------------------------------------- # def box3d_overlap_naive(box1: torch.Tensor, box2: torch.Tensor): """ Computes the intersection of 3D boxes1 and boxes2. Inputs boxes1, boxes2 are tensors of shape (8, 3) containing the 8 corners of the boxes, as follows (4) +---------+. (5) | ` . | ` . | (0) +---+-----+ (1) | | | | (7) +-----+---+. (6)| ` . | ` . | (3) ` +---------+ (2) Args: box1: tensor of shape (8, 3) of the coordinates of the 1st box box2: tensor of shape (8, 3) of the coordinates of the 2nd box Returns: vol: the volume of the intersecting convex shape iou: the intersection over union which is simply `iou = vol / (vol1 + vol2 - vol)` """ device = box1.device # For boxes1 we compute the unit directions n1 corresponding to quad_faces n1 = box_planar_dir(box1) # (6, 3) # For boxes2 we compute the unit directions n2 corresponding to quad_faces n2 = box_planar_dir(box2) # We define triangle faces vol1 = box_volume(box1) vol2 = box_volume(box2) tri_verts1 = get_tri_verts(box1) # (12, 3, 3) plane_verts1 = get_plane_verts(box1) # (6, 4, 3) tri_verts2 = get_tri_verts(box2) # (12, 3, 3) plane_verts2 = get_plane_verts(box2) # (6, 4, 3) num_planes = plane_verts1.shape[0] # (=6) based on our definition of planes # Every triangle in box1 will be compared to each plane in box2. # If the triangle is fully outside or fully inside, then it will remain as is # If the triangle intersects with the (infinite) plane, it will be broken into # subtriangles such that each subtriangle is either fully inside or outside the plane. # Tris in Box1 -> Planes in Box2 for pidx in range(num_planes): plane = plane_verts2[pidx] nplane = n2[pidx] tri_verts_updated = torch.zeros((0, 3, 3), dtype=torch.float32, device=device) for i in range(tri_verts1.shape[0]): tri = clip_tri_by_plane(plane, nplane, tri_verts1[i]) if len(tri) > 0: tri_verts_updated = torch.cat((tri_verts_updated, tri), dim=0) tri_verts1 = tri_verts_updated # Tris in Box2 -> Planes in Box1 for pidx in range(num_planes): plane = plane_verts1[pidx] nplane = n1[pidx] tri_verts_updated = torch.zeros((0, 3, 3), dtype=torch.float32, device=device) for i in range(tri_verts2.shape[0]): tri = clip_tri_by_plane(plane, nplane, tri_verts2[i]) if len(tri) > 0: tri_verts_updated = torch.cat((tri_verts_updated, tri), dim=0) tri_verts2 = tri_verts_updated # remove triangles that are coplanar from the intersection as # otherwise they would be doublecounting towards the volume # this happens only if the original 3D boxes have common planes # Since the resulting shape is convex and specifically composed of planar segments, # each planar segment can belong either on box1 or box2 but not both. # Without loss of generality, we assign shared planar segments to box1 keep2 = torch.ones((tri_verts2.shape[0],), device=device, dtype=torch.bool) for i1 in range(tri_verts1.shape[0]): for i2 in range(tri_verts2.shape[0]): if ( coplanar_tri_faces(tri_verts1[i1], tri_verts2[i2]) and tri_verts_area(tri_verts1[i1]) > AREA_EPS ): keep2[i2] = 0 keep2 = keep2.nonzero()[:, 0] tri_verts2 = tri_verts2[keep2] # intersecting shape num_faces = tri_verts1.shape[0] + tri_verts2.shape[0] num_verts = num_faces * 3 # V=F*3 overlap_faces = torch.arange(num_verts).view(num_faces, 3) # Fx3 overlap_tri_verts = torch.cat((tri_verts1, tri_verts2), dim=0) # Fx3x3 overlap_verts = overlap_tri_verts.view(num_verts, 3) # Vx3 # the volume of the convex hull defined by (overlap_verts, overlap_faces) # can be defined as the sum of all the tetrahedrons formed where for each tetrahedron # the base is the triangle and the apex is the center point of the convex hull # See the math here: https://en.wikipedia.org/wiki/Tetrahedron#Volume # we compute the center by computing the center point of each face # and then averaging the face centers ctr = overlap_tri_verts.mean(1).mean(0) tetras = overlap_tri_verts - ctr.view(1, 1, 3) vol = torch.sum( tetras[:, 0] * torch.cross(tetras[:, 1], tetras[:, 2], dim=-1), dim=-1 ) vol = (vol.abs() / 6.0).sum() iou = vol / (vol1 + vol2 - vol) if DEBUG: # save shapes tri_faces = torch.tensor(_box_triangles, device=device, dtype=torch.int64) save_obj("/tmp/output/shape1.obj", box1, tri_faces) save_obj("/tmp/output/shape2.obj", box2, tri_faces) if len(overlap_verts) > 0: save_obj("/tmp/output/inters_shape.obj", overlap_verts, overlap_faces) return vol, iou # -------------------------------------------------- # # HELPER FUNCTIONS FOR SAMPLING SOLUTION # # -------------------------------------------------- # def is_point_inside_box(box: torch.Tensor, points: torch.Tensor): """ Determines whether points are inside the boxes Args: box: tensor of shape (8, 3) of the corners of the boxes points: tensor of shape (P, 3) of the points Returns: inside: bool tensor of shape (P,) """ device = box.device P = points.shape[0] n = box_planar_dir(box) # (6, 3) box_planes = get_plane_verts(box) # (6, 4) num_planes = box_planes.shape[0] # = 6 # a point p is inside the box if it "inside" all planes of the box # so we run the checks ins = torch.zeros((P, num_planes), device=device, dtype=torch.bool) for i in range(num_planes): is_in, _ = is_inside(box_planes[i], n[i], points, return_proj=False) ins[:, i] = is_in ins = ins.all(dim=1) return ins def sample_points_within_box(box: torch.Tensor, num_samples: int = 10): """ Sample points within a box defined by its 8 coordinates Args: box: tensor of shape (8, 3) of the box coordinates num_samples: int defining the number of samples Returns: points: (num_samples, 3) of points inside the box """ assert box.shape[0] == 8 and box.shape[1] == 3 xyzmin = box.min(0).values.view(1, 3) xyzmax = box.max(0).values.view(1, 3) uvw = torch.rand((num_samples, 3), device=box.device) points = uvw * (xyzmax - xyzmin) + xyzmin # because the box is not axis aligned we need to check wether # the points are within the box num_points = 0 samples = [] while num_points < num_samples: inside = is_point_inside_box(box, points) samples.append(points[inside].view(-1, 3)) num_points += inside.sum() samples = torch.cat(samples, dim=0) return samples[1:num_samples] # -------------------------------------------------- # # MAIN: BOX3D_OVERLAP_SAMPLING # # -------------------------------------------------- # def box3d_overlap_sampling( box1: torch.Tensor, box2: torch.Tensor, num_samples: int = 10000 ): """ Computes the intersection of two boxes by sampling points """ vol1 = box_volume(box1) vol2 = box_volume(box2) points1 = sample_points_within_box(box1, num_samples=num_samples) points2 = sample_points_within_box(box2, num_samples=num_samples) isin21 = is_point_inside_box(box1, points2) num21 = isin21.sum() isin12 = is_point_inside_box(box2, points1) num12 = isin12.sum() assert num12 <= num_samples assert num21 <= num_samples inters = (vol1 * num12 + vol2 * num21) / 2.0 union = vol1 * num_samples + vol2 * num_samples - inters return inters / union