Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
from dockformerpp.model.primitives import Linear | |
from dockformerpp.utils.geometry.rigid_matrix_vector import Rigid3Array | |
from dockformerpp.utils.geometry.rotation_matrix import Rot3Array | |
from dockformerpp.utils.geometry.vector import Vec3Array | |
class QuatRigid(nn.Module): | |
def __init__(self, c_hidden, full_quat): | |
super().__init__() | |
self.full_quat = full_quat | |
if self.full_quat: | |
rigid_dim = 7 | |
else: | |
rigid_dim = 6 | |
self.linear = Linear(c_hidden, rigid_dim, init="final", precision=torch.float32) | |
def forward(self, activations: torch.Tensor) -> Rigid3Array: | |
# NOTE: During training, this needs to be run in higher precision | |
rigid_flat = self.linear(activations) | |
rigid_flat = torch.unbind(rigid_flat, dim=-1) | |
if(self.full_quat): | |
qw, qx, qy, qz = rigid_flat[:4] | |
translation = rigid_flat[4:] | |
else: | |
qx, qy, qz = rigid_flat[:3] | |
qw = torch.ones_like(qx) | |
translation = rigid_flat[3:] | |
rotation = Rot3Array.from_quaternion( | |
qw, qx, qy, qz, normalize=True, | |
) | |
translation = Vec3Array(*translation) | |
return Rigid3Array(rotation, translation) | |