import torch import torch.nn as nn from dockformerpp.model.primitives import Linear from dockformerpp.utils.geometry.rigid_matrix_vector import Rigid3Array from dockformerpp.utils.geometry.rotation_matrix import Rot3Array from dockformerpp.utils.geometry.vector import Vec3Array class QuatRigid(nn.Module): def __init__(self, c_hidden, full_quat): super().__init__() self.full_quat = full_quat if self.full_quat: rigid_dim = 7 else: rigid_dim = 6 self.linear = Linear(c_hidden, rigid_dim, init="final", precision=torch.float32) def forward(self, activations: torch.Tensor) -> Rigid3Array: # NOTE: During training, this needs to be run in higher precision rigid_flat = self.linear(activations) rigid_flat = torch.unbind(rigid_flat, dim=-1) if(self.full_quat): qw, qx, qy, qz = rigid_flat[:4] translation = rigid_flat[4:] else: qx, qy, qz = rigid_flat[:3] qw = torch.ones_like(qx) translation = rigid_flat[3:] rotation = Rot3Array.from_quaternion( qw, qx, qy, qz, normalize=True, ) translation = Vec3Array(*translation) return Rigid3Array(rotation, translation)