Instructions to use Synthyra/Boltz2 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Synthyra/Boltz2 with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="Synthyra/Boltz2", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Synthyra/Boltz2", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| from abc import ABC, abstractmethod | |
| from typing import Optional, Dict, Any, Set, List, Union | |
| import torch | |
| import numpy as np | |
| from . import vb_const as const | |
| from .vb_potentials_schedules import ( | |
| ParameterSchedule, | |
| ExponentialInterpolation, | |
| PiecewiseStepFunction, | |
| ) | |
| from .vb_loss_diffusionv2 import weighted_rigid_align | |
| class Potential(ABC): | |
| def __init__( | |
| self, | |
| parameters: Optional[ | |
| Dict[str, Union[ParameterSchedule, float, int, bool]] | |
| ] = None, | |
| ): | |
| self.parameters = parameters | |
| def compute(self, coords, feats, parameters): | |
| index, args, com_args, ref_args, operator_args = self.compute_args( | |
| feats, parameters | |
| ) | |
| if index.shape[1] == 0: | |
| return torch.zeros(coords.shape[:-2], device=coords.device) | |
| if com_args is not None: | |
| com_index, atom_pad_mask = com_args | |
| unpad_com_index = com_index[atom_pad_mask] | |
| unpad_coords = coords[..., atom_pad_mask, :] | |
| coords = torch.zeros( | |
| (*unpad_coords.shape[:-2], unpad_com_index.max() + 1, 3), | |
| device=coords.device, | |
| ).scatter_reduce( | |
| -2, | |
| unpad_com_index.unsqueeze(-1).expand_as(unpad_coords), | |
| unpad_coords, | |
| "mean", | |
| ) | |
| else: | |
| com_index, atom_pad_mask = None, None | |
| if ref_args is not None: | |
| ref_coords, ref_mask, ref_atom_index, ref_token_index = ref_args | |
| coords = coords[..., ref_atom_index, :] | |
| else: | |
| ref_coords, ref_mask, ref_atom_index, ref_token_index = ( | |
| None, | |
| None, | |
| None, | |
| None, | |
| ) | |
| if operator_args is not None: | |
| negation_mask, union_index = operator_args | |
| else: | |
| negation_mask, union_index = None, None | |
| value = self.compute_variable( | |
| coords, | |
| index, | |
| ref_coords=ref_coords, | |
| ref_mask=ref_mask, | |
| compute_gradient=False, | |
| ) | |
| energy = self.compute_function( | |
| value, *args, negation_mask=negation_mask, compute_derivative=False | |
| ) | |
| if union_index is not None: | |
| neg_exp_energy = torch.exp(-1 * parameters["union_lambda"] * energy) | |
| Z = torch.zeros( | |
| (*energy.shape[:-1], union_index.max() + 1), device=union_index.device | |
| ).scatter_reduce( | |
| -1, | |
| union_index.expand_as(neg_exp_energy), | |
| neg_exp_energy, | |
| "sum", | |
| ) | |
| softmax_energy = neg_exp_energy / Z[..., union_index] | |
| softmax_energy[Z[..., union_index] == 0] = 0 | |
| return (energy * softmax_energy).sum(dim=-1) | |
| return energy.sum(dim=tuple(range(1, energy.dim()))) | |
| def compute_gradient(self, coords, feats, parameters): | |
| index, args, com_args, ref_args, operator_args = self.compute_args( | |
| feats, parameters | |
| ) | |
| if index.shape[1] == 0: | |
| return torch.zeros_like(coords) | |
| if com_args is not None: | |
| com_index, atom_pad_mask = com_args | |
| unpad_coords = coords[..., atom_pad_mask, :] | |
| unpad_com_index = com_index[atom_pad_mask] | |
| coords = torch.zeros( | |
| (*unpad_coords.shape[:-2], unpad_com_index.max() + 1, 3), | |
| device=coords.device, | |
| ).scatter_reduce( | |
| -2, | |
| unpad_com_index.unsqueeze(-1).expand_as(unpad_coords), | |
| unpad_coords, | |
| "mean", | |
| ) | |
| com_counts = torch.bincount(com_index[atom_pad_mask]) | |
| else: | |
| com_index, atom_pad_mask = None, None | |
| if ref_args is not None: | |
| ref_coords, ref_mask, ref_atom_index, ref_token_index = ref_args | |
| coords = coords[..., ref_atom_index, :] | |
| else: | |
| ref_coords, ref_mask, ref_atom_index, ref_token_index = ( | |
| None, | |
| None, | |
| None, | |
| None, | |
| ) | |
| if operator_args is not None: | |
| negation_mask, union_index = operator_args | |
| else: | |
| negation_mask, union_index = None, None | |
| value, grad_value = self.compute_variable( | |
| coords, | |
| index, | |
| ref_coords=ref_coords, | |
| ref_mask=ref_mask, | |
| compute_gradient=True, | |
| ) | |
| energy, dEnergy = self.compute_function( | |
| value, | |
| *args, negation_mask=negation_mask, compute_derivative=True | |
| ) | |
| if union_index is not None: | |
| neg_exp_energy = torch.exp(-1 * parameters["union_lambda"] * energy) | |
| Z = torch.zeros( | |
| (*energy.shape[:-1], union_index.max() + 1), device=union_index.device | |
| ).scatter_reduce( | |
| -1, | |
| union_index.expand_as(energy), | |
| neg_exp_energy, | |
| "sum", | |
| ) | |
| softmax_energy = neg_exp_energy / Z[..., union_index] | |
| softmax_energy[Z[..., union_index] == 0] = 0 | |
| f = torch.zeros( | |
| (*energy.shape[:-1], union_index.max() + 1), device=union_index.device | |
| ).scatter_reduce( | |
| -1, | |
| union_index.expand_as(energy), | |
| energy * softmax_energy, | |
| "sum", | |
| ) | |
| dSoftmax = ( | |
| dEnergy | |
| * softmax_energy | |
| * (1 + parameters["union_lambda"] * (energy - f[..., union_index])) | |
| ) | |
| prod = dSoftmax.tile(grad_value.shape[-3]).unsqueeze( | |
| -1 | |
| ) * grad_value.flatten(start_dim=-3, end_dim=-2) | |
| if prod.dim() > 3: | |
| prod = prod.sum(dim=list(range(1, prod.dim() - 2))) | |
| grad_atom = torch.zeros_like(coords).scatter_reduce( | |
| -2, | |
| index.flatten(start_dim=0, end_dim=1) | |
| .unsqueeze(-1) | |
| .expand((*coords.shape[:-2], -1, 3)), | |
| prod, | |
| "sum", | |
| ) | |
| else: | |
| prod = dEnergy.tile(grad_value.shape[-3]).unsqueeze( | |
| -1 | |
| ) * grad_value.flatten(start_dim=-3, end_dim=-2) | |
| if prod.dim() > 3: | |
| prod = prod.sum(dim=list(range(1, prod.dim() - 2))) | |
| grad_atom = torch.zeros_like(coords).scatter_reduce( | |
| -2, | |
| index.flatten(start_dim=0, end_dim=1) | |
| .unsqueeze(-1) | |
| .expand((*coords.shape[:-2], -1, 3)), # 9 x 516 x 3 | |
| prod, | |
| "sum", | |
| ) | |
| if com_index is not None: | |
| grad_atom = grad_atom[..., com_index, :] | |
| elif ref_token_index is not None: | |
| grad_atom = grad_atom[..., ref_token_index, :] | |
| return grad_atom | |
| def compute_parameters(self, t): | |
| if self.parameters is None: | |
| return None | |
| parameters = { | |
| name: parameter | |
| if not isinstance(parameter, ParameterSchedule) | |
| else parameter.compute(t) | |
| for name, parameter in self.parameters.items() | |
| } | |
| return parameters | |
| def compute_function( | |
| self, value, *args, negation_mask=None, compute_derivative=False | |
| ): | |
| raise NotImplementedError | |
| def compute_variable(self, coords, index, compute_gradient=False): | |
| raise NotImplementedError | |
| def compute_args(self, t, feats, **parameters): | |
| raise NotImplementedError | |
| def get_reference_coords(self, feats, parameters): | |
| return None, None | |
| class FlatBottomPotential(Potential): | |
| def compute_function( | |
| self, | |
| value, | |
| k, | |
| lower_bounds, | |
| upper_bounds, | |
| negation_mask=None, | |
| compute_derivative=False, | |
| ): | |
| if lower_bounds is None: | |
| lower_bounds = torch.full_like(value, float("-inf")) | |
| if upper_bounds is None: | |
| upper_bounds = torch.full_like(value, float("inf")) | |
| lower_bounds = lower_bounds.expand_as(value).clone() | |
| upper_bounds = upper_bounds.expand_as(value).clone() | |
| if negation_mask is not None: | |
| unbounded_below_mask = torch.isneginf(lower_bounds) | |
| unbounded_above_mask = torch.isposinf(upper_bounds) | |
| unbounded_mask = unbounded_below_mask + unbounded_above_mask | |
| assert torch.all(unbounded_mask + negation_mask) | |
| lower_bounds[~unbounded_above_mask * ~negation_mask] = upper_bounds[ | |
| ~unbounded_above_mask * ~negation_mask | |
| ] | |
| upper_bounds[~unbounded_above_mask * ~negation_mask] = float("inf") | |
| upper_bounds[~unbounded_below_mask * ~negation_mask] = lower_bounds[ | |
| ~unbounded_below_mask * ~negation_mask | |
| ] | |
| lower_bounds[~unbounded_below_mask * ~negation_mask] = float("-inf") | |
| neg_overflow_mask = value < lower_bounds | |
| pos_overflow_mask = value > upper_bounds | |
| energy = torch.zeros_like(value) | |
| energy[neg_overflow_mask] = (k * (lower_bounds - value))[neg_overflow_mask] | |
| energy[pos_overflow_mask] = (k * (value - upper_bounds))[pos_overflow_mask] | |
| if not compute_derivative: | |
| return energy | |
| dEnergy = torch.zeros_like(value) | |
| dEnergy[neg_overflow_mask] = ( | |
| -1 * k.expand_as(neg_overflow_mask)[neg_overflow_mask] | |
| ) | |
| dEnergy[pos_overflow_mask] = ( | |
| 1 * k.expand_as(pos_overflow_mask)[pos_overflow_mask] | |
| ) | |
| return energy, dEnergy | |
| class ReferencePotential(Potential): | |
| def compute_variable( | |
| self, coords, index, ref_coords, ref_mask, compute_gradient=False | |
| ): | |
| aligned_ref_coords = weighted_rigid_align( | |
| ref_coords.float(), | |
| coords[:, index].float(), | |
| ref_mask, | |
| ref_mask, | |
| ) | |
| r = coords[:, index] - aligned_ref_coords | |
| r_norm = torch.linalg.norm(r, dim=-1) | |
| if not compute_gradient: | |
| return r_norm | |
| r_hat = r / r_norm.unsqueeze(-1) | |
| grad = (r_hat * ref_mask.unsqueeze(-1)).unsqueeze(1) | |
| return r_norm, grad | |
| class DistancePotential(Potential): | |
| def compute_variable( | |
| self, coords, index, ref_coords=None, ref_mask=None, compute_gradient=False | |
| ): | |
| r_ij = coords.index_select(-2, index[0]) - coords.index_select(-2, index[1]) | |
| r_ij_norm = torch.linalg.norm(r_ij, dim=-1) | |
| r_hat_ij = r_ij / r_ij_norm.unsqueeze(-1) | |
| if not compute_gradient: | |
| return r_ij_norm | |
| grad_i = r_hat_ij | |
| grad_j = -1 * r_hat_ij | |
| grad = torch.stack((grad_i, grad_j), dim=1) | |
| return r_ij_norm, grad | |
| class DihedralPotential(Potential): | |
| def compute_variable( | |
| self, coords, index, ref_coords=None, ref_mask=None, compute_gradient=False | |
| ): | |
| r_ij = coords.index_select(-2, index[0]) - coords.index_select(-2, index[1]) | |
| r_kj = coords.index_select(-2, index[2]) - coords.index_select(-2, index[1]) | |
| r_kl = coords.index_select(-2, index[2]) - coords.index_select(-2, index[3]) | |
| n_ijk = torch.cross(r_ij, r_kj, dim=-1) | |
| n_jkl = torch.cross(r_kj, r_kl, dim=-1) | |
| r_kj_norm = torch.linalg.norm(r_kj, dim=-1) | |
| n_ijk_norm = torch.linalg.norm(n_ijk, dim=-1) | |
| n_jkl_norm = torch.linalg.norm(n_jkl, dim=-1) | |
| sign_phi = torch.sign( | |
| r_kj.unsqueeze(-2) @ torch.cross(n_ijk, n_jkl, dim=-1).unsqueeze(-1) | |
| ).squeeze(-1, -2) | |
| phi = sign_phi * torch.arccos( | |
| torch.clamp( | |
| (n_ijk.unsqueeze(-2) @ n_jkl.unsqueeze(-1)).squeeze(-1, -2) | |
| / (n_ijk_norm * n_jkl_norm), | |
| -1 + 1e-8, | |
| 1 - 1e-8, | |
| ) | |
| ) | |
| if not compute_gradient: | |
| return phi | |
| a = ( | |
| (r_ij.unsqueeze(-2) @ r_kj.unsqueeze(-1)).squeeze(-1, -2) / (r_kj_norm**2) | |
| ).unsqueeze(-1) | |
| b = ( | |
| (r_kl.unsqueeze(-2) @ r_kj.unsqueeze(-1)).squeeze(-1, -2) / (r_kj_norm**2) | |
| ).unsqueeze(-1) | |
| grad_i = n_ijk * (r_kj_norm / n_ijk_norm**2).unsqueeze(-1) | |
| grad_l = -1 * n_jkl * (r_kj_norm / n_jkl_norm**2).unsqueeze(-1) | |
| grad_j = (a - 1) * grad_i - b * grad_l | |
| grad_k = (b - 1) * grad_l - a * grad_i | |
| grad = torch.stack((grad_i, grad_j, grad_k, grad_l), dim=1) | |
| return phi, grad | |
| class AbsDihedralPotential(DihedralPotential): | |
| def compute_variable( | |
| self, coords, index, ref_coords=None, ref_mask=None, compute_gradient=False | |
| ): | |
| if not compute_gradient: | |
| phi = super().compute_variable( | |
| coords, index, compute_gradient=compute_gradient | |
| ) | |
| phi = torch.abs(phi) | |
| return phi | |
| phi, grad = super().compute_variable( | |
| coords, index, compute_gradient=compute_gradient | |
| ) | |
| grad[(phi < 0)[..., None, :, None].expand_as(grad)] *= -1 | |
| phi = torch.abs(phi) | |
| return phi, grad | |
| class PoseBustersPotential(FlatBottomPotential, DistancePotential): | |
| def compute_args(self, feats, parameters): | |
| pair_index = feats["rdkit_bounds_index"][0] | |
| lower_bounds = feats["rdkit_lower_bounds"][0].clone() | |
| upper_bounds = feats["rdkit_upper_bounds"][0].clone() | |
| bond_mask = feats["rdkit_bounds_bond_mask"][0] | |
| angle_mask = feats["rdkit_bounds_angle_mask"][0] | |
| lower_bounds[bond_mask * ~angle_mask] *= 1.0 - parameters["bond_buffer"] | |
| upper_bounds[bond_mask * ~angle_mask] *= 1.0 + parameters["bond_buffer"] | |
| lower_bounds[~bond_mask * angle_mask] *= 1.0 - parameters["angle_buffer"] | |
| upper_bounds[~bond_mask * angle_mask] *= 1.0 + parameters["angle_buffer"] | |
| lower_bounds[bond_mask * angle_mask] *= 1.0 - min( | |
| parameters["bond_buffer"], parameters["angle_buffer"] | |
| ) | |
| upper_bounds[bond_mask * angle_mask] *= 1.0 + min( | |
| parameters["bond_buffer"], parameters["angle_buffer"] | |
| ) | |
| lower_bounds[~bond_mask * ~angle_mask] *= 1.0 - parameters["clash_buffer"] | |
| upper_bounds[~bond_mask * ~angle_mask] = float("inf") | |
| vdw_radii = torch.zeros( | |
| const.num_elements, dtype=torch.float32, device=pair_index.device | |
| ) | |
| vdw_radii[1:119] = torch.tensor( | |
| const.vdw_radii, dtype=torch.float32, device=pair_index.device | |
| ) | |
| atom_vdw_radii = ( | |
| feats["ref_element"].float() @ vdw_radii.unsqueeze(-1) | |
| ).squeeze(-1)[0] | |
| bond_cutoffs = 0.35 + atom_vdw_radii[pair_index].mean(dim=0) | |
| lower_bounds[~bond_mask] = torch.max(lower_bounds[~bond_mask], bond_cutoffs[~bond_mask]) | |
| upper_bounds[bond_mask] = torch.min(upper_bounds[bond_mask], bond_cutoffs[bond_mask]) | |
| k = torch.ones_like(lower_bounds) | |
| return pair_index, (k, lower_bounds, upper_bounds), None, None, None | |
| class ConnectionsPotential(FlatBottomPotential, DistancePotential): | |
| def compute_args(self, feats, parameters): | |
| pair_index = feats["connected_atom_index"][0] | |
| lower_bounds = None | |
| upper_bounds = torch.full( | |
| (pair_index.shape[1],), parameters["buffer"], device=pair_index.device | |
| ) | |
| k = torch.ones_like(upper_bounds) | |
| return pair_index, (k, lower_bounds, upper_bounds), None, None, None | |
| class VDWOverlapPotential(FlatBottomPotential, DistancePotential): | |
| def compute_args(self, feats, parameters): | |
| atom_chain_id = ( | |
| torch.bmm( | |
| feats["atom_to_token"].float(), feats["asym_id"].unsqueeze(-1).float() | |
| ) | |
| .squeeze(-1) | |
| .long() | |
| )[0] | |
| atom_pad_mask = feats["atom_pad_mask"][0].bool() | |
| chain_sizes = torch.bincount(atom_chain_id[atom_pad_mask]) | |
| single_ion_mask = (chain_sizes > 1)[atom_chain_id] | |
| vdw_radii = torch.zeros( | |
| const.num_elements, dtype=torch.float32, device=atom_chain_id.device | |
| ) | |
| vdw_radii[1:119] = torch.tensor( | |
| const.vdw_radii, dtype=torch.float32, device=atom_chain_id.device | |
| ) | |
| atom_vdw_radii = ( | |
| feats["ref_element"].float() @ vdw_radii.unsqueeze(-1) | |
| ).squeeze(-1)[0] | |
| pair_index = torch.triu_indices( | |
| atom_chain_id.shape[0], | |
| atom_chain_id.shape[0], | |
| 1, | |
| device=atom_chain_id.device, | |
| ) | |
| pair_pad_mask = atom_pad_mask[pair_index].all(dim=0) | |
| pair_ion_mask = single_ion_mask[pair_index[0]] * single_ion_mask[pair_index[1]] | |
| num_chains = atom_chain_id.max() + 1 | |
| connected_chain_index = feats["connected_chain_index"][0] | |
| connected_chain_matrix = torch.eye( | |
| num_chains, device=atom_chain_id.device, dtype=torch.bool | |
| ) | |
| connected_chain_matrix[connected_chain_index[0], connected_chain_index[1]] = ( | |
| True | |
| ) | |
| connected_chain_matrix[connected_chain_index[1], connected_chain_index[0]] = ( | |
| True | |
| ) | |
| connected_chain_mask = connected_chain_matrix[ | |
| atom_chain_id[pair_index[0]], atom_chain_id[pair_index[1]] | |
| ] | |
| pair_index = pair_index[ | |
| :, pair_pad_mask * pair_ion_mask * ~connected_chain_mask | |
| ] | |
| lower_bounds = atom_vdw_radii[pair_index].sum(dim=0) * ( | |
| 1.0 - parameters["buffer"] | |
| ) | |
| upper_bounds = None | |
| k = torch.ones_like(lower_bounds) | |
| return pair_index, (k, lower_bounds, upper_bounds), None, None, None | |
| class SymmetricChainCOMPotential(FlatBottomPotential, DistancePotential): | |
| def compute_args(self, feats, parameters): | |
| atom_chain_id = ( | |
| torch.bmm( | |
| feats["atom_to_token"].float(), feats["asym_id"].unsqueeze(-1).float() | |
| ) | |
| .squeeze(-1) | |
| .long() | |
| )[0] | |
| atom_pad_mask = feats["atom_pad_mask"][0].bool() | |
| chain_sizes = torch.bincount(atom_chain_id[atom_pad_mask]) | |
| single_ion_mask = chain_sizes > 1 | |
| pair_index = feats["symmetric_chain_index"][0] | |
| pair_ion_mask = single_ion_mask[pair_index[0]] * single_ion_mask[pair_index[1]] | |
| pair_index = pair_index[:, pair_ion_mask] | |
| lower_bounds = torch.full( | |
| (pair_index.shape[1],), | |
| parameters["buffer"], | |
| dtype=torch.float32, | |
| device=pair_index.device, | |
| ) | |
| upper_bounds = None | |
| k = torch.ones_like(lower_bounds) | |
| return ( | |
| pair_index, | |
| (k, lower_bounds, upper_bounds), | |
| (atom_chain_id, atom_pad_mask), | |
| None, | |
| None, | |
| ) | |
| class StereoBondPotential(FlatBottomPotential, AbsDihedralPotential): | |
| def compute_args(self, feats, parameters): | |
| stereo_bond_index = feats["stereo_bond_index"][0] | |
| stereo_bond_orientations = feats["stereo_bond_orientations"][0].bool() | |
| lower_bounds = torch.zeros( | |
| stereo_bond_orientations.shape, device=stereo_bond_orientations.device | |
| ) | |
| upper_bounds = torch.zeros( | |
| stereo_bond_orientations.shape, device=stereo_bond_orientations.device | |
| ) | |
| lower_bounds[stereo_bond_orientations] = torch.pi - parameters["buffer"] | |
| upper_bounds[stereo_bond_orientations] = float("inf") | |
| lower_bounds[~stereo_bond_orientations] = float("-inf") | |
| upper_bounds[~stereo_bond_orientations] = parameters["buffer"] | |
| k = torch.ones_like(lower_bounds) | |
| return stereo_bond_index, (k, lower_bounds, upper_bounds), None, None, None | |
| class ChiralAtomPotential(FlatBottomPotential, DihedralPotential): | |
| def compute_args(self, feats, parameters): | |
| chiral_atom_index = feats["chiral_atom_index"][0] | |
| chiral_atom_orientations = feats["chiral_atom_orientations"][0].bool() | |
| lower_bounds = torch.zeros( | |
| chiral_atom_orientations.shape, device=chiral_atom_orientations.device | |
| ) | |
| upper_bounds = torch.zeros( | |
| chiral_atom_orientations.shape, device=chiral_atom_orientations.device | |
| ) | |
| lower_bounds[chiral_atom_orientations] = parameters["buffer"] | |
| upper_bounds[chiral_atom_orientations] = float("inf") | |
| upper_bounds[~chiral_atom_orientations] = -1 * parameters["buffer"] | |
| lower_bounds[~chiral_atom_orientations] = float("-inf") | |
| k = torch.ones_like(lower_bounds) | |
| return chiral_atom_index, (k, lower_bounds, upper_bounds), None, None, None | |
| class PlanarBondPotential(FlatBottomPotential, AbsDihedralPotential): | |
| def compute_args(self, feats, parameters): | |
| double_bond_index = feats["planar_bond_index"][0].T | |
| double_bond_improper_index = torch.tensor( | |
| [ | |
| [1, 2, 3, 0], | |
| [4, 5, 0, 3], | |
| ], | |
| device=double_bond_index.device, | |
| ).T | |
| improper_index = ( | |
| double_bond_index[:, double_bond_improper_index] | |
| .swapaxes(0, 1) | |
| .flatten(start_dim=1) | |
| ) | |
| lower_bounds = None | |
| upper_bounds = torch.full( | |
| (improper_index.shape[1],), | |
| parameters["buffer"], | |
| device=improper_index.device, | |
| ) | |
| k = torch.ones_like(upper_bounds) | |
| return improper_index, (k, lower_bounds, upper_bounds), None, None, None | |
| class TemplateReferencePotential(FlatBottomPotential, ReferencePotential): | |
| def compute_args(self, feats, parameters): | |
| if "template_mask_cb" not in feats or "template_force" not in feats: | |
| return torch.empty([1, 0]), None, None, None, None | |
| template_mask = feats["template_mask_cb"][feats["template_force"]] | |
| if template_mask.shape[0] == 0: | |
| return torch.empty([1, 0]), None, None, None, None | |
| ref_coords = feats["template_cb"][feats["template_force"]].clone() | |
| ref_mask = feats["template_mask_cb"][feats["template_force"]].clone() | |
| ref_atom_index = ( | |
| torch.bmm( | |
| feats["token_to_rep_atom"].float(), | |
| torch.arange( | |
| feats["atom_pad_mask"].shape[1], | |
| device=feats["atom_pad_mask"].device, | |
| dtype=torch.float32, | |
| )[None, :, None], | |
| ) | |
| .squeeze(-1) | |
| .long() | |
| )[0] | |
| ref_token_index = ( | |
| torch.bmm( | |
| feats["atom_to_token"].float(), | |
| feats["token_index"].unsqueeze(-1).float(), | |
| ) | |
| .squeeze(-1) | |
| .long() | |
| )[0] | |
| index = torch.arange( | |
| template_mask.shape[-1], dtype=torch.long, device=template_mask.device | |
| )[None] | |
| upper_bounds = torch.full( | |
| template_mask.shape, float("inf"), device=index.device, dtype=torch.float32 | |
| ) | |
| ref_idxs = torch.argwhere(template_mask).T | |
| upper_bounds[ref_idxs.unbind()] = feats["template_force_threshold"][ | |
| feats["template_force"] | |
| ][ref_idxs[0]] | |
| lower_bounds = None | |
| k = torch.ones_like(upper_bounds) | |
| return ( | |
| index, | |
| (k, lower_bounds, upper_bounds), | |
| None, | |
| (ref_coords, ref_mask, ref_atom_index, ref_token_index), | |
| None, | |
| ) | |
| class ContactPotentital(FlatBottomPotential, DistancePotential): | |
| def compute_args(self, feats, parameters): | |
| index = feats["contact_pair_index"][0] | |
| union_index = feats["contact_union_index"][0] | |
| negation_mask = feats["contact_negation_mask"][0] | |
| lower_bounds = None | |
| upper_bounds = feats["contact_thresholds"][0].clone() | |
| k = torch.ones_like(upper_bounds) | |
| return ( | |
| index, | |
| (k, lower_bounds, upper_bounds), | |
| None, | |
| None, | |
| (negation_mask, union_index), | |
| ) | |
| def get_potentials(steering_args, boltz2=False): | |
| potentials = [] | |
| if steering_args["fk_steering"] or steering_args["physical_guidance_update"]: | |
| potentials.extend( | |
| [ | |
| SymmetricChainCOMPotential( | |
| parameters={ | |
| "guidance_interval": 4, | |
| "guidance_weight": 0.5 | |
| if steering_args["physical_guidance_update"] | |
| else 0.0, | |
| "resampling_weight": 0.5, | |
| "buffer": ExponentialInterpolation( | |
| start=1.0, end=5.0, alpha=-2.0 | |
| ), | |
| } | |
| ), | |
| VDWOverlapPotential( | |
| parameters={ | |
| "guidance_interval": 5, | |
| "guidance_weight": ( | |
| PiecewiseStepFunction(thresholds=[0.4], values=[0.125, 0.0]) | |
| if steering_args["physical_guidance_update"] | |
| else 0.0 | |
| ), | |
| "resampling_weight": PiecewiseStepFunction( | |
| thresholds=[0.6], values=[0.01, 0.0] | |
| ), | |
| "buffer": 0.225, | |
| } | |
| ), | |
| ConnectionsPotential( | |
| parameters={ | |
| "guidance_interval": 1, | |
| "guidance_weight": 0.15 | |
| if steering_args["physical_guidance_update"] | |
| else 0.0, | |
| "resampling_weight": 1.0, | |
| "buffer": 2.0, | |
| } | |
| ), | |
| PoseBustersPotential( | |
| parameters={ | |
| "guidance_interval": 1, | |
| "guidance_weight": 0.01 | |
| if steering_args["physical_guidance_update"] | |
| else 0.0, | |
| "resampling_weight": 0.1, | |
| "bond_buffer": 0.125, | |
| "angle_buffer": 0.125, | |
| "clash_buffer": 0.10, | |
| } | |
| ), | |
| ChiralAtomPotential( | |
| parameters={ | |
| "guidance_interval": 1, | |
| "guidance_weight": 0.1 | |
| if steering_args["physical_guidance_update"] | |
| else 0.0, | |
| "resampling_weight": 1.0, | |
| "buffer": 0.52360, | |
| } | |
| ), | |
| StereoBondPotential( | |
| parameters={ | |
| "guidance_interval": 1, | |
| "guidance_weight": 0.05 | |
| if steering_args["physical_guidance_update"] | |
| else 0.0, | |
| "resampling_weight": 1.0, | |
| "buffer": 0.52360, | |
| } | |
| ), | |
| PlanarBondPotential( | |
| parameters={ | |
| "guidance_interval": 1, | |
| "guidance_weight": 0.05 | |
| if steering_args["physical_guidance_update"] | |
| else 0.0, | |
| "resampling_weight": 1.0, | |
| "buffer": 0.26180, | |
| } | |
| ), | |
| ] | |
| ) | |
| if boltz2 and ( | |
| steering_args["fk_steering"] or steering_args["contact_guidance_update"] | |
| ): | |
| potentials.extend( | |
| [ | |
| ContactPotentital( | |
| parameters={ | |
| "guidance_interval": 4, | |
| "guidance_weight": ( | |
| PiecewiseStepFunction( | |
| thresholds=[0.25, 0.75], values=[0.0, 0.5, 1.0] | |
| ) | |
| if steering_args["contact_guidance_update"] | |
| else 0.0 | |
| ), | |
| "resampling_weight": 1.0, | |
| "union_lambda": ExponentialInterpolation( | |
| start=8.0, end=0.0, alpha=-2.0 | |
| ), | |
| } | |
| ), | |
| TemplateReferencePotential( | |
| parameters={ | |
| "guidance_interval": 2, | |
| "guidance_weight": 0.1 | |
| if steering_args["contact_guidance_update"] | |
| else 0.0, | |
| "resampling_weight": 1.0, | |
| } | |
| ), | |
| ] | |
| ) | |
| return potentials | |