Spaces:
Running
Running
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the BSD-style license found in the | |
# LICENSE file in the root directory of this source tree. | |
import unittest | |
import torch | |
from pytorch3d.loss import mesh_edge_loss | |
from pytorch3d.structures import Meshes | |
from .common_testing import TestCaseMixin | |
from .test_sample_points_from_meshes import init_meshes | |
class TestMeshEdgeLoss(TestCaseMixin, unittest.TestCase): | |
def test_empty_meshes(self): | |
device = torch.device("cuda:0") | |
target_length = 0 | |
N = 10 | |
V = 32 | |
verts_list = [] | |
faces_list = [] | |
for _ in range(N): | |
vn = torch.randint(3, high=V, size=(1,))[0].item() | |
verts = torch.rand((vn, 3), dtype=torch.float32, device=device) | |
faces = torch.tensor([], dtype=torch.int64, device=device) | |
verts_list.append(verts) | |
faces_list.append(faces) | |
mesh = Meshes(verts=verts_list, faces=faces_list) | |
loss = mesh_edge_loss(mesh, target_length=target_length) | |
self.assertClose(loss, torch.tensor([0.0], dtype=torch.float32, device=device)) | |
self.assertTrue(loss.requires_grad) | |
def mesh_edge_loss_naive(meshes, target_length: float = 0.0): | |
""" | |
Naive iterative implementation of mesh loss calculation. | |
""" | |
edges_packed = meshes.edges_packed() | |
verts_packed = meshes.verts_packed() | |
edge_to_mesh = meshes.edges_packed_to_mesh_idx() | |
N = len(meshes) | |
device = meshes.device | |
valid = meshes.valid | |
predlosses = torch.zeros((N,), dtype=torch.float32, device=device) | |
for b in range(N): | |
if valid[b] == 0: | |
continue | |
mesh_edges = edges_packed[edge_to_mesh == b] | |
verts_edges = verts_packed[mesh_edges] | |
num_edges = mesh_edges.size(0) | |
for e in range(num_edges): | |
v0, v1 = verts_edges[e, 0], verts_edges[e, 1] | |
predlosses[b] += ((v0 - v1).norm(dim=0, p=2) - target_length) ** 2.0 | |
if num_edges > 0: | |
predlosses[b] = predlosses[b] / num_edges | |
return predlosses.mean() | |
def test_mesh_edge_loss_output(self): | |
""" | |
Check outputs of tensorized and iterative implementations are the same. | |
""" | |
device = torch.device("cuda:0") | |
target_length = 0.5 | |
num_meshes = 10 | |
num_verts = 32 | |
num_faces = 64 | |
verts_list = [] | |
faces_list = [] | |
valid = torch.randint(2, size=(num_meshes,)) | |
for n in range(num_meshes): | |
if valid[n]: | |
vn = torch.randint(3, high=num_verts, size=(1,))[0].item() | |
fn = torch.randint(vn, high=num_faces, size=(1,))[0].item() | |
verts = torch.rand((vn, 3), dtype=torch.float32, device=device) | |
faces = torch.randint( | |
vn, size=(fn, 3), dtype=torch.int64, device=device | |
) | |
else: | |
verts = torch.tensor([], dtype=torch.float32, device=device) | |
faces = torch.tensor([], dtype=torch.int64, device=device) | |
verts_list.append(verts) | |
faces_list.append(faces) | |
meshes = Meshes(verts=verts_list, faces=faces_list) | |
loss = mesh_edge_loss(meshes, target_length=target_length) | |
predloss = TestMeshEdgeLoss.mesh_edge_loss_naive(meshes, target_length) | |
self.assertClose(loss, predloss) | |
def mesh_edge_loss(num_meshes: int = 10, max_v: int = 100, max_f: int = 300): | |
meshes = init_meshes(num_meshes, max_v, max_f, device="cuda:0") | |
torch.cuda.synchronize() | |
def compute_loss(): | |
mesh_edge_loss(meshes, target_length=0.0) | |
torch.cuda.synchronize() | |
return compute_loss | |