Linly-Talker / pytorch3d /tests /test_mesh_edge_loss.py
linxianzhong0128's picture
Upload folder using huggingface_hub
7088d16 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import torch
from pytorch3d.loss import mesh_edge_loss
from pytorch3d.structures import Meshes
from .common_testing import TestCaseMixin
from .test_sample_points_from_meshes import init_meshes
class TestMeshEdgeLoss(TestCaseMixin, unittest.TestCase):
def test_empty_meshes(self):
device = torch.device("cuda:0")
target_length = 0
N = 10
V = 32
verts_list = []
faces_list = []
for _ in range(N):
vn = torch.randint(3, high=V, size=(1,))[0].item()
verts = torch.rand((vn, 3), dtype=torch.float32, device=device)
faces = torch.tensor([], dtype=torch.int64, device=device)
verts_list.append(verts)
faces_list.append(faces)
mesh = Meshes(verts=verts_list, faces=faces_list)
loss = mesh_edge_loss(mesh, target_length=target_length)
self.assertClose(loss, torch.tensor([0.0], dtype=torch.float32, device=device))
self.assertTrue(loss.requires_grad)
@staticmethod
def mesh_edge_loss_naive(meshes, target_length: float = 0.0):
"""
Naive iterative implementation of mesh loss calculation.
"""
edges_packed = meshes.edges_packed()
verts_packed = meshes.verts_packed()
edge_to_mesh = meshes.edges_packed_to_mesh_idx()
N = len(meshes)
device = meshes.device
valid = meshes.valid
predlosses = torch.zeros((N,), dtype=torch.float32, device=device)
for b in range(N):
if valid[b] == 0:
continue
mesh_edges = edges_packed[edge_to_mesh == b]
verts_edges = verts_packed[mesh_edges]
num_edges = mesh_edges.size(0)
for e in range(num_edges):
v0, v1 = verts_edges[e, 0], verts_edges[e, 1]
predlosses[b] += ((v0 - v1).norm(dim=0, p=2) - target_length) ** 2.0
if num_edges > 0:
predlosses[b] = predlosses[b] / num_edges
return predlosses.mean()
def test_mesh_edge_loss_output(self):
"""
Check outputs of tensorized and iterative implementations are the same.
"""
device = torch.device("cuda:0")
target_length = 0.5
num_meshes = 10
num_verts = 32
num_faces = 64
verts_list = []
faces_list = []
valid = torch.randint(2, size=(num_meshes,))
for n in range(num_meshes):
if valid[n]:
vn = torch.randint(3, high=num_verts, size=(1,))[0].item()
fn = torch.randint(vn, high=num_faces, size=(1,))[0].item()
verts = torch.rand((vn, 3), dtype=torch.float32, device=device)
faces = torch.randint(
vn, size=(fn, 3), dtype=torch.int64, device=device
)
else:
verts = torch.tensor([], dtype=torch.float32, device=device)
faces = torch.tensor([], dtype=torch.int64, device=device)
verts_list.append(verts)
faces_list.append(faces)
meshes = Meshes(verts=verts_list, faces=faces_list)
loss = mesh_edge_loss(meshes, target_length=target_length)
predloss = TestMeshEdgeLoss.mesh_edge_loss_naive(meshes, target_length)
self.assertClose(loss, predloss)
@staticmethod
def mesh_edge_loss(num_meshes: int = 10, max_v: int = 100, max_f: int = 300):
meshes = init_meshes(num_meshes, max_v, max_f, device="cuda:0")
torch.cuda.synchronize()
def compute_loss():
mesh_edge_loss(meshes, target_length=0.0)
torch.cuda.synchronize()
return compute_loss