Instructions to use Synthyra/Boltz2 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Synthyra/Boltz2 with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="Synthyra/Boltz2", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Synthyra/Boltz2", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| from __future__ import annotations | |
| import torch | |
| from torch import nn | |
| from torch.nn import Module | |
| from .vb_modules_encodersv2 import ( | |
| AtomEncoder, | |
| PairwiseConditioning, | |
| ) | |
| class DiffusionConditioning(Module): | |
| def __init__( | |
| self, | |
| token_s: int, | |
| token_z: int, | |
| atom_s: int, | |
| atom_z: int, | |
| atoms_per_window_queries: int = 32, | |
| atoms_per_window_keys: int = 128, | |
| atom_encoder_depth: int = 3, | |
| atom_encoder_heads: int = 4, | |
| token_transformer_depth: int = 24, | |
| token_transformer_heads: int = 8, | |
| atom_decoder_depth: int = 3, | |
| atom_decoder_heads: int = 4, | |
| atom_feature_dim: int = 128, | |
| conditioning_transition_layers: int = 2, | |
| use_no_atom_char: bool = False, | |
| use_atom_backbone_feat: bool = False, | |
| use_residue_feats_atoms: bool = False, | |
| ) -> None: | |
| super().__init__() | |
| self.pairwise_conditioner = PairwiseConditioning( | |
| token_z=token_z, | |
| dim_token_rel_pos_feats=token_z, | |
| num_transitions=conditioning_transition_layers, | |
| ) | |
| self.atom_encoder = AtomEncoder( | |
| atom_s=atom_s, | |
| atom_z=atom_z, | |
| token_s=token_s, | |
| token_z=token_z, | |
| atoms_per_window_queries=atoms_per_window_queries, | |
| atoms_per_window_keys=atoms_per_window_keys, | |
| atom_feature_dim=atom_feature_dim, | |
| structure_prediction=True, | |
| use_no_atom_char=use_no_atom_char, | |
| use_atom_backbone_feat=use_atom_backbone_feat, | |
| use_residue_feats_atoms=use_residue_feats_atoms, | |
| ) | |
| self.atom_enc_proj_z = nn.ModuleList() | |
| for _ in range(atom_encoder_depth): | |
| self.atom_enc_proj_z.append( | |
| nn.Sequential( | |
| nn.LayerNorm(atom_z), | |
| nn.Linear(atom_z, atom_encoder_heads, bias=False), | |
| ) | |
| ) | |
| self.atom_dec_proj_z = nn.ModuleList() | |
| for _ in range(atom_decoder_depth): | |
| self.atom_dec_proj_z.append( | |
| nn.Sequential( | |
| nn.LayerNorm(atom_z), | |
| nn.Linear(atom_z, atom_decoder_heads, bias=False), | |
| ) | |
| ) | |
| self.token_trans_proj_z = nn.ModuleList() | |
| for _ in range(token_transformer_depth): | |
| self.token_trans_proj_z.append( | |
| nn.Sequential( | |
| nn.LayerNorm(token_z), | |
| nn.Linear(token_z, token_transformer_heads, bias=False), | |
| ) | |
| ) | |
| def forward( | |
| self, | |
| s_trunk, # Float['b n ts'] | |
| z_trunk, # Float['b n n tz'] | |
| relative_position_encoding, # Float['b n n tz'] | |
| feats, | |
| ): | |
| z = self.pairwise_conditioner( | |
| z_trunk, | |
| relative_position_encoding, | |
| ) | |
| q, c, p, to_keys = self.atom_encoder( | |
| feats=feats, | |
| s_trunk=s_trunk, # Float['b n ts'], | |
| z=z, # Float['b n n tz'], | |
| ) | |
| atom_enc_bias = [] | |
| for layer in self.atom_enc_proj_z: | |
| atom_enc_bias.append(layer(p)) | |
| atom_enc_bias = torch.cat(atom_enc_bias, dim=-1) | |
| atom_dec_bias = [] | |
| for layer in self.atom_dec_proj_z: | |
| atom_dec_bias.append(layer(p)) | |
| atom_dec_bias = torch.cat(atom_dec_bias, dim=-1) | |
| token_trans_bias = [] | |
| for layer in self.token_trans_proj_z: | |
| token_trans_bias.append(layer(z)) | |
| token_trans_bias = torch.cat(token_trans_bias, dim=-1) | |
| return q, c, to_keys, atom_enc_bias, atom_dec_bias, token_trans_bias | |