diff --git a/rfdiffusion.egg-info/PKG-INFO b/rfdiffusion.egg-info/PKG-INFO new file mode 100644 index 0000000000000000000000000000000000000000..f19686041fc62ab5803b5ecc08226d6b49e1c527 --- /dev/null +++ b/rfdiffusion.egg-info/PKG-INFO @@ -0,0 +1,7 @@ +Metadata-Version: 2.1 +Name: rfdiffusion +Version: 1.1.0 +Summary: RFdiffusion is an open source method for protein structure generation. +Home-page: https://github.com/RosettaCommons/RFdiffusion +Author: Rosetta Commons +License-File: LICENSE diff --git a/rfdiffusion.egg-info/SOURCES.txt b/rfdiffusion.egg-info/SOURCES.txt new file mode 100644 index 0000000000000000000000000000000000000000..96d49e02fc55433caefef8fb0180ae3c37b1065f --- /dev/null +++ b/rfdiffusion.egg-info/SOURCES.txt @@ -0,0 +1,34 @@ +LICENSE +README.md +setup.py +rfdiffusion/Attention_module.py +rfdiffusion/AuxiliaryPredictor.py +rfdiffusion/Embeddings.py +rfdiffusion/RoseTTAFoldModel.py +rfdiffusion/SE3_network.py +rfdiffusion/Track_module.py +rfdiffusion/__init__.py +rfdiffusion/chemical.py +rfdiffusion/contigs.py +rfdiffusion/coords6d.py +rfdiffusion/diffusion.py +rfdiffusion/igso3.py +rfdiffusion/kinematics.py +rfdiffusion/model_input_logger.py +rfdiffusion/scoring.py +rfdiffusion/util.py +rfdiffusion/util_module.py +rfdiffusion.egg-info/PKG-INFO +rfdiffusion.egg-info/SOURCES.txt +rfdiffusion.egg-info/dependency_links.txt +rfdiffusion.egg-info/requires.txt +rfdiffusion.egg-info/top_level.txt +rfdiffusion/inference/__init__.py +rfdiffusion/inference/model_runners.py +rfdiffusion/inference/symmetry.py +rfdiffusion/inference/utils.py +rfdiffusion/potentials/__init__.py +rfdiffusion/potentials/manager.py +rfdiffusion/potentials/potentials.py +scripts/run_inference.py +tests/test_diffusion.py \ No newline at end of file diff --git a/rfdiffusion.egg-info/dependency_links.txt b/rfdiffusion.egg-info/dependency_links.txt new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/rfdiffusion.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/rfdiffusion.egg-info/requires.txt b/rfdiffusion.egg-info/requires.txt new file mode 100644 index 0000000000000000000000000000000000000000..2f903b33fa84cf703dd285176dd3ca5da79ade11 --- /dev/null +++ b/rfdiffusion.egg-info/requires.txt @@ -0,0 +1,2 @@ +torch +se3-transformer diff --git a/rfdiffusion.egg-info/top_level.txt b/rfdiffusion.egg-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..185ab8e8a379d672865331e6b692b2f0dc8e6aad --- /dev/null +++ b/rfdiffusion.egg-info/top_level.txt @@ -0,0 +1 @@ +rfdiffusion diff --git a/rfdiffusion/Attention_module.py b/rfdiffusion/Attention_module.py new file mode 100644 index 0000000000000000000000000000000000000000..f8868fc27d9cbe0afb134d6d3a0c5b9cba0b1de2 --- /dev/null +++ b/rfdiffusion/Attention_module.py @@ -0,0 +1,404 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +from opt_einsum import contract as einsum +from rfdiffusion.util_module import init_lecun_normal + +class FeedForwardLayer(nn.Module): + def __init__(self, d_model, r_ff, p_drop=0.1): + super(FeedForwardLayer, self).__init__() + self.norm = nn.LayerNorm(d_model) + self.linear1 = nn.Linear(d_model, d_model*r_ff) + self.dropout = nn.Dropout(p_drop) + self.linear2 = nn.Linear(d_model*r_ff, d_model) + + self.reset_parameter() + + def reset_parameter(self): + # initialize linear layer right before ReLu: He initializer (kaiming normal) + nn.init.kaiming_normal_(self.linear1.weight, nonlinearity='relu') + nn.init.zeros_(self.linear1.bias) + + # initialize linear layer right before residual connection: zero initialize + nn.init.zeros_(self.linear2.weight) + nn.init.zeros_(self.linear2.bias) + + def forward(self, src): + src = self.norm(src) + src = self.linear2(self.dropout(F.relu_(self.linear1(src)))) + return src + +class Attention(nn.Module): + # calculate multi-head attention + def __init__(self, d_query, d_key, n_head, d_hidden, d_out): + super(Attention, self).__init__() + self.h = n_head + self.dim = d_hidden + # + self.to_q = nn.Linear(d_query, n_head*d_hidden, bias=False) + self.to_k = nn.Linear(d_key, n_head*d_hidden, bias=False) + self.to_v = nn.Linear(d_key, n_head*d_hidden, bias=False) + # + self.to_out = nn.Linear(n_head*d_hidden, d_out) + self.scaling = 1/math.sqrt(d_hidden) + # + # initialize all parameters properly + self.reset_parameter() + + def reset_parameter(self): + # query/key/value projection: Glorot uniform / Xavier uniform + nn.init.xavier_uniform_(self.to_q.weight) + nn.init.xavier_uniform_(self.to_k.weight) + nn.init.xavier_uniform_(self.to_v.weight) + + # to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining + nn.init.zeros_(self.to_out.weight) + nn.init.zeros_(self.to_out.bias) + + def forward(self, query, key, value): + B, Q = query.shape[:2] + B, K = key.shape[:2] + # + query = self.to_q(query).reshape(B, Q, self.h, self.dim) + key = self.to_k(key).reshape(B, K, self.h, self.dim) + value = self.to_v(value).reshape(B, K, self.h, self.dim) + # + query = query * self.scaling + attn = einsum('bqhd,bkhd->bhqk', query, key) + attn = F.softmax(attn, dim=-1) + # + out = einsum('bhqk,bkhd->bqhd', attn, value) + out = out.reshape(B, Q, self.h*self.dim) + # + out = self.to_out(out) + + return out + +class AttentionWithBias(nn.Module): + def __init__(self, d_in=256, d_bias=128, n_head=8, d_hidden=32): + super(AttentionWithBias, self).__init__() + self.norm_in = nn.LayerNorm(d_in) + self.norm_bias = nn.LayerNorm(d_bias) + # + self.to_q = nn.Linear(d_in, n_head*d_hidden, bias=False) + self.to_k = nn.Linear(d_in, n_head*d_hidden, bias=False) + self.to_v = nn.Linear(d_in, n_head*d_hidden, bias=False) + self.to_b = nn.Linear(d_bias, n_head, bias=False) + self.to_g = nn.Linear(d_in, n_head*d_hidden) + self.to_out = nn.Linear(n_head*d_hidden, d_in) + + self.scaling = 1/math.sqrt(d_hidden) + self.h = n_head + self.dim = d_hidden + + self.reset_parameter() + + def reset_parameter(self): + # query/key/value projection: Glorot uniform / Xavier uniform + nn.init.xavier_uniform_(self.to_q.weight) + nn.init.xavier_uniform_(self.to_k.weight) + nn.init.xavier_uniform_(self.to_v.weight) + + # bias: normal distribution + self.to_b = init_lecun_normal(self.to_b) + + # gating: zero weights, one biases (mostly open gate at the begining) + nn.init.zeros_(self.to_g.weight) + nn.init.ones_(self.to_g.bias) + + # to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining + nn.init.zeros_(self.to_out.weight) + nn.init.zeros_(self.to_out.bias) + + def forward(self, x, bias): + B, L = x.shape[:2] + # + x = self.norm_in(x) + bias = self.norm_bias(bias) + # + query = self.to_q(x).reshape(B, L, self.h, self.dim) + key = self.to_k(x).reshape(B, L, self.h, self.dim) + value = self.to_v(x).reshape(B, L, self.h, self.dim) + bias = self.to_b(bias) # (B, L, L, h) + gate = torch.sigmoid(self.to_g(x)) + # + key = key * self.scaling + attn = einsum('bqhd,bkhd->bqkh', query, key) + attn = attn + bias + attn = F.softmax(attn, dim=-2) + # + out = einsum('bqkh,bkhd->bqhd', attn, value).reshape(B, L, -1) + out = gate * out + # + out = self.to_out(out) + return out + +# MSA Attention (row/column) from AlphaFold architecture +class SequenceWeight(nn.Module): + def __init__(self, d_msa, n_head, d_hidden, p_drop=0.1): + super(SequenceWeight, self).__init__() + self.h = n_head + self.dim = d_hidden + self.scale = 1.0 / math.sqrt(self.dim) + + self.to_query = nn.Linear(d_msa, n_head*d_hidden) + self.to_key = nn.Linear(d_msa, n_head*d_hidden) + self.dropout = nn.Dropout(p_drop) + + self.reset_parameter() + + def reset_parameter(self): + # query/key/value projection: Glorot uniform / Xavier uniform + nn.init.xavier_uniform_(self.to_query.weight) + nn.init.xavier_uniform_(self.to_key.weight) + + def forward(self, msa): + B, N, L = msa.shape[:3] + + tar_seq = msa[:,0] + + q = self.to_query(tar_seq).view(B, 1, L, self.h, self.dim) + k = self.to_key(msa).view(B, N, L, self.h, self.dim) + + q = q * self.scale + attn = einsum('bqihd,bkihd->bkihq', q, k) + attn = F.softmax(attn, dim=1) + return self.dropout(attn) + +class MSARowAttentionWithBias(nn.Module): + def __init__(self, d_msa=256, d_pair=128, n_head=8, d_hidden=32): + super(MSARowAttentionWithBias, self).__init__() + self.norm_msa = nn.LayerNorm(d_msa) + self.norm_pair = nn.LayerNorm(d_pair) + # + self.seq_weight = SequenceWeight(d_msa, n_head, d_hidden, p_drop=0.1) + self.to_q = nn.Linear(d_msa, n_head*d_hidden, bias=False) + self.to_k = nn.Linear(d_msa, n_head*d_hidden, bias=False) + self.to_v = nn.Linear(d_msa, n_head*d_hidden, bias=False) + self.to_b = nn.Linear(d_pair, n_head, bias=False) + self.to_g = nn.Linear(d_msa, n_head*d_hidden) + self.to_out = nn.Linear(n_head*d_hidden, d_msa) + + self.scaling = 1/math.sqrt(d_hidden) + self.h = n_head + self.dim = d_hidden + + self.reset_parameter() + + def reset_parameter(self): + # query/key/value projection: Glorot uniform / Xavier uniform + nn.init.xavier_uniform_(self.to_q.weight) + nn.init.xavier_uniform_(self.to_k.weight) + nn.init.xavier_uniform_(self.to_v.weight) + + # bias: normal distribution + self.to_b = init_lecun_normal(self.to_b) + + # gating: zero weights, one biases (mostly open gate at the begining) + nn.init.zeros_(self.to_g.weight) + nn.init.ones_(self.to_g.bias) + + # to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining + nn.init.zeros_(self.to_out.weight) + nn.init.zeros_(self.to_out.bias) + + def forward(self, msa, pair): # TODO: make this as tied-attention + B, N, L = msa.shape[:3] + # + msa = self.norm_msa(msa) + pair = self.norm_pair(pair) + # + seq_weight = self.seq_weight(msa) # (B, N, L, h, 1) + query = self.to_q(msa).reshape(B, N, L, self.h, self.dim) + key = self.to_k(msa).reshape(B, N, L, self.h, self.dim) + value = self.to_v(msa).reshape(B, N, L, self.h, self.dim) + bias = self.to_b(pair) # (B, L, L, h) + gate = torch.sigmoid(self.to_g(msa)) + # + query = query * seq_weight.expand(-1, -1, -1, -1, self.dim) + key = key * self.scaling + attn = einsum('bsqhd,bskhd->bqkh', query, key) + attn = attn + bias + attn = F.softmax(attn, dim=-2) + # + out = einsum('bqkh,bskhd->bsqhd', attn, value).reshape(B, N, L, -1) + out = gate * out + # + out = self.to_out(out) + return out + +class MSAColAttention(nn.Module): + def __init__(self, d_msa=256, n_head=8, d_hidden=32): + super(MSAColAttention, self).__init__() + self.norm_msa = nn.LayerNorm(d_msa) + # + self.to_q = nn.Linear(d_msa, n_head*d_hidden, bias=False) + self.to_k = nn.Linear(d_msa, n_head*d_hidden, bias=False) + self.to_v = nn.Linear(d_msa, n_head*d_hidden, bias=False) + self.to_g = nn.Linear(d_msa, n_head*d_hidden) + self.to_out = nn.Linear(n_head*d_hidden, d_msa) + + self.scaling = 1/math.sqrt(d_hidden) + self.h = n_head + self.dim = d_hidden + + self.reset_parameter() + + def reset_parameter(self): + # query/key/value projection: Glorot uniform / Xavier uniform + nn.init.xavier_uniform_(self.to_q.weight) + nn.init.xavier_uniform_(self.to_k.weight) + nn.init.xavier_uniform_(self.to_v.weight) + + # gating: zero weights, one biases (mostly open gate at the begining) + nn.init.zeros_(self.to_g.weight) + nn.init.ones_(self.to_g.bias) + + # to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining + nn.init.zeros_(self.to_out.weight) + nn.init.zeros_(self.to_out.bias) + + def forward(self, msa): + B, N, L = msa.shape[:3] + # + msa = self.norm_msa(msa) + # + query = self.to_q(msa).reshape(B, N, L, self.h, self.dim) + key = self.to_k(msa).reshape(B, N, L, self.h, self.dim) + value = self.to_v(msa).reshape(B, N, L, self.h, self.dim) + gate = torch.sigmoid(self.to_g(msa)) + # + query = query * self.scaling + attn = einsum('bqihd,bkihd->bihqk', query, key) + attn = F.softmax(attn, dim=-1) + # + out = einsum('bihqk,bkihd->bqihd', attn, value).reshape(B, N, L, -1) + out = gate * out + # + out = self.to_out(out) + return out + +class MSAColGlobalAttention(nn.Module): + def __init__(self, d_msa=64, n_head=8, d_hidden=8): + super(MSAColGlobalAttention, self).__init__() + self.norm_msa = nn.LayerNorm(d_msa) + # + self.to_q = nn.Linear(d_msa, n_head*d_hidden, bias=False) + self.to_k = nn.Linear(d_msa, d_hidden, bias=False) + self.to_v = nn.Linear(d_msa, d_hidden, bias=False) + self.to_g = nn.Linear(d_msa, n_head*d_hidden) + self.to_out = nn.Linear(n_head*d_hidden, d_msa) + + self.scaling = 1/math.sqrt(d_hidden) + self.h = n_head + self.dim = d_hidden + + self.reset_parameter() + + def reset_parameter(self): + # query/key/value projection: Glorot uniform / Xavier uniform + nn.init.xavier_uniform_(self.to_q.weight) + nn.init.xavier_uniform_(self.to_k.weight) + nn.init.xavier_uniform_(self.to_v.weight) + + # gating: zero weights, one biases (mostly open gate at the begining) + nn.init.zeros_(self.to_g.weight) + nn.init.ones_(self.to_g.bias) + + # to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining + nn.init.zeros_(self.to_out.weight) + nn.init.zeros_(self.to_out.bias) + + def forward(self, msa): + B, N, L = msa.shape[:3] + # + msa = self.norm_msa(msa) + # + query = self.to_q(msa).reshape(B, N, L, self.h, self.dim) + query = query.mean(dim=1) # (B, L, h, dim) + key = self.to_k(msa) # (B, N, L, dim) + value = self.to_v(msa) # (B, N, L, dim) + gate = torch.sigmoid(self.to_g(msa)) # (B, N, L, h*dim) + # + query = query * self.scaling + attn = einsum('bihd,bkid->bihk', query, key) # (B, L, h, N) + attn = F.softmax(attn, dim=-1) + # + out = einsum('bihk,bkid->bihd', attn, value).reshape(B, 1, L, -1) # (B, 1, L, h*dim) + out = gate * out # (B, N, L, h*dim) + # + out = self.to_out(out) + return out + +# Instead of triangle attention, use Tied axail attention with bias from coordinates..? +class BiasedAxialAttention(nn.Module): + def __init__(self, d_pair, d_bias, n_head, d_hidden, p_drop=0.1, is_row=True): + super(BiasedAxialAttention, self).__init__() + # + self.is_row = is_row + self.norm_pair = nn.LayerNorm(d_pair) + self.norm_bias = nn.LayerNorm(d_bias) + + self.to_q = nn.Linear(d_pair, n_head*d_hidden, bias=False) + self.to_k = nn.Linear(d_pair, n_head*d_hidden, bias=False) + self.to_v = nn.Linear(d_pair, n_head*d_hidden, bias=False) + self.to_b = nn.Linear(d_bias, n_head, bias=False) + self.to_g = nn.Linear(d_pair, n_head*d_hidden) + self.to_out = nn.Linear(n_head*d_hidden, d_pair) + + self.scaling = 1/math.sqrt(d_hidden) + self.h = n_head + self.dim = d_hidden + + # initialize all parameters properly + self.reset_parameter() + + def reset_parameter(self): + # query/key/value projection: Glorot uniform / Xavier uniform + nn.init.xavier_uniform_(self.to_q.weight) + nn.init.xavier_uniform_(self.to_k.weight) + nn.init.xavier_uniform_(self.to_v.weight) + + # bias: normal distribution + self.to_b = init_lecun_normal(self.to_b) + + # gating: zero weights, one biases (mostly open gate at the begining) + nn.init.zeros_(self.to_g.weight) + nn.init.ones_(self.to_g.bias) + + # to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining + nn.init.zeros_(self.to_out.weight) + nn.init.zeros_(self.to_out.bias) + + def forward(self, pair, bias): + # pair: (B, L, L, d_pair) + B, L = pair.shape[:2] + + if self.is_row: + pair = pair.permute(0,2,1,3) + bias = bias.permute(0,2,1,3) + + pair = self.norm_pair(pair) + bias = self.norm_bias(bias) + + query = self.to_q(pair).reshape(B, L, L, self.h, self.dim) + key = self.to_k(pair).reshape(B, L, L, self.h, self.dim) + value = self.to_v(pair).reshape(B, L, L, self.h, self.dim) + bias = self.to_b(bias) # (B, L, L, h) + gate = torch.sigmoid(self.to_g(pair)) # (B, L, L, h*dim) + + query = query * self.scaling + key = key / math.sqrt(L) # normalize for tied attention + attn = einsum('bnihk,bnjhk->bijh', query, key) # tied attention + attn = attn + bias # apply bias + attn = F.softmax(attn, dim=-2) # (B, L, L, h) + + out = einsum('bijh,bkjhd->bikhd', attn, value).reshape(B, L, L, -1) + out = gate * out + + out = self.to_out(out) + if self.is_row: + out = out.permute(0,2,1,3) + return out + diff --git a/rfdiffusion/AuxiliaryPredictor.py b/rfdiffusion/AuxiliaryPredictor.py new file mode 100644 index 0000000000000000000000000000000000000000..dd246e193cbe54bdec383aa46c575f6e2de3d1d7 --- /dev/null +++ b/rfdiffusion/AuxiliaryPredictor.py @@ -0,0 +1,92 @@ +import torch +import torch.nn as nn + +class DistanceNetwork(nn.Module): + def __init__(self, n_feat, p_drop=0.1): + super(DistanceNetwork, self).__init__() + # + self.proj_symm = nn.Linear(n_feat, 37*2) + self.proj_asymm = nn.Linear(n_feat, 37+19) + + self.reset_parameter() + + def reset_parameter(self): + # initialize linear layer for final logit prediction + nn.init.zeros_(self.proj_symm.weight) + nn.init.zeros_(self.proj_asymm.weight) + nn.init.zeros_(self.proj_symm.bias) + nn.init.zeros_(self.proj_asymm.bias) + + def forward(self, x): + # input: pair info (B, L, L, C) + + # predict theta, phi (non-symmetric) + logits_asymm = self.proj_asymm(x) + logits_theta = logits_asymm[:,:,:,:37].permute(0,3,1,2) + logits_phi = logits_asymm[:,:,:,37:].permute(0,3,1,2) + + # predict dist, omega + logits_symm = self.proj_symm(x) + logits_symm = logits_symm + logits_symm.permute(0,2,1,3) + logits_dist = logits_symm[:,:,:,:37].permute(0,3,1,2) + logits_omega = logits_symm[:,:,:,37:].permute(0,3,1,2) + + return logits_dist, logits_omega, logits_theta, logits_phi + +class MaskedTokenNetwork(nn.Module): + def __init__(self, n_feat): + super(MaskedTokenNetwork, self).__init__() + self.proj = nn.Linear(n_feat, 21) + + self.reset_parameter() + + def reset_parameter(self): + nn.init.zeros_(self.proj.weight) + nn.init.zeros_(self.proj.bias) + + def forward(self, x): + B, N, L = x.shape[:3] + logits = self.proj(x).permute(0,3,1,2).reshape(B, -1, N*L) + + return logits + +class LDDTNetwork(nn.Module): + def __init__(self, n_feat, n_bin_lddt=50): + super(LDDTNetwork, self).__init__() + self.proj = nn.Linear(n_feat, n_bin_lddt) + + self.reset_parameter() + + def reset_parameter(self): + nn.init.zeros_(self.proj.weight) + nn.init.zeros_(self.proj.bias) + + def forward(self, x): + logits = self.proj(x) # (B, L, 50) + + return logits.permute(0,2,1) + +class ExpResolvedNetwork(nn.Module): + def __init__(self, d_msa, d_state, p_drop=0.1): + super(ExpResolvedNetwork, self).__init__() + self.norm_msa = nn.LayerNorm(d_msa) + self.norm_state = nn.LayerNorm(d_state) + self.proj = nn.Linear(d_msa+d_state, 1) + + self.reset_parameter() + + def reset_parameter(self): + nn.init.zeros_(self.proj.weight) + nn.init.zeros_(self.proj.bias) + + def forward(self, seq, state): + B, L = seq.shape[:2] + + seq = self.norm_msa(seq) + state = self.norm_state(state) + feat = torch.cat((seq, state), dim=-1) + logits = self.proj(feat) + return logits.reshape(B, L) + + + diff --git a/rfdiffusion/Embeddings.py b/rfdiffusion/Embeddings.py new file mode 100644 index 0000000000000000000000000000000000000000..a052c9dbafcb3d33b39b06dcaaf7c31d7eb3ecf8 --- /dev/null +++ b/rfdiffusion/Embeddings.py @@ -0,0 +1,303 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from opt_einsum import contract as einsum +import torch.utils.checkpoint as checkpoint +from rfdiffusion.util import get_tips +from rfdiffusion.util_module import Dropout, create_custom_forward, rbf, init_lecun_normal +from rfdiffusion.Attention_module import Attention, FeedForwardLayer, AttentionWithBias +from rfdiffusion.Track_module import PairStr2Pair +import math + +# Module contains classes and functions to generate initial embeddings + +class PositionalEncoding2D(nn.Module): + # Add relative positional encoding to pair features + def __init__(self, d_model, minpos=-32, maxpos=32, p_drop=0.1): + super(PositionalEncoding2D, self).__init__() + self.minpos = minpos + self.maxpos = maxpos + self.nbin = abs(minpos)+maxpos+1 + self.emb = nn.Embedding(self.nbin, d_model) + self.drop = nn.Dropout(p_drop) + + def forward(self, x, idx): + bins = torch.arange(self.minpos, self.maxpos, device=x.device) + seqsep = idx[:,None,:] - idx[:,:,None] # (B, L, L) + # + ib = torch.bucketize(seqsep, bins).long() # (B, L, L) + emb = self.emb(ib) #(B, L, L, d_model) + x = x + emb # add relative positional encoding + return self.drop(x) + +class MSA_emb(nn.Module): + # Get initial seed MSA embedding + def __init__(self, d_msa=256, d_pair=128, d_state=32, d_init=22+22+2+2, + minpos=-32, maxpos=32, p_drop=0.1, input_seq_onehot=False): + super(MSA_emb, self).__init__() + self.emb = nn.Linear(d_init, d_msa) # embedding for general MSA + self.emb_q = nn.Embedding(22, d_msa) # embedding for query sequence -- used for MSA embedding + self.emb_left = nn.Embedding(22, d_pair) # embedding for query sequence -- used for pair embedding + self.emb_right = nn.Embedding(22, d_pair) # embedding for query sequence -- used for pair embedding + self.emb_state = nn.Embedding(22, d_state) + self.drop = nn.Dropout(p_drop) + self.pos = PositionalEncoding2D(d_pair, minpos=minpos, maxpos=maxpos, p_drop=p_drop) + + self.input_seq_onehot=input_seq_onehot + + self.reset_parameter() + + def reset_parameter(self): + self.emb = init_lecun_normal(self.emb) + self.emb_q = init_lecun_normal(self.emb_q) + self.emb_left = init_lecun_normal(self.emb_left) + self.emb_right = init_lecun_normal(self.emb_right) + self.emb_state = init_lecun_normal(self.emb_state) + + nn.init.zeros_(self.emb.bias) + + def forward(self, msa, seq, idx): + # Inputs: + # - msa: Input MSA (B, N, L, d_init) + # - seq: Input Sequence (B, L) + # - idx: Residue index + # Outputs: + # - msa: Initial MSA embedding (B, N, L, d_msa) + # - pair: Initial Pair embedding (B, L, L, d_pair) + + N = msa.shape[1] # number of sequenes in MSA + + # msa embedding + msa = self.emb(msa) # (B, N, L, d_model) # MSA embedding + + # Sergey's one hot trick + tmp = (seq @ self.emb_q.weight).unsqueeze(1) # (B, 1, L, d_model) -- query embedding + + msa = msa + tmp.expand(-1, N, -1, -1) # adding query embedding to MSA + msa = self.drop(msa) + + # pair embedding + # Sergey's one hot trick + left = (seq @ self.emb_left.weight)[:,None] # (B, 1, L, d_pair) + right = (seq @ self.emb_right.weight)[:,:,None] # (B, L, 1, d_pair) + + pair = left + right # (B, L, L, d_pair) + pair = self.pos(pair, idx) # add relative position + + # state embedding + # Sergey's one hot trick + state = self.drop(seq @ self.emb_state.weight) + return msa, pair, state + +class Extra_emb(nn.Module): + # Get initial seed MSA embedding + def __init__(self, d_msa=256, d_init=22+1+2, p_drop=0.1, input_seq_onehot=False): + super(Extra_emb, self).__init__() + self.emb = nn.Linear(d_init, d_msa) # embedding for general MSA + self.emb_q = nn.Embedding(22, d_msa) # embedding for query sequence + self.drop = nn.Dropout(p_drop) + + self.input_seq_onehot=input_seq_onehot + + self.reset_parameter() + + def reset_parameter(self): + self.emb = init_lecun_normal(self.emb) + nn.init.zeros_(self.emb.bias) + + def forward(self, msa, seq, idx): + # Inputs: + # - msa: Input MSA (B, N, L, d_init) + # - seq: Input Sequence (B, L) + # - idx: Residue index + # Outputs: + # - msa: Initial MSA embedding (B, N, L, d_msa) + N = msa.shape[1] # number of sequenes in MSA + msa = self.emb(msa) # (B, N, L, d_model) # MSA embedding + + # Sergey's one hot trick + seq = (seq @ self.emb_q.weight).unsqueeze(1) # (B, 1, L, d_model) -- query embedding + msa = msa + seq.expand(-1, N, -1, -1) # adding query embedding to MSA + return self.drop(msa) + +class TemplatePairStack(nn.Module): + # process template pairwise features + # use structure-biased attention + def __init__(self, n_block=2, d_templ=64, n_head=4, d_hidden=16, p_drop=0.25): + super(TemplatePairStack, self).__init__() + self.n_block = n_block + proc_s = [PairStr2Pair(d_pair=d_templ, n_head=n_head, d_hidden=d_hidden, p_drop=p_drop) for i in range(n_block)] + self.block = nn.ModuleList(proc_s) + self.norm = nn.LayerNorm(d_templ) + def forward(self, templ, rbf_feat, use_checkpoint=False): + B, T, L = templ.shape[:3] + templ = templ.reshape(B*T, L, L, -1) + + for i_block in range(self.n_block): + if use_checkpoint: + templ = checkpoint.checkpoint(create_custom_forward(self.block[i_block]), templ, rbf_feat) + else: + templ = self.block[i_block](templ, rbf_feat) + return self.norm(templ).reshape(B, T, L, L, -1) + +class TemplateTorsionStack(nn.Module): + def __init__(self, n_block=2, d_templ=64, n_head=4, d_hidden=16, p_drop=0.15): + super(TemplateTorsionStack, self).__init__() + self.n_block=n_block + self.proj_pair = nn.Linear(d_templ+36, d_templ) + proc_s = [AttentionWithBias(d_in=d_templ, d_bias=d_templ, + n_head=n_head, d_hidden=d_hidden) for i in range(n_block)] + self.row_attn = nn.ModuleList(proc_s) + proc_s = [FeedForwardLayer(d_templ, 4, p_drop=p_drop) for i in range(n_block)] + self.ff = nn.ModuleList(proc_s) + self.norm = nn.LayerNorm(d_templ) + + def reset_parameter(self): + self.proj_pair = init_lecun_normal(self.proj_pair) + nn.init.zeros_(self.proj_pair.bias) + + def forward(self, tors, pair, rbf_feat, use_checkpoint=False): + B, T, L = tors.shape[:3] + tors = tors.reshape(B*T, L, -1) + pair = pair.reshape(B*T, L, L, -1) + pair = torch.cat((pair, rbf_feat), dim=-1) + pair = self.proj_pair(pair) + + for i_block in range(self.n_block): + if use_checkpoint: + tors = tors + checkpoint.checkpoint(create_custom_forward(self.row_attn[i_block]), tors, pair) + else: + tors = tors + self.row_attn[i_block](tors, pair) + tors = tors + self.ff[i_block](tors) + return self.norm(tors).reshape(B, T, L, -1) + +class Templ_emb(nn.Module): + # Get template embedding + # Features are + # t2d: + # - 37 distogram bins + 6 orientations (43) + # - Mask (missing/unaligned) (1) + # t1d: + # - tiled AA sequence (20 standard aa + gap) + # - confidence (1) + # - contacting or note (1). NB this is added for diffusion model. Used only in complex training examples - 1 signifies that a residue in the non-diffused chain\ + # i.e. the context, is in contact with the diffused chain. + # + #Added extra t1d dimension for contacting or not + def __init__(self, d_t1d=21+1+1, d_t2d=43+1, d_tor=30, d_pair=128, d_state=32, + n_block=2, d_templ=64, + n_head=4, d_hidden=16, p_drop=0.25): + super(Templ_emb, self).__init__() + # process 2D features + self.emb = nn.Linear(d_t1d*2+d_t2d, d_templ) + self.templ_stack = TemplatePairStack(n_block=n_block, d_templ=d_templ, n_head=n_head, + d_hidden=d_hidden, p_drop=p_drop) + + self.attn = Attention(d_pair, d_templ, n_head, d_hidden, d_pair) + + # process torsion angles + self.emb_t1d = nn.Linear(d_t1d+d_tor, d_templ) + self.proj_t1d = nn.Linear(d_templ, d_templ) + #self.tor_stack = TemplateTorsionStack(n_block=n_block, d_templ=d_templ, n_head=n_head, + # d_hidden=d_hidden, p_drop=p_drop) + self.attn_tor = Attention(d_state, d_templ, n_head, d_hidden, d_state) + + self.reset_parameter() + + def reset_parameter(self): + self.emb = init_lecun_normal(self.emb) + nn.init.zeros_(self.emb.bias) + + nn.init.kaiming_normal_(self.emb_t1d.weight, nonlinearity='relu') + nn.init.zeros_(self.emb_t1d.bias) + + self.proj_t1d = init_lecun_normal(self.proj_t1d) + nn.init.zeros_(self.proj_t1d.bias) + + def forward(self, t1d, t2d, alpha_t, xyz_t, pair, state, use_checkpoint=False): + # Input + # - t1d: 1D template info (B, T, L, 23) + # - t2d: 2D template info (B, T, L, L, 44) + B, T, L, _ = t1d.shape + + # Prepare 2D template features + left = t1d.unsqueeze(3).expand(-1,-1,-1,L,-1) + right = t1d.unsqueeze(2).expand(-1,-1,L,-1,-1) + # + templ = torch.cat((t2d, left, right), -1) # (B, T, L, L, 90) + templ = self.emb(templ) # Template templures (B, T, L, L, d_templ) + # process each template features + xyz_t = xyz_t.reshape(B*T, L, -1, 3) + rbf_feat = rbf(torch.cdist(xyz_t[:,:,1], xyz_t[:,:,1])) + templ = self.templ_stack(templ, rbf_feat, use_checkpoint=use_checkpoint) # (B, T, L,L, d_templ) + + # Prepare 1D template torsion angle features + t1d = torch.cat((t1d, alpha_t), dim=-1) # (B, T, L, 23+30) + + # process each template features + t1d = self.proj_t1d(F.relu_(self.emb_t1d(t1d))) + + # mixing query state features to template state features + state = state.reshape(B*L, 1, -1) + t1d = t1d.permute(0,2,1,3).reshape(B*L, T, -1) + if use_checkpoint: + out = checkpoint.checkpoint(create_custom_forward(self.attn_tor), state, t1d, t1d) + out = out.reshape(B, L, -1) + else: + out = self.attn_tor(state, t1d, t1d).reshape(B, L, -1) + state = state.reshape(B, L, -1) + state = state + out + + # mixing query pair features to template information (Template pointwise attention) + pair = pair.reshape(B*L*L, 1, -1) + templ = templ.permute(0, 2, 3, 1, 4).reshape(B*L*L, T, -1) + if use_checkpoint: + out = checkpoint.checkpoint(create_custom_forward(self.attn), pair, templ, templ) + out = out.reshape(B, L, L, -1) + else: + out = self.attn(pair, templ, templ).reshape(B, L, L, -1) + # + pair = pair.reshape(B, L, L, -1) + pair = pair + out + + return pair, state + +class Recycling(nn.Module): + def __init__(self, d_msa=256, d_pair=128, d_state=32): + super(Recycling, self).__init__() + self.proj_dist = nn.Linear(36+d_state*2, d_pair) + self.norm_state = nn.LayerNorm(d_state) + self.norm_pair = nn.LayerNorm(d_pair) + self.norm_msa = nn.LayerNorm(d_msa) + + self.reset_parameter() + + def reset_parameter(self): + self.proj_dist = init_lecun_normal(self.proj_dist) + nn.init.zeros_(self.proj_dist.bias) + + def forward(self, seq, msa, pair, xyz, state): + B, L = pair.shape[:2] + state = self.norm_state(state) + # + left = state.unsqueeze(2).expand(-1,-1,L,-1) + right = state.unsqueeze(1).expand(-1,L,-1,-1) + + # three anchor atoms + N = xyz[:,:,0] + Ca = xyz[:,:,1] + C = xyz[:,:,2] + + # recreate Cb given N,Ca,C + b = Ca - N + c = C - Ca + a = torch.cross(b, c, dim=-1) + Cb = -0.58273431*a + 0.56802827*b - 0.54067466*c + Ca + + dist = rbf(torch.cdist(Cb, Cb)) + dist = torch.cat((dist, left, right), dim=-1) + dist = self.proj_dist(dist) + pair = dist + self.norm_pair(pair) + msa = self.norm_msa(msa) + return msa, pair, state + diff --git a/rfdiffusion/RoseTTAFoldModel.py b/rfdiffusion/RoseTTAFoldModel.py new file mode 100644 index 0000000000000000000000000000000000000000..84fbac437a08b009644e3620bdac3998ef971969 --- /dev/null +++ b/rfdiffusion/RoseTTAFoldModel.py @@ -0,0 +1,140 @@ +import torch +import torch.nn as nn +from rfdiffusion.Embeddings import MSA_emb, Extra_emb, Templ_emb, Recycling +from rfdiffusion.Track_module import IterativeSimulator +from rfdiffusion.AuxiliaryPredictor import DistanceNetwork, MaskedTokenNetwork, ExpResolvedNetwork, LDDTNetwork +from opt_einsum import contract as einsum + +class RoseTTAFoldModule(nn.Module): + def __init__(self, + n_extra_block, + n_main_block, + n_ref_block, + d_msa, + d_msa_full, + d_pair, + d_templ, + n_head_msa, + n_head_pair, + n_head_templ, + d_hidden, + d_hidden_templ, + p_drop, + d_t1d, + d_t2d, + T, # total timesteps (used in timestep emb + use_motif_timestep, # Whether to have a distinct emb for motif + freeze_track_motif, # Whether to freeze updates to motif in track + SE3_param_full={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32}, + SE3_param_topk={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32}, + input_seq_onehot=False, # For continuous vs. discrete sequence + ): + + super(RoseTTAFoldModule, self).__init__() + + self.freeze_track_motif = freeze_track_motif + + # Input Embeddings + d_state = SE3_param_topk['l0_out_features'] + self.latent_emb = MSA_emb(d_msa=d_msa, d_pair=d_pair, d_state=d_state, + p_drop=p_drop, input_seq_onehot=input_seq_onehot) # Allowed to take onehotseq + self.full_emb = Extra_emb(d_msa=d_msa_full, d_init=25, + p_drop=p_drop, input_seq_onehot=input_seq_onehot) # Allowed to take onehotseq + self.templ_emb = Templ_emb(d_pair=d_pair, d_templ=d_templ, d_state=d_state, + n_head=n_head_templ, + d_hidden=d_hidden_templ, p_drop=0.25, d_t1d=d_t1d, d_t2d=d_t2d) + + + # Update inputs with outputs from previous round + self.recycle = Recycling(d_msa=d_msa, d_pair=d_pair, d_state=d_state) + # + self.simulator = IterativeSimulator(n_extra_block=n_extra_block, + n_main_block=n_main_block, + n_ref_block=n_ref_block, + d_msa=d_msa, d_msa_full=d_msa_full, + d_pair=d_pair, d_hidden=d_hidden, + n_head_msa=n_head_msa, + n_head_pair=n_head_pair, + SE3_param_full=SE3_param_full, + SE3_param_topk=SE3_param_topk, + p_drop=p_drop) + ## + self.c6d_pred = DistanceNetwork(d_pair, p_drop=p_drop) + self.aa_pred = MaskedTokenNetwork(d_msa) + self.lddt_pred = LDDTNetwork(d_state) + + self.exp_pred = ExpResolvedNetwork(d_msa, d_state) + + def forward(self, msa_latent, msa_full, seq, xyz, idx, t, + t1d=None, t2d=None, xyz_t=None, alpha_t=None, + msa_prev=None, pair_prev=None, state_prev=None, + return_raw=False, return_full=False, return_infer=False, + use_checkpoint=False, motif_mask=None, i_cycle=None, n_cycle=None): + + B, N, L = msa_latent.shape[:3] + # Get embeddings + msa_latent, pair, state = self.latent_emb(msa_latent, seq, idx) + msa_full = self.full_emb(msa_full, seq, idx) + + # Do recycling + if msa_prev == None: + msa_prev = torch.zeros_like(msa_latent[:,0]) + pair_prev = torch.zeros_like(pair) + state_prev = torch.zeros_like(state) + msa_recycle, pair_recycle, state_recycle = self.recycle(seq, msa_prev, pair_prev, xyz, state_prev) + msa_latent[:,0] = msa_latent[:,0] + msa_recycle.reshape(B,L,-1) + pair = pair + pair_recycle + state = state + state_recycle + + + # Get timestep embedding (if using) + if hasattr(self, 'timestep_embedder'): + assert t is not None + time_emb = self.timestep_embedder(L,t,motif_mask) + n_tmpl = t1d.shape[1] + t1d = torch.cat([t1d, time_emb[None,None,...].repeat(1,n_tmpl,1,1)], dim=-1) + + # add template embedding + pair, state = self.templ_emb(t1d, t2d, alpha_t, xyz_t, pair, state, use_checkpoint=use_checkpoint) + + # Predict coordinates from given inputs + is_frozen_residue = motif_mask if self.freeze_track_motif else torch.zeros_like(motif_mask).bool() + msa, pair, R, T, alpha_s, state = self.simulator(seq, msa_latent, msa_full, pair, xyz[:,:,:3], + state, idx, use_checkpoint=use_checkpoint, + motif_mask=is_frozen_residue) + + if return_raw: + # get last structure + xyz = einsum('bnij,bnaj->bnai', R[-1], xyz[:,:,:3]-xyz[:,:,1].unsqueeze(-2)) + T[-1].unsqueeze(-2) + return msa[:,0], pair, xyz, state, alpha_s[-1] + + # predict masked amino acids + logits_aa = self.aa_pred(msa) + + # Predict LDDT + lddt = self.lddt_pred(state) + + if return_infer: + # get last structure + xyz = einsum('bnij,bnaj->bnai', R[-1], xyz[:,:,:3]-xyz[:,:,1].unsqueeze(-2)) + T[-1].unsqueeze(-2) + + # get scalar plddt + nbin = lddt.shape[1] + bin_step = 1.0 / nbin + lddt_bins = torch.linspace(bin_step, 1.0, nbin, dtype=lddt.dtype, device=lddt.device) + pred_lddt = nn.Softmax(dim=1)(lddt) + pred_lddt = torch.sum(lddt_bins[None,:,None]*pred_lddt, dim=1) + + return msa[:,0], pair, xyz, state, alpha_s[-1], logits_aa.permute(0,2,1), pred_lddt + + # + # predict distogram & orientograms + logits = self.c6d_pred(pair) + + # predict experimentally resolved or not + logits_exp = self.exp_pred(msa[:,0], state) + + # get all intermediate bb structures + xyz = einsum('rbnij,bnaj->rbnai', R, xyz[:,:,:3]-xyz[:,:,1].unsqueeze(-2)) + T.unsqueeze(-2) + + return logits, logits_aa, logits_exp, xyz, alpha_s, lddt diff --git a/rfdiffusion/SE3_network.py b/rfdiffusion/SE3_network.py new file mode 100644 index 0000000000000000000000000000000000000000..394f85bb49a07d259ba405a87243ca365f2ae21b --- /dev/null +++ b/rfdiffusion/SE3_network.py @@ -0,0 +1,83 @@ +import torch +import torch.nn as nn + +#from equivariant_attention.modules import get_basis_and_r, GSE3Res, GNormBias +#from equivariant_attention.modules import GConvSE3, GNormSE3 +#from equivariant_attention.fibers import Fiber + +from rfdiffusion.util_module import init_lecun_normal_param +from se3_transformer.model import SE3Transformer +from se3_transformer.model.fiber import Fiber + +class SE3TransformerWrapper(nn.Module): + """SE(3) equivariant GCN with attention""" + def __init__(self, num_layers=2, num_channels=32, num_degrees=3, n_heads=4, div=4, + l0_in_features=32, l0_out_features=32, + l1_in_features=3, l1_out_features=2, + num_edge_features=32): + super().__init__() + # Build the network + self.l1_in = l1_in_features + # + fiber_edge = Fiber({0: num_edge_features}) + if l1_out_features > 0: + if l1_in_features > 0: + fiber_in = Fiber({0: l0_in_features, 1: l1_in_features}) + fiber_hidden = Fiber.create(num_degrees, num_channels) + fiber_out = Fiber({0: l0_out_features, 1: l1_out_features}) + else: + fiber_in = Fiber({0: l0_in_features}) + fiber_hidden = Fiber.create(num_degrees, num_channels) + fiber_out = Fiber({0: l0_out_features, 1: l1_out_features}) + else: + if l1_in_features > 0: + fiber_in = Fiber({0: l0_in_features, 1: l1_in_features}) + fiber_hidden = Fiber.create(num_degrees, num_channels) + fiber_out = Fiber({0: l0_out_features}) + else: + fiber_in = Fiber({0: l0_in_features}) + fiber_hidden = Fiber.create(num_degrees, num_channels) + fiber_out = Fiber({0: l0_out_features}) + + self.se3 = SE3Transformer(num_layers=num_layers, + fiber_in=fiber_in, + fiber_hidden=fiber_hidden, + fiber_out = fiber_out, + num_heads=n_heads, + channels_div=div, + fiber_edge=fiber_edge, + use_layer_norm=True) + #use_layer_norm=False) + + self.reset_parameter() + + def reset_parameter(self): + + # make sure linear layer before ReLu are initialized with kaiming_normal_ + for n, p in self.se3.named_parameters(): + if "bias" in n: + nn.init.zeros_(p) + elif len(p.shape) == 1: + continue + else: + if "radial_func" not in n: + p = init_lecun_normal_param(p) + else: + if "net.6" in n: + nn.init.zeros_(p) + else: + nn.init.kaiming_normal_(p, nonlinearity='relu') + + # make last layers to be zero-initialized + #self.se3.graph_modules[-1].to_kernel_self['0'] = init_lecun_normal_param(self.se3.graph_modules[-1].to_kernel_self['0']) + #self.se3.graph_modules[-1].to_kernel_self['1'] = init_lecun_normal_param(self.se3.graph_modules[-1].to_kernel_self['1']) + nn.init.zeros_(self.se3.graph_modules[-1].to_kernel_self['0']) + nn.init.zeros_(self.se3.graph_modules[-1].to_kernel_self['1']) + + def forward(self, G, type_0_features, type_1_features=None, edge_features=None): + if self.l1_in > 0: + node_features = {'0': type_0_features, '1': type_1_features} + else: + node_features = {'0': type_0_features} + edge_features = {'0': edge_features} + return self.se3(G, node_features, edge_features) diff --git a/rfdiffusion/Track_module.py b/rfdiffusion/Track_module.py new file mode 100644 index 0000000000000000000000000000000000000000..12c0863d117dbc44de5852e9a90524d1c234c7a6 --- /dev/null +++ b/rfdiffusion/Track_module.py @@ -0,0 +1,474 @@ +import torch.utils.checkpoint as checkpoint +from rfdiffusion.util_module import * +from rfdiffusion.Attention_module import * +from rfdiffusion.SE3_network import SE3TransformerWrapper + +# Components for three-track blocks +# 1. MSA -> MSA update (biased attention. bias from pair & structure) +# 2. Pair -> Pair update (biased attention. bias from structure) +# 3. MSA -> Pair update (extract coevolution signal) +# 4. Str -> Str update (node from MSA, edge from Pair) + +# Update MSA with biased self-attention. bias from Pair & Str +class MSAPairStr2MSA(nn.Module): + def __init__(self, d_msa=256, d_pair=128, n_head=8, d_state=16, + d_hidden=32, p_drop=0.15, use_global_attn=False): + super(MSAPairStr2MSA, self).__init__() + self.norm_pair = nn.LayerNorm(d_pair) + self.proj_pair = nn.Linear(d_pair+36, d_pair) + self.norm_state = nn.LayerNorm(d_state) + self.proj_state = nn.Linear(d_state, d_msa) + self.drop_row = Dropout(broadcast_dim=1, p_drop=p_drop) + self.row_attn = MSARowAttentionWithBias(d_msa=d_msa, d_pair=d_pair, + n_head=n_head, d_hidden=d_hidden) + if use_global_attn: + self.col_attn = MSAColGlobalAttention(d_msa=d_msa, n_head=n_head, d_hidden=d_hidden) + else: + self.col_attn = MSAColAttention(d_msa=d_msa, n_head=n_head, d_hidden=d_hidden) + self.ff = FeedForwardLayer(d_msa, 4, p_drop=p_drop) + + # Do proper initialization + self.reset_parameter() + + def reset_parameter(self): + # initialize weights to normal distrib + self.proj_pair = init_lecun_normal(self.proj_pair) + self.proj_state = init_lecun_normal(self.proj_state) + + # initialize bias to zeros + nn.init.zeros_(self.proj_pair.bias) + nn.init.zeros_(self.proj_state.bias) + + def forward(self, msa, pair, rbf_feat, state): + ''' + Inputs: + - msa: MSA feature (B, N, L, d_msa) + - pair: Pair feature (B, L, L, d_pair) + - rbf_feat: Ca-Ca distance feature calculated from xyz coordinates (B, L, L, 36) + - xyz: xyz coordinates (B, L, n_atom, 3) + - state: updated node features after SE(3)-Transformer layer (B, L, d_state) + Output: + - msa: Updated MSA feature (B, N, L, d_msa) + ''' + B, N, L = msa.shape[:3] + + # prepare input bias feature by combining pair & coordinate info + pair = self.norm_pair(pair) + pair = torch.cat((pair, rbf_feat), dim=-1) + pair = self.proj_pair(pair) # (B, L, L, d_pair) + # + # update query sequence feature (first sequence in the MSA) with feedbacks (state) from SE3 + state = self.norm_state(state) + state = self.proj_state(state).reshape(B, 1, L, -1) + msa = msa.index_add(1, torch.tensor([0,], device=state.device), state) + # + # Apply row/column attention to msa & transform + msa = msa + self.drop_row(self.row_attn(msa, pair)) + msa = msa + self.col_attn(msa) + msa = msa + self.ff(msa) + + return msa + +class PairStr2Pair(nn.Module): + def __init__(self, d_pair=128, n_head=4, d_hidden=32, d_rbf=36, p_drop=0.15): + super(PairStr2Pair, self).__init__() + + self.emb_rbf = nn.Linear(d_rbf, d_hidden) + self.proj_rbf = nn.Linear(d_hidden, d_pair) + + self.drop_row = Dropout(broadcast_dim=1, p_drop=p_drop) + self.drop_col = Dropout(broadcast_dim=2, p_drop=p_drop) + + self.row_attn = BiasedAxialAttention(d_pair, d_pair, n_head, d_hidden, p_drop=p_drop, is_row=True) + self.col_attn = BiasedAxialAttention(d_pair, d_pair, n_head, d_hidden, p_drop=p_drop, is_row=False) + + self.ff = FeedForwardLayer(d_pair, 2) + + self.reset_parameter() + + def reset_parameter(self): + nn.init.kaiming_normal_(self.emb_rbf.weight, nonlinearity='relu') + nn.init.zeros_(self.emb_rbf.bias) + + self.proj_rbf = init_lecun_normal(self.proj_rbf) + nn.init.zeros_(self.proj_rbf.bias) + + def forward(self, pair, rbf_feat): + B, L = pair.shape[:2] + + rbf_feat = self.proj_rbf(F.relu_(self.emb_rbf(rbf_feat))) + + pair = pair + self.drop_row(self.row_attn(pair, rbf_feat)) + pair = pair + self.drop_col(self.col_attn(pair, rbf_feat)) + pair = pair + self.ff(pair) + return pair + +class MSA2Pair(nn.Module): + def __init__(self, d_msa=256, d_pair=128, d_hidden=32, p_drop=0.15): + super(MSA2Pair, self).__init__() + self.norm = nn.LayerNorm(d_msa) + self.proj_left = nn.Linear(d_msa, d_hidden) + self.proj_right = nn.Linear(d_msa, d_hidden) + self.proj_out = nn.Linear(d_hidden*d_hidden, d_pair) + + self.reset_parameter() + + def reset_parameter(self): + # normal initialization + self.proj_left = init_lecun_normal(self.proj_left) + self.proj_right = init_lecun_normal(self.proj_right) + nn.init.zeros_(self.proj_left.bias) + nn.init.zeros_(self.proj_right.bias) + + # zero initialize output + nn.init.zeros_(self.proj_out.weight) + nn.init.zeros_(self.proj_out.bias) + + def forward(self, msa, pair): + B, N, L = msa.shape[:3] + msa = self.norm(msa) + left = self.proj_left(msa) + right = self.proj_right(msa) + right = right / float(N) + out = einsum('bsli,bsmj->blmij', left, right).reshape(B, L, L, -1) + out = self.proj_out(out) + + pair = pair + out + + return pair + +class SCPred(nn.Module): + def __init__(self, d_msa=256, d_state=32, d_hidden=128, p_drop=0.15): + super(SCPred, self).__init__() + self.norm_s0 = nn.LayerNorm(d_msa) + self.norm_si = nn.LayerNorm(d_state) + self.linear_s0 = nn.Linear(d_msa, d_hidden) + self.linear_si = nn.Linear(d_state, d_hidden) + + # ResNet layers + self.linear_1 = nn.Linear(d_hidden, d_hidden) + self.linear_2 = nn.Linear(d_hidden, d_hidden) + self.linear_3 = nn.Linear(d_hidden, d_hidden) + self.linear_4 = nn.Linear(d_hidden, d_hidden) + + # Final outputs + self.linear_out = nn.Linear(d_hidden, 20) + + self.reset_parameter() + + def reset_parameter(self): + # normal initialization + self.linear_s0 = init_lecun_normal(self.linear_s0) + self.linear_si = init_lecun_normal(self.linear_si) + self.linear_out = init_lecun_normal(self.linear_out) + nn.init.zeros_(self.linear_s0.bias) + nn.init.zeros_(self.linear_si.bias) + nn.init.zeros_(self.linear_out.bias) + + # right before relu activation: He initializer (kaiming normal) + nn.init.kaiming_normal_(self.linear_1.weight, nonlinearity='relu') + nn.init.zeros_(self.linear_1.bias) + nn.init.kaiming_normal_(self.linear_3.weight, nonlinearity='relu') + nn.init.zeros_(self.linear_3.bias) + + # right before residual connection: zero initialize + nn.init.zeros_(self.linear_2.weight) + nn.init.zeros_(self.linear_2.bias) + nn.init.zeros_(self.linear_4.weight) + nn.init.zeros_(self.linear_4.bias) + + def forward(self, seq, state): + ''' + Predict side-chain torsion angles along with backbone torsions + Inputs: + - seq: hidden embeddings corresponding to query sequence (B, L, d_msa) + - state: state feature (output l0 feature) from previous SE3 layer (B, L, d_state) + Outputs: + - si: predicted torsion angles (phi, psi, omega, chi1~4 with cos/sin, Cb bend, Cb twist, CG) (B, L, 10, 2) + ''' + B, L = seq.shape[:2] + seq = self.norm_s0(seq) + state = self.norm_si(state) + si = self.linear_s0(seq) + self.linear_si(state) + + si = si + self.linear_2(F.relu_(self.linear_1(F.relu_(si)))) + si = si + self.linear_4(F.relu_(self.linear_3(F.relu_(si)))) + + si = self.linear_out(F.relu_(si)) + return si.view(B, L, 10, 2) + + +class Str2Str(nn.Module): + def __init__(self, d_msa=256, d_pair=128, d_state=16, + SE3_param={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32}, p_drop=0.1): + super(Str2Str, self).__init__() + + # initial node & pair feature process + self.norm_msa = nn.LayerNorm(d_msa) + self.norm_pair = nn.LayerNorm(d_pair) + self.norm_state = nn.LayerNorm(d_state) + + self.embed_x = nn.Linear(d_msa+d_state, SE3_param['l0_in_features']) + self.embed_e1 = nn.Linear(d_pair, SE3_param['num_edge_features']) + self.embed_e2 = nn.Linear(SE3_param['num_edge_features']+36+1, SE3_param['num_edge_features']) + + self.norm_node = nn.LayerNorm(SE3_param['l0_in_features']) + self.norm_edge1 = nn.LayerNorm(SE3_param['num_edge_features']) + self.norm_edge2 = nn.LayerNorm(SE3_param['num_edge_features']) + + self.se3 = SE3TransformerWrapper(**SE3_param) + self.sc_predictor = SCPred(d_msa=d_msa, d_state=SE3_param['l0_out_features'], + p_drop=p_drop) + + self.reset_parameter() + + def reset_parameter(self): + # initialize weights to normal distribution + self.embed_x = init_lecun_normal(self.embed_x) + self.embed_e1 = init_lecun_normal(self.embed_e1) + self.embed_e2 = init_lecun_normal(self.embed_e2) + + # initialize bias to zeros + nn.init.zeros_(self.embed_x.bias) + nn.init.zeros_(self.embed_e1.bias) + nn.init.zeros_(self.embed_e2.bias) + + @torch.cuda.amp.autocast(enabled=False) + def forward(self, msa, pair, R_in, T_in, xyz, state, idx, motif_mask, top_k=64, eps=1e-5): + B, N, L = msa.shape[:3] + + if motif_mask is None: + motif_mask = torch.zeros(L).bool() + + # process msa & pair features + node = self.norm_msa(msa[:,0]) + pair = self.norm_pair(pair) + state = self.norm_state(state) + + node = torch.cat((node, state), dim=-1) + node = self.norm_node(self.embed_x(node)) + pair = self.norm_edge1(self.embed_e1(pair)) + + neighbor = get_seqsep(idx) + rbf_feat = rbf(torch.cdist(xyz[:,:,1], xyz[:,:,1])) + pair = torch.cat((pair, rbf_feat, neighbor), dim=-1) + pair = self.norm_edge2(self.embed_e2(pair)) + + # define graph + if top_k != 0: + G, edge_feats = make_topk_graph(xyz[:,:,1,:], pair, idx, top_k=top_k) + else: + G, edge_feats = make_full_graph(xyz[:,:,1,:], pair, idx, top_k=top_k) + l1_feats = xyz - xyz[:,:,1,:].unsqueeze(2) + l1_feats = l1_feats.reshape(B*L, -1, 3) + + # apply SE(3) Transformer & update coordinates + shift = self.se3(G, node.reshape(B*L, -1, 1), l1_feats, edge_feats) + + state = shift['0'].reshape(B, L, -1) # (B, L, C) + + offset = shift['1'].reshape(B, L, 2, 3) + offset[:,motif_mask,...] = 0 # NOTE: motif mask is all zeros if not freeezing the motif + + delTi = offset[:,:,0,:] / 10.0 # translation + R = offset[:,:,1,:] / 100.0 # rotation + + Qnorm = torch.sqrt( 1 + torch.sum(R*R, dim=-1) ) + qA, qB, qC, qD = 1/Qnorm, R[:,:,0]/Qnorm, R[:,:,1]/Qnorm, R[:,:,2]/Qnorm + + delRi = torch.zeros((B,L,3,3), device=xyz.device) + delRi[:,:,0,0] = qA*qA+qB*qB-qC*qC-qD*qD + delRi[:,:,0,1] = 2*qB*qC - 2*qA*qD + delRi[:,:,0,2] = 2*qB*qD + 2*qA*qC + delRi[:,:,1,0] = 2*qB*qC + 2*qA*qD + delRi[:,:,1,1] = qA*qA-qB*qB+qC*qC-qD*qD + delRi[:,:,1,2] = 2*qC*qD - 2*qA*qB + delRi[:,:,2,0] = 2*qB*qD - 2*qA*qC + delRi[:,:,2,1] = 2*qC*qD + 2*qA*qB + delRi[:,:,2,2] = qA*qA-qB*qB-qC*qC+qD*qD + + Ri = einsum('bnij,bnjk->bnik', delRi, R_in) + Ti = delTi + T_in #einsum('bnij,bnj->bni', delRi, T_in) + delTi + + alpha = self.sc_predictor(msa[:,0], state) + return Ri, Ti, state, alpha + +class IterBlock(nn.Module): + def __init__(self, d_msa=256, d_pair=128, + n_head_msa=8, n_head_pair=4, + use_global_attn=False, + d_hidden=32, d_hidden_msa=None, p_drop=0.15, + SE3_param={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32}): + super(IterBlock, self).__init__() + if d_hidden_msa == None: + d_hidden_msa = d_hidden + + self.msa2msa = MSAPairStr2MSA(d_msa=d_msa, d_pair=d_pair, + n_head=n_head_msa, + d_state=SE3_param['l0_out_features'], + use_global_attn=use_global_attn, + d_hidden=d_hidden_msa, p_drop=p_drop) + self.msa2pair = MSA2Pair(d_msa=d_msa, d_pair=d_pair, + d_hidden=d_hidden//2, p_drop=p_drop) + #d_hidden=d_hidden, p_drop=p_drop) + self.pair2pair = PairStr2Pair(d_pair=d_pair, n_head=n_head_pair, + d_hidden=d_hidden, p_drop=p_drop) + self.str2str = Str2Str(d_msa=d_msa, d_pair=d_pair, + d_state=SE3_param['l0_out_features'], + SE3_param=SE3_param, + p_drop=p_drop) + + def forward(self, msa, pair, R_in, T_in, xyz, state, idx, motif_mask, use_checkpoint=False): + rbf_feat = rbf(torch.cdist(xyz[:,:,1,:], xyz[:,:,1,:])) + if use_checkpoint: + msa = checkpoint.checkpoint(create_custom_forward(self.msa2msa), msa, pair, rbf_feat, state) + pair = checkpoint.checkpoint(create_custom_forward(self.msa2pair), msa, pair) + pair = checkpoint.checkpoint(create_custom_forward(self.pair2pair), pair, rbf_feat) + R, T, state, alpha = checkpoint.checkpoint(create_custom_forward(self.str2str, top_k=0), msa, pair, R_in, T_in, xyz, state, idx, motif_mask) + else: + msa = self.msa2msa(msa, pair, rbf_feat, state) + pair = self.msa2pair(msa, pair) + pair = self.pair2pair(pair, rbf_feat) + R, T, state, alpha = self.str2str(msa, pair, R_in, T_in, xyz, state, idx, motif_mask=motif_mask, top_k=0) + + return msa, pair, R, T, state, alpha + +class IterativeSimulator(nn.Module): + def __init__(self, n_extra_block=4, n_main_block=12, n_ref_block=4, + d_msa=256, d_msa_full=64, d_pair=128, d_hidden=32, + n_head_msa=8, n_head_pair=4, + SE3_param_full={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32}, + SE3_param_topk={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32}, + p_drop=0.15): + super(IterativeSimulator, self).__init__() + self.n_extra_block = n_extra_block + self.n_main_block = n_main_block + self.n_ref_block = n_ref_block + + self.proj_state = nn.Linear(SE3_param_topk['l0_out_features'], SE3_param_full['l0_out_features']) + # Update with extra sequences + if n_extra_block > 0: + self.extra_block = nn.ModuleList([IterBlock(d_msa=d_msa_full, d_pair=d_pair, + n_head_msa=n_head_msa, + n_head_pair=n_head_pair, + d_hidden_msa=8, + d_hidden=d_hidden, + p_drop=p_drop, + use_global_attn=True, + SE3_param=SE3_param_full) + for i in range(n_extra_block)]) + + # Update with seed sequences + if n_main_block > 0: + self.main_block = nn.ModuleList([IterBlock(d_msa=d_msa, d_pair=d_pair, + n_head_msa=n_head_msa, + n_head_pair=n_head_pair, + d_hidden=d_hidden, + p_drop=p_drop, + use_global_attn=False, + SE3_param=SE3_param_full) + for i in range(n_main_block)]) + + self.proj_state2 = nn.Linear(SE3_param_full['l0_out_features'], SE3_param_topk['l0_out_features']) + # Final SE(3) refinement + if n_ref_block > 0: + self.str_refiner = Str2Str(d_msa=d_msa, d_pair=d_pair, + d_state=SE3_param_topk['l0_out_features'], + SE3_param=SE3_param_topk, + p_drop=p_drop) + + self.reset_parameter() + def reset_parameter(self): + self.proj_state = init_lecun_normal(self.proj_state) + nn.init.zeros_(self.proj_state.bias) + self.proj_state2 = init_lecun_normal(self.proj_state2) + nn.init.zeros_(self.proj_state2.bias) + + def forward(self, seq, msa, msa_full, pair, xyz_in, state, idx, use_checkpoint=False, motif_mask=None): + """ + input: + seq: query sequence (B, L) + msa: seed MSA embeddings (B, N, L, d_msa) + msa_full: extra MSA embeddings (B, N, L, d_msa_full) + pair: initial residue pair embeddings (B, L, L, d_pair) + xyz_in: initial BB coordinates (B, L, n_atom, 3) + state: initial state features containing mixture of query seq, sidechain, accuracy info (B, L, d_state) + idx: residue index + motif_mask: bool tensor, True if motif position that is frozen, else False(L,) + """ + + B, L = pair.shape[:2] + + if motif_mask is None: + motif_mask = torch.zeros(L).bool() + + R_in = torch.eye(3, device=xyz_in.device).reshape(1,1,3,3).expand(B, L, -1, -1) + T_in = xyz_in[:,:,1].clone() + xyz_in = xyz_in - T_in.unsqueeze(-2) + + state = self.proj_state(state) + + R_s = list() + T_s = list() + alpha_s = list() + for i_m in range(self.n_extra_block): + R_in = R_in.detach() # detach rotation (for stability) + T_in = T_in.detach() + # Get current BB structure + xyz = einsum('bnij,bnaj->bnai', R_in, xyz_in) + T_in.unsqueeze(-2) + + msa_full, pair, R_in, T_in, state, alpha = self.extra_block[i_m](msa_full, + pair, + R_in, + T_in, + xyz, + state, + idx, + motif_mask=motif_mask, + use_checkpoint=use_checkpoint) + R_s.append(R_in) + T_s.append(T_in) + alpha_s.append(alpha) + + for i_m in range(self.n_main_block): + R_in = R_in.detach() + T_in = T_in.detach() + # Get current BB structure + xyz = einsum('bnij,bnaj->bnai', R_in, xyz_in) + T_in.unsqueeze(-2) + + msa, pair, R_in, T_in, state, alpha = self.main_block[i_m](msa, + pair, + R_in, + T_in, + xyz, + state, + idx, + motif_mask=motif_mask, + use_checkpoint=use_checkpoint) + R_s.append(R_in) + T_s.append(T_in) + alpha_s.append(alpha) + + state = self.proj_state2(state) + for i_m in range(self.n_ref_block): + R_in = R_in.detach() + T_in = T_in.detach() + xyz = einsum('bnij,bnaj->bnai', R_in, xyz_in) + T_in.unsqueeze(-2) + R_in, T_in, state, alpha = self.str_refiner(msa, + pair, + R_in, + T_in, + xyz, + state, + idx, + top_k=64, + motif_mask=motif_mask) + R_s.append(R_in) + T_s.append(T_in) + alpha_s.append(alpha) + + R_s = torch.stack(R_s, dim=0) + T_s = torch.stack(T_s, dim=0) + alpha_s = torch.stack(alpha_s, dim=0) + + return msa, pair, R_s, T_s, alpha_s, state diff --git a/rfdiffusion/__init__.py b/rfdiffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/rfdiffusion/__pycache__/Attention_module.cpython-310.pyc b/rfdiffusion/__pycache__/Attention_module.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6f7ad81ca67e6074b962cf858b77a4e8bf07315 Binary files /dev/null and b/rfdiffusion/__pycache__/Attention_module.cpython-310.pyc differ diff --git a/rfdiffusion/__pycache__/Attention_module.cpython-311.pyc b/rfdiffusion/__pycache__/Attention_module.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd0bb3c4e718b32be4222c5b14bdbc4bd804e779 Binary files /dev/null and b/rfdiffusion/__pycache__/Attention_module.cpython-311.pyc differ diff --git a/rfdiffusion/__pycache__/Attention_module.cpython-39.pyc b/rfdiffusion/__pycache__/Attention_module.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..51efc67c2bb5ce641da6fd8e0435f8c37676aced Binary files /dev/null and b/rfdiffusion/__pycache__/Attention_module.cpython-39.pyc differ diff --git a/rfdiffusion/__pycache__/AuxiliaryPredictor.cpython-310.pyc b/rfdiffusion/__pycache__/AuxiliaryPredictor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb52ac65960016dab46c7a48dc607b306bb4685e Binary files /dev/null and b/rfdiffusion/__pycache__/AuxiliaryPredictor.cpython-310.pyc differ diff --git a/rfdiffusion/__pycache__/AuxiliaryPredictor.cpython-311.pyc b/rfdiffusion/__pycache__/AuxiliaryPredictor.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..19ef527060ecef2c8a669f476fc86474869a6154 Binary files /dev/null and b/rfdiffusion/__pycache__/AuxiliaryPredictor.cpython-311.pyc differ diff --git a/rfdiffusion/__pycache__/AuxiliaryPredictor.cpython-39.pyc b/rfdiffusion/__pycache__/AuxiliaryPredictor.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..217a4e86ce6c90111bd33876572fb35bb1004625 Binary files /dev/null and b/rfdiffusion/__pycache__/AuxiliaryPredictor.cpython-39.pyc differ diff --git a/rfdiffusion/__pycache__/Embeddings.cpython-310.pyc b/rfdiffusion/__pycache__/Embeddings.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b077ec060ac8a58dedb393a78561773eabfae964 Binary files /dev/null and b/rfdiffusion/__pycache__/Embeddings.cpython-310.pyc differ diff --git a/rfdiffusion/__pycache__/Embeddings.cpython-311.pyc b/rfdiffusion/__pycache__/Embeddings.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..377061eec2954e95d2e8596c681d32502682464c Binary files /dev/null and b/rfdiffusion/__pycache__/Embeddings.cpython-311.pyc differ diff --git a/rfdiffusion/__pycache__/Embeddings.cpython-39.pyc b/rfdiffusion/__pycache__/Embeddings.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6566db703030cd0b68887ac5e49fffc403651e0f Binary files /dev/null and b/rfdiffusion/__pycache__/Embeddings.cpython-39.pyc differ diff --git a/rfdiffusion/__pycache__/RoseTTAFoldModel.cpython-310.pyc b/rfdiffusion/__pycache__/RoseTTAFoldModel.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3693ce73ba5953de8e97724c26a894fff91afe9f Binary files /dev/null and b/rfdiffusion/__pycache__/RoseTTAFoldModel.cpython-310.pyc differ diff --git a/rfdiffusion/__pycache__/RoseTTAFoldModel.cpython-311.pyc b/rfdiffusion/__pycache__/RoseTTAFoldModel.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e0f5353e02191dfe86afa3abb6281fc168389a3 Binary files /dev/null and b/rfdiffusion/__pycache__/RoseTTAFoldModel.cpython-311.pyc differ diff --git a/rfdiffusion/__pycache__/RoseTTAFoldModel.cpython-39.pyc b/rfdiffusion/__pycache__/RoseTTAFoldModel.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fcd5a3721210c324b5f826f3d543ceee904d4995 Binary files /dev/null and b/rfdiffusion/__pycache__/RoseTTAFoldModel.cpython-39.pyc differ diff --git a/rfdiffusion/__pycache__/SE3_network.cpython-310.pyc b/rfdiffusion/__pycache__/SE3_network.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70774a5a80e4826eaceb64bcaabd0c15450bd911 Binary files /dev/null and b/rfdiffusion/__pycache__/SE3_network.cpython-310.pyc differ diff --git a/rfdiffusion/__pycache__/SE3_network.cpython-311.pyc b/rfdiffusion/__pycache__/SE3_network.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bbad742b5217db3a42b0823cd96eb0ddd1abf3d5 Binary files /dev/null and b/rfdiffusion/__pycache__/SE3_network.cpython-311.pyc differ diff --git a/rfdiffusion/__pycache__/SE3_network.cpython-39.pyc b/rfdiffusion/__pycache__/SE3_network.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..729d9be88f56c0d6fd8470da9d7067e1484d7ca2 Binary files /dev/null and b/rfdiffusion/__pycache__/SE3_network.cpython-39.pyc differ diff --git a/rfdiffusion/__pycache__/Track_module.cpython-310.pyc b/rfdiffusion/__pycache__/Track_module.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92824bac1f7df749d3471acb8419bd540ad95ecc Binary files /dev/null and b/rfdiffusion/__pycache__/Track_module.cpython-310.pyc differ diff --git a/rfdiffusion/__pycache__/Track_module.cpython-311.pyc b/rfdiffusion/__pycache__/Track_module.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb8073fc0539014e2326e90a09f957c4f1f6bc25 Binary files /dev/null and b/rfdiffusion/__pycache__/Track_module.cpython-311.pyc differ diff --git a/rfdiffusion/__pycache__/Track_module.cpython-39.pyc b/rfdiffusion/__pycache__/Track_module.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4aef2a70d953e046431e6e59cb80a78cdd02609 Binary files /dev/null and b/rfdiffusion/__pycache__/Track_module.cpython-39.pyc differ diff --git a/rfdiffusion/__pycache__/__init__.cpython-310.pyc b/rfdiffusion/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78d1a1a3dfd501b2188bd9dd10008edff25a928b Binary files /dev/null and b/rfdiffusion/__pycache__/__init__.cpython-310.pyc differ diff --git a/rfdiffusion/__pycache__/__init__.cpython-311.pyc b/rfdiffusion/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..abaa4d5fe15da4d6fe71cbe1c43a00e8b08e5c6d Binary files /dev/null and b/rfdiffusion/__pycache__/__init__.cpython-311.pyc differ diff --git a/rfdiffusion/__pycache__/__init__.cpython-39.pyc b/rfdiffusion/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae4912cb9ccdc00fda14f5fa751c54a882b9a903 Binary files /dev/null and b/rfdiffusion/__pycache__/__init__.cpython-39.pyc differ diff --git a/rfdiffusion/__pycache__/chemical.cpython-310.pyc b/rfdiffusion/__pycache__/chemical.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..677d99a9dabcc7af7a203ecdf73d8d70ce11e78e Binary files /dev/null and b/rfdiffusion/__pycache__/chemical.cpython-310.pyc differ diff --git a/rfdiffusion/__pycache__/chemical.cpython-311.pyc b/rfdiffusion/__pycache__/chemical.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c683d9eab75764f930f67a927db77a446e078ed2 Binary files /dev/null and b/rfdiffusion/__pycache__/chemical.cpython-311.pyc differ diff --git a/rfdiffusion/__pycache__/chemical.cpython-39.pyc b/rfdiffusion/__pycache__/chemical.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d56a0665cfa36c3d312bb1263475b54777953b3 Binary files /dev/null and b/rfdiffusion/__pycache__/chemical.cpython-39.pyc differ diff --git a/rfdiffusion/__pycache__/contigs.cpython-310.pyc b/rfdiffusion/__pycache__/contigs.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..13f56a47c0d0de773e7f9d2bf444fb20828c37c6 Binary files /dev/null and b/rfdiffusion/__pycache__/contigs.cpython-310.pyc differ diff --git a/rfdiffusion/__pycache__/contigs.cpython-311.pyc b/rfdiffusion/__pycache__/contigs.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8152f3a1102d88e2cc8c1d47202c5506a4574b40 Binary files /dev/null and b/rfdiffusion/__pycache__/contigs.cpython-311.pyc differ diff --git a/rfdiffusion/__pycache__/contigs.cpython-39.pyc b/rfdiffusion/__pycache__/contigs.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2768a06dbbb9399504df3d86cb81b34f8e8a86c3 Binary files /dev/null and b/rfdiffusion/__pycache__/contigs.cpython-39.pyc differ diff --git a/rfdiffusion/__pycache__/diffusion.cpython-310.pyc b/rfdiffusion/__pycache__/diffusion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8744f8a24faad61659533849b72cec89a1a002fa Binary files /dev/null and b/rfdiffusion/__pycache__/diffusion.cpython-310.pyc differ diff --git a/rfdiffusion/__pycache__/diffusion.cpython-311.pyc b/rfdiffusion/__pycache__/diffusion.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb3e3bff931286445e853900139391f1163dbaf7 Binary files /dev/null and b/rfdiffusion/__pycache__/diffusion.cpython-311.pyc differ diff --git a/rfdiffusion/__pycache__/diffusion.cpython-39.pyc b/rfdiffusion/__pycache__/diffusion.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..afa03846086e5191b16da3260c4a317e52be33db Binary files /dev/null and b/rfdiffusion/__pycache__/diffusion.cpython-39.pyc differ diff --git a/rfdiffusion/__pycache__/igso3.cpython-310.pyc b/rfdiffusion/__pycache__/igso3.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..288abdf6c9a0dbc7636284f213f4b88c6bde8f18 Binary files /dev/null and b/rfdiffusion/__pycache__/igso3.cpython-310.pyc differ diff --git a/rfdiffusion/__pycache__/igso3.cpython-311.pyc b/rfdiffusion/__pycache__/igso3.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..00509b0b575aaebded85d8b8a91f9ebc89327063 Binary files /dev/null and b/rfdiffusion/__pycache__/igso3.cpython-311.pyc differ diff --git a/rfdiffusion/__pycache__/igso3.cpython-39.pyc b/rfdiffusion/__pycache__/igso3.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..37e2b2921d21db0bb1d1331a8fccddf936f4bda9 Binary files /dev/null and b/rfdiffusion/__pycache__/igso3.cpython-39.pyc differ diff --git a/rfdiffusion/__pycache__/kinematics.cpython-310.pyc b/rfdiffusion/__pycache__/kinematics.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c16e203a9ada86a895a44f872a6cd74ea0e9f1ee Binary files /dev/null and b/rfdiffusion/__pycache__/kinematics.cpython-310.pyc differ diff --git a/rfdiffusion/__pycache__/kinematics.cpython-311.pyc b/rfdiffusion/__pycache__/kinematics.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1116588ddbff9f42a829237826081f99171a0451 Binary files /dev/null and b/rfdiffusion/__pycache__/kinematics.cpython-311.pyc differ diff --git a/rfdiffusion/__pycache__/kinematics.cpython-39.pyc b/rfdiffusion/__pycache__/kinematics.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..704b462966c73af50fd7f3bcfe2d07eabee2aae3 Binary files /dev/null and b/rfdiffusion/__pycache__/kinematics.cpython-39.pyc differ diff --git a/rfdiffusion/__pycache__/model_input_logger.cpython-311.pyc b/rfdiffusion/__pycache__/model_input_logger.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af00f6e9eed5ae16ca594e0e1c0477ba2de8e9f6 Binary files /dev/null and b/rfdiffusion/__pycache__/model_input_logger.cpython-311.pyc differ diff --git a/rfdiffusion/__pycache__/model_input_logger.cpython-39.pyc b/rfdiffusion/__pycache__/model_input_logger.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..041331099ad20861843d054c19516b2e3ca566a4 Binary files /dev/null and b/rfdiffusion/__pycache__/model_input_logger.cpython-39.pyc differ diff --git a/rfdiffusion/__pycache__/scoring.cpython-310.pyc b/rfdiffusion/__pycache__/scoring.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a28cf9219d1fbf518718dd3f132102be0916d964 Binary files /dev/null and b/rfdiffusion/__pycache__/scoring.cpython-310.pyc differ diff --git a/rfdiffusion/__pycache__/scoring.cpython-311.pyc b/rfdiffusion/__pycache__/scoring.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70b9f872f6b9900aff8f478ed1ad25cf529a500f Binary files /dev/null and b/rfdiffusion/__pycache__/scoring.cpython-311.pyc differ diff --git a/rfdiffusion/__pycache__/scoring.cpython-39.pyc b/rfdiffusion/__pycache__/scoring.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..395a375143de21d5c2e971914ff72bdaf5fe1649 Binary files /dev/null and b/rfdiffusion/__pycache__/scoring.cpython-39.pyc differ diff --git a/rfdiffusion/__pycache__/util.cpython-310.pyc b/rfdiffusion/__pycache__/util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..300249b2bd2e0f3c6b56d21d08ab66181640ad68 Binary files /dev/null and b/rfdiffusion/__pycache__/util.cpython-310.pyc differ diff --git a/rfdiffusion/__pycache__/util.cpython-311.pyc b/rfdiffusion/__pycache__/util.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7fedb8cf0777593c340171a34d67320b4a779f50 Binary files /dev/null and b/rfdiffusion/__pycache__/util.cpython-311.pyc differ diff --git a/rfdiffusion/__pycache__/util.cpython-39.pyc b/rfdiffusion/__pycache__/util.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e67861024e7da8437fa3756136c4099f19aa6bc2 Binary files /dev/null and b/rfdiffusion/__pycache__/util.cpython-39.pyc differ diff --git a/rfdiffusion/__pycache__/util_module.cpython-310.pyc b/rfdiffusion/__pycache__/util_module.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6802c01b533b1eee58c85c7f6e5b3d8f885ee968 Binary files /dev/null and b/rfdiffusion/__pycache__/util_module.cpython-310.pyc differ diff --git a/rfdiffusion/__pycache__/util_module.cpython-311.pyc b/rfdiffusion/__pycache__/util_module.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa5b728ff9c80cb66fb56daec4b659a8cb0e2971 Binary files /dev/null and b/rfdiffusion/__pycache__/util_module.cpython-311.pyc differ diff --git a/rfdiffusion/__pycache__/util_module.cpython-39.pyc b/rfdiffusion/__pycache__/util_module.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c72a2db9c80d8d01390408dfe7d004df7cee4b08 Binary files /dev/null and b/rfdiffusion/__pycache__/util_module.cpython-39.pyc differ diff --git a/rfdiffusion/chemical.py b/rfdiffusion/chemical.py new file mode 100644 index 0000000000000000000000000000000000000000..dc36d156a0aa3d81c101add49604f4208ca269ce --- /dev/null +++ b/rfdiffusion/chemical.py @@ -0,0 +1,585 @@ +import torch +import numpy as np + +num2aa=[ + 'ALA','ARG','ASN','ASP','CYS', + 'GLN','GLU','GLY','HIS','ILE', + 'LEU','LYS','MET','PHE','PRO', + 'SER','THR','TRP','TYR','VAL', + 'UNK','MAS', + ] + +# Mapping 3 letter AA to 1 letter AA (e.g. ALA to A) +one_letter = ["A", "R", "N", "D", "C", \ + "Q", "E", "G", "H", "I", \ + "L", "K", "M", "F", "P", \ + "S", "T", "W", "Y", "V", "?", "-"] + +aa2num= {x:i for i,x in enumerate(num2aa)} + +aa_321 = {a:b for a,b in zip(num2aa,one_letter)} +aa_123 = {val:key for key,val in aa_321.items()} + + +# create single letter code string from parsed integer sequence +def seq2chars(seq): + out = ''.join([aa_321[num2aa[a]] for a in seq]) + return out + +# full sc atom representation (Nx14) +aa2long=[ + (" N "," CA "," C "," O "," CB ", None, None, None, None, None, None, None, None, None," H "," HA ","1HB ","2HB ","3HB ", None, None, None, None, None, None, None, None), # ala + (" N "," CA "," C "," O "," CB "," CG "," CD "," NE "," CZ "," NH1"," NH2", None, None, None," H "," HA ","1HB ","2HB ","1HG ","2HG ","1HD ","2HD "," HE ","1HH1","2HH1","1HH2","2HH2"), # arg + (" N "," CA "," C "," O "," CB "," CG "," OD1"," ND2", None, None, None, None, None, None," H "," HA ","1HB ","2HB ","1HD2","2HD2", None, None, None, None, None, None, None), # asn + (" N "," CA "," C "," O "," CB "," CG "," OD1"," OD2", None, None, None, None, None, None," H "," HA ","1HB ","2HB ", None, None, None, None, None, None, None, None, None), # asp + (" N "," CA "," C "," O "," CB "," SG ", None, None, None, None, None, None, None, None," H "," HA ","1HB ","2HB "," HG ", None, None, None, None, None, None, None, None), # cys + (" N "," CA "," C "," O "," CB "," CG "," CD "," OE1"," NE2", None, None, None, None, None," H "," HA ","1HB ","2HB ","1HG ","2HG ","1HE2","2HE2", None, None, None, None, None), # gln + (" N "," CA "," C "," O "," CB "," CG "," CD "," OE1"," OE2", None, None, None, None, None," H "," HA ","1HB ","2HB ","1HG ","2HG ", None, None, None, None, None, None, None), # glu + (" N "," CA "," C "," O ", None, None, None, None, None, None, None, None, None, None," H ","1HA ","2HA ", None, None, None, None, None, None, None, None, None, None), # gly + (" N "," CA "," C "," O "," CB "," CG "," ND1"," CD2"," CE1"," NE2", None, None, None, None," H "," HA ","1HB ","2HB "," HD2"," HE1"," HE2", None, None, None, None, None, None), # his + (" N "," CA "," C "," O "," CB "," CG1"," CG2"," CD1", None, None, None, None, None, None," H "," HA "," HB ","1HG2","2HG2","3HG2","1HG1","2HG1","1HD1","2HD1","3HD1", None, None), # ile + (" N "," CA "," C "," O "," CB "," CG "," CD1"," CD2", None, None, None, None, None, None," H "," HA ","1HB ","2HB "," HG ","1HD1","2HD1","3HD1","1HD2","2HD2","3HD2", None, None), # leu + (" N "," CA "," C "," O "," CB "," CG "," CD "," CE "," NZ ", None, None, None, None, None," H "," HA ","1HB ","2HB ","1HG ","2HG ","1HD ","2HD ","1HE ","2HE ","1HZ ","2HZ ","3HZ "), # lys + (" N "," CA "," C "," O "," CB "," CG "," SD "," CE ", None, None, None, None, None, None," H "," HA ","1HB ","2HB ","1HG ","2HG ","1HE ","2HE ","3HE ", None, None, None, None), # met + (" N "," CA "," C "," O "," CB "," CG "," CD1"," CD2"," CE1"," CE2"," CZ ", None, None, None," H "," HA ","1HB ","2HB "," HD1"," HD2"," HE1"," HE2"," HZ ", None, None, None, None), # phe + (" N "," CA "," C "," O "," CB "," CG "," CD ", None, None, None, None, None, None, None," HA ","1HB ","2HB ","1HG ","2HG ","1HD ","2HD ", None, None, None, None, None, None), # pro + (" N "," CA "," C "," O "," CB "," OG ", None, None, None, None, None, None, None, None," H "," HG "," HA ","1HB ","2HB ", None, None, None, None, None, None, None, None), # ser + (" N "," CA "," C "," O "," CB "," OG1"," CG2", None, None, None, None, None, None, None," H "," HG1"," HA "," HB ","1HG2","2HG2","3HG2", None, None, None, None, None, None), # thr + (" N "," CA "," C "," O "," CB "," CG "," CD1"," CD2"," NE1"," CE2"," CE3"," CZ2"," CZ3"," CH2"," H "," HA ","1HB ","2HB "," HD1"," HE1"," HZ2"," HH2"," HZ3"," HE3", None, None, None), # trp + (" N "," CA "," C "," O "," CB "," CG "," CD1"," CD2"," CE1"," CE2"," CZ "," OH ", None, None," H "," HA ","1HB ","2HB "," HD1"," HE1"," HE2"," HD2"," HH ", None, None, None, None), # tyr + (" N "," CA "," C "," O "," CB "," CG1"," CG2", None, None, None, None, None, None, None," H "," HA "," HB ","1HG1","2HG1","3HG1","1HG2","2HG2","3HG2", None, None, None, None), # val + (" N "," CA "," C "," O "," CB ", None, None, None, None, None, None, None, None, None," H "," HA ","1HB ","2HB ","3HB ", None, None, None, None, None, None, None, None), # unk + (" N "," CA "," C "," O "," CB ", None, None, None, None, None, None, None, None, None," H "," HA ","1HB ","2HB ","3HB ", None, None, None, None, None, None, None, None), # mask +] + +# build the "alternate" sc mapping +aa2longalt=[ + (" N "," CA "," C "," O "," CB ", None, None, None, None, None, None, None, None, None," H "," HA ","1HB ","2HB ","3HB ", None, None, None, None, None, None, None, None), # ala + (" N "," CA "," C "," O "," CB "," CG "," CD "," NE "," CZ "," NH1"," NH2", None, None, None," H "," HA ","1HB ","2HB ","1HG ","2HG ","1HD ","2HD "," HE ","1HH1","2HH1","1HH2","2HH2"), # arg + (" N "," CA "," C "," O "," CB "," CG "," OD1"," ND2", None, None, None, None, None, None," H "," HA ","1HB ","2HB ","1HD2","2HD2", None, None, None, None, None, None, None), # asn + (" N "," CA "," C "," O "," CB "," CG "," OD2"," OD1", None, None, None, None, None, None," H "," HA ","1HB ","2HB ", None, None, None, None, None, None, None, None, None), # asp + (" N "," CA "," C "," O "," CB "," SG ", None, None, None, None, None, None, None, None," H "," HA ","1HB ","2HB "," HG ", None, None, None, None, None, None, None, None), # cys + (" N "," CA "," C "," O "," CB "," CG "," CD "," OE1"," NE2", None, None, None, None, None," H "," HA ","1HB ","2HB ","1HG ","2HG ","1HE2","2HE2", None, None, None, None, None), # gln + (" N "," CA "," C "," O "," CB "," CG "," CD "," OE2"," OE1", None, None, None, None, None," H "," HA ","1HB ","2HB ","1HG ","2HG ", None, None, None, None, None, None, None), # glu + (" N "," CA "," C "," O ", None, None, None, None, None, None, None, None, None, None," H ","1HA ","2HA ", None, None, None, None, None, None, None, None, None, None), # gly + (" N "," CA "," C "," O "," CB "," CG "," NE2"," CD2"," CE1"," ND1", None, None, None, None," H "," HA ","1HB ","2HB "," HD2"," HE1"," HE2", None, None, None, None, None, None), # his + (" N "," CA "," C "," O "," CB "," CG1"," CG2"," CD1", None, None, None, None, None, None," H "," HA "," HB ","1HG2","2HG2","3HG2","1HG1","2HG1","1HD1","2HD1","3HD1", None, None), # ile + (" N "," CA "," C "," O "," CB "," CG "," CD1"," CD2", None, None, None, None, None, None," H "," HA ","1HB ","2HB "," HG ","1HD1","2HD1","3HD1","1HD2","2HD2","3HD2", None, None), # leu + (" N "," CA "," C "," O "," CB "," CG "," CD "," CE "," NZ ", None, None, None, None, None," H "," HA ","1HB ","2HB ","1HG ","2HG ","1HD ","2HD ","1HE ","2HE ","1HZ ","2HZ ","3HZ "), # lys + (" N "," CA "," C "," O "," CB "," CG "," SD "," CE ", None, None, None, None, None, None," H "," HA ","1HB ","2HB ","1HG ","2HG ","1HE ","2HE ","3HE ", None, None, None, None), # met + (" N "," CA "," C "," O "," CB "," CG "," CD2"," CD1"," CE2"," CE1"," CZ ", None, None, None," H "," HD2"," HE2"," HZ "," HE1"," HD1"," HA ","1HB ","2HB ", None, None, None, None), # phe + (" N "," CA "," C "," O "," CB "," CG "," CD ", None, None, None, None, None, None, None," HA ","1HB ","2HB ","1HG ","2HG ","1HD ","2HD ", None, None, None, None, None, None), # pro + (" N "," CA "," C "," O "," CB "," OG ", None, None, None, None, None, None, None, None," H "," HG "," HA ","1HB ","2HB ", None, None, None, None, None, None, None, None), # ser + (" N "," CA "," C "," O "," CB "," OG1"," CG2", None, None, None, None, None, None, None," H "," HG1"," HA "," HB ","1HG2","2HG2","3HG2", None, None, None, None, None, None), # thr + (" N "," CA "," C "," O "," CB "," CG "," CD1"," CD2"," NE1"," CE2"," CE3"," CZ2"," CZ3"," CH2"," H "," HA ","1HB ","2HB "," HD1"," HE1"," HZ2"," HH2"," HZ3"," HE3", None, None, None), # trp + (" N "," CA "," C "," O "," CB "," CG "," CD2"," CD1"," CE2"," CE1"," CZ "," OH ", None, None," H "," HA ","1HB ","2HB "," HD2"," HE2"," HE1"," HD1"," HH ", None, None, None, None), # tyr + (" N "," CA "," C "," O "," CB "," CG1"," CG2", None, None, None, None, None, None, None," H "," HA "," HB ","1HG1","2HG1","3HG1","1HG2","2HG2","3HG2", None, None, None, None), # val + (" N "," CA "," C "," O "," CB ", None, None, None, None, None, None, None, None, None," H "," HA ","1HB ","2HB ","3HB ", None, None, None, None, None, None, None, None), # unk + (" N "," CA "," C "," O "," CB ", None, None, None, None, None, None, None, None, None," H "," HA ","1HB ","2HB ","3HB ", None, None, None, None, None, None, None, None), # mask +] + +aabonds=[ + ((" N "," CA "),(" N "," H "),(" CA "," C "),(" CA "," CB "),(" CA "," HA "),(" C "," O "),(" CB ","1HB "),(" CB ","2HB "),(" CB ","3HB ")) , # ala + ((" N "," CA "),(" N "," H "),(" CA "," C "),(" CA "," CB "),(" CA "," HA "),(" C "," O "),(" CB "," CG "),(" CB ","1HB "),(" CB ","2HB "),(" CG "," CD "),(" CG ","1HG "),(" CG ","2HG "),(" CD "," NE "),(" CD ","1HD "),(" CD ","2HD "),(" NE "," CZ "),(" NE "," HE "),(" CZ "," NH1"),(" CZ "," NH2"),(" NH1","1HH1"),(" NH1","2HH1"),(" NH2","1HH2"),(" NH2","2HH2")) , # arg + ((" N "," CA "),(" N "," H "),(" CA "," C "),(" CA "," CB "),(" CA "," HA "),(" C "," O "),(" CB "," CG "),(" CB ","1HB "),(" CB ","2HB "),(" CG "," OD1"),(" CG "," ND2"),(" ND2","1HD2"),(" ND2","2HD2")) , # asn + ((" N "," CA "),(" N "," H "),(" CA "," C "),(" CA "," CB "),(" CA "," HA "),(" C "," O "),(" CB "," CG "),(" CB ","1HB "),(" CB ","2HB "),(" CG "," OD1"),(" CG "," OD2")) , # asp + ((" N "," CA "),(" N "," H "),(" CA "," C "),(" CA "," CB "),(" CA "," HA "),(" C "," O "),(" CB "," SG "),(" CB ","1HB "),(" CB ","2HB "),(" SG "," HG ")) , # cys + ((" N "," CA "),(" N "," H "),(" CA "," C "),(" CA "," CB "),(" CA "," HA "),(" C "," O "),(" CB "," CG "),(" CB ","1HB "),(" CB ","2HB "),(" CG "," CD "),(" CG ","1HG "),(" CG ","2HG "),(" CD "," OE1"),(" CD "," NE2"),(" NE2","1HE2"),(" NE2","2HE2")) , # gln + ((" N "," CA "),(" N "," H "),(" CA "," C "),(" CA "," CB "),(" CA "," HA "),(" C "," O "),(" CB "," CG "),(" CB ","1HB "),(" CB ","2HB "),(" CG "," CD "),(" CG ","1HG "),(" CG ","2HG "),(" CD "," OE1"),(" CD "," OE2")) , # glu + ((" N "," CA "),(" N "," H "),(" CA "," C "),(" CA ","1HA "),(" CA ","2HA "),(" C "," O ")) , # gly + ((" N "," CA "),(" N "," H "),(" CA "," C "),(" CA "," CB "),(" CA "," HA "),(" C "," O "),(" CB "," CG "),(" CB ","1HB "),(" CB ","2HB "),(" CG "," ND1"),(" CG "," CD2"),(" ND1"," CE1"),(" CD2"," NE2"),(" CD2"," HD2"),(" CE1"," NE2"),(" CE1"," HE1"),(" NE2"," HE2")) , # his + ((" N "," CA "),(" N "," H "),(" CA "," C "),(" CA "," CB "),(" CA "," HA "),(" C "," O "),(" CB "," CG1"),(" CB "," CG2"),(" CB "," HB "),(" CG1"," CD1"),(" CG1","1HG1"),(" CG1","2HG1"),(" CG2","1HG2"),(" CG2","2HG2"),(" CG2","3HG2"),(" CD1","1HD1"),(" CD1","2HD1"),(" CD1","3HD1")) , # ile + ((" N "," CA "),(" N "," H "),(" CA "," C "),(" CA "," CB "),(" CA "," HA "),(" C "," O "),(" CB "," CG "),(" CB ","1HB "),(" CB ","2HB "),(" CG "," CD1"),(" CG "," CD2"),(" CG "," HG "),(" CD1","1HD1"),(" CD1","2HD1"),(" CD1","3HD1"),(" CD2","1HD2"),(" CD2","2HD2"),(" CD2","3HD2")) , # leu + ((" N "," CA "),(" N "," H "),(" CA "," C "),(" CA "," CB "),(" CA "," HA "),(" C "," O "),(" CB "," CG "),(" CB ","1HB "),(" CB ","2HB "),(" CG "," CD "),(" CG ","1HG "),(" CG ","2HG "),(" CD "," CE "),(" CD ","1HD "),(" CD ","2HD "),(" CE "," NZ "),(" CE ","1HE "),(" CE ","2HE "),(" NZ ","1HZ "),(" NZ ","2HZ "),(" NZ ","3HZ ")) , # lys + ((" N "," CA "),(" N "," H "),(" CA "," C "),(" CA "," CB "),(" CA "," HA "),(" C "," O "),(" CB "," CG "),(" CB ","1HB "),(" CB ","2HB "),(" CG "," SD "),(" CG ","1HG "),(" CG ","2HG "),(" SD "," CE "),(" CE ","1HE "),(" CE ","2HE "),(" CE ","3HE ")) , # met + ((" N "," CA "),(" N "," H "),(" CA "," C "),(" CA "," CB "),(" CA "," HA "),(" C "," O "),(" CB "," CG "),(" CB ","1HB "),(" CB ","2HB "),(" CG "," CD1"),(" CG "," CD2"),(" CD1"," CE1"),(" CD1"," HD1"),(" CD2"," CE2"),(" CD2"," HD2"),(" CE1"," CZ "),(" CE1"," HE1"),(" CE2"," CZ "),(" CE2"," HE2"),(" CZ "," HZ ")) , # phe + ((" N "," CA "),(" N "," CD "),(" CA "," C "),(" CA "," CB "),(" CA "," HA "),(" C "," O "),(" CB "," CG "),(" CB ","1HB "),(" CB ","2HB "),(" CG "," CD "),(" CG ","1HG "),(" CG ","2HG "),(" CD ","1HD "),(" CD ","2HD ")) , # pro + ((" N "," CA "),(" N "," H "),(" CA "," C "),(" CA "," CB "),(" CA "," HA "),(" C "," O "),(" CB "," OG "),(" CB ","1HB "),(" CB ","2HB "),(" OG "," HG ")) , # ser + ((" N "," CA "),(" N "," H "),(" CA "," C "),(" CA "," CB "),(" CA "," HA "),(" C "," O "),(" CB "," OG1"),(" CB "," CG2"),(" CB "," HB "),(" OG1"," HG1"),(" CG2","1HG2"),(" CG2","2HG2"),(" CG2","3HG2")) , # thr + ((" N "," CA "),(" N "," H "),(" CA "," C "),(" CA "," CB "),(" CA "," HA "),(" C "," O "),(" CB "," CG "),(" CB ","1HB "),(" CB ","2HB "),(" CG "," CD1"),(" CG "," CD2"),(" CD1"," NE1"),(" CD1"," HD1"),(" CD2"," CE2"),(" CD2"," CE3"),(" NE1"," CE2"),(" NE1"," HE1"),(" CE2"," CZ2"),(" CE3"," CZ3"),(" CE3"," HE3"),(" CZ2"," CH2"),(" CZ2"," HZ2"),(" CZ3"," CH2"),(" CZ3"," HZ3"),(" CH2"," HH2")) , # trp + ((" N "," CA "),(" N "," H "),(" CA "," C "),(" CA "," CB "),(" CA "," HA "),(" C "," O "),(" CB "," CG "),(" CB ","1HB "),(" CB ","2HB "),(" CG "," CD1"),(" CG "," CD2"),(" CD1"," CE1"),(" CD1"," HD1"),(" CD2"," CE2"),(" CD2"," HD2"),(" CE1"," CZ "),(" CE1"," HE1"),(" CE2"," CZ "),(" CE2"," HE2"),(" CZ "," OH "),(" OH "," HH ")) , # tyr + ((" N "," CA "),(" N "," H "),(" CA "," C "),(" CA "," CB "),(" CA "," HA "),(" C "," O "),(" CB "," CG1"),(" CB "," CG2"),(" CB "," HB "),(" CG1","1HG1"),(" CG1","2HG1"),(" CG1","3HG1"),(" CG2","1HG2"),(" CG2","2HG2"),(" CG2","3HG2")), # val + ((" N "," CA "),(" N "," H "),(" CA "," C "),(" CA "," CB "),(" CA "," HA "),(" C "," O "),(" CB ","1HB "),(" CB ","2HB "),(" CB ","3HB ")) , # unk + ((" N "," CA "),(" N "," H "),(" CA "," C "),(" CA "," CB "),(" CA "," HA "),(" C "," O "),(" CB ","1HB "),(" CB ","2HB "),(" CB ","3HB ")) , # mask +] + +aa2type = [ + ("Nbb", "CAbb","CObb","OCbb","CH3", None, None, None, None, None, None, None, None, None,"HNbb","Hapo","Hapo","Hapo","Hapo", None, None, None, None, None, None, None, None), # ala + ("Nbb", "CAbb","CObb","OCbb","CH2", "CH2", "CH2", "NtrR","aroC","Narg","Narg", None, None, None,"HNbb","Hapo","Hapo","Hapo","Hapo","Hapo","Hapo","Hapo","Hpol","Hpol","Hpol","Hpol","Hpol"), # arg + ("Nbb", "CAbb","CObb","OCbb","CH2", "CNH2","ONH2","NH2O", None, None, None, None, None, None,"HNbb","Hapo","Hapo","Hapo","Hpol","Hpol", None, None, None, None, None, None, None), # asn + ("Nbb", "CAbb","CObb","OCbb","CH2", "COO", "OOC", "OOC", None, None, None, None, None, None,"HNbb","Hapo","Hapo","Hapo", None, None, None, None, None, None, None, None, None), # asp + ("Nbb", "CAbb","CObb","OCbb","CH2", "SH1", None, None, None, None, None, None, None, None,"HNbb","Hapo","Hapo","Hapo","HS", None, None, None, None, None, None, None, None), # cys + ("Nbb", "CAbb","CObb","OCbb","CH2", "CH2", "CNH2","ONH2","NH2O", None, None, None, None, None,"HNbb","Hapo","Hapo","Hapo","Hapo","Hapo","Hpol","Hpol", None, None, None, None, None), # gln + ("Nbb", "CAbb","CObb","OCbb","CH2", "CH2", "COO", "OOC", "OOC", None, None, None, None, None,"HNbb","Hapo","Hapo","Hapo","Hapo","Hapo", None, None, None, None, None, None, None), # glu + ("Nbb", "CAbb","CObb","OCbb", None, None, None, None, None, None, None, None, None, None,"HNbb","Hapo","Hapo", None, None, None, None, None, None, None, None, None, None), # gly + ("Nbb", "CAbb","CObb","OCbb","CH2", "CH0", "Nhis","aroC","aroC","Ntrp", None, None, None, None,"HNbb","Hapo","Hapo","Hapo","Hpol","Hapo","Hapo", None, None, None, None, None, None), # his + ("Nbb", "CAbb","CObb","OCbb","CH1", "CH2", "CH3", "CH3", None, None, None, None, None, None,"HNbb","Hapo","Hapo","Hapo","Hapo","Hapo","Hapo","Hapo","Hapo","Hapo","Hapo", None, None), # ile + ("Nbb", "CAbb","CObb","OCbb","CH2", "CH1", "CH3", "CH3", None, None, None, None, None, None,"HNbb","Hapo","Hapo","Hapo","Hapo","Hapo","Hapo","Hapo","Hapo","Hapo","Hapo", None, None), # leu + ("Nbb", "CAbb","CObb","OCbb","CH2", "CH2", "CH2", "CH2", "Nlys", None, None, None, None, None,"HNbb","Hapo","Hapo","Hapo","Hapo","Hapo","Hapo","Hapo","Hapo","Hapo","Hpol","Hpol","Hpol"), # lys + ("Nbb", "CAbb","CObb","OCbb","CH2", "CH2", "S", "CH3", None, None, None, None, None, None,"HNbb","Hapo","Hapo","Hapo","Hapo","Hapo","Hapo","Hapo","Hapo", None, None, None, None), # met + ("Nbb", "CAbb","CObb","OCbb","CH2", "CH0", "aroC","aroC","aroC","aroC","aroC", None, None, None,"HNbb","Hapo","Hapo","Hapo","Haro","Haro","Haro","Haro","Haro", None, None, None, None), # phe + ("Npro","CAbb","CObb","OCbb","CH2", "CH2", "CH2", None, None, None, None, None, None, None,"Hapo","Hapo","Hapo","Hapo","Hapo","Hapo","Hapo", None, None, None, None, None, None), # pro + ("Nbb", "CAbb","CObb","OCbb","CH2", "OH", None, None, None, None, None, None, None, None,"HNbb","Hpol","Hapo","Hapo","Hapo", None, None, None, None, None, None, None, None), # ser + ("Nbb", "CAbb","CObb","OCbb","CH1", "OH", "CH3", None, None, None, None, None, None, None,"HNbb","Hpol","Hapo","Hapo","Hapo","Hapo","Hapo", None, None, None, None, None, None), # thr + ("Nbb", "CAbb","CObb","OCbb","CH2", "CH0", "aroC","CH0", "Ntrp","CH0", "aroC","aroC","aroC","aroC","HNbb","Haro","Hapo","Hapo","Hapo","Hpol","Haro","Haro","Haro","Haro", None, None, None), # trp + ("Nbb", "CAbb","CObb","OCbb","CH2", "CH0", "aroC","aroC","aroC","aroC","CH0", "OHY", None, None,"HNbb","Haro","Haro","Haro","Haro","Hapo","Hapo","Hapo","Hpol", None, None, None, None), # tyr + ("Nbb", "CAbb","CObb","OCbb","CH1", "CH3", "CH3", None, None, None, None, None, None, None,"HNbb","Hapo","Hapo","Hapo","Hapo","Hapo","Hapo","Hapo","Hapo", None, None, None, None), # val + ("Nbb", "CAbb","CObb","OCbb","CH3", None, None, None, None, None, None, None, None, None,"HNbb","Hapo","Hapo","Hapo","Hapo", None, None, None, None, None, None, None, None), # unk + ("Nbb", "CAbb","CObb","OCbb","CH3", None, None, None, None, None, None, None, None, None,"HNbb","Hapo","Hapo","Hapo","Hapo", None, None, None, None, None, None, None, None), # mask +] + +# tip atom +aa2tip = [ + " CB ", # ala + " CZ ", # arg + " ND2", # asn + " CG ", # asp + " SG ", # cys + " NE2", # gln + " CD ", # glu + " CA ", # gly + " NE2", # his + " CD1", # ile + " CG ", # leu + " NZ ", # lys + " SD ", # met + " CZ ", # phe + " CG ", # pro + " OG ", # ser + " OG1", # thr + " CH2", # trp + " OH ", # tyr + " CB ", # val + " CB ", # unknown (gap etc) + " CB " # masked + ] + + +torsions=[ + [ None, None, None, None ], # ala + [ [" N "," CA "," CB "," CG "], [" CA "," CB "," CG "," CD "], [" CB "," CG "," CD "," NE "], [" CG "," CD "," NE "," CZ "] ], # arg + [ [" N "," CA "," CB "," CG "], [" CA "," CB "," CG "," OD1"], None, None ], # asn + [ [" N "," CA "," CB "," CG "], [" CA "," CB "," CG "," OD1"], None, None ], # asp + [ [" N "," CA "," CB "," SG "], [" CA "," CB "," SG "," HG "], None, None ], # cys + [ [" N "," CA "," CB "," CG "], [" CA "," CB "," CG "," CD "], [" CB "," CG "," CD "," OE1"], None ], # gln + [ [" N "," CA "," CB "," CG "], [" CA "," CB "," CG "," CD "], [" CB "," CG "," CD "," OE1"], None ], # glu + [ None, None, None, None ], # gly + [ [" N "," CA "," CB "," CG "], [" CA "," CB "," CG "," ND1"], [" CD2"," CE1"," HE1"," NE2"], None ], # his (protonation handled as a pseudo-torsion) + [ [" N "," CA "," CB "," CG1"], [" CA "," CB "," CG1"," CD1"], None, None ], # ile + [ [" N "," CA "," CB "," CG "], [" CA "," CB "," CG "," CD1"], None, None ], # leu + [ [" N "," CA "," CB "," CG "], [" CA "," CB "," CG "," CD "], [" CB "," CG "," CD "," CE "], [" CG "," CD "," CE "," NZ "] ], # lys + [ [" N "," CA "," CB "," CG "], [" CA "," CB "," CG "," SD "], [" CB "," CG "," SD "," CE "], None ], # met + [ [" N "," CA "," CB "," CG "], [" CA "," CB "," CG "," CD1"], None, None ], # phe + [ [" N "," CA "," CB "," CG "], [" CA "," CB "," CG "," CD "], [" CB "," CG "," CD ","1HD "], None ], # pro + [ [" N "," CA "," CB "," OG "], [" CA "," CB "," OG "," HG "], None, None ], # ser + [ [" N "," CA "," CB "," OG1"], [" CA "," CB "," OG1"," HG1"], None, None ], # thr + [ [" N "," CA "," CB "," CG "], [" CA "," CB "," CG "," CD1"], None, None ], # trp + [ [" N "," CA "," CB "," CG "], [" CA "," CB "," CG "," CD1"], [" CE1"," CZ "," OH "," HH "], None ], # tyr + [ [" N "," CA "," CB "," CG1"], None, None, None ], # val + [ None, None, None, None ], # unk + [ None, None, None, None ], # mask +] + +# ideal N, CA, C initial coordinates +init_N = torch.tensor([-0.5272, 1.3593, 0.000]).float() +init_CA = torch.zeros_like(init_N) +init_C = torch.tensor([1.5233, 0.000, 0.000]).float() +INIT_CRDS = torch.full((27, 3), np.nan) +INIT_CRDS[:3] = torch.stack((init_N, init_CA, init_C), dim=0) # (3,3) + +norm_N = init_N / (torch.norm(init_N, dim=-1, keepdim=True) + 1e-5) +norm_C = init_C / (torch.norm(init_C, dim=-1, keepdim=True) + 1e-5) +cos_ideal_NCAC = torch.sum(norm_N*norm_C, dim=-1) # cosine of ideal N-CA-C bond angle + +#fd Rosetta ideal coords +#fd - uses same "frame-building" as AF2 +ideal_coords = [ + [ # 0 ala + [' N ', 0, (-0.5272, 1.3593, 0.000)], + [' CA ', 0, (0.000, 0.000, 0.000)], + [' C ', 0, (1.5233, 0.000, 0.000)], + [' O ', 3, (0.6303, 1.0574, 0.000)], + [' H ', 2, (0.4920,-0.8821, 0.0000)], + [' HA ', 0, (-0.3341, -0.4928, 0.9132)], + [' CB ', 8, (-0.5289,-0.7734,-1.1991)], + ['1HB ', 8, (-0.1265, -1.7863, -1.1851)], + ['2HB ', 8, (-1.6173, -0.8147, -1.1541)], + ['3HB ', 8, (-0.2229, -0.2744, -2.1172)], + ], + [ # 1 arg + [' N ', 0, (-0.5272, 1.3593, 0.000)], + [' CA ', 0, (0.000, 0.000, 0.000)], + [' C ', 0, (1.5233, 0.000, 0.000)], + [' O ', 3, (0.6303, 1.0574, 0.000)], + [' H ', 2, (0.4920,-0.8821, 0.0000)], + [' HA ', 0, (-0.3467, -0.5055, 0.9018)], + [' CB ', 8, (-0.5042,-0.7698,-1.2118)], + ['1HB ', 4, ( 0.3635, -0.5318, 0.8781)], + ['2HB ', 4, ( 0.3639, -0.5323, -0.8789)], + [' CG ', 4, (0.6396,1.3794, 0.000)], + ['1HG ', 5, (0.3639, -0.5139, 0.8900)], + ['2HG ', 5, (0.3641, -0.5140, -0.8903)], + [' CD ', 5, (0.5492,1.3801, 0.000)], + ['1HD ', 6, (0.3637, -0.5135, 0.8895)], + ['2HD ', 6, (0.3636, -0.5134, -0.8893)], + [' NE ', 6, (0.5423,1.3491, 0.000)], + [' NH1', 7, (0.2012,2.2965, 0.000)], + [' NH2', 7, (2.0824,1.0030, 0.000)], + [' CZ ', 7, (0.7650,1.1090, 0.000)], + [' HE ', 7, (0.4701,-0.8955, 0.000)], + ['1HH1', 7, (-0.8059,2.3776, 0.000)], + ['1HH2', 7, (2.5160,0.0898, 0.000)], + ['2HH1', 7, (0.7745,3.1277, 0.000)], + ['2HH2', 7, (2.6554,1.8336, 0.000)], + ], + [ # 2 asn + [' N ', 0, (-0.5272, 1.3593, 0.000)], + [' CA ', 0, (0.000, 0.000, 0.000)], + [' C ', 0, (1.5233, 0.000, 0.000)], + [' O ', 3, (0.6303, 1.0574, 0.000)], + [' H ', 2, (0.4920,-0.8821, 0.0000)], + [' HA ', 0, (-0.3233, -0.4967, 0.9162)], + [' CB ', 8, (-0.5341,-0.7799,-1.1874)], + ['1HB ', 4, ( 0.3641, -0.5327, 0.8795)], + ['2HB ', 4, ( 0.3639, -0.5323, -0.8789)], + [' CG ', 4, (0.5778,1.3881, 0.000)], + [' ND2', 5, (0.5839,-1.1711, 0.000)], + [' OD1', 5, (0.6331,1.0620, 0.000)], + ['1HD2', 5, (1.5825, -1.2322, 0.000)], + ['2HD2', 5, (0.0323, -2.0046, 0.000)], + ], + [ # 3 asp + [' N ', 0, (-0.5272, 1.3593, 0.000)], + [' CA ', 0, (0.000, 0.000, 0.000)], + [' C ', 0, (1.5233, 0.000, 0.000)], + [' O ', 3, (0.6303, 1.0574, 0.000)], + [' H ', 2, (0.4920,-0.8821, 0.0000)], + [' HA ', 0, (-0.3233, -0.4967, 0.9162)], + [' CB ', 8, (-0.5162,-0.7757,-1.2144)], + ['1HB ', 4, ( 0.3639, -0.5324, 0.8791)], + ['2HB ', 4, ( 0.3640, -0.5325, -0.8792)], + [' CG ', 4, (0.5926,1.4028, 0.000)], + [' OD1', 5, (0.5746,1.0629, 0.000)], + [' OD2', 5, (0.5738,-1.0627, 0.000)], + ], + [ # 4 cys + [' N ', 0, (-0.5272, 1.3593, 0.000)], + [' CA ', 0, (0.000, 0.000, 0.000)], + [' C ', 0, (1.5233, 0.000, 0.000)], + [' O ', 3, (0.6303, 1.0574, 0.000)], + [' H ', 2, (0.4920,-0.8821, 0.0000)], + [' HA ', 0, (-0.3481, -0.5059, 0.9006)], + [' CB ', 8, (-0.5046,-0.7727,-1.2189)], + ['1HB ', 4, ( 0.3639, -0.5324, 0.8791)], + ['2HB ', 4, ( 0.3638, -0.5322, -0.8787)], + [' SG ', 4, (0.7386,1.6511, 0.000)], + [' HG ', 5, (0.1387,1.3221, 0.000)], + ], + [ # 5 gln + [' N ', 0, (-0.5272, 1.3593, 0.000)], + [' CA ', 0, (0.000, 0.000, 0.000)], + [' C ', 0, (1.5233, 0.000, 0.000)], + [' O ', 3, (0.6303, 1.0574, 0.000)], + [' H ', 2, (0.4920,-0.8821, 0.0000)], + [' HA ', 0, (-0.3363, -0.5013, 0.9074)], + [' CB ', 8, (-0.5226,-0.7776,-1.2109)], + ['1HB ', 4, ( 0.3638, -0.5323, 0.8789)], + ['2HB ', 4, ( 0.3638, -0.5322, -0.8788)], + [' CG ', 4, (0.6225,1.3857, 0.000)], + ['1HG ', 5, ( 0.3531, -0.5156, 0.8931)], + ['2HG ', 5, ( 0.3531, -0.5156, -0.8931)], + [' CD ', 5, (0.5788,1.4021, 0.000)], + [' NE2', 6, (0.5908,-1.1895, 0.000)], + [' OE1', 6, (0.6347,1.0584, 0.000)], + ['1HE2', 6, (1.5825, -1.2525, 0.000)], + ['2HE2', 6, (0.0380, -2.0229, 0.000)], + ], + [ # 6 glu + [' N ', 0, (-0.5272, 1.3593, 0.000)], + [' CA ', 0, (0.000, 0.000, 0.000)], + [' C ', 0, (1.5233, 0.000, 0.000)], + [' O ', 3, (0.6303, 1.0574, 0.000)], + [' H ', 2, (0.4920,-0.8821, 0.0000)], + [' HA ', 0, (-0.3363, -0.5013, 0.9074)], + [' CB ', 8, (-0.5197,-0.7737,-1.2137)], + ['1HB ', 4, ( 0.3638, -0.5323, 0.8789)], + ['2HB ', 4, ( 0.3638, -0.5322, -0.8788)], + [' CG ', 4, (0.6287,1.3862, 0.000)], + ['1HG ', 5, ( 0.3531, -0.5156, 0.8931)], + ['2HG ', 5, ( 0.3531, -0.5156, -0.8931)], + [' CD ', 5, (0.5850,1.3849, 0.000)], + [' OE1', 6, (0.5752,1.0618, 0.000)], + [' OE2', 6, (0.5741,-1.0635, 0.000)], + ], + [ # 7 gly + [' N ', 0, (-0.5272, 1.3593, 0.000)], + [' CA ', 0, (0.000, 0.000, 0.000)], + [' C ', 0, (1.5233, 0.000, 0.000)], + [' O ', 3, (0.6303, 1.0574, 0.000)], + [' H ', 2, (0.4920,-0.8821, 0.0000)], + ['1HA ', 0, ( -0.3676, -0.5329, 0.8771)], + ['2HA ', 0, ( -0.3674, -0.5325, -0.8765)], + ], + [ # 8 his + [' N ', 0, (-0.5272, 1.3593, 0.000)], + [' CA ', 0, (0.000, 0.000, 0.000)], + [' C ', 0, (1.5233, 0.000, 0.000)], + [' O ', 3, (0.6303, 1.0574, 0.000)], + [' H ', 2, (0.4920,-0.8821, 0.0000)], + [' HA ', 0, (-0.3299, -0.5180, 0.9001)], + [' CB ', 8, (-0.5163,-0.7809,-1.2129)], + ['1HB ', 4, ( 0.3640, -0.5325, 0.8793)], + ['2HB ', 4, ( 0.3637, -0.5321, -0.8786)], + [' CG ', 4, (0.6016,1.3710, 0.000)], + [' CD2', 5, (0.8918,-1.0184, 0.000)], + [' CE1', 5, (2.0299,0.8564, 0.000)], + [' HE1', 5, (2.8542, 1.5693, 0.000)], + [' HD2', 5, ( 0.6584, -2.0835, 0.000) ], + [' ND1', 6, (-1.8631, -1.0722, 0.000)], + [' NE2', 6, (-1.8625, 1.0707, 0.000)], + [' HE2', 6, (-1.5439, 2.0292, 0.000)], + ], + [ # 9 ile + [' N ', 0, (-0.5272, 1.3593, 0.000)], + [' CA ', 0, (0.000, 0.000, 0.000)], + [' C ', 0, (1.5233, 0.000, 0.000)], + [' O ', 3, (0.6303, 1.0574, 0.000)], + [' H ', 2, (0.4920,-0.8821, 0.0000)], + [' HA ', 0, (-0.3405, -0.5028, 0.9044)], + [' CB ', 8, (-0.5140,-0.7885,-1.2184)], + [' HB ', 4, (0.3637, -0.4714, 0.9125)], + [' CG1', 4, (0.5339,1.4348,0.000)], + [' CG2', 4, (0.5319,-0.7693,-1.1994)], + ['1HG2', 4, (1.6215, -0.7588, -1.1842)], + ['2HG2', 4, (0.1785, -1.7986, -1.1569)], + ['3HG2', 4, (0.1773, -0.3016, -2.1180)], + [' CD1', 5, (0.6106,1.3829, 0.000)], + ['1HG1', 5, (0.3637, -0.5338, 0.8774)], + ['2HG1', 5, (0.3640, -0.5322, -0.8793)], + ['1HD1', 5, (1.6978, 1.3006, 0.000)], + ['2HD1', 5, (0.2873, 1.9236, -0.8902)], + ['3HD1', 5, (0.2888, 1.9224, 0.8896)], + ], + [ # 10 leu + [' N ', 0, (-0.5272, 1.3593, 0.000)], + [' CA ', 0, (0.000, 0.000, 0.000)], + [' C ', 0, (1.525, -0.000, -0.000)], + [' O ', 3, (0.6303, 1.0574, 0.000)], + [' H ', 2, (0.4920,-0.8821, 0.0000)], + [' HA ', 0, (-0.3435, -0.5040, 0.9027)], + [' CB ', 8, (-0.5175,-0.7692,-1.2220)], + ['1HB ', 4, ( 0.3473, -0.5346, 0.8827)], + ['2HB ', 4, ( 0.3476, -0.5351, -0.8836)], + [' CG ', 4, (0.6652,1.3823, 0.000)], + [' CD1', 5, (0.5083,1.4353, 0.000)], + [' CD2', 5, (0.5079,-0.7600,1.2163)], + [' HG ', 5, (0.3640, -0.4825, -0.9075)], + ['1HD1', 5, (1.5984, 1.4353, 0.000)], + ['2HD1', 5, (0.1462, 1.9496, -0.8903)], + ['3HD1', 5, (0.1459, 1.9494, 0.8895)], + ['1HD2', 5, (1.5983, -0.7606, 1.2158)], + ['2HD2', 5, (0.1456, -0.2774, 2.1243)], + ['3HD2', 5, (0.1444, -1.7871, 1.1815)], + ], + [ # 11 lys + [' N ', 0, (-0.5272, 1.3593, 0.000)], + [' CA ', 0, (0.000, 0.000, 0.000)], + [' C ', 0, (1.5233, 0.000, 0.000)], + [' O ', 3, (0.6303, 1.0574, 0.000)], + [' H ', 2, (0.4920,-0.8821, 0.0000)], + [' HA ', 0, (-0.3335, -0.5005, 0.9097)], + ['1HB ', 4, ( 0.3640, -0.5324, 0.8791)], + ['2HB ', 4, ( 0.3639, -0.5324, -0.8790)], + [' CB ', 8, (-0.5259,-0.7785,-1.2069)], + ['1HG ', 5, (0.3641, -0.5229, 0.8852)], + ['2HG ', 5, (0.3637, -0.5227, -0.8841)], + [' CG ', 4, (0.6291,1.3869, 0.000)], + [' CD ', 5, (0.5526,1.4174, 0.000)], + ['1HD ', 6, (0.3641, -0.5239, 0.8848)], + ['2HD ', 6, (0.3638, -0.5219, -0.8850)], + [' CE ', 6, (0.5544,1.4170, 0.000)], + [' NZ ', 7, (0.5566,1.3801, 0.000)], + ['1HE ', 7, (0.4199, -0.4638, 0.9482)], + ['2HE ', 7, (0.4202, -0.4631, -0.8172)], + ['1HZ ', 7, (1.6223, 1.3980, 0.0658)], + ['2HZ ', 7, (0.2970, 1.9326, -0.7584)], + ['3HZ ', 7, (0.2981, 1.9319, 0.8909)], + ], + [ # 12 met + [' N ', 0, (-0.5272, 1.3593, 0.000)], + [' CA ', 0, (0.000, 0.000, 0.000)], + [' C ', 0, (1.5233, 0.000, 0.000)], + [' O ', 3, (0.6303, 1.0574, 0.000)], + [' H ', 2, (0.4920,-0.8821, 0.0000)], + [' HA ', 0, (-0.3303, -0.4990, 0.9108)], + ['1HB ', 4, ( 0.3635, -0.5318, 0.8781)], + ['2HB ', 4, ( 0.3641, -0.5326, -0.8795)], + [' CB ', 8, (-0.5331,-0.7727,-1.2048)], + ['1HG ', 5, (0.3637, -0.5256, 0.8823)], + ['2HG ', 5, (0.3638, -0.5249, -0.8831)], + [' CG ', 4, (0.6298,1.3858,0.000)], + [' SD ', 5, (0.6953,1.6645,0.000)], + [' CE ', 6, (0.3383,1.7581,0.000)], + ['1HE ', 6, (1.7054, 2.0532, -0.0063)], + ['2HE ', 6, (0.1906, 2.3099, -0.9072)], + ['3HE ', 6, (0.1917, 2.3792, 0.8720)], + ], + [ # 13 phe + [' N ', 0, (-0.5272, 1.3593, 0.000)], + [' CA ', 0, (0.000, 0.000, 0.000)], + [' C ', 0, (1.5233, 0.000, 0.000)], + [' O ', 3, (0.6303, 1.0574, 0.000)], + [' H ', 2, (0.4920,-0.8821, 0.0000)], + [' HA ', 0, (-0.3303, -0.4990, 0.9108)], + ['1HB ', 4, ( 0.3635, -0.5318, 0.8781)], + ['2HB ', 4, ( 0.3641, -0.5326, -0.8795)], + [' CB ', 8, (-0.5150,-0.7729,-1.2156)], + [' CG ', 4, (0.6060,1.3746, 0.000)], + [' CD1', 5, (0.7078,1.1928, 0.000)], + [' CD2', 5, (0.7084,-1.1920, 0.000)], + [' CE1', 5, (2.0900,1.1940, 0.000)], + [' CE2', 5, (2.0897,-1.1939, 0.000)], + [' CZ ', 5, (2.7809, 0.000, 0.000)], + [' HD1', 5, (0.1613, 2.1362, 0.000)], + [' HD2', 5, (0.1621, -2.1360, 0.000)], + [' HE1', 5, (2.6335, 2.1384, 0.000)], + [' HE2', 5, (2.6344, -2.1378, 0.000)], + [' HZ ', 5, (3.8700, 0.000, 0.000)], + ], + [ # 14 pro + [' N ', 0, (-0.5272, 1.3593, 0.000)], + [' CA ', 0, (0.000, 0.000, 0.000)], + [' C ', 0, (1.5233, 0.000, 0.000)], + [' O ', 3, (0.6303, 1.0574, 0.000)], + [' HA ', 0, (-0.3868, -0.5380, 0.8781)], + ['1HB ', 4, ( 0.3762, -0.5355, 0.8842)], + ['2HB ', 4, ( 0.3762, -0.5355, -0.8842)], + [' CB ', 8, (-0.5649,-0.5888,-1.2966)], + [' CG ', 4, (0.3657,1.4451,0.0000)], + [' CD ', 5, (0.3744,1.4582, 0.0)], + ['1HG ', 5, (0.3798, -0.5348, 0.8830)], + ['2HG ', 5, (0.3798, -0.5348, -0.8830)], + ['1HD ', 6, (0.3798, -0.5348, 0.8830)], + ['2HD ', 6, (0.3798, -0.5348, -0.8830)], + ], + [ # 15 ser + [' N ', 0, (-0.5272, 1.3593, 0.000)], + [' CA ', 0, (0.000, 0.000, 0.000)], + [' C ', 0, (1.5233, 0.000, 0.000)], + [' O ', 3, (0.6303, 1.0574, 0.000)], + [' H ', 2, (0.4920,-0.8821, 0.0000)], + [' HA ', 0, (-0.3425, -0.5041, 0.9048)], + ['1HB ', 4, ( 0.3637, -0.5321, 0.8786)], + ['2HB ', 4, ( 0.3636, -0.5319, -0.8782)], + [' CB ', 8, (-0.5146,-0.7595,-1.2073)], + [' OG ', 4, (0.5021,1.3081, 0.000)], + [' HG ', 5, (0.2647, 0.9230, 0.000)], + ], + [ # 16 thr + [' N ', 0, (-0.5272, 1.3593, 0.000)], + [' CA ', 0, (0.000, 0.000, 0.000)], + [' C ', 0, (1.5233, 0.000, 0.000)], + [' O ', 3, (0.6303, 1.0574, 0.000)], + [' H ', 2, (0.4920,-0.8821, 0.0000)], + [' HA ', 0, (-0.3364, -0.5015, 0.9078)], + [' HB ', 4, ( 0.3638, -0.5006, 0.8971)], + ['1HG2', 4, ( 1.6231, -0.7142, -1.2097)], + ['2HG2', 4, ( 0.1792, -1.7546, -1.2237)], + ['3HG2', 4, ( 0.1808, -0.2222, -2.1269)], + [' CB ', 8, (-0.5172,-0.7952,-1.2130)], + [' CG2', 4, (0.5334,-0.7239,-1.2267)], + [' OG1', 4, (0.4804,1.3506,0.000)], + [' HG1', 5, (0.3194, 0.9056, 0.000)], + ], + [ # 17 trp + [' N ', 0, (-0.5272, 1.3593, 0.000)], + [' CA ', 0, (0.000, 0.000, 0.000)], + [' C ', 0, (1.5233, 0.000, 0.000)], + [' O ', 3, (0.6303, 1.0574, 0.000)], + [' H ', 2, (0.4920,-0.8821, 0.0000)], + [' HA ', 0, (-0.3436, -0.5042, 0.9031)], + ['1HB ', 4, ( 0.3639, -0.5323, 0.8790)], + ['2HB ', 4, ( 0.3638, -0.5322, -0.8787)], + [' CB ', 8, (-0.5136,-0.7712,-1.2173)], + [' CG ', 4, (0.5984,1.3741, 0.000)], + [' CD1', 5, (0.8151,1.0921, 0.000)], + [' CD2', 5, (0.8753,-1.1538, 0.000)], + [' CE2', 5, (2.1865,-0.6707, 0.000)], + [' CE3', 5, (0.6541,-2.5366, 0.000)], + [' NE1', 5, (2.1309,0.7003, 0.000)], + [' CH2', 5, (3.0315,-2.8930, 0.000)], + [' CZ2', 5, (3.2813,-1.5205, 0.000)], + [' CZ3', 5, (1.7521,-3.3888, 0.000)], + [' HD1', 5, (0.4722, 2.1252, 0.000)], + [' HE1', 5, ( 2.9291, 1.3191, 0.000)], + [' HE3', 5, (-0.3597, -2.9356, 0.000)], + [' HZ2', 5, (4.3053, -1.1462, 0.000)], + [' HZ3', 5, ( 1.5712, -4.4640, 0.000)], + [' HH2', 5, ( 3.8700, -3.5898, 0.000)], + ], + [ # 18 tyr + [' N ', 0, (-0.5272, 1.3593, 0.000)], + [' CA ', 0, (0.000, 0.000, 0.000)], + [' C ', 0, (1.5233, 0.000, 0.000)], + [' O ', 3, (0.6303, 1.0574, 0.000)], + [' H ', 2, (0.4920,-0.8821, 0.0000)], + [' HA ', 0, (-0.3305, -0.4992, 0.9112)], + ['1HB ', 4, ( 0.3642, -0.5327, 0.8797)], + ['2HB ', 4, ( 0.3637, -0.5321, -0.8785)], + [' CB ', 8, (-0.5305,-0.7799,-1.2051)], + [' CG ', 4, (0.6104,1.3840, 0.000)], + [' CD1', 5, (0.6936,1.2013, 0.000)], + [' CD2', 5, (0.6934,-1.2011, 0.000)], + [' CE1', 5, (2.0751,1.2013, 0.000)], + [' CE2', 5, (2.0748,-1.2011, 0.000)], + [' OH ', 5, (4.1408, 0.000, 0.000)], + [' CZ ', 5, (2.7648, 0.000, 0.000)], + [' HD1', 5, (0.1485, 2.1455, 0.000)], + [' HD2', 5, (0.1484, -2.1451, 0.000)], + [' HE1', 5, (2.6200, 2.1450, 0.000)], + [' HE2', 5, (2.6199, -2.1453, 0.000)], + [' HH ', 6, (0.3190, 0.9057, 0.000)], + ], + [ # 19 val + [' N ', 0, (-0.5272, 1.3593, 0.000)], + [' CA ', 0, (0.000, 0.000, 0.000)], + [' C ', 0, (1.5233, 0.000, 0.000)], + [' O ', 3, (0.6303, 1.0574, 0.000)], + [' H ', 2, (0.4920,-0.8821, 0.0000)], + [' HA ', 0, (-0.3497, -0.5068, 0.9002)], + [' CB ', 8, (-0.5105,-0.7712,-1.2317)], + [' CG1', 4, (0.5326,1.4252, 0.000)], + [' CG2', 4, (0.5177,-0.7693,1.2057)], + [' HB ', 4, (0.3541, -0.4754, -0.9148)], + ['1HG1', 4, (1.6228, 1.4063, 0.000)], + ['2HG1', 4, (0.1790, 1.9457, -0.8898)], + ['3HG1', 4, (0.1798, 1.9453, 0.8903)], + ['1HG2', 4, (1.6073, -0.7659, 1.1989)], + ['2HG2', 4, (0.1586, -0.2971, 2.1203)], + ['3HG2', 4, (0.1582, -1.7976, 1.1631)], + ], + [ # 20 unk + [' N ', 0, (-0.5272, 1.3593, 0.000)], + [' CA ', 0, (0.000, 0.000, 0.000)], + [' C ', 0, (1.5233, 0.000, 0.000)], + [' O ', 3, (0.6303, 1.0574, 0.000)], + [' H ', 2, (0.4920,-0.8821, 0.0000)], + [' HA ', 0, (-0.3341, -0.4928, 0.9132)], + [' CB ', 8, (-0.5289,-0.7734,-1.1991)], + ['1HB ', 8, (-0.1265, -1.7863, -1.1851)], + ['2HB ', 8, (-1.6173, -0.8147, -1.1541)], + ['3HB ', 8, (-0.2229, -0.2744, -2.1172)], + ], + [ # 21 mask + [' N ', 0, (-0.5272, 1.3593, 0.000)], + [' CA ', 0, (0.000, 0.000, 0.000)], + [' C ', 0, (1.5233, 0.000, 0.000)], + [' O ', 3, (0.6303, 1.0574, 0.000)], + [' H ', 2, (0.4920,-0.8821, 0.0000)], + [' HA ', 0, (-0.3341, -0.4928, 0.9132)], + [' CB ', 8, (-0.5289,-0.7734,-1.1991)], + ['1HB ', 8, (-0.1265, -1.7863, -1.1851)], + ['2HB ', 8, (-1.6173, -0.8147, -1.1541)], + ['3HB ', 8, (-0.2229, -0.2744, -2.1172)], + ], +] diff --git a/rfdiffusion/contigs.py b/rfdiffusion/contigs.py new file mode 100644 index 0000000000000000000000000000000000000000..8a63e95e2dd062529cfa31398999be876a5154a8 --- /dev/null +++ b/rfdiffusion/contigs.py @@ -0,0 +1,396 @@ +import sys +import numpy as np +import random + + +class ContigMap: + """ + Class for doing mapping. + Inherited from Inpainting. To update at some point. + Supports multichain or multiple crops from a single receptor chain. + Also supports indexing jump (+200) or not, based on contig input. + Default chain outputs are inpainted chains as A (and B, C etc if multiple chains), and all fragments of receptor chain on the next one (generally B) + Output chains can be specified. Sequence must be the same number of elements as in contig string + """ + + def __init__( + self, + parsed_pdb, + contigs=None, + inpaint_seq=None, + inpaint_str=None, + length=None, + ref_idx=None, + hal_idx=None, + idx_rf=None, + inpaint_seq_tensor=None, + inpaint_str_tensor=None, + topo=False, + provide_seq=None, + ): + # sanity checks + if contigs is None and ref_idx is None: + sys.exit("Must either specify a contig string or precise mapping") + if idx_rf is not None or hal_idx is not None or ref_idx is not None: + if idx_rf is None or hal_idx is None or ref_idx is None: + sys.exit( + "If you're specifying specific contig mappings, the reference and output positions must be specified, AND the indexing for RoseTTAFold (idx_rf)" + ) + + self.chain_order = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + if length is not None: + if "-" not in length: + self.length = [int(length), int(length) + 1] + else: + self.length = [int(length.split("-")[0]), int(length.split("-")[1]) + 1] + else: + self.length = None + self.ref_idx = ref_idx + self.hal_idx = hal_idx + self.idx_rf = idx_rf + self.inpaint_seq = ( + "/".join(inpaint_seq).split("/") if inpaint_seq is not None else None + ) + self.inpaint_str = ( + "/".join(inpaint_str).split("/") if inpaint_str is not None else None + ) + self.inpaint_seq_tensor = inpaint_seq_tensor + self.inpaint_str_tensor = inpaint_str_tensor + self.parsed_pdb = parsed_pdb + self.topo = topo + if ref_idx is None: + # using default contig generation, which outputs in rosetta-like format + self.contigs = contigs + ( + self.sampled_mask, + self.contig_length, + self.n_inpaint_chains, + ) = self.get_sampled_mask() + self.receptor_chain = self.chain_order[self.n_inpaint_chains] + ( + self.receptor, + self.receptor_hal, + self.receptor_rf, + self.inpaint, + self.inpaint_hal, + self.inpaint_rf, + ) = self.expand_sampled_mask() + self.ref = self.inpaint + self.receptor + self.hal = self.inpaint_hal + self.receptor_hal + self.rf = self.inpaint_rf + self.receptor_rf + else: + # specifying precise mappings + self.ref = ref_idx + self.hal = hal_idx + self.rf = idx_rf + self.mask_1d = [False if i == ("_", "_") else True for i in self.ref] + # take care of sequence and structure masking + if self.inpaint_seq_tensor is None: + if self.inpaint_seq is not None: + self.inpaint_seq = self.get_inpaint_seq_str(self.inpaint_seq) + else: + self.inpaint_seq = np.array( + [True if i != ("_", "_") else False for i in self.ref] + ) + else: + self.inpaint_seq = self.inpaint_seq_tensor + + if self.inpaint_str_tensor is None: + if self.inpaint_str is not None: + self.inpaint_str = self.get_inpaint_seq_str(self.inpaint_str) + else: + self.inpaint_str = np.array( + [True if i != ("_", "_") else False for i in self.ref] + ) + else: + self.inpaint_str = self.inpaint_str_tensor + # get 0-indexed input/output (for trb file) + ( + self.ref_idx0, + self.hal_idx0, + self.ref_idx0_inpaint, + self.hal_idx0_inpaint, + self.ref_idx0_receptor, + self.hal_idx0_receptor, + ) = self.get_idx0() + self.con_ref_pdb_idx = [i for i in self.ref if i != ("_", "_")] + + # Handle provide seq. This is zero-indexed, and used only for partial diffusion + if provide_seq is not None: + for i in provide_seq: + if "-" in i: + self.inpaint_seq[ + int(i.split("-")[0]) : int(i.split("-")[1]) + 1 + ] = True + else: + self.inpaint_seq[int(i)] = True + + def get_sampled_mask(self): + """ + Function to get a sampled mask from a contig. + """ + length_compatible = False + count = 0 + while length_compatible is False: + inpaint_chains = 0 + contig_list = self.contigs[0].strip().split() + sampled_mask = [] + sampled_mask_length = 0 + # allow receptor chain to be last in contig string + if all([i[0].isalpha() for i in contig_list[-1].split("/")]): + contig_list[-1] = f"{contig_list[-1]}/0" + for con in contig_list: + if ( + all([i[0].isalpha() for i in con.split("/")[:-1]]) + and con.split("/")[-1] == "0" + ) or self.topo is True: + # receptor chain + sampled_mask.append(con) + else: + inpaint_chains += 1 + # chain to be inpainted. These are the only chains that count towards the length of the contig + subcons = con.split("/") + subcon_out = [] + for subcon in subcons: + if subcon[0].isalpha(): + subcon_out.append(subcon) + if "-" in subcon: + sampled_mask_length += ( + int(subcon.split("-")[1]) + - int(subcon.split("-")[0][1:]) + + 1 + ) + else: + sampled_mask_length += 1 + + else: + if "-" in subcon: + length_inpaint = random.randint( + int(subcon.split("-")[0]), int(subcon.split("-")[1]) + ) + subcon_out.append(f"{length_inpaint}-{length_inpaint}") + sampled_mask_length += length_inpaint + elif subcon == "0": + subcon_out.append("0") + else: + length_inpaint = int(subcon) + subcon_out.append(f"{length_inpaint}-{length_inpaint}") + sampled_mask_length += int(subcon) + sampled_mask.append("/".join(subcon_out)) + # check length is compatible + if self.length is not None: + if ( + sampled_mask_length >= self.length[0] + and sampled_mask_length < self.length[1] + ): + length_compatible = True + else: + length_compatible = True + count += 1 + if count == 100000: # contig string incompatible with this length + sys.exit("Contig string incompatible with --length range") + return sampled_mask, sampled_mask_length, inpaint_chains + + def expand_sampled_mask(self): + chain_order = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + receptor = [] + inpaint = [] + receptor_hal = [] + inpaint_hal = [] + receptor_idx = 1 + inpaint_idx = 1 + inpaint_chain_idx = -1 + receptor_chain_break = [] + inpaint_chain_break = [] + for con in self.sampled_mask: + if ( + all([i[0].isalpha() for i in con.split("/")[:-1]]) + and con.split("/")[-1] == "0" + ) or self.topo is True: + # receptor chain + subcons = con.split("/")[:-1] + assert all( + [i[0] == subcons[0][0] for i in subcons] + ), "If specifying fragmented receptor in a single block of the contig string, they MUST derive from the same chain" + assert all( + int(subcons[i].split("-")[0][1:]) + < int(subcons[i + 1].split("-")[0][1:]) + for i in range(len(subcons) - 1) + ), "If specifying multiple fragments from the same chain, pdb indices must be in ascending order!" + for idx, subcon in enumerate(subcons): + ref_to_add = [ + (subcon[0], i) + for i in np.arange( + int(subcon.split("-")[0][1:]), int(subcon.split("-")[1]) + 1 + ) + ] + receptor.extend(ref_to_add) + receptor_hal.extend( + [ + (self.receptor_chain, i) + for i in np.arange( + receptor_idx, receptor_idx + len(ref_to_add) + ) + ] + ) + receptor_idx += len(ref_to_add) + if idx != len(subcons) - 1: + idx_jump = ( + int(subcons[idx + 1].split("-")[0][1:]) + - int(subcon.split("-")[1]) + - 1 + ) + receptor_chain_break.append( + (receptor_idx - 1, idx_jump) + ) # actual chain break in pdb chain + else: + receptor_chain_break.append( + (receptor_idx - 1, 200) + ) # 200 aa chain break + else: + inpaint_chain_idx += 1 + for subcon in con.split("/"): + if subcon[0].isalpha(): + ref_to_add = [ + (subcon[0], i) + for i in np.arange( + int(subcon.split("-")[0][1:]), + int(subcon.split("-")[1]) + 1, + ) + ] + inpaint.extend(ref_to_add) + inpaint_hal.extend( + [ + (chain_order[inpaint_chain_idx], i) + for i in np.arange( + inpaint_idx, inpaint_idx + len(ref_to_add) + ) + ] + ) + inpaint_idx += len(ref_to_add) + + else: + inpaint.extend([("_", "_")] * int(subcon.split("-")[0])) + inpaint_hal.extend( + [ + (chain_order[inpaint_chain_idx], i) + for i in np.arange( + inpaint_idx, inpaint_idx + int(subcon.split("-")[0]) + ) + ] + ) + inpaint_idx += int(subcon.split("-")[0]) + inpaint_chain_break.append((inpaint_idx - 1, 200)) + + if self.topo is True or inpaint_hal == []: + receptor_hal = [(i[0], i[1]) for i in receptor_hal] + else: + receptor_hal = [ + (i[0], i[1] + inpaint_hal[-1][1]) for i in receptor_hal + ] # rosetta-like numbering + # get rf indexes, with chain breaks + inpaint_rf = np.arange(0, len(inpaint)) + receptor_rf = np.arange(len(inpaint) + 200, len(inpaint) + len(receptor) + 200) + for ch_break in inpaint_chain_break[:-1]: + receptor_rf[:] += 200 + inpaint_rf[ch_break[0] :] += ch_break[1] + for ch_break in receptor_chain_break[:-1]: + receptor_rf[ch_break[0] :] += ch_break[1] + + return ( + receptor, + receptor_hal, + receptor_rf.tolist(), + inpaint, + inpaint_hal, + inpaint_rf.tolist(), + ) + + def get_inpaint_seq_str(self, inpaint_s): + """ + function to generate inpaint_str or inpaint_seq masks specific to this contig + """ + s_mask = np.copy(self.mask_1d) + inpaint_s_list = [] + for i in inpaint_s: + if "-" in i: + inpaint_s_list.extend( + [ + (i[0], p) + for p in range( + int(i.split("-")[0][1:]), int(i.split("-")[1]) + 1 + ) + ] + ) + else: + inpaint_s_list.append((i[0], int(i[1:]))) + for res in inpaint_s_list: + if res in self.ref: + s_mask[self.ref.index(res)] = False # mask this residue + + return np.array(s_mask) + + def get_idx0(self): + ref_idx0 = [] + hal_idx0 = [] + ref_idx0_inpaint = [] + hal_idx0_inpaint = [] + ref_idx0_receptor = [] + hal_idx0_receptor = [] + for idx, val in enumerate(self.ref): + if val != ("_", "_"): + assert val in self.parsed_pdb["pdb_idx"], f"{val} is not in pdb file!" + hal_idx0.append(idx) + ref_idx0.append(self.parsed_pdb["pdb_idx"].index(val)) + for idx, val in enumerate(self.inpaint): + if val != ("_", "_"): + hal_idx0_inpaint.append(idx) + ref_idx0_inpaint.append(self.parsed_pdb["pdb_idx"].index(val)) + for idx, val in enumerate(self.receptor): + if val != ("_", "_"): + hal_idx0_receptor.append(idx) + ref_idx0_receptor.append(self.parsed_pdb["pdb_idx"].index(val)) + + return ( + ref_idx0, + hal_idx0, + ref_idx0_inpaint, + hal_idx0_inpaint, + ref_idx0_receptor, + hal_idx0_receptor, + ) + + def get_mappings(self): + mappings = {} + mappings["con_ref_pdb_idx"] = [i for i in self.inpaint if i != ("_", "_")] + mappings["con_hal_pdb_idx"] = [ + self.inpaint_hal[i] + for i in range(len(self.inpaint_hal)) + if self.inpaint[i] != ("_", "_") + ] + mappings["con_ref_idx0"] = np.array(self.ref_idx0_inpaint) + mappings["con_hal_idx0"] = np.array(self.hal_idx0_inpaint) + if self.inpaint != self.ref: + mappings["complex_con_ref_pdb_idx"] = [ + i for i in self.ref if i != ("_", "_") + ] + mappings["complex_con_hal_pdb_idx"] = [ + self.hal[i] for i in range(len(self.hal)) if self.ref[i] != ("_", "_") + ] + mappings["receptor_con_ref_pdb_idx"] = [ + i for i in self.receptor if i != ("_", "_") + ] + mappings["receptor_con_hal_pdb_idx"] = [ + self.receptor_hal[i] + for i in range(len(self.receptor_hal)) + if self.receptor[i] != ("_", "_") + ] + mappings["complex_con_ref_idx0"] = np.array(self.ref_idx0) + mappings["complex_con_hal_idx0"] = np.array(self.hal_idx0) + mappings["receptor_con_ref_idx0"] = np.array(self.ref_idx0_receptor) + mappings["receptor_con_hal_idx0"] = np.array(self.hal_idx0_receptor) + mappings["inpaint_str"] = self.inpaint_str + mappings["inpaint_seq"] = self.inpaint_seq + mappings["sampled_mask"] = self.sampled_mask + mappings["mask_1d"] = self.mask_1d + return mappings diff --git a/rfdiffusion/coords6d.py b/rfdiffusion/coords6d.py new file mode 100644 index 0000000000000000000000000000000000000000..d32245439f97a2b4f1b266f47900355b25cfbee1 --- /dev/null +++ b/rfdiffusion/coords6d.py @@ -0,0 +1,63 @@ +import numpy as np +import scipy +import scipy.spatial +from rfdiffusion.kinematics import get_dih + +# calculate planar angles defined by 3 sets of points +def get_angles(a, b, c): + + v = a - b + v /= np.linalg.norm(v, axis=-1)[:,None] + + w = c - b + w /= np.linalg.norm(w, axis=-1)[:,None] + + x = np.sum(v*w, axis=1) + + #return np.arccos(x) + return np.arccos(np.clip(x, -1.0, 1.0)) + +# get 6d coordinates from x,y,z coords of N,Ca,C atoms +def get_coords6d(xyz, dmax): + + nres = xyz.shape[1] + + # three anchor atoms + N = xyz[0] + Ca = xyz[1] + C = xyz[2] + + # recreate Cb given N,Ca,C + b = Ca - N + c = C - Ca + a = np.cross(b, c) + Cb = -0.58273431*a + 0.56802827*b - 0.54067466*c + Ca + + # fast neighbors search to collect all + # Cb-Cb pairs within dmax + kdCb = scipy.spatial.cKDTree(Cb) + indices = kdCb.query_ball_tree(kdCb, dmax) + + # indices of contacting residues + idx = np.array([[i,j] for i in range(len(indices)) for j in indices[i] if i != j]).T + idx0 = idx[0] + idx1 = idx[1] + + # Cb-Cb distance matrix + dist6d = np.full((nres, nres),999.9, dtype=np.float32) + dist6d[idx0,idx1] = np.linalg.norm(Cb[idx1]-Cb[idx0], axis=-1) + + # matrix of Ca-Cb-Cb-Ca dihedrals + omega6d = np.zeros((nres, nres), dtype=np.float32) + omega6d[idx0,idx1] = get_dih(Ca[idx0], Cb[idx0], Cb[idx1], Ca[idx1]) + # matrix of polar coord theta + theta6d = np.zeros((nres, nres), dtype=np.float32) + theta6d[idx0,idx1] = get_dih(N[idx0], Ca[idx0], Cb[idx0], Cb[idx1]) + + # matrix of polar coord phi + phi6d = np.zeros((nres, nres), dtype=np.float32) + phi6d[idx0,idx1] = get_angles(Ca[idx0], Cb[idx0], Cb[idx1]) + + mask = np.zeros((nres, nres), dtype=np.float32) + mask[idx0, idx1] = 1.0 + return dist6d, omega6d, theta6d, phi6d, mask diff --git a/rfdiffusion/diffusion.py b/rfdiffusion/diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..a67e5794ff6e32c594d0b0abfa789c27bc907193 --- /dev/null +++ b/rfdiffusion/diffusion.py @@ -0,0 +1,695 @@ +# script for diffusion protocols +import torch +import pickle +import numpy as np +import os +import logging + +from scipy.spatial.transform import Rotation as scipy_R + +from rfdiffusion.util import rigid_from_3_points + +from rfdiffusion.util_module import ComputeAllAtomCoords + +from rfdiffusion import igso3 +import time + +torch.set_printoptions(sci_mode=False) + + +def get_beta_schedule(T, b0, bT, schedule_type, schedule_params={}, inference=False): + """ + Given a noise schedule type, create the beta schedule + """ + assert schedule_type in ["linear"] + + # Adjust b0 and bT if T is not 200 + # This is a good approximation, with the beta correction below, unless T is very small + assert T >= 15, "With discrete time and T < 15, the schedule is badly approximated" + b0 *= 200 / T + bT *= 200 / T + + # linear noise schedule + if schedule_type == "linear": + schedule = torch.linspace(b0, bT, T) + + else: + raise NotImplementedError(f"Schedule of type {schedule_type} not implemented.") + + # get alphabar_t for convenience + alpha_schedule = 1 - schedule + alphabar_t_schedule = torch.cumprod(alpha_schedule, dim=0) + + if inference: + print( + f"With this beta schedule ({schedule_type} schedule, beta_0 = {round(b0, 3)}, beta_T = {round(bT,3)}), alpha_bar_T = {alphabar_t_schedule[-1]}" + ) + + return schedule, alpha_schedule, alphabar_t_schedule + + +class EuclideanDiffuser: + # class for diffusing points in 3D + + def __init__( + self, + T, + b_0, + b_T, + schedule_type="linear", + schedule_kwargs={}, + ): + self.T = T + + # make noise/beta schedule + ( + self.beta_schedule, + self.alpha_schedule, + self.alphabar_schedule, + ) = get_beta_schedule(T, b_0, b_T, schedule_type, **schedule_kwargs) + + def diffuse_translations(self, xyz, diffusion_mask=None, var_scale=1): + return self.apply_kernel_recursive(xyz, diffusion_mask, var_scale) + + def apply_kernel(self, x, t, diffusion_mask=None, var_scale=1): + """ + Applies a noising kernel to the points in x + + Parameters: + x (torch.tensor, required): (N,3,3) set of backbone coordinates + + t (int, required): Which timestep + + noise_scale (float, required): scale for noise + """ + t_idx = t - 1 # bring from 1-indexed to 0-indexed + + assert len(x.shape) == 3 + L, _, _ = x.shape + + # c-alpha crds + ca_xyz = x[:, 1, :] + + b_t = self.beta_schedule[t_idx] + + # get the noise at timestep t + mean = torch.sqrt(1 - b_t) * ca_xyz + var = torch.ones(L, 3) * (b_t) * var_scale + + sampled_crds = torch.normal(mean, torch.sqrt(var)) + delta = sampled_crds - ca_xyz + + if not diffusion_mask is None: + delta[diffusion_mask, ...] = 0 + + out_crds = x + delta[:, None, :] + + return out_crds, delta + + def apply_kernel_recursive(self, xyz, diffusion_mask=None, var_scale=1): + """ + Repeatedly apply self.apply_kernel T times and return all crds + """ + bb_stack = [] + T_stack = [] + + cur_xyz = torch.clone(xyz) + + for t in range(1, self.T + 1): + cur_xyz, cur_T = self.apply_kernel( + cur_xyz, t, var_scale=var_scale, diffusion_mask=diffusion_mask + ) + bb_stack.append(cur_xyz) + T_stack.append(cur_T) + + return torch.stack(bb_stack).transpose(0, 1), torch.stack(T_stack).transpose( + 0, 1 + ) + + +def write_pkl(save_path: str, pkl_data): + """Serialize data into a pickle file.""" + with open(save_path, "wb") as handle: + pickle.dump(pkl_data, handle, protocol=pickle.HIGHEST_PROTOCOL) + + +def read_pkl(read_path: str, verbose=False): + """Read data from a pickle file.""" + with open(read_path, "rb") as handle: + try: + return pickle.load(handle) + except Exception as e: + if verbose: + print(f"Failed to read {read_path}") + raise (e) + + +class IGSO3: + """ + Class for taking in a set of backbone crds and performing IGSO3 diffusion + on all of them. + + Unlike the diffusion on translations, much of this class is written for a + scaling between an initial time t=0 and final time t=1. + """ + + def __init__( + self, + *, + T, + min_sigma, + max_sigma, + min_b, + max_b, + cache_dir, + num_omega=1000, + schedule="linear", + L=2000, + ): + """ + + Args: + T: total number of time steps + min_sigma: smallest allowed scale parameter, should be at least 0.01 to maintain numerical stability. Recommended value is 0.05. + max_sigma: for exponential schedule, the largest scale parameter. Ignored for recommeded linear schedule + min_b: lower value of beta in Ho schedule analogue + max_b: upper value of beta in Ho schedule analogue + num_omega: discretization level in the angles across [0, pi] + schedule: currently only linear and exponential are supported. The exponential schedule may be noising too slowly. + L: truncation level + """ + self._log = logging.getLogger(__name__) + + self.T = T + + self.schedule = schedule + self.cache_dir = cache_dir + self.min_sigma = min_sigma + self.max_sigma = max_sigma + + if self.schedule == "linear": + self.min_b = min_b + self.max_b = max_b + self.max_sigma = self.sigma(1.0) + self.num_omega = num_omega + self.num_sigma = 500 + # Calculate igso3 values. + self.L = L # truncation level + self.igso3_vals = self._calc_igso3_vals(L=L) + self.step_size = 1 / self.T + + def _calc_igso3_vals(self, L=2000): + """_calc_igso3_vals computes numerical approximations to the + relevant analytically intractable functionals of the igso3 + distribution. + + The calculated values are cached, or loaded from cache if they already + exist. + + Args: + L: truncation level for power series expansion of the pdf. + """ + replace_period = lambda x: str(x).replace(".", "_") + if self.schedule == "linear": + cache_fname = os.path.join( + self.cache_dir, + f"T_{self.T}_omega_{self.num_omega}_min_sigma_{replace_period(self.min_sigma)}" + + f"_min_b_{replace_period(self.min_b)}_max_b_{replace_period(self.max_b)}_schedule_{self.schedule}.pkl", + ) + elif self.schedule == "exponential": + cache_fname = os.path.join( + self.cache_dir, + f"T_{self.T}_omega_{self.num_omega}_min_sigma_{replace_period(self.min_sigma)}" + f"_max_sigma_{replace_period(self.max_sigma)}_schedule_{self.schedule}", + ) + else: + raise ValueError(f"Unrecognize schedule {self.schedule}") + + if not os.path.isdir(self.cache_dir): + os.makedirs(self.cache_dir) + + if os.path.exists(cache_fname): + self._log.info("Using cached IGSO3.") + igso3_vals = read_pkl(cache_fname) + else: + self._log.info("Calculating IGSO3.") + igso3_vals = igso3.calculate_igso3( + num_sigma=self.num_sigma, + min_sigma=self.min_sigma, + max_sigma=self.max_sigma, + num_omega=self.num_omega + ) + write_pkl(cache_fname, igso3_vals) + + return igso3_vals + + @property + def discrete_sigma(self): + return self.igso3_vals["discrete_sigma"] + + def sigma_idx(self, sigma: np.ndarray): + """ + Calculates the index for discretized sigma during IGSO(3) initialization.""" + return np.digitize(sigma, self.discrete_sigma) - 1 + + def t_to_idx(self, t: np.ndarray): + """ + Helper function to go from discrete time index t to corresponding sigma_idx. + + Args: + t: time index (integer between 1 and 200) + """ + continuous_t = t / self.T + return self.sigma_idx(self.sigma(continuous_t)) + + def sigma(self, t: torch.tensor): + """ + Extract \sigma(t) corresponding to chosen sigma schedule. + + Args: + t: torch tensor with time between 0 and 1 + """ + if not type(t) == torch.Tensor: + t = torch.tensor(t) + if torch.any(t < 0) or torch.any(t > 1): + raise ValueError(f"Invalid t={t}") + if self.schedule == "exponential": + sigma = t * np.log10(self.max_sigma) + (1 - t) * np.log10(self.min_sigma) + return 10**sigma + elif self.schedule == "linear": # Variance exploding analogue of Ho schedule + # add self.min_sigma for stability + return ( + self.min_sigma + + t * self.min_b + + (1 / 2) * (t**2) * (self.max_b - self.min_b) + ) + else: + raise ValueError(f"Unrecognize schedule {self.schedule}") + + def g(self, t): + """ + g returns the drift coefficient at time t + + since + sigma(t)^2 := \int_0^t g(s)^2 ds, + for arbitrary sigma(t) we invert this relationship to compute + g(t) = sqrt(d/dt sigma(t)^2). + + Args: + t: scalar time between 0 and 1 + + Returns: + drift cooeficient as a scalar. + """ + t = torch.tensor(t, requires_grad=True) + sigma_sqr = self.sigma(t) ** 2 + grads = torch.autograd.grad(sigma_sqr.sum(), t)[0] + return torch.sqrt(grads) + + def sample(self, ts, n_samples=1): + """ + sample uses the inverse cdf to sample an angle of rotation from + IGSO(3) + + Args: + ts: array of integer time steps to sample from. + n_samples: number of samples to draw. + Returns: + sampled angles of rotation. [len(ts), N] + """ + assert sum(ts == 0) == 0, "assumes one-indexed, not zero indexed" + all_samples = [] + for t in ts: + sigma_idx = self.t_to_idx(t) + sample_i = np.interp( + np.random.rand(n_samples), + self.igso3_vals["cdf"][sigma_idx], + self.igso3_vals["discrete_omega"], + ) # [N, 1] + all_samples.append(sample_i) + return np.stack(all_samples, axis=0) + + def sample_vec(self, ts, n_samples=1): + """sample_vec generates a rotation vector(s) from IGSO(3) at time steps + ts. + + Return: + Sampled vector of shape [len(ts), N, 3] + """ + x = np.random.randn(len(ts), n_samples, 3) + x /= np.linalg.norm(x, axis=-1, keepdims=True) + return x * self.sample(ts, n_samples=n_samples)[..., None] + + def score_norm(self, t, omega): + """ + score_norm computes the score norm based on the time step and angle + Args: + t: integer time step + omega: angles (scalar or shape [N]) + Return: + score_norm with same shape as omega + """ + sigma_idx = self.t_to_idx(t) + score_norm_t = np.interp( + omega, + self.igso3_vals["discrete_omega"], + self.igso3_vals["score_norm"][sigma_idx], + ) + return score_norm_t + + def score_vec(self, ts, vec): + """score_vec computes the score of the IGSO(3) density as a rotation + vector. This score vector is in the direction of the sampled vector, + and has magnitude given by score_norms. + + In particular, Rt @ hat(score_vec(ts, vec)) is what is referred to as + the score approximation in Algorithm 1 + + + Args: + ts: times of shape [T] + vec: where to compute the score of shape [T, N, 3] + Returns: + score vectors of shape [T, N, 3] + """ + omega = np.linalg.norm(vec, axis=-1) + all_score_norm = [] + for i, t in enumerate(ts): + omega_t = omega[i] + t_idx = t - 1 + sigma_idx = self.t_to_idx(t) + score_norm_t = np.interp( + omega_t, + self.igso3_vals["discrete_omega"], + self.igso3_vals["score_norm"][sigma_idx], + )[:, None] + all_score_norm.append(score_norm_t) + score_norm = np.stack(all_score_norm, axis=0) + return score_norm * vec / omega[..., None] + + def exp_score_norm(self, ts): + """exp_score_norm returns the expected value of norm of the score for + IGSO(3) with time parameter ts of shape [T]. + """ + sigma_idcs = [self.t_to_idx(t) for t in ts] + return self.igso3_vals["exp_score_norms"][sigma_idcs] + + def diffuse_frames(self, xyz, t_list, diffusion_mask=None): + """diffuse_frames samples from the IGSO(3) distribution to noise frames + + Parameters: + xyz (np.array or torch.tensor, required): (L,3,3) set of backbone coordinates + mask (np.array or torch.tensor, required): (L,) set of bools. True/1 is NOT diffused, False/0 IS diffused + Returns: + np.array : N/CA/C coordinates for each residue + (T,L,3,3), where T is num timesteps + """ + + if torch.is_tensor(xyz): + xyz = xyz.numpy() + + t = np.arange(self.T) + 1 # 1-indexed!! + num_res = len(xyz) + + N = torch.from_numpy(xyz[None, :, 0, :]) + Ca = torch.from_numpy(xyz[None, :, 1, :]) # [1, num_res, 3, 3] + C = torch.from_numpy(xyz[None, :, 2, :]) + + # scipy rotation object for true coordinates + R_true, Ca = rigid_from_3_points(N, Ca, C) + R_true = R_true[0] + Ca = Ca[0] + + # Sample rotations and scores from IGSO3 + sampled_rots = self.sample_vec(t, n_samples=num_res) # [T, N, 3] + + if diffusion_mask is not None: + non_diffusion_mask = 1 - diffusion_mask[None, :, None] + sampled_rots = sampled_rots * non_diffusion_mask + + # Apply sampled rot. + R_sampled = ( + scipy_R.from_rotvec(sampled_rots.reshape(-1, 3)) + .as_matrix() + .reshape(self.T, num_res, 3, 3) + ) + R_perturbed = np.einsum("tnij,njk->tnik", R_sampled, R_true) + perturbed_crds = ( + np.einsum( + "tnij,naj->tnai", R_sampled, xyz[:, :3, :] - Ca[:, None, ...].numpy() + ) + + Ca[None, :, None].numpy() + ) + + if t_list != None: + idx = [i - 1 for i in t_list] + perturbed_crds = perturbed_crds[idx] + R_perturbed = R_perturbed[idx] + + return ( + perturbed_crds.transpose(1, 0, 2, 3), # [L, T, 3, 3] + R_perturbed.transpose(1, 0, 2, 3), + ) + + def reverse_sample_vectorized( + self, R_t, R_0, t, noise_level, mask=None, return_perturb=False + ): + """reverse_sample uses an approximation to the IGSO3 score to sample + a rotation at the previous time step. + + Roughly - this update follows the reverse time SDE for Reimannian + manifolds proposed by de Bortoli et al. Theorem 1 [1]. But with an + approximation to the score based on the prediction of R0. + Unlike in reference [1], this diffusion on SO(3) relies on geometric + variance schedule. Specifically we follow [2] (appendix C) and assume + sigma_t = sigma_min * (sigma_max / sigma_min)^{t/T}, + for time step t. When we view this as a discretization of the SDE + from time 0 to 1 with step size (1/T). Following Eq. 5 and Eq. 6, + this maps on to the forward time SDEs + dx = g(t) dBt [FORWARD] + and + dx = g(t)^2 score(xt, t)dt + g(t) B't, [REVERSE] + where g(t) = sigma_t * sqrt(2 * log(sigma_max/ sigma_min)), and Bt and + B't are Brownian motions. The formula for g(t) obtains from equation 9 + of [2], from which this sampling function may be generalized to + alternative noising schedules. + Args: + R_t: noisy rotation of shape [N, 3, 3] + R_0: prediction of un-noised rotation + t: integer time step + noise_level: scaling on the noise added when obtaining sample + (preliminary performance seems empirically better with noise + level=0.5) + mask: whether the residue is to be updated. A value of 1 means the + rotation is not updated from r_t. A value of 0 means the + rotation is updated. + Return: + sampled rotation matrix for time t-1 of shape [3, 3] + Reference: + [1] De Bortoli, V., Mathieu, E., Hutchinson, M., Thornton, J., Teh, Y. + W., & Doucet, A. (2022). Riemannian score-based generative modeling. + arXiv preprint arXiv:2202.02763. + [2] Song, Y., Sohl-Dickstein, J., Kingma, D. P., Kumar, A., Ermon, S., + & Poole, B. (2020). Score-based generative modeling through stochastic + differential equations. arXiv preprint arXiv:2011.13456. + """ + # compute rotation vector corresponding to prediction of how r_t goes to r_0 + R_0, R_t = torch.tensor(R_0), torch.tensor(R_t) + R_0t = torch.einsum("...ij,...kj->...ik", R_t, R_0) + R_0t_rotvec = torch.tensor( + scipy_R.from_matrix(R_0t.cpu().numpy()).as_rotvec() + ).to(R_0.device) + + # Approximate the score based on the prediction of R0. + # R_t @ hat(Score_approx) is the score approximation in the Lie algebra + # SO(3) (i.e. the output of Algorithm 1) + Omega = torch.linalg.norm(R_0t_rotvec, axis=-1).numpy() + Score_approx = R_0t_rotvec * (self.score_norm(t, Omega) / Omega)[:, None] + + # Compute scaling for score and sampled noise (following Eq 6 of [2]) + continuous_t = t / self.T + rot_g = self.g(continuous_t).to(Score_approx.device) + + # Sample and scale noise to add to the rotation perturbation in the + # SO(3) tangent space. Since IG-SO(3) is the Brownian motion on SO(3) + # (up to a deceleration of time by a factor of two), for small enough + # time-steps, this is equivalent to perturbing r_t with IG-SO(3) noise. + # See e.g. Algorithm 1 of De Bortoli et al. + Z = np.random.normal(size=(R_0.shape[0], 3)) + Z = torch.from_numpy(Z).to(Score_approx.device) + Z *= noise_level + + Delta_r = (rot_g**2) * self.step_size * Score_approx + + # Sample perturbation from discretized SDE (following eq. 6 of [2]), + # This approximate sampling from IGSO3(* ; Delta_r, rot_g^2 * + # self.step_size) with tangent Gaussian. + Perturb_tangent = Delta_r + rot_g * np.sqrt(self.step_size) * Z + if mask is not None: + Perturb_tangent *= (1 - mask.long())[:, None, None] + Perturb = igso3.Exp(Perturb_tangent) + + if return_perturb: + return Perturb + + Interp_rot = torch.einsum("...ij,...jk->...ik", Perturb, R_t) + + return Interp_rot + + +class Diffuser: + # wrapper for yielding diffused coordinates + + def __init__( + self, + T, + b_0, + b_T, + min_sigma, + max_sigma, + min_b, + max_b, + schedule_type, + so3_schedule_type, + so3_type, + crd_scale, + schedule_kwargs={}, + var_scale=1.0, + cache_dir=".", + partial_T=None, + truncation_level=2000, + ): + """ + Parameters: + + T (int, required): Number of steps in the schedule + + b_0 (float, required): Starting variance for Euclidean schedule + + b_T (float, required): Ending variance for Euclidean schedule + + """ + self.T = T + self.b_0 = b_0 + self.b_T = b_T + self.min_sigma = min_sigma + self.max_sigma = max_sigma + self.crd_scale = crd_scale + self.var_scale = var_scale + self.cache_dir = cache_dir + + # get backbone frame diffuser + self.so3_diffuser = IGSO3( + T=self.T, + min_sigma=self.min_sigma, + max_sigma=self.max_sigma, + schedule=so3_schedule_type, + min_b=min_b, + max_b=max_b, + cache_dir=self.cache_dir, + L=truncation_level, + ) + + # get backbone translation diffuser + self.eucl_diffuser = EuclideanDiffuser( + self.T, b_0, b_T, schedule_type=schedule_type, **schedule_kwargs + ) + + print("Successful diffuser __init__") + + def diffuse_pose( + self, + xyz, + seq, + atom_mask, + include_motif_sidechains=True, + diffusion_mask=None, + t_list=None, + ): + """ + Given full atom xyz, sequence and atom mask, diffuse the protein frame + translations and rotations + + Parameters: + + xyz (L,14/27,3) set of coordinates + + seq (L,) integer sequence + + atom_mask: mask describing presence/absence of an atom in pdb + + diffusion_mask (torch.tensor, optional): Tensor of bools, True means NOT diffused at this residue, False means diffused + + t_list (list, optional): If present, only return the diffused coordinates at timesteps t within the list + + + """ + + if diffusion_mask is None: + diffusion_mask = torch.zeros(len(xyz.squeeze())).to(dtype=bool) + + get_allatom = ComputeAllAtomCoords().to(device=xyz.device) + L = len(xyz) + + # bring to origin and scale + # check if any BB atoms are nan before centering + nan_mask = ~torch.isnan(xyz.squeeze()[:, :3]).any(dim=-1).any(dim=-1) + assert torch.sum(~nan_mask) == 0 + + # Centre unmasked structure at origin, as in training (to prevent information leak) + if torch.sum(diffusion_mask) != 0: + self.motif_com = xyz[diffusion_mask, 1, :].mean( + dim=0 + ) # This is needed for one of the potentials + xyz = xyz - self.motif_com + elif torch.sum(diffusion_mask) == 0: + xyz = xyz - xyz[:, 1, :].mean(dim=0) + + xyz_true = torch.clone(xyz) + xyz = xyz * self.crd_scale + + # 1 get translations + tick = time.time() + diffused_T, deltas = self.eucl_diffuser.diffuse_translations( + xyz[:, :3, :].clone(), diffusion_mask=diffusion_mask + ) + # print('Time to diffuse coordinates: ',time.time()-tick) + diffused_T /= self.crd_scale + deltas /= self.crd_scale + + # 2 get frames + tick = time.time() + diffused_frame_crds, diffused_frames = self.so3_diffuser.diffuse_frames( + xyz[:, :3, :].clone(), diffusion_mask=diffusion_mask.numpy(), t_list=None + ) + diffused_frame_crds /= self.crd_scale + # print('Time to diffuse frames: ',time.time()-tick) + + ##### Now combine all the diffused quantities to make full atom diffused poses + tick = time.time() + cum_delta = deltas.cumsum(dim=1) + # The coordinates of the translated AND rotated frames + diffused_BB = ( + torch.from_numpy(diffused_frame_crds) + cum_delta[:, :, None, :] + ).transpose( + 0, 1 + ) # [n,L,3,3] + # diffused_BB = torch.from_numpy(diffused_frame_crds).transpose(0,1) + + # diffused_BB is [t_steps,L,3,3] + t_steps, L = diffused_BB.shape[:2] + + diffused_fa = torch.zeros(t_steps, L, 27, 3) + diffused_fa[:, :, :3, :] = diffused_BB + + # Add in sidechains from motif + if include_motif_sidechains: + diffused_fa[:, diffusion_mask, :14, :] = xyz_true[None, diffusion_mask, :14] + + if t_list is None: + fa_stack = diffused_fa + else: + t_idx_list = [t - 1 for t in t_list] + fa_stack = diffused_fa[t_idx_list] + + return fa_stack, xyz_true diff --git a/rfdiffusion/igso3.py b/rfdiffusion/igso3.py new file mode 100644 index 0000000000000000000000000000000000000000..6d90bdb21bbef0cfa443f2b6641031fef076cb22 --- /dev/null +++ b/rfdiffusion/igso3.py @@ -0,0 +1,118 @@ +"""SO(3) diffusion methods.""" +import numpy as np +import os +from functools import cached_property +import torch +from scipy.spatial.transform import Rotation +import scipy.linalg + + +### First define geometric operations on the SO3 manifold + +# hat map from vector space R^3 to Lie algebra so(3) +def hat(v): + hat_v = torch.zeros([v.shape[0], 3, 3]) + hat_v[:, 0, 1], hat_v[:, 0, 2], hat_v[:, 1, 2] = -v[:, 2], v[:, 1], -v[:, 0] + return hat_v + -hat_v.transpose(2, 1) + +# Logarithmic map from SO(3) to R^3 (i.e. rotation vector) +def Log(R): return torch.tensor(Rotation.from_matrix(R.numpy()).as_rotvec()) + +# logarithmic map from SO(3) to so(3), this is the matrix logarithm +def log(R): return hat(Log(R)) + +# Exponential map from vector space of so(3) to SO(3), this is the matrix +# exponential combined with the "hat" map +def Exp(A): return torch.tensor(Rotation.from_rotvec(A.numpy()).as_matrix()) + +# Angle of rotation SO(3) to R^+ +def Omega(R): return np.linalg.norm(log(R), axis=[-2, -1])/np.sqrt(2.) + +L_default = 2000 +def f_igso3(omega, t, L=L_default): + """Truncated sum of IGSO(3) distribution. + + This function approximates the power series in equation 5 of + "DENOISING DIFFUSION PROBABILISTIC MODELS ON SO(3) FOR ROTATIONAL + ALIGNMENT" + Leach et al. 2022 + + This expression diverges from the expression in Leach in that here, sigma = + sqrt(2) * eps, if eps_leach were the scale parameter of the IGSO(3). + + With this reparameterization, IGSO(3) agrees with the Brownian motion on + SO(3) with t=sigma^2 when defined for the canonical inner product on SO3, + _SO3 = Trace(u v^T)/2 + + Args: + omega: i.e. the angle of rotation associated with rotation matrix + t: variance parameter of IGSO(3), maps onto time in Brownian motion + L: Truncation level + """ + ls = torch.arange(L)[None] # of shape [1, L] + return ((2*ls + 1) * torch.exp(-ls*(ls+1)*t/2) * + torch.sin(omega[:, None]*(ls+1/2)) / torch.sin(omega[:, None]/2)).sum(dim=-1) + +def d_logf_d_omega(omega, t, L=L_default): + omega = torch.tensor(omega, requires_grad=True) + log_f = torch.log(f_igso3(omega, t, L)) + return torch.autograd.grad(log_f.sum(), omega)[0].numpy() + +# IGSO3 density with respect to the volume form on SO(3) +def igso3_density(Rt, t, L=L_default): + return f_igso3(torch.tensor(Omega(Rt)), t, L).numpy() + +def igso3_density_angle(omega, t, L=L_default): + return f_igso3(torch.tensor(omega), t, L).numpy()*(1-np.cos(omega))/np.pi + +# grad_R log IGSO3(R; I_3, t) +def igso3_score(R, t, L=L_default): + omega = Omega(R) + unit_vector = np.einsum('Nij,Njk->Nik', R, log(R))/omega[:, None, None] + return unit_vector * d_logf_d_omega(omega, t, L)[:, None, None] + +def calculate_igso3(*, num_sigma, num_omega, min_sigma, max_sigma): + """calculate_igso3 pre-computes numerical approximations to the IGSO3 cdfs + and score norms and expected squared score norms. + + Args: + num_sigma: number of different sigmas for which to compute igso3 + quantities. + num_omega: number of point in the discretization in the angle of + rotation. + min_sigma, max_sigma: the upper and lower ranges for the angle of + rotation on which to consider the IGSO3 distribution. This cannot + be too low or it will create numerical instability. + """ + # Discretize omegas for calculating CDFs. Skip omega=0. + discrete_omega = np.linspace(0, np.pi, num_omega+1)[1:] + + # Exponential noise schedule. This choice is closely tied to the + # scalings used when simulating the reverse time SDE. For each step n, + # discrete_sigma[n] = min_eps^(1-n/num_eps) * max_eps^(n/num_eps) + discrete_sigma = 10 ** np.linspace(np.log10(min_sigma), np.log10(max_sigma), num_sigma + 1)[1:] + + # Compute the pdf and cdf values for the marginal distribution of the angle + # of rotation (which is needed for sampling) + pdf_vals = np.asarray( + [igso3_density_angle(discrete_omega, sigma**2) for sigma in discrete_sigma]) + cdf_vals = np.asarray( + [pdf.cumsum() / num_omega * np.pi for pdf in pdf_vals]) + + # Compute the norms of the scores. This are used to scale the rotation axis when + # computing the score as a vector. + score_norm = np.asarray( + [d_logf_d_omega(discrete_omega, sigma**2) for sigma in discrete_sigma]) + + # Compute the standard deviation of the score norm for each sigma + exp_score_norms = np.sqrt( + np.sum( + score_norm**2 * pdf_vals, axis=1) / np.sum( + pdf_vals, axis=1)) + return { + 'cdf': cdf_vals, + 'score_norm': score_norm, + 'exp_score_norms': exp_score_norms, + 'discrete_omega': discrete_omega, + 'discrete_sigma': discrete_sigma, + } diff --git a/rfdiffusion/inference/__init__.py b/rfdiffusion/inference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/rfdiffusion/inference/__pycache__/__init__.cpython-310.pyc b/rfdiffusion/inference/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e32b8c3246fa8b7970d93f72a181d0f14d2c84e0 Binary files /dev/null and b/rfdiffusion/inference/__pycache__/__init__.cpython-310.pyc differ diff --git a/rfdiffusion/inference/__pycache__/__init__.cpython-311.pyc b/rfdiffusion/inference/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f5cf6275ad13142ce0ca329bf3d14302d3853c05 Binary files /dev/null and b/rfdiffusion/inference/__pycache__/__init__.cpython-311.pyc differ diff --git a/rfdiffusion/inference/__pycache__/__init__.cpython-39.pyc b/rfdiffusion/inference/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e8e800fe675b9b033d7c7bde8a32295c059c5474 Binary files /dev/null and b/rfdiffusion/inference/__pycache__/__init__.cpython-39.pyc differ diff --git a/rfdiffusion/inference/__pycache__/model_runners.cpython-310.pyc b/rfdiffusion/inference/__pycache__/model_runners.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b5dccad8542f0587cdd286cb738845dffe252ba Binary files /dev/null and b/rfdiffusion/inference/__pycache__/model_runners.cpython-310.pyc differ diff --git a/rfdiffusion/inference/__pycache__/model_runners.cpython-311.pyc b/rfdiffusion/inference/__pycache__/model_runners.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08d1394ae3ef4aab698c264c3c46bb59fb66d366 Binary files /dev/null and b/rfdiffusion/inference/__pycache__/model_runners.cpython-311.pyc differ diff --git a/rfdiffusion/inference/__pycache__/model_runners.cpython-39.pyc b/rfdiffusion/inference/__pycache__/model_runners.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d2a064e241efb6e72bf343bc27433421c641ffd Binary files /dev/null and b/rfdiffusion/inference/__pycache__/model_runners.cpython-39.pyc differ diff --git a/rfdiffusion/inference/__pycache__/symmetry.cpython-310.pyc b/rfdiffusion/inference/__pycache__/symmetry.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e39092552d7d0d55b27b791edf989764fbe9a96a Binary files /dev/null and b/rfdiffusion/inference/__pycache__/symmetry.cpython-310.pyc differ diff --git a/rfdiffusion/inference/__pycache__/symmetry.cpython-311.pyc b/rfdiffusion/inference/__pycache__/symmetry.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c0190da115cb422be481cad9de39def5a9d9ac6 Binary files /dev/null and b/rfdiffusion/inference/__pycache__/symmetry.cpython-311.pyc differ diff --git a/rfdiffusion/inference/__pycache__/symmetry.cpython-39.pyc b/rfdiffusion/inference/__pycache__/symmetry.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94247f37344b257685fba8e3f9dfc5282a2b06bc Binary files /dev/null and b/rfdiffusion/inference/__pycache__/symmetry.cpython-39.pyc differ diff --git a/rfdiffusion/inference/__pycache__/utils.cpython-310.pyc b/rfdiffusion/inference/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2066cacd44f916612d52fac145834278528bbbfd Binary files /dev/null and b/rfdiffusion/inference/__pycache__/utils.cpython-310.pyc differ diff --git a/rfdiffusion/inference/__pycache__/utils.cpython-311.pyc b/rfdiffusion/inference/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b6eb401a884874e12caeebb73e5330f4f478dfe Binary files /dev/null and b/rfdiffusion/inference/__pycache__/utils.cpython-311.pyc differ diff --git a/rfdiffusion/inference/__pycache__/utils.cpython-39.pyc b/rfdiffusion/inference/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e24325153efb2eb226e13a7f3f46b5023e012313 Binary files /dev/null and b/rfdiffusion/inference/__pycache__/utils.cpython-39.pyc differ diff --git a/rfdiffusion/inference/model_runners.py b/rfdiffusion/inference/model_runners.py new file mode 100644 index 0000000000000000000000000000000000000000..27d553084e45a0d65feefd057c051d6f20781d4d --- /dev/null +++ b/rfdiffusion/inference/model_runners.py @@ -0,0 +1,956 @@ +import torch +import numpy as np +from omegaconf import DictConfig, OmegaConf +from rfdiffusion.RoseTTAFoldModel import RoseTTAFoldModule +from rfdiffusion.kinematics import get_init_xyz, xyz_to_t2d +from rfdiffusion.diffusion import Diffuser +from rfdiffusion.chemical import seq2chars +from rfdiffusion.util_module import ComputeAllAtomCoords +from rfdiffusion.contigs import ContigMap +from rfdiffusion.inference import utils as iu, symmetry +from rfdiffusion.potentials.manager import PotentialManager +import logging +import torch.nn.functional as nn +from rfdiffusion import util +from hydra.core.hydra_config import HydraConfig +import os + +from rfdiffusion.model_input_logger import pickle_function_call +import sys + +SCRIPT_DIR=os.path.dirname(os.path.realpath(__file__)) + +TOR_INDICES = util.torsion_indices +TOR_CAN_FLIP = util.torsion_can_flip +REF_ANGLES = util.reference_angles + + +class Sampler: + + def __init__(self, conf: DictConfig): + """ + Initialize sampler. + Args: + conf: Configuration. + """ + self.initialized = False + self.initialize(conf) + + def initialize(self, conf: DictConfig) -> None: + """ + Initialize sampler. + Args: + conf: Configuration + + - Selects appropriate model from input + - Assembles Config from model checkpoint and command line overrides + + """ + self._log = logging.getLogger(__name__) + if torch.cuda.is_available(): + self.device = torch.device('cuda') + else: + self.device = torch.device('cpu') + needs_model_reload = not self.initialized or conf.inference.ckpt_override_path != self._conf.inference.ckpt_override_path + + # Assign config to Sampler + self._conf = conf + + ################################ + ### Select Appropriate Model ### + ################################ + + if conf.inference.model_directory_path is not None: + model_directory = conf.inference.model_directory_path + else: + model_directory = f"{SCRIPT_DIR}/../../models" + + print(f"Reading models from {model_directory}") + + # Initialize inference only helper objects to Sampler + if conf.inference.ckpt_override_path is not None: + self.ckpt_path = conf.inference.ckpt_override_path + print("WARNING: You're overriding the checkpoint path from the defaults. Check that the model you're providing can run with the inputs you're providing.") + else: + if conf.contigmap.inpaint_seq is not None or conf.contigmap.provide_seq is not None: + # use model trained for inpaint_seq + if conf.contigmap.provide_seq is not None: + # this is only used for partial diffusion + assert conf.diffuser.partial_T is not None, "The provide_seq input is specifically for partial diffusion" + if conf.scaffoldguided.scaffoldguided: + self.ckpt_path = f'{model_directory}/InpaintSeq_Fold_ckpt.pt' + else: + self.ckpt_path = f'{model_directory}/InpaintSeq_ckpt.pt' + elif conf.ppi.hotspot_res is not None and conf.scaffoldguided.scaffoldguided is False: + # use complex trained model + self.ckpt_path = f'{model_directory}/Complex_base_ckpt.pt' + elif conf.scaffoldguided.scaffoldguided is True: + # use complex and secondary structure-guided model + self.ckpt_path = f'{model_directory}/Complex_Fold_base_ckpt.pt' + else: + # use default model + self.ckpt_path = f'{model_directory}/Base_ckpt.pt' + # for saving in trb file: + assert self._conf.inference.trb_save_ckpt_path is None, "trb_save_ckpt_path is not the place to specify an input model. Specify in inference.ckpt_override_path" + self._conf['inference']['trb_save_ckpt_path']=self.ckpt_path + + ####################### + ### Assemble Config ### + ####################### + + if needs_model_reload: + # Load checkpoint, so that we can assemble the config + self.load_checkpoint() + self.assemble_config_from_chk() + # Now actually load the model weights into RF + self.model = self.load_model() + else: + self.assemble_config_from_chk() + + # self.initialize_sampler(conf) + self.initialized=True + + # Initialize helper objects + self.inf_conf = self._conf.inference + self.contig_conf = self._conf.contigmap + self.denoiser_conf = self._conf.denoiser + self.ppi_conf = self._conf.ppi + self.potential_conf = self._conf.potentials + self.diffuser_conf = self._conf.diffuser + self.preprocess_conf = self._conf.preprocess + + if conf.inference.schedule_directory_path is not None: + schedule_directory = conf.inference.schedule_directory_path + else: + schedule_directory = f"{SCRIPT_DIR}/../../schedules" + + # Check for cache schedule + if not os.path.exists(schedule_directory): + os.mkdir(schedule_directory) + self.diffuser = Diffuser(**self._conf.diffuser, cache_dir=schedule_directory) + + ########################### + ### Initialise Symmetry ### + ########################### + + if self.inf_conf.symmetry is not None: + self.symmetry = symmetry.SymGen( + self.inf_conf.symmetry, + self.inf_conf.recenter, + self.inf_conf.radius, + self.inf_conf.model_only_neighbors, + ) + else: + self.symmetry = None + + self.allatom = ComputeAllAtomCoords().to(self.device) + + if self.inf_conf.input_pdb is None: + # set default pdb + script_dir=os.path.dirname(os.path.realpath(__file__)) + self.inf_conf.input_pdb=os.path.join(script_dir, '../../examples/input_pdbs/1qys.pdb') + self.target_feats = iu.process_target(self.inf_conf.input_pdb, parse_hetatom=True, center=False) + self.chain_idx = None + + ############################## + ### Handle Partial Noising ### + ############################## + + if self.diffuser_conf.partial_T: + assert self.diffuser_conf.partial_T <= self.diffuser_conf.T + self.t_step_input = int(self.diffuser_conf.partial_T) + else: + self.t_step_input = int(self.diffuser_conf.T) + + @property + def T(self): + ''' + Return the maximum number of timesteps + that this design protocol will perform. + + Output: + T (int): The maximum number of timesteps to perform + ''' + return self.diffuser_conf.T + + def load_checkpoint(self) -> None: + """Loads RF checkpoint, from which config can be generated.""" + self._log.info(f'Reading checkpoint from {self.ckpt_path}') + print('This is inf_conf.ckpt_path') + print(self.ckpt_path) + self.ckpt = torch.load( + self.ckpt_path, map_location=self.device) + + def assemble_config_from_chk(self) -> None: + """ + Function for loading model config from checkpoint directly. + + Takes: + - config file + + Actions: + - Replaces all -model and -diffuser items + - Throws a warning if there are items in -model and -diffuser that aren't in the checkpoint + + This throws an error if there is a flag in the checkpoint 'config_dict' that isn't in the inference config. + This should ensure that whenever a feature is added in the training setup, it is accounted for in the inference script. + + """ + # get overrides to re-apply after building the config from the checkpoint + overrides = [] + if HydraConfig.initialized(): + overrides = HydraConfig.get().overrides.task + print("Assembling -model, -diffuser and -preprocess configs from checkpoint") + + for cat in ['model','diffuser','preprocess']: + for key in self._conf[cat]: + try: + print(f"USING MODEL CONFIG: self._conf[{cat}][{key}] = {self.ckpt['config_dict'][cat][key]}") + self._conf[cat][key] = self.ckpt['config_dict'][cat][key] + except: + pass + + # add overrides back in again + for override in overrides: + if override.split(".")[0] in ['model','diffuser','preprocess']: + print(f'WARNING: You are changing {override.split("=")[0]} from the value this model was trained with. Are you sure you know what you are doing?') + mytype = type(self._conf[override.split(".")[0]][override.split(".")[1].split("=")[0]]) + self._conf[override.split(".")[0]][override.split(".")[1].split("=")[0]] = mytype(override.split("=")[1]) + + def load_model(self): + """Create RosettaFold model from preloaded checkpoint.""" + + # Read input dimensions from checkpoint. + self.d_t1d=self._conf.preprocess.d_t1d + self.d_t2d=self._conf.preprocess.d_t2d + model = RoseTTAFoldModule(**self._conf.model, d_t1d=self.d_t1d, d_t2d=self.d_t2d, T=self._conf.diffuser.T).to(self.device) + if self._conf.logging.inputs: + pickle_dir = pickle_function_call(model, 'forward', 'inference') + print(f'pickle_dir: {pickle_dir}') + model = model.eval() + self._log.info(f'Loading checkpoint.') + model.load_state_dict(self.ckpt['model_state_dict'], strict=True) + return model + + def construct_contig(self, target_feats): + """ + Construct contig class describing the protein to be generated + """ + self._log.info(f'Using contig: {self.contig_conf.contigs}') + return ContigMap(target_feats, **self.contig_conf) + + def construct_denoiser(self, L, visible): + """Make length-specific denoiser.""" + denoise_kwargs = OmegaConf.to_container(self.diffuser_conf) + denoise_kwargs.update(OmegaConf.to_container(self.denoiser_conf)) + denoise_kwargs.update({ + 'L': L, + 'diffuser': self.diffuser, + 'potential_manager': self.potential_manager, + }) + return iu.Denoise(**denoise_kwargs) + + def sample_init(self, return_forward_trajectory=False): + """ + Initial features to start the sampling process. + + Modify signature and function body for different initialization + based on the config. + + Returns: + xt: Starting positions with a portion of them randomly sampled. + seq_t: Starting sequence with a portion of them set to unknown. + """ + + ####################### + ### Parse input pdb ### + ####################### + + self.target_feats = iu.process_target(self.inf_conf.input_pdb, parse_hetatom=True, center=False) + + ################################ + ### Generate specific contig ### + ################################ + + # Generate a specific contig from the range of possibilities specified at input + + self.contig_map = self.construct_contig(self.target_feats) + self.mappings = self.contig_map.get_mappings() + self.mask_seq = torch.from_numpy(self.contig_map.inpaint_seq)[None,:] + self.mask_str = torch.from_numpy(self.contig_map.inpaint_str)[None,:] + self.binderlen = len(self.contig_map.inpaint) + + #################### + ### Get Hotspots ### + #################### + + self.hotspot_0idx=iu.get_idx0_hotspots(self.mappings, self.ppi_conf, self.binderlen) + + + ##################################### + ### Initialise Potentials Manager ### + ##################################### + + self.potential_manager = PotentialManager(self.potential_conf, + self.ppi_conf, + self.diffuser_conf, + self.inf_conf, + self.hotspot_0idx, + self.binderlen) + + ################################### + ### Initialize other attributes ### + ################################### + + xyz_27 = self.target_feats['xyz_27'] + mask_27 = self.target_feats['mask_27'] + seq_orig = self.target_feats['seq'].long() + L_mapped = len(self.contig_map.ref) + contig_map=self.contig_map + + self.diffusion_mask = self.mask_str + self.chain_idx=['A' if i < self.binderlen else 'B' for i in range(L_mapped)] + + #################################### + ### Generate initial coordinates ### + #################################### + + if self.diffuser_conf.partial_T: + assert xyz_27.shape[0] == L_mapped, f"there must be a coordinate in the input PDB for \ + each residue implied by the contig string for partial diffusion. length of \ + input PDB != length of contig string: {xyz_27.shape[0]} != {L_mapped}" + assert contig_map.hal_idx0 == contig_map.ref_idx0, f'for partial diffusion there can \ + be no offset between the index of a residue in the input and the index of the \ + residue in the output, {contig_map.hal_idx0} != {contig_map.ref_idx0}' + # Partially diffusing from a known structure + xyz_mapped=xyz_27 + atom_mask_mapped = mask_27 + else: + # Fully diffusing from points initialised at the origin + # adjust size of input xt according to residue map + xyz_mapped = torch.full((1,1,L_mapped,27,3), np.nan) + xyz_mapped[:, :, contig_map.hal_idx0, ...] = xyz_27[contig_map.ref_idx0,...] + xyz_motif_prealign = xyz_mapped.clone() + motif_prealign_com = xyz_motif_prealign[0,0,:,1].mean(dim=0) + self.motif_com = xyz_27[contig_map.ref_idx0,1].mean(dim=0) + xyz_mapped = get_init_xyz(xyz_mapped).squeeze() + # adjust the size of the input atom map + atom_mask_mapped = torch.full((L_mapped, 27), False) + atom_mask_mapped[contig_map.hal_idx0] = mask_27[contig_map.ref_idx0] + + # Diffuse the contig-mapped coordinates + if self.diffuser_conf.partial_T: + assert self.diffuser_conf.partial_T <= self.diffuser_conf.T, "Partial_T must be less than T" + self.t_step_input = int(self.diffuser_conf.partial_T) + else: + self.t_step_input = int(self.diffuser_conf.T) + t_list = np.arange(1, self.t_step_input+1) + + ################################# + ### Generate initial sequence ### + ################################# + + seq_t = torch.full((1,L_mapped), 21).squeeze() # 21 is the mask token + seq_t[contig_map.hal_idx0] = seq_orig[contig_map.ref_idx0] + + # Unmask sequence if desired + if self._conf.contigmap.provide_seq is not None: + seq_t[self.mask_seq.squeeze()] = seq_orig[self.mask_seq.squeeze()] + + seq_t[~self.mask_seq.squeeze()] = 21 + seq_t = torch.nn.functional.one_hot(seq_t, num_classes=22).float() # [L,22] + seq_orig = torch.nn.functional.one_hot(seq_orig, num_classes=22).float() # [L,22] + + fa_stack, xyz_true = self.diffuser.diffuse_pose( + xyz_mapped, + torch.clone(seq_t), + atom_mask_mapped.squeeze(), + diffusion_mask=self.diffusion_mask.squeeze(), + t_list=t_list) + xT = fa_stack[-1].squeeze()[:,:14,:] + xt = torch.clone(xT) + + self.denoiser = self.construct_denoiser(len(self.contig_map.ref), visible=self.mask_seq.squeeze()) + + ###################### + ### Apply Symmetry ### + ###################### + + if self.symmetry is not None: + xt, seq_t = self.symmetry.apply_symmetry(xt, seq_t) + self._log.info(f'Sequence init: {seq2chars(torch.argmax(seq_t, dim=-1))}') + + self.msa_prev = None + self.pair_prev = None + self.state_prev = None + + ######################################### + ### Parse ligand for ligand potential ### + ######################################### + + if self.potential_conf.guiding_potentials is not None: + if any(list(filter(lambda x: "substrate_contacts" in x, self.potential_conf.guiding_potentials))): + assert len(self.target_feats['xyz_het']) > 0, "If you're using the Substrate Contact potential, \ + you need to make sure there's a ligand in the input_pdb file!" + het_names = np.array([i['name'].strip() for i in self.target_feats['info_het']]) + xyz_het = self.target_feats['xyz_het'][het_names == self._conf.potentials.substrate] + xyz_het = torch.from_numpy(xyz_het) + assert xyz_het.shape[0] > 0, f'expected >0 heteroatoms from ligand with name {self._conf.potentials.substrate}' + xyz_motif_prealign = xyz_motif_prealign[0,0][self.diffusion_mask.squeeze()] + motif_prealign_com = xyz_motif_prealign[:,1].mean(dim=0) + xyz_het_com = xyz_het.mean(dim=0) + for pot in self.potential_manager.potentials_to_apply: + pot.motif_substrate_atoms = xyz_het + pot.diffusion_mask = self.diffusion_mask.squeeze() + pot.xyz_motif = xyz_motif_prealign + pot.diffuser = self.diffuser + return xt, seq_t + + def _preprocess(self, seq, xyz_t, t, repack=False): + + """ + Function to prepare inputs to diffusion model + + seq (L,22) one-hot sequence + + msa_masked (1,1,L,48) + + msa_full (1,1,L,25) + + xyz_t (L,14,3) template crds (diffused) + + t1d (1,L,28) this is the t1d before tacking on the chi angles: + - seq + unknown/mask (21) + - global timestep (1-t/T if not motif else 1) (1) + + MODEL SPECIFIC: + - contacting residues: for ppi. Target residues in contact with binder (1) + - empty feature (legacy) (1) + - ss (H, E, L, MASK) (4) + + t2d (1, L, L, 45) + - last plane is block adjacency + """ + + L = seq.shape[0] + T = self.T + binderlen = self.binderlen + target_res = self.ppi_conf.hotspot_res + + ################## + ### msa_masked ### + ################## + msa_masked = torch.zeros((1,1,L,48)) + msa_masked[:,:,:,:22] = seq[None, None] + msa_masked[:,:,:,22:44] = seq[None, None] + msa_masked[:,:,0,46] = 1.0 + msa_masked[:,:,-1,47] = 1.0 + + ################ + ### msa_full ### + ################ + msa_full = torch.zeros((1,1,L,25)) + msa_full[:,:,:,:22] = seq[None, None] + msa_full[:,:,0,23] = 1.0 + msa_full[:,:,-1,24] = 1.0 + + ########### + ### t1d ### + ########### + + # Here we need to go from one hot with 22 classes to one hot with 21 classes (last plane is missing token) + t1d = torch.zeros((1,1,L,21)) + + seqt1d = torch.clone(seq) + for idx in range(L): + if seqt1d[idx,21] == 1: + seqt1d[idx,20] = 1 + seqt1d[idx,21] = 0 + + t1d[:,:,:,:21] = seqt1d[None,None,:,:21] + + + # Set timestep feature to 1 where diffusion mask is True, else 1-t/T + timefeature = torch.zeros((L)).float() + timefeature[self.mask_str.squeeze()] = 1 + timefeature[~self.mask_str.squeeze()] = 1 - t/self.T + timefeature = timefeature[None,None,...,None] + + t1d = torch.cat((t1d, timefeature), dim=-1).float() + + ############# + ### xyz_t ### + ############# + if self.preprocess_conf.sidechain_input: + xyz_t[torch.where(seq == 21, True, False),3:,:] = float('nan') + else: + xyz_t[~self.mask_str.squeeze(),3:,:] = float('nan') + + xyz_t=xyz_t[None, None] + xyz_t = torch.cat((xyz_t, torch.full((1,1,L,13,3), float('nan'))), dim=3) + + ########### + ### t2d ### + ########### + t2d = xyz_to_t2d(xyz_t) + + ########### + ### idx ### + ########### + idx = torch.tensor(self.contig_map.rf)[None] + + ############### + ### alpha_t ### + ############### + seq_tmp = t1d[...,:-1].argmax(dim=-1).reshape(-1,L) + alpha, _, alpha_mask, _ = util.get_torsions(xyz_t.reshape(-1, L, 27, 3), seq_tmp, TOR_INDICES, TOR_CAN_FLIP, REF_ANGLES) + alpha_mask = torch.logical_and(alpha_mask, ~torch.isnan(alpha[...,0])) + alpha[torch.isnan(alpha)] = 0.0 + alpha = alpha.reshape(1,-1,L,10,2) + alpha_mask = alpha_mask.reshape(1,-1,L,10,1) + alpha_t = torch.cat((alpha, alpha_mask), dim=-1).reshape(1, -1, L, 30) + + #put tensors on device + msa_masked = msa_masked.to(self.device) + msa_full = msa_full.to(self.device) + seq = seq.to(self.device) + xyz_t = xyz_t.to(self.device) + idx = idx.to(self.device) + t1d = t1d.to(self.device) + t2d = t2d.to(self.device) + alpha_t = alpha_t.to(self.device) + + ###################### + ### added_features ### + ###################### + if self.preprocess_conf.d_t1d >= 24: # add hotspot residues + hotspot_tens = torch.zeros(L).float() + if self.ppi_conf.hotspot_res is None: + print("WARNING: you're using a model trained on complexes and hotspot residues, without specifying hotspots.\ + If you're doing monomer diffusion this is fine") + hotspot_idx=[] + else: + hotspots = [(i[0],int(i[1:])) for i in self.ppi_conf.hotspot_res] + hotspot_idx=[] + for i,res in enumerate(self.contig_map.con_ref_pdb_idx): + if res in hotspots: + hotspot_idx.append(self.contig_map.hal_idx0[i]) + hotspot_tens[hotspot_idx] = 1.0 + + # Add blank (legacy) feature and hotspot tensor + t1d=torch.cat((t1d, torch.zeros_like(t1d[...,:1]), hotspot_tens[None,None,...,None].to(self.device)), dim=-1) + + return msa_masked, msa_full, seq[None], torch.squeeze(xyz_t, dim=0), idx, t1d, t2d, xyz_t, alpha_t + + def sample_step(self, *, t, x_t, seq_init, final_step): + '''Generate the next pose that the model should be supplied at timestep t-1. + + Args: + t (int): The timestep that has just been predicted + seq_t (torch.tensor): (L,22) The sequence at the beginning of this timestep + x_t (torch.tensor): (L,14,3) The residue positions at the beginning of this timestep + seq_init (torch.tensor): (L,22) The initialized sequence used in updating the sequence. + + Returns: + px0: (L,14,3) The model's prediction of x0. + x_t_1: (L,14,3) The updated positions of the next step. + seq_t_1: (L,22) The updated sequence of the next step. + tors_t_1: (L, ?) The updated torsion angles of the next step. + plddt: (L, 1) Predicted lDDT of x0. + ''' + msa_masked, msa_full, seq_in, xt_in, idx_pdb, t1d, t2d, xyz_t, alpha_t = self._preprocess( + seq_init, x_t, t) + + N,L = msa_masked.shape[:2] + + if self.symmetry is not None: + idx_pdb, self.chain_idx = self.symmetry.res_idx_procesing(res_idx=idx_pdb) + + msa_prev = None + pair_prev = None + state_prev = None + + with torch.no_grad(): + msa_prev, pair_prev, px0, state_prev, alpha, logits, plddt = self.model(msa_masked, + msa_full, + seq_in, + xt_in, + idx_pdb, + t1d=t1d, + t2d=t2d, + xyz_t=xyz_t, + alpha_t=alpha_t, + msa_prev = msa_prev, + pair_prev = pair_prev, + state_prev = state_prev, + t=torch.tensor(t), + return_infer=True, + motif_mask=self.diffusion_mask.squeeze().to(self.device)) + + # prediction of X0 + _, px0 = self.allatom(torch.argmax(seq_in, dim=-1), px0, alpha) + px0 = px0.squeeze()[:,:14] + + ##################### + ### Get next pose ### + ##################### + + if t > final_step: + seq_t_1 = nn.one_hot(seq_init,num_classes=22).to(self.device) + x_t_1, px0 = self.denoiser.get_next_pose( + xt=x_t, + px0=px0, + t=t, + diffusion_mask=self.mask_str.squeeze(), + align_motif=self.inf_conf.align_motif + ) + else: + x_t_1 = torch.clone(px0).to(x_t.device) + seq_t_1 = torch.clone(seq_init) + px0 = px0.to(x_t.device) + + if self.symmetry is not None: + x_t_1, seq_t_1 = self.symmetry.apply_symmetry(x_t_1, seq_t_1) + + return px0, x_t_1, seq_t_1, plddt + + +class SelfConditioning(Sampler): + """ + Model Runner for self conditioning + pX0[t+1] is provided as a template input to the model at time t + """ + + def sample_step(self, *, t, x_t, seq_init, final_step): + ''' + Generate the next pose that the model should be supplied at timestep t-1. + Args: + t (int): The timestep that has just been predicted + seq_t (torch.tensor): (L,22) The sequence at the beginning of this timestep + x_t (torch.tensor): (L,14,3) The residue positions at the beginning of this timestep + seq_init (torch.tensor): (L,22) The initialized sequence used in updating the sequence. + Returns: + px0: (L,14,3) The model's prediction of x0. + x_t_1: (L,14,3) The updated positions of the next step. + seq_t_1: (L) The sequence to the next step (== seq_init) + plddt: (L, 1) Predicted lDDT of x0. + ''' + + msa_masked, msa_full, seq_in, xt_in, idx_pdb, t1d, t2d, xyz_t, alpha_t = self._preprocess( + seq_init, x_t, t) + B,N,L = xyz_t.shape[:3] + + ################################## + ######## Str Self Cond ########### + ################################## + if (t < self.diffuser.T) and (t != self.diffuser_conf.partial_T): + zeros = torch.zeros(B,1,L,24,3).float().to(xyz_t.device) + xyz_t = torch.cat((self.prev_pred.unsqueeze(1),zeros), dim=-2) # [B,T,L,27,3] + t2d_44 = xyz_to_t2d(xyz_t) # [B,T,L,L,44] + else: + xyz_t = torch.zeros_like(xyz_t) + t2d_44 = torch.zeros_like(t2d[...,:44]) + # No effect if t2d is only dim 44 + t2d[...,:44] = t2d_44 + + if self.symmetry is not None: + idx_pdb, self.chain_idx = self.symmetry.res_idx_procesing(res_idx=idx_pdb) + + #################### + ### Forward Pass ### + #################### + + with torch.no_grad(): + msa_prev, pair_prev, px0, state_prev, alpha, logits, plddt = self.model(msa_masked, + msa_full, + seq_in, + xt_in, + idx_pdb, + t1d=t1d, + t2d=t2d, + xyz_t=xyz_t, + alpha_t=alpha_t, + msa_prev = None, + pair_prev = None, + state_prev = None, + t=torch.tensor(t), + return_infer=True, + motif_mask=self.diffusion_mask.squeeze().to(self.device)) + + if self.symmetry is not None and self.inf_conf.symmetric_self_cond: + px0 = self.symmetrise_prev_pred(px0=px0,seq_in=seq_in, alpha=alpha)[:,:,:3] + + self.prev_pred = torch.clone(px0) + + # prediction of X0 + _, px0 = self.allatom(torch.argmax(seq_in, dim=-1), px0, alpha) + px0 = px0.squeeze()[:,:14] + + ########################### + ### Generate Next Input ### + ########################### + + seq_t_1 = torch.clone(seq_init) + if t > final_step: + x_t_1, px0 = self.denoiser.get_next_pose( + xt=x_t, + px0=px0, + t=t, + diffusion_mask=self.mask_str.squeeze(), + align_motif=self.inf_conf.align_motif, + include_motif_sidechains=self.preprocess_conf.motif_sidechain_input + ) + self._log.info( + f'Timestep {t}, input to next step: { seq2chars(torch.argmax(seq_t_1, dim=-1).tolist())}') + else: + x_t_1 = torch.clone(px0).to(x_t.device) + px0 = px0.to(x_t.device) + + ###################### + ### Apply symmetry ### + ###################### + + if self.symmetry is not None: + x_t_1, seq_t_1 = self.symmetry.apply_symmetry(x_t_1, seq_t_1) + + return px0, x_t_1, seq_t_1, plddt + + def symmetrise_prev_pred(self, px0, seq_in, alpha): + """ + Method for symmetrising px0 output for self-conditioning + """ + _,px0_aa = self.allatom(torch.argmax(seq_in, dim=-1), px0, alpha) + px0_sym,_ = self.symmetry.apply_symmetry(px0_aa.to('cpu').squeeze()[:,:14], torch.argmax(seq_in, dim=-1).squeeze().to('cpu')) + px0_sym = px0_sym[None].to(self.device) + return px0_sym + +class ScaffoldedSampler(SelfConditioning): + """ + Model Runner for Scaffold-Constrained diffusion + """ + def __init__(self, conf: DictConfig): + """ + Initialize scaffolded sampler. + Two basic approaches here: + i) Given a block adjacency/secondary structure input, generate a fold (in the presence or absence of a target) + - This allows easy generation of binders or specific folds + - Allows simple expansion of an input, to sample different lengths + ii) Providing a contig input and corresponding block adjacency/secondary structure input + - This allows mixed motif scaffolding and fold-conditioning. + - Adjacency/secondary structure inputs must correspond exactly in length to the contig string + """ + super().__init__(conf) + # initialize BlockAdjacency sampling class + self.blockadjacency = iu.BlockAdjacency(conf, conf.inference.num_designs) + + ################################################# + ### Initialize target, if doing binder design ### + ################################################# + + if conf.scaffoldguided.target_pdb: + self.target = iu.Target(conf.scaffoldguided, conf.ppi.hotspot_res) + self.target_pdb = self.target.get_target() + if conf.scaffoldguided.target_ss is not None: + self.target_ss = torch.load(conf.scaffoldguided.target_ss).long() + self.target_ss = torch.nn.functional.one_hot(self.target_ss, num_classes=4) + if self._conf.scaffoldguided.contig_crop is not None: + self.target_ss=self.target_ss[self.target_pdb['crop_mask']] + if conf.scaffoldguided.target_adj is not None: + self.target_adj = torch.load(conf.scaffoldguided.target_adj).long() + self.target_adj=torch.nn.functional.one_hot(self.target_adj, num_classes=3) + if self._conf.scaffoldguided.contig_crop is not None: + self.target_adj=self.target_adj[self.target_pdb['crop_mask']] + self.target_adj=self.target_adj[:,self.target_pdb['crop_mask']] + else: + self.target = None + self.target_pdb=False + + def sample_init(self): + """ + Wrapper method for taking secondary structure + adj, and outputting xt, seq_t + """ + + ########################## + ### Process Fold Input ### + ########################## + self.L, self.ss, self.adj = self.blockadjacency.get_scaffold() + self.adj = nn.one_hot(self.adj.long(), num_classes=3) + + ############################## + ### Auto-contig generation ### + ############################## + + if self.contig_conf.contigs is None: + # process target + xT = torch.full((self.L, 27,3), np.nan) + xT = get_init_xyz(xT[None,None]).squeeze() + seq_T = torch.full((self.L,),21) + self.diffusion_mask = torch.full((self.L,),False) + atom_mask = torch.full((self.L,27), False) + self.binderlen=self.L + + if self.target: + target_L = np.shape(self.target_pdb['xyz'])[0] + # xyz + target_xyz = torch.full((target_L, 27, 3), np.nan) + target_xyz[:,:14,:] = torch.from_numpy(self.target_pdb['xyz']) + xT = torch.cat((xT, target_xyz), dim=0) + # seq + seq_T = torch.cat((seq_T, torch.from_numpy(self.target_pdb['seq'])), dim=0) + # diffusion mask + self.diffusion_mask = torch.cat((self.diffusion_mask, torch.full((target_L,), True)),dim=0) + # atom mask + mask_27 = torch.full((target_L, 27), False) + mask_27[:,:14] = torch.from_numpy(self.target_pdb['mask']) + atom_mask = torch.cat((atom_mask, mask_27), dim=0) + self.L += target_L + # generate contigmap object + contig = [] + for idx,i in enumerate(self.target_pdb['pdb_idx'][:-1]): + if idx==0: + start=i[1] + if i[1] + 1 != self.target_pdb['pdb_idx'][idx+1][1] or i[0] != self.target_pdb['pdb_idx'][idx+1][0]: + contig.append(f'{i[0]}{start}-{i[1]}/0 ') + start = self.target_pdb['pdb_idx'][idx+1][1] + contig.append(f"{self.target_pdb['pdb_idx'][-1][0]}{start}-{self.target_pdb['pdb_idx'][-1][1]}/0 ") + contig.append(f"{self.binderlen}-{self.binderlen}") + contig = ["".join(contig)] + else: + contig = [f"{self.binderlen}-{self.binderlen}"] + self.contig_map=ContigMap(self.target_pdb, contig) + self.mappings = self.contig_map.get_mappings() + self.mask_seq = self.diffusion_mask + self.mask_str = self.diffusion_mask + L_mapped=len(self.contig_map.ref) + + ############################ + ### Specific Contig mode ### + ############################ + + else: + # get contigmap from command line + assert self.target is None, "Giving a target is the wrong way of handling this is you're doing contigs and secondary structure" + + # process target and reinitialise potential_manager. This is here because the 'target' is always set up to be the second chain in out inputs. + self.target_feats = iu.process_target(self.inf_conf.input_pdb) + self.contig_map = self.construct_contig(self.target_feats) + self.mappings = self.contig_map.get_mappings() + self.mask_seq = torch.from_numpy(self.contig_map.inpaint_seq)[None,:] + self.mask_str = torch.from_numpy(self.contig_map.inpaint_str)[None,:] + self.binderlen = len(self.contig_map.inpaint) + target_feats = self.target_feats + contig_map = self.contig_map + + xyz_27 = target_feats['xyz_27'] + mask_27 = target_feats['mask_27'] + seq_orig = target_feats['seq'] + L_mapped = len(self.contig_map.ref) + seq_T=torch.full((L_mapped,),21) + seq_T[contig_map.hal_idx0] = seq_orig[contig_map.ref_idx0] + seq_T[~self.mask_seq.squeeze()] = 21 + assert L_mapped==self.adj.shape[0] + diffusion_mask = self.mask_str + self.diffusion_mask = diffusion_mask + + xT = torch.full((1,1,L_mapped,27,3), np.nan) + xT[:, :, contig_map.hal_idx0, ...] = xyz_27[contig_map.ref_idx0,...] + xT = get_init_xyz(xT).squeeze() + atom_mask = torch.full((L_mapped, 27), False) + atom_mask[contig_map.hal_idx0] = mask_27[contig_map.ref_idx0] + + #################### + ### Get hotspots ### + #################### + self.hotspot_0idx=iu.get_idx0_hotspots(self.mappings, self.ppi_conf, self.binderlen) + + ######################### + ### Set up potentials ### + ######################### + + self.potential_manager = PotentialManager(self.potential_conf, + self.ppi_conf, + self.diffuser_conf, + self.inf_conf, + self.hotspot_0idx, + self.binderlen) + + self.chain_idx=['A' if i < self.binderlen else 'B' for i in range(self.L)] + + ######################## + ### Handle Partial T ### + ######################## + + if self.diffuser_conf.partial_T: + assert self.diffuser_conf.partial_T <= self.diffuser_conf.T + self.t_step_input = int(self.diffuser_conf.partial_T) + else: + self.t_step_input = int(self.diffuser_conf.T) + t_list = np.arange(1, self.t_step_input+1) + seq_T=torch.nn.functional.one_hot(seq_T, num_classes=22).float() + + fa_stack, xyz_true = self.diffuser.diffuse_pose( + xT, + torch.clone(seq_T), + atom_mask.squeeze(), + diffusion_mask=self.diffusion_mask.squeeze(), + t_list=t_list, + include_motif_sidechains=self.preprocess_conf.motif_sidechain_input) + + ####################### + ### Set up Denoiser ### + ####################### + + self.denoiser = self.construct_denoiser(self.L, visible=self.mask_seq.squeeze()) + + + xT = torch.clone(fa_stack[-1].squeeze()[:,:14,:]) + return xT, seq_T + + def _preprocess(self, seq, xyz_t, t): + msa_masked, msa_full, seq, xyz_prev, idx_pdb, t1d, t2d, xyz_t, alpha_t = super()._preprocess(seq, xyz_t, t, repack=False) + + ################################### + ### Add Adj/Secondary Structure ### + ################################### + + assert self.preprocess_conf.d_t1d == 28, "The checkpoint you're using hasn't been trained with sec-struc/block adjacency features" + assert self.preprocess_conf.d_t2d == 47, "The checkpoint you're using hasn't been trained with sec-struc/block adjacency features" + + ##################### + ### Handle Target ### + ##################### + + if self.target: + blank_ss = torch.nn.functional.one_hot(torch.full((self.L-self.binderlen,), 3), num_classes=4) + full_ss = torch.cat((self.ss, blank_ss), dim=0) + if self._conf.scaffoldguided.target_ss is not None: + full_ss[self.binderlen:] = self.target_ss + else: + full_ss = self.ss + t1d=torch.cat((t1d, full_ss[None,None].to(self.device)), dim=-1) + + t1d = t1d.float() + + ########### + ### t2d ### + ########### + + if self.d_t2d == 47: + if self.target: + full_adj = torch.zeros((self.L, self.L, 3)) + full_adj[:,:,-1] = 1. #set to mask + full_adj[:self.binderlen, :self.binderlen] = self.adj + if self._conf.scaffoldguided.target_adj is not None: + full_adj[self.binderlen:,self.binderlen:] = self.target_adj + else: + full_adj = self.adj + t2d=torch.cat((t2d, full_adj[None,None].to(self.device)),dim=-1) + + ########### + ### idx ### + ########### + + if self.target: + idx_pdb[:,self.binderlen:] += 200 + + return msa_masked, msa_full, seq, xyz_prev, idx_pdb, t1d, t2d, xyz_t, alpha_t diff --git a/rfdiffusion/inference/sym_rots.npz b/rfdiffusion/inference/sym_rots.npz new file mode 100644 index 0000000000000000000000000000000000000000..18a7f032d6a6ba323b9c2ff8ea3e6e9795b9094e --- /dev/null +++ b/rfdiffusion/inference/sym_rots.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:67c7ad09f2c09e465ecb6b785ea92de92d79c99ef7381a60bcfca1b9b1c8d4f2 +size 7694 diff --git a/rfdiffusion/inference/symmetry.py b/rfdiffusion/inference/symmetry.py new file mode 100644 index 0000000000000000000000000000000000000000..864a5abe73aed33a4d68545d571bf48aa3bf606b --- /dev/null +++ b/rfdiffusion/inference/symmetry.py @@ -0,0 +1,236 @@ +"""Helper class for handle symmetric assemblies.""" +from pyrsistent import v +from scipy.spatial.transform import Rotation +import functools as fn +import torch +import string +import logging +import numpy as np +import pathlib + +format_rots = lambda r: torch.tensor(r).float() + +T3_ROTATIONS = [ + torch.Tensor([ + [ 1., 0., 0.], + [ 0., 1., 0.], + [ 0., 0., 1.]]).float(), + torch.Tensor([ + [-1., -0., 0.], + [-0., 1., 0.], + [-0., 0., -1.]]).float(), + torch.Tensor([ + [-1., 0., 0.], + [ 0., -1., 0.], + [ 0., 0., 1.]]).float(), + torch.Tensor([ + [ 1., 0., 0.], + [ 0., -1., 0.], + [ 0., 0., -1.]]).float(), +] + +saved_symmetries = ['tetrahedral', 'octahedral', 'icosahedral'] + +class SymGen: + + def __init__(self, global_sym, recenter, radius, model_only_neighbors=False): + self._log = logging.getLogger(__name__) + self._recenter = recenter + self._radius = radius + + if global_sym.lower().startswith('c'): + # Cyclic symmetry + if not global_sym[1:].isdigit(): + raise ValueError(f'Invalid cyclic symmetry {global_sym}') + self._log.info( + f'Initializing cyclic symmetry order {global_sym[1:]}.') + self._init_cyclic(int(global_sym[1:])) + self.apply_symmetry = self._apply_cyclic + + elif global_sym.lower().startswith('d'): + # Dihedral symmetry + if not global_sym[1:].isdigit(): + raise ValueError(f'Invalid dihedral symmetry {global_sym}') + self._log.info( + f'Initializing dihedral symmetry order {global_sym[1:]}.') + self._init_dihedral(int(global_sym[1:])) + # Applied the same way as cyclic symmetry + self.apply_symmetry = self._apply_cyclic + + elif global_sym.lower() == 't3': + # Tetrahedral (T3) symmetry + self._log.info('Initializing T3 symmetry order.') + self.sym_rots = T3_ROTATIONS + self.order = 4 + # Applied the same way as cyclic symmetry + self.apply_symmetry = self._apply_cyclic + + elif global_sym == 'octahedral': + # Octahedral symmetry + self._log.info( + 'Initializing octahedral symmetry.') + self._init_octahedral() + self.apply_symmetry = self._apply_octahedral + + elif global_sym.lower() in saved_symmetries: + # Using a saved symmetry + self._log.info('Initializing %s symmetry order.'%global_sym) + self._init_from_symrots_file(global_sym) + + # Applied the same way as cyclic symmetry + self.apply_symmetry = self._apply_cyclic + else: + raise ValueError(f'Unrecognized symmetry {global_sym}') + + self.res_idx_procesing = fn.partial( + self._lin_chainbreaks, num_breaks=self.order) + + ##################### + ## Cyclic symmetry ## + ##################### + def _init_cyclic(self, order): + sym_rots = [] + for i in range(order): + deg = i * 360.0 / order + r = Rotation.from_euler('z', deg, degrees=True) + sym_rots.append(format_rots(r.as_matrix())) + self.sym_rots = sym_rots + self.order = order + + def _apply_cyclic(self, coords_in, seq_in): + coords_out = torch.clone(coords_in) + seq_out = torch.clone(seq_in) + if seq_out.shape[0] % self.order != 0: + raise ValueError( + f'Sequence length must be divisble by {self.order}') + subunit_len = seq_out.shape[0] // self.order + for i in range(self.order): + start_i = subunit_len * i + end_i = subunit_len * (i+1) + coords_out[start_i:end_i] = torch.einsum( + 'bnj,kj->bnk', coords_out[:subunit_len], self.sym_rots[i]) + seq_out[start_i:end_i] = seq_out[:subunit_len] + return coords_out, seq_out + + def _lin_chainbreaks(self, num_breaks, res_idx, offset=None): + assert res_idx.ndim == 2 + res_idx = torch.clone(res_idx) + subunit_len = res_idx.shape[-1] // num_breaks + chain_delimiters = [] + if offset is None: + offset = res_idx.shape[-1] + for i in range(num_breaks): + start_i = subunit_len * i + end_i = subunit_len * (i+1) + chain_labels = list(string.ascii_uppercase) + [str(i+j) for i in + string.ascii_uppercase for j in string.ascii_uppercase] + chain_delimiters.extend( + [chain_labels[i] for _ in range(subunit_len)] + ) + res_idx[:, start_i:end_i] = res_idx[:, start_i:end_i] + offset * (i+1) + return res_idx, chain_delimiters + + ####################### + ## Dihedral symmetry ## + ####################### + def _init_dihedral(self, order): + sym_rots = [] + flip = Rotation.from_euler('x', 180, degrees=True).as_matrix() + for i in range(order): + deg = i * 360.0 / order + rot = Rotation.from_euler('z', deg, degrees=True).as_matrix() + sym_rots.append(format_rots(rot)) + rot2 = flip @ rot + sym_rots.append(format_rots(rot2)) + self.sym_rots = sym_rots + self.order = order * 2 + + ######################### + ## Octahedral symmetry ## + ######################### + def _init_octahedral(self): + sym_rots = np.load(f"{pathlib.Path(__file__).parent.resolve()}/sym_rots.npz") + self.sym_rots = [ + torch.tensor(v_i, dtype=torch.float32) + for v_i in sym_rots['octahedral'] + ] + self.order = len(self.sym_rots) + + def _apply_octahedral(self, coords_in, seq_in): + coords_out = torch.clone(coords_in) + seq_out = torch.clone(seq_in) + if seq_out.shape[0] % self.order != 0: + raise ValueError( + f'Sequence length must be divisble by {self.order}') + subunit_len = seq_out.shape[0] // self.order + base_axis = torch.tensor([self._radius, 0., 0.])[None] + for i in range(self.order): + start_i = subunit_len * i + end_i = subunit_len * (i+1) + subunit_chain = torch.einsum( + 'bnj,kj->bnk', coords_in[:subunit_len], self.sym_rots[i]) + + if self._recenter: + center = torch.mean(subunit_chain[:, 1, :], axis=0) + subunit_chain -= center[None, None, :] + rotated_axis = torch.einsum( + 'nj,kj->nk', base_axis, self.sym_rots[i]) + subunit_chain += rotated_axis[:, None, :] + + coords_out[start_i:end_i] = subunit_chain + seq_out[start_i:end_i] = seq_out[:subunit_len] + return coords_out, seq_out + + ####################### + ## symmetry from file # + ####################### + def _init_from_symrots_file(self, name): + """ _init_from_symrots_file initializes using + ./inference/sym_rots.npz + + Args: + name: name of symmetry (of tetrahedral, octahedral, icosahedral) + + sets self.sym_rots to be a list of torch.tensor of shape [3, 3] + """ + assert name in saved_symmetries, name + " not in " + str(saved_symmetries) + + # Load in list of rotation matrices for `name` + fn = f"{pathlib.Path(__file__).parent.resolve()}/sym_rots.npz" + obj = np.load(fn) + symms = None + for k, v in obj.items(): + if str(k) == name: symms = v + assert symms is not None, "%s not found in %s"%(name, fn) + + + self.sym_rots = [torch.tensor(v_i, dtype=torch.float32) for v_i in symms] + self.order = len(self.sym_rots) + + # Return if identity is the first rotation + if not np.isclose(((self.sym_rots[0]-np.eye(3))**2).sum(), 0): + + # Move identity to be the first rotation + for i, rot in enumerate(self.sym_rots): + if np.isclose(((rot-np.eye(3))**2).sum(), 0): + self.sym_rots = [self.sym_rots.pop(i)] + self.sym_rots + + assert len(self.sym_rots) == self.order + assert np.isclose(((self.sym_rots[0]-np.eye(3))**2).sum(), 0) + + def close_neighbors(self): + """close_neighbors finds the rotations within self.sym_rots that + correspond to close neighbors. + + Returns: + list of rotation matrices corresponding to the identity and close neighbors + """ + # set of small rotation angle rotations + rel_rot = lambda M: np.linalg.norm(Rotation.from_matrix(M).as_rotvec()) + rel_rots = [(i+1, rel_rot(M)) for i, M in enumerate(self.sym_rots[1:])] + min_rot = min(rel_rot_val[1] for rel_rot_val in rel_rots) + close_rots = [np.eye(3)] + [ + self.sym_rots[i] for i, rel_rot_val in rel_rots if + np.isclose(rel_rot_val, min_rot) + ] + return close_rots diff --git a/rfdiffusion/inference/utils.py b/rfdiffusion/inference/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..43e7e99b7e6bd272690d6a97b79c5ce77ca9606b --- /dev/null +++ b/rfdiffusion/inference/utils.py @@ -0,0 +1,1003 @@ +import numpy as np +import os +from omegaconf import DictConfig +import torch +import torch.nn.functional as nn +from rfdiffusion.diffusion import get_beta_schedule +from scipy.spatial.transform import Rotation as scipy_R +from rfdiffusion.util import rigid_from_3_points +from rfdiffusion.util_module import ComputeAllAtomCoords +from rfdiffusion import util +import random +import logging +from rfdiffusion.inference import model_runners +import glob + +########################################################### +#### Functions which can be called outside of Denoiser #### +########################################################### + + +def get_next_frames(xt, px0, t, diffuser, so3_type, diffusion_mask, noise_scale=1.0): + """ + get_next_frames gets updated frames using IGSO(3) + score_based reverse diffusion. + + + based on self.so3_type use score based update. + + Generate frames at t-1 + Rather than generating random rotations (as occurs during forward process), calculate rotation between xt and px0 + + Args: + xt: noised coordinates of shape [L, 14, 3] + px0: prediction of coordinates at t=0, of shape [L, 14, 3] + t: integer time step + diffuser: Diffuser object for reverse igSO3 sampling + so3_type: The type of SO3 noising being used ('igso3') + diffusion_mask: of shape [L] of type bool, True means not to be + updated (e.g. mask is true for motif residues) + noise_scale: scale factor for the noise added (IGSO3 only) + + Returns: + backbone coordinates for step x_t-1 of shape [L, 3, 3] + """ + N_0 = px0[None, :, 0, :] + Ca_0 = px0[None, :, 1, :] + C_0 = px0[None, :, 2, :] + + R_0, Ca_0 = rigid_from_3_points(N_0, Ca_0, C_0) + + N_t = xt[None, :, 0, :] + Ca_t = xt[None, :, 1, :] + C_t = xt[None, :, 2, :] + + R_t, Ca_t = rigid_from_3_points(N_t, Ca_t, C_t) + + # this must be to normalize them or something + R_0 = scipy_R.from_matrix(R_0.squeeze().numpy()).as_matrix() + R_t = scipy_R.from_matrix(R_t.squeeze().numpy()).as_matrix() + + L = R_t.shape[0] + all_rot_transitions = np.broadcast_to(np.identity(3), (L, 3, 3)).copy() + # Sample next frame for each residue + if so3_type == "igso3": + # don't do calculations on masked positions since they end up as identity matrix + all_rot_transitions[ + ~diffusion_mask + ] = diffuser.so3_diffuser.reverse_sample_vectorized( + R_t[~diffusion_mask], + R_0[~diffusion_mask], + t, + noise_level=noise_scale, + mask=None, + return_perturb=True, + ) + else: + assert False, "so3 diffusion type %s not implemented" % so3_type + + all_rot_transitions = all_rot_transitions[:, None, :, :] + + # Apply the interpolated rotation matrices to the coordinates + next_crds = ( + np.einsum( + "lrij,laj->lrai", + all_rot_transitions, + xt[:, :3, :] - Ca_t.squeeze()[:, None, ...].numpy(), + ) + + Ca_t.squeeze()[:, None, None, ...].numpy() + ) + + # (L,3,3) set of backbone coordinates with slight rotation + return next_crds.squeeze(1) + + +def get_mu_xt_x0(xt, px0, t, beta_schedule, alphabar_schedule, eps=1e-6): + """ + Given xt, predicted x0 and the timestep t, give mu of x(t-1) + Assumes t is 0 indexed + """ + # sigma is predefined from beta. Often referred to as beta tilde t + t_idx = t - 1 + sigma = ( + (1 - alphabar_schedule[t_idx - 1]) / (1 - alphabar_schedule[t_idx]) + ) * beta_schedule[t_idx] + + xt_ca = xt[:, 1, :] + px0_ca = px0[:, 1, :] + + a = ( + (torch.sqrt(alphabar_schedule[t_idx - 1] + eps) * beta_schedule[t_idx]) + / (1 - alphabar_schedule[t_idx]) + ) * px0_ca + b = ( + ( + torch.sqrt(1 - beta_schedule[t_idx] + eps) + * (1 - alphabar_schedule[t_idx - 1]) + ) + / (1 - alphabar_schedule[t_idx]) + ) * xt_ca + + mu = a + b + + return mu, sigma + + +def get_next_ca( + xt, + px0, + t, + diffusion_mask, + crd_scale, + beta_schedule, + alphabar_schedule, + noise_scale=1.0, +): + """ + Given full atom x0 prediction (xyz coordinates), diffuse to x(t-1) + + Parameters: + + xt (L, 14/27, 3) set of coordinates + + px0 (L, 14/27, 3) set of coordinates + + t: time step. Note this is zero-index current time step, so are generating t-1 + + logits_aa (L x 20 ) amino acid probabilities at each position + + seq_schedule (L): Tensor of bools, True is unmasked, False is masked. For this specific t + + diffusion_mask (torch.tensor, required): Tensor of bools, True means NOT diffused at this residue, False means diffused + + noise_scale: scale factor for the noise being added + + """ + get_allatom = ComputeAllAtomCoords().to(device=xt.device) + L = len(xt) + + # bring to origin after global alignment (when don't have a motif) or replace input motif and bring to origin, and then scale + px0 = px0 * crd_scale + xt = xt * crd_scale + + # get mu(xt, x0) + mu, sigma = get_mu_xt_x0( + xt, px0, t, beta_schedule=beta_schedule, alphabar_schedule=alphabar_schedule + ) + + sampled_crds = torch.normal(mu, torch.sqrt(sigma * noise_scale)) + delta = sampled_crds - xt[:, 1, :] # check sign of this is correct + + if not diffusion_mask is None: + # Don't move motif + delta[diffusion_mask, ...] = 0 + + out_crds = xt + delta[:, None, :] + + return out_crds / crd_scale, delta / crd_scale + + +def get_noise_schedule(T, noiseT, noise1, schedule_type): + """ + Function to create a schedule that varies the scale of noise given to the model over time + + Parameters: + + T: The total number of timesteps in the denoising trajectory + + noiseT: The inital (t=T) noise scale + + noise1: The final (t=1) noise scale + + schedule_type: The type of function to use to interpolate between noiseT and noise1 + + Returns: + + noise_schedule: A function which maps timestep to noise scale + + """ + + noise_schedules = { + "constant": lambda t: noiseT, + "linear": lambda t: ((t - 1) / (T - 1)) * (noiseT - noise1) + noise1, + } + + assert ( + schedule_type in noise_schedules + ), f"noise_schedule must be one of {noise_schedules.keys()}. Received noise_schedule={schedule_type}. Exiting." + + return noise_schedules[schedule_type] + + +class Denoise: + """ + Class for getting x(t-1) from predicted x0 and x(t) + Strategy: + Ca coordinates: Rediffuse to x(t-1) from predicted x0 + Frames: Approximate update from rotation score + Torsions: 1/t of the way to the x0 prediction + + """ + + def __init__( + self, + T, + L, + diffuser, + b_0=0.001, + b_T=0.1, + min_b=1.0, + max_b=12.5, + min_sigma=0.05, + max_sigma=1.5, + noise_level=0.5, + schedule_type="linear", + so3_schedule_type="linear", + schedule_kwargs={}, + so3_type="igso3", + noise_scale_ca=1.0, + final_noise_scale_ca=1, + ca_noise_schedule_type="constant", + noise_scale_frame=0.5, + final_noise_scale_frame=0.5, + frame_noise_schedule_type="constant", + crd_scale=1 / 15, + potential_manager=None, + partial_T=None, + ): + """ + + Parameters: + noise_level: scaling on the noise added (set to 0 to use no noise, + to 1 to have full noise) + + """ + self.T = T + self.L = L + self.diffuser = diffuser + self.b_0 = b_0 + self.b_T = b_T + self.noise_level = noise_level + self.schedule_type = schedule_type + self.so3_type = so3_type + self.crd_scale = crd_scale + self.noise_scale_ca = noise_scale_ca + self.final_noise_scale_ca = final_noise_scale_ca + self.ca_noise_schedule_type = ca_noise_schedule_type + self.noise_scale_frame = noise_scale_frame + self.final_noise_scale_frame = final_noise_scale_frame + self.frame_noise_schedule_type = frame_noise_schedule_type + self.potential_manager = potential_manager + self._log = logging.getLogger(__name__) + + self.schedule, self.alpha_schedule, self.alphabar_schedule = get_beta_schedule( + self.T, self.b_0, self.b_T, self.schedule_type, inference=True + ) + + self.noise_schedule_ca = get_noise_schedule( + self.T, + self.noise_scale_ca, + self.final_noise_scale_ca, + self.ca_noise_schedule_type, + ) + self.noise_schedule_frame = get_noise_schedule( + self.T, + self.noise_scale_frame, + self.final_noise_scale_frame, + self.frame_noise_schedule_type, + ) + + @property + def idx2steps(self): + return self.decode_scheduler.idx2steps.numpy() + + def align_to_xt_motif(self, px0, xT, diffusion_mask, eps=1e-6): + """ + Need to align px0 to motif in xT. This is to permit the swapping of residue positions in the px0 motif for the true coordinates. + First, get rotation matrix from px0 to xT for the motif residues. + Second, rotate px0 (whole structure) by that rotation matrix + Third, centre at origin + """ + + def rmsd(V, W, eps=0): + # First sum down atoms, then sum down xyz + N = V.shape[-2] + return np.sqrt(np.sum((V - W) * (V - W), axis=(-2, -1)) / N + eps) + + assert ( + xT.shape[1] == px0.shape[1] + ), f"xT has shape {xT.shape} and px0 has shape {px0.shape}" + + L, n_atom, _ = xT.shape # A is number of atoms + atom_mask = ~torch.isnan(px0) + # convert to numpy arrays + px0 = px0.cpu().detach().numpy() + xT = xT.cpu().detach().numpy() + diffusion_mask = diffusion_mask.cpu().detach().numpy() + + # 1 centre motifs at origin and get rotation matrix + px0_motif = px0[diffusion_mask, :3].reshape(-1, 3) + xT_motif = xT[diffusion_mask, :3].reshape(-1, 3) + px0_motif_mean = np.copy(px0_motif.mean(0)) # need later + xT_motif_mean = np.copy(xT_motif.mean(0)) + + # center at origin + px0_motif = px0_motif - px0_motif_mean + xT_motif = xT_motif - xT_motif_mean + + # A = px0_motif + # B = xT_motif + A = xT_motif + B = px0_motif + + C = np.matmul(A.T, B) + + # compute optimal rotation matrix using SVD + U, S, Vt = np.linalg.svd(C) + + # ensure right handed coordinate system + d = np.eye(3) + d[-1, -1] = np.sign(np.linalg.det(Vt.T @ U.T)) + + # construct rotation matrix + R = Vt.T @ d @ U.T + + # get rotated coords + rB = B @ R + + # calculate rmsd + rms = rmsd(A, rB) + self._log.info(f"Sampled motif RMSD: {rms:.2f}") + + # 2 rotate whole px0 by rotation matrix + atom_mask = atom_mask.cpu() + px0[~atom_mask] = 0 # convert nans to 0 + px0 = px0.reshape(-1, 3) - px0_motif_mean + px0_ = px0 @ R + + # 3 put in same global position as xT + px0_ = px0_ + xT_motif_mean + px0_ = px0_.reshape([L, n_atom, 3]) + px0_[~atom_mask] = float("nan") + return torch.Tensor(px0_) + + def get_potential_gradients(self, xyz, diffusion_mask): + """ + This could be moved into potential manager if desired - NRB + + Function to take a structure (x) and get per-atom gradients used to guide diffusion update + + Inputs: + + xyz (torch.tensor, required): [L,27,3] Coordinates at which the gradient will be computed + + Outputs: + + Ca_grads (torch.tensor): [L,3] The gradient at each Ca atom + """ + + if self.potential_manager == None or self.potential_manager.is_empty(): + return torch.zeros(xyz.shape[0], 3) + + use_Cb = False + + # seq.requires_grad = True + xyz.requires_grad = True + + if not xyz.grad is None: + xyz.grad.zero_() + + current_potential = self.potential_manager.compute_all_potentials(xyz) + current_potential.backward() + + # Since we are not moving frames, Cb grads are same as Ca grads + # Need access to calculated Cb coordinates to be able to get Cb grads though + Ca_grads = xyz.grad[:, 1, :] + + if not diffusion_mask == None: + Ca_grads[diffusion_mask, :] = 0 + + # check for NaN's + if torch.isnan(Ca_grads).any(): + print("WARNING: NaN in potential gradients, replacing with zero grad.") + Ca_grads[:] = 0 + + return Ca_grads + + def get_next_pose( + self, + xt, + px0, + t, + diffusion_mask, + fix_motif=True, + align_motif=True, + include_motif_sidechains=True, + ): + """ + Wrapper function to take px0, xt and t, and to produce xt-1 + First, aligns px0 to xt + Then gets coordinates, frames and torsion angles + + Parameters: + + xt (torch.tensor, required): Current coordinates at timestep t + + px0 (torch.tensor, required): Prediction of x0 + + t (int, required): timestep t + + diffusion_mask (torch.tensor, required): Mask for structure diffusion + + fix_motif (bool): Fix the motif structure + + align_motif (bool): Align the model's prediction of the motif to the input motif + + include_motif_sidechains (bool): Provide sidechains of the fixed motif to the model + """ + + get_allatom = ComputeAllAtomCoords().to(device=xt.device) + L, n_atom = xt.shape[:2] + assert (xt.shape[1] == 14) or (xt.shape[1] == 27) + assert (px0.shape[1] == 14) or (px0.shape[1] == 27) + + ############################### + ### Align pX0 onto Xt motif ### + ############################### + + if align_motif and diffusion_mask.any(): + px0 = self.align_to_xt_motif(px0, xt, diffusion_mask) + # xT_motif_aligned = self.align_to_xt_motif(px0, xt, diffusion_mask) + + px0 = px0.to(xt.device) + # Now done with diffusion mask. if fix motif is False, just set diffusion mask to be all True, and all coordinates can diffuse + if not fix_motif: + diffusion_mask[:] = False + + # get the next set of CA coordinates + noise_scale_ca = self.noise_schedule_ca(t) + _, ca_deltas = get_next_ca( + xt, + px0, + t, + diffusion_mask, + crd_scale=self.crd_scale, + beta_schedule=self.schedule, + alphabar_schedule=self.alphabar_schedule, + noise_scale=noise_scale_ca, + ) + + # get the next set of backbone frames (coordinates) + noise_scale_frame = self.noise_schedule_frame(t) + frames_next = get_next_frames( + xt, + px0, + t, + diffuser=self.diffuser, + so3_type=self.so3_type, + diffusion_mask=diffusion_mask, + noise_scale=noise_scale_frame, + ) + + # Apply gradient step from guiding potentials + # This can be moved to below where the full atom representation is calculated to allow for potentials involving sidechains + + grad_ca = self.get_potential_gradients( + xt.clone(), diffusion_mask=diffusion_mask + ) + + ca_deltas += self.potential_manager.get_guide_scale(t) * grad_ca + + # add the delta to the new frames + frames_next = torch.from_numpy(frames_next) + ca_deltas[:, None, :] # translate + + fullatom_next = torch.full_like(xt, float("nan")).unsqueeze(0) + fullatom_next[:, :, :3] = frames_next[None] + # This is never used so just make it a fudged tensor - NRB + torsions_next = torch.zeros(1, 1) + + if include_motif_sidechains: + fullatom_next[:, diffusion_mask, :14] = xt[None, diffusion_mask] + + return fullatom_next.squeeze()[:, :14, :], px0 + + +def sampler_selector(conf: DictConfig): + if conf.scaffoldguided.scaffoldguided: + sampler = model_runners.ScaffoldedSampler(conf) + else: + if conf.inference.model_runner == "default": + sampler = model_runners.Sampler(conf) + elif conf.inference.model_runner == "SelfConditioning": + sampler = model_runners.SelfConditioning(conf) + elif conf.inference.model_runner == "ScaffoldedSampler": + sampler = model_runners.ScaffoldedSampler(conf) + else: + raise ValueError(f"Unrecognized sampler {conf.model_runner}") + return sampler + + +def parse_pdb(filename, **kwargs): + """extract xyz coords for all heavy atoms""" + with open(filename,"r") as f: + lines=f.readlines() + return parse_pdb_lines(lines, **kwargs) + + +def parse_pdb_lines(lines, parse_hetatom=False, ignore_het_h=True): + # indices of residues observed in the structure + res, pdb_idx = [],[] + for l in lines: + if l[:4] == "ATOM" and l[12:16].strip() == "CA": + res.append((l[22:26], l[17:20])) + # chain letter, res num + pdb_idx.append((l[21:22].strip(), int(l[22:26].strip()))) + seq = [util.aa2num[r[1]] if r[1] in util.aa2num.keys() else 20 for r in res] + pdb_idx = [ + (l[21:22].strip(), int(l[22:26].strip())) + for l in lines + if l[:4] == "ATOM" and l[12:16].strip() == "CA" + ] # chain letter, res num + + # 4 BB + up to 10 SC atoms + xyz = np.full((len(res), 14, 3), np.nan, dtype=np.float32) + for l in lines: + if l[:4] != "ATOM": + continue + chain, resNo, atom, aa = ( + l[21:22], + int(l[22:26]), + " " + l[12:16].strip().ljust(3), + l[17:20], + ) + if (chain,resNo) in pdb_idx: + idx = pdb_idx.index((chain, resNo)) + # for i_atm, tgtatm in enumerate(util.aa2long[util.aa2num[aa]]): + for i_atm, tgtatm in enumerate( + util.aa2long[util.aa2num[aa]][:14] + ): + if ( + tgtatm is not None and tgtatm.strip() == atom.strip() + ): # ignore whitespace + xyz[idx, i_atm, :] = [float(l[30:38]), float(l[38:46]), float(l[46:54])] + break + + # save atom mask + mask = np.logical_not(np.isnan(xyz[..., 0])) + xyz[np.isnan(xyz[..., 0])] = 0.0 + + # remove duplicated (chain, resi) + new_idx = [] + i_unique = [] + for i, idx in enumerate(pdb_idx): + if idx not in new_idx: + new_idx.append(idx) + i_unique.append(i) + + pdb_idx = new_idx + xyz = xyz[i_unique] + mask = mask[i_unique] + + seq = np.array(seq)[i_unique] + + out = { + "xyz": xyz, # cartesian coordinates, [Lx14] + "mask": mask, # mask showing which atoms are present in the PDB file, [Lx14] + "idx": np.array( + [i[1] for i in pdb_idx] + ), # residue numbers in the PDB file, [L] + "seq": np.array(seq), # amino acid sequence, [L] + "pdb_idx": pdb_idx, # list of (chain letter, residue number) in the pdb file, [L] + } + + # heteroatoms (ligands, etc) + if parse_hetatom: + xyz_het, info_het = [], [] + for l in lines: + if l[:6] == "HETATM" and not (ignore_het_h and l[77] == "H"): + info_het.append( + dict( + idx=int(l[7:11]), + atom_id=l[12:16], + atom_type=l[77], + name=l[16:20], + ) + ) + xyz_het.append([float(l[30:38]), float(l[38:46]), float(l[46:54])]) + + out["xyz_het"] = np.array(xyz_het) + out["info_het"] = info_het + + return out + + +def process_target(pdb_path, parse_hetatom=False, center=True): + # Read target pdb and extract features. + target_struct = parse_pdb(pdb_path, parse_hetatom=parse_hetatom) + + # Zero-center positions + ca_center = target_struct["xyz"][:, :1, :].mean(axis=0, keepdims=True) + if not center: + ca_center = 0 + xyz = torch.from_numpy(target_struct["xyz"] - ca_center) + seq_orig = torch.from_numpy(target_struct["seq"]) + atom_mask = torch.from_numpy(target_struct["mask"]) + seq_len = len(xyz) + + # Make 27 atom representation + xyz_27 = torch.full((seq_len, 27, 3), np.nan).float() + xyz_27[:, :14, :] = xyz[:, :14, :] + + mask_27 = torch.full((seq_len, 27), False) + mask_27[:, :14] = atom_mask + out = { + "xyz_27": xyz_27, + "mask_27": mask_27, + "seq": seq_orig, + "pdb_idx": target_struct["pdb_idx"], + } + if parse_hetatom: + out["xyz_het"] = target_struct["xyz_het"] + out["info_het"] = target_struct["info_het"] + return out + + +def get_idx0_hotspots(mappings, ppi_conf, binderlen): + """ + Take pdb-indexed hotspot resudes and the length of the binder, and makes the 0-indexed tensor of hotspots + """ + + hotspot_idx = None + if binderlen > 0: + if ppi_conf.hotspot_res is not None: + assert all( + [i[0].isalpha() for i in ppi_conf.hotspot_res] + ), "Hotspot residues need to be provided in pdb-indexed form. E.g. A100,A103" + hotspots = [(i[0], int(i[1:])) for i in ppi_conf.hotspot_res] + hotspot_idx = [] + for i, res in enumerate(mappings["receptor_con_ref_pdb_idx"]): + if res in hotspots: + hotspot_idx.append(mappings["receptor_con_hal_idx0"][i]) + return hotspot_idx + + +class BlockAdjacency: + """ + Class for handling PPI design inference with ss/block_adj inputs. + Basic idea is to provide a list of scaffolds, and to output ss and adjacency + matrices based off of these, while sampling additional lengths. + Inputs: + - scaffold_list: list of scaffolds (e.g. ['2kl8','1cif']). Can also be a .txt file. + - scaffold dir: directory where scaffold ss and adj are precalculated + - sampled_insertion: how many additional residues do you want to add to each loop segment? Randomly sampled 0-this number (or within given range) + - sampled_N: randomly sample up to this number of additional residues at N-term + - sampled_C: randomly sample up to this number of additional residues at C-term + - ss_mask: how many residues do you want to mask at either end of a ss (H or E) block. Fixed value + - num_designs: how many designs are you wanting to generate? Currently only used for bookkeeping + - systematic: do you want to systematically work through the list of scaffolds, or randomly sample (default) + - num_designs_per_input: Not really implemented yet. Maybe not necessary + Outputs: + - L: new length of chain to be diffused + - ss: all loops and insertions, and ends of ss blocks (up to ss_mask) set to mask token (3). Onehot encoded. (L,4) + - adj: block adjacency with equivalent masking as ss (L,L) + """ + + def __init__(self, conf, num_designs): + """ + Parameters: + inputs: + conf.scaffold_list as conf + conf.inference.num_designs for sanity checking + """ + + self.conf=conf + # either list or path to .txt file with list of scaffolds + if self.conf.scaffoldguided.scaffold_list is not None: + if type(self.conf.scaffoldguided.scaffold_list) == list: + self.scaffold_list = scaffold_list + elif self.conf.scaffoldguided.scaffold_list[-4:] == ".txt": + # txt file with list of ids + list_from_file = [] + with open(self.conf.scaffoldguided.scaffold_list, "r") as f: + for line in f: + list_from_file.append(line.strip()) + self.scaffold_list = list_from_file + else: + raise NotImplementedError + else: + self.scaffold_list = [ + os.path.split(i)[1][:-6] + for i in glob.glob(f"{self.conf.scaffoldguided.scaffold_dir}/*_ss.pt") + ] + self.scaffold_list.sort() + + # path to directory with scaffolds, ss files and block_adjacency files + self.scaffold_dir = self.conf.scaffoldguided.scaffold_dir + + # maximum sampled insertion in each loop segment + if "-" in str(self.conf.scaffoldguided.sampled_insertion): + self.sampled_insertion = [ + int(str(self.conf.scaffoldguided.sampled_insertion).split("-")[0]), + int(str(self.conf.scaffoldguided.sampled_insertion).split("-")[1]), + ] + else: + self.sampled_insertion = [0, int(self.conf.scaffoldguided.sampled_insertion)] + + # maximum sampled insertion at N- and C-terminus + if "-" in str(self.conf.scaffoldguided.sampled_N): + self.sampled_N = [ + int(str(self.conf.scaffoldguided.sampled_N).split("-")[0]), + int(str(self.conf.scaffoldguided.sampled_N).split("-")[1]), + ] + else: + self.sampled_N = [0, int(self.conf.scaffoldguided.sampled_N)] + if "-" in str(self.conf.scaffoldguided.sampled_C): + self.sampled_C = [ + int(str(self.conf.scaffoldguided.sampled_C).split("-")[0]), + int(str(self.conf.scaffoldguided.sampled_C).split("-")[1]), + ] + else: + self.sampled_C = [0, int(self.conf.scaffoldguided.sampled_C)] + + # number of residues to mask ss identity of in H/E regions (from junction) + # e.g. if ss_mask = 2, L,L,L,H,H,H,H,H,H,H,L,L,E,E,E,E,E,E,L,L,L,L,L,L would become\ + # M,M,M,M,M,H,H,H,M,M,M,M,M,M,E,E,M,M,M,M,M,M,M,M where M is mask + self.ss_mask = self.conf.scaffoldguided.ss_mask + + # whether or not to work systematically through the list + self.systematic = self.conf.scaffoldguided.systematic + + self.num_designs = num_designs + + if len(self.scaffold_list) > self.num_designs: + print( + "WARNING: Scaffold set is bigger than num_designs, so not every scaffold type will be sampled" + ) + + # for tracking number of designs + self.num_completed = 0 + if self.systematic: + self.item_n = 0 + + # whether to mask loops or not + if not self.conf.scaffoldguided.mask_loops: + assert self.conf.scaffoldguided.sampled_N == 0, "can't add length if not masking loops" + assert self.conf.scaffoldguided.sampled_C == 0, "can't add lemgth if not masking loops" + assert self.conf.scaffoldguided.sampled_insertion == 0, "can't add length if not masking loops" + self.mask_loops = False + else: + self.mask_loops = True + + def get_ss_adj(self, item): + """ + Given at item, get the ss tensor and block adjacency matrix for that item + """ + ss = torch.load(os.path.join(self.scaffold_dir, f'{item.split(".")[0]}_ss.pt')) + adj = torch.load( + os.path.join(self.scaffold_dir, f'{item.split(".")[0]}_adj.pt') + ) + + return ss, adj + + def mask_to_segments(self, mask): + """ + Takes a mask of True (loop) and False (non-loop), and outputs list of tuples (loop or not, length of element) + """ + segments = [] + begin = -1 + end = -1 + for i in range(mask.shape[0]): + # Starting edge case + if i == 0: + begin = 0 + continue + + if not mask[i] == mask[i - 1]: + end = i + if mask[i - 1].item() is True: + segments.append(("loop", end - begin)) + else: + segments.append(("ss", end - begin)) + begin = i + + # Ending edge case: last segment is length one + if not end == mask.shape[0]: + if mask[i].item() is True: + segments.append(("loop", mask.shape[0] - begin)) + else: + segments.append(("ss", mask.shape[0] - begin)) + return segments + + def expand_mask(self, mask, segments): + """ + Function to generate a new mask with dilated loops and N and C terminal additions + """ + N_add = random.randint(self.sampled_N[0], self.sampled_N[1]) + C_add = random.randint(self.sampled_C[0], self.sampled_C[1]) + + output = N_add * [False] + for ss, length in segments: + if ss == "ss": + output.extend(length * [True]) + else: + # randomly sample insertion length + ins = random.randint( + self.sampled_insertion[0], self.sampled_insertion[1] + ) + output.extend((length + ins) * [False]) + output.extend(C_add * [False]) + assert torch.sum(torch.tensor(output)) == torch.sum(~mask) + return torch.tensor(output) + + def expand_ss(self, ss, adj, mask, expanded_mask): + """ + Given an expanded mask, populate a new ss and adj based on this + """ + ss_out = torch.ones(expanded_mask.shape[0]) * 3 # set to mask token + adj_out = torch.full((expanded_mask.shape[0], expanded_mask.shape[0]), 0.0) + ss_out[expanded_mask] = ss[~mask] + expanded_mask_2d = torch.full(adj_out.shape, True) + # mask out loops/insertions, which is ~expanded_mask + expanded_mask_2d[~expanded_mask, :] = False + expanded_mask_2d[:, ~expanded_mask] = False + + mask_2d = torch.full(adj.shape, True) + # mask out loops. This mask is True=loop + mask_2d[mask, :] = False + mask_2d[:, mask] = False + adj_out[expanded_mask_2d] = adj[mask_2d] + adj_out = adj_out.reshape((expanded_mask.shape[0], expanded_mask.shape[0])) + + return ss_out, adj_out + + def mask_ss_adj(self, ss, adj, expanded_mask): + """ + Given an expanded ss and adj, mask some number of residues at either end of non-loop ss + """ + original_mask = torch.clone(expanded_mask) + if self.ss_mask > 0: + for i in range(1, self.ss_mask + 1): + expanded_mask[i:] *= original_mask[:-i] + expanded_mask[:-i] *= original_mask[i:] + + if self.mask_loops: + ss[~expanded_mask] = 3 + adj[~expanded_mask, :] = 0 + adj[:, ~expanded_mask] = 0 + + # mask adjacency + adj[~expanded_mask] = 2 + adj[:, ~expanded_mask] = 2 + + return ss, adj + + def get_scaffold(self): + """ + Wrapper method for pulling an item from the list, and preparing ss and block adj features + """ + + # Handle determinism. Useful for integration tests + if self.conf.inference.deterministic: + torch.manual_seed(self.num_completed) + np.random.seed(self.num_completed) + random.seed(self.num_completed) + + if self.systematic: + # reset if num designs > num_scaffolds + if self.item_n >= len(self.scaffold_list): + self.item_n = 0 + item = self.scaffold_list[self.item_n] + self.item_n += 1 + else: + item = random.choice(self.scaffold_list) + print("Scaffold constrained based on file: ", item) + # load files + ss, adj = self.get_ss_adj(item) + adj_orig = torch.clone(adj) + # separate into segments (loop or not) + mask = torch.where(ss == 2, 1, 0).bool() + segments = self.mask_to_segments(mask) + + # insert into loops to generate new mask + expanded_mask = self.expand_mask(mask, segments) + + # expand ss and adj + ss, adj = self.expand_ss(ss, adj, mask, expanded_mask) + + # finally, mask some proportion of the ss at either end of the non-loop ss blocks + ss, adj = self.mask_ss_adj(ss, adj, expanded_mask) + + # and then update num_completed + self.num_completed += 1 + + return ss.shape[0], torch.nn.functional.one_hot(ss.long(), num_classes=4), adj + + +class Target: + """ + Class to handle targets (fixed chains). + Inputs: + - path to pdb file + - hotspot residues, in the form B10,B12,B60 etc + - whether or not to crop, and with which method + Outputs: + - Dictionary of xyz coordinates, indices, pdb_indices, pdb mask + """ + + def __init__(self, conf: DictConfig, hotspots=None): + self.pdb = parse_pdb(conf.target_path) + + if hotspots is not None: + self.hotspots = hotspots + else: + self.hotspots = [] + self.pdb["hotspots"] = np.array( + [ + True if f"{i[0]}{i[1]}" in self.hotspots else False + for i in self.pdb["pdb_idx"] + ] + ) + + if conf.contig_crop: + self.contig_crop(conf.contig_crop) + + def parse_contig(self, contig_crop): + """ + Takes contig input and parses + """ + contig_list = [] + for contig in contig_crop[0].split(" "): + subcon = [] + for crop in contig.split("/"): + if crop[0].isalpha(): + subcon.extend( + [ + (crop[0], p) + for p in np.arange( + int(crop.split("-")[0][1:]), int(crop.split("-")[1]) + 1 + ) + ] + ) + contig_list.append(subcon) + return contig_list + + def contig_crop(self, contig_crop, residue_offset=200) -> None: + """ + Method to take a contig string referring to the receptor and output a pdb dictionary with just this crop + NB there are two ways to provide inputs: + - 1) e.g. B1-30,0 B50-60,0. This will add a residue offset between each chunk + - 2) e.g. B1-30,B50-60,B80-100. This will keep the original indexing of the pdb file. + Can handle the target being on multiple chains + """ + + # add residue offset between chains if multiple chains in receptor file + for idx, val in enumerate(self.pdb["pdb_idx"]): + if idx != 0 and val != self.pdb["pdb_idx"][idx - 1]: + self.pdb["idx"][idx:] += residue_offset + idx + + # convert contig to mask + contig_list = self.parse_contig(contig_crop) + + # add residue offset to different parts of contig_list + for contig in contig_list[1:]: + start = int(contig[0][1]) + self.pdb["idx"][start:] += residue_offset + # flatten list + contig_list = [i for j in contig_list for i in j] + mask = np.array( + [True if i in contig_list else False for i in self.pdb["pdb_idx"]] + ) + + # sanity check + assert np.sum(self.pdb["hotspots"]) == np.sum( + self.pdb["hotspots"][mask] + ), "Supplied hotspot residues are missing from the target contig!" + # crop pdb + for key, val in self.pdb.items(): + try: + self.pdb[key] = val[mask] + except: + self.pdb[key] = [i for idx, i in enumerate(val) if mask[idx]] + self.pdb["crop_mask"] = mask + + def get_target(self): + return self.pdb diff --git a/rfdiffusion/kinematics.py b/rfdiffusion/kinematics.py new file mode 100644 index 0000000000000000000000000000000000000000..8d548394ce6f3819297812c622646aa02e4252d6 --- /dev/null +++ b/rfdiffusion/kinematics.py @@ -0,0 +1,309 @@ +import numpy as np +import torch +from rfdiffusion.chemical import INIT_CRDS +from rfdiffusion.util import generate_Cbeta + +PARAMS = { + "DMIN" : 2.0, + "DMAX" : 20.0, + "DBINS" : 36, + "ABINS" : 36, +} + +# ============================================================ +def get_pair_dist(a, b): + """calculate pair distances between two sets of points + + Parameters + ---------- + a,b : pytorch tensors of shape [batch,nres,3] + store Cartesian coordinates of two sets of atoms + Returns + ------- + dist : pytorch tensor of shape [batch,nres,nres] + stores paitwise distances between atoms in a and b + """ + + dist = torch.cdist(a, b, p=2) + return dist + +# ============================================================ +def get_ang(a, b, c): + """calculate planar angles for all consecutive triples (a[i],b[i],c[i]) + from Cartesian coordinates of three sets of atoms a,b,c + + Parameters + ---------- + a,b,c : pytorch tensors of shape [batch,nres,3] + store Cartesian coordinates of three sets of atoms + Returns + ------- + ang : pytorch tensor of shape [batch,nres] + stores resulting planar angles + """ + v = a - b + w = c - b + v /= torch.norm(v, dim=-1, keepdim=True) + w /= torch.norm(w, dim=-1, keepdim=True) + vw = torch.sum(v*w, dim=-1) + + return torch.acos(vw) + +# ============================================================ +def get_dih(a, b, c, d): + """calculate dihedral angles for all consecutive quadruples (a[i],b[i],c[i],d[i]) + given Cartesian coordinates of four sets of atoms a,b,c,d + + Parameters + ---------- + a,b,c,d : pytorch tensors or numpy array of shape [batch,nres,3] + store Cartesian coordinates of four sets of atoms + Returns + ------- + dih : pytorch tensor or numpy array of shape [batch,nres] + stores resulting dihedrals + """ + convert_to_torch = lambda *arrays: [torch.from_numpy(arr) for arr in arrays] + output_np=False + if isinstance(a, np.ndarray): + output_np=True + a,b,c,d = convert_to_torch(a,b,c,d) + b0 = a - b + b1 = c - b + b2 = d - c + + b1 /= torch.norm(b1, dim=-1, keepdim=True) + + v = b0 - torch.sum(b0*b1, dim=-1, keepdim=True)*b1 + w = b2 - torch.sum(b2*b1, dim=-1, keepdim=True)*b1 + + x = torch.sum(v*w, dim=-1) + y = torch.sum(torch.cross(b1,v,dim=-1)*w, dim=-1) + output = torch.atan2(y, x) + if output_np: + return output.numpy() + return output + +# ============================================================ +def xyz_to_c6d(xyz, params=PARAMS): + """convert cartesian coordinates into 2d distance + and orientation maps + + Parameters + ---------- + xyz : pytorch tensor of shape [batch,nres,3,3] + stores Cartesian coordinates of backbone N,Ca,C atoms + Returns + ------- + c6d : pytorch tensor of shape [batch,nres,nres,4] + stores stacked dist,omega,theta,phi 2D maps + """ + + batch = xyz.shape[0] + nres = xyz.shape[1] + + # three anchor atoms + N = xyz[:,:,0] + Ca = xyz[:,:,1] + C = xyz[:,:,2] + Cb = generate_Cbeta(N, Ca, C) + + # 6d coordinates order: (dist,omega,theta,phi) + c6d = torch.zeros([batch,nres,nres,4],dtype=xyz.dtype,device=xyz.device) + + dist = get_pair_dist(Cb,Cb) + dist[torch.isnan(dist)] = 999.9 + c6d[...,0] = dist + 999.9*torch.eye(nres,device=xyz.device)[None,...] + b,i,j = torch.where(c6d[...,0]=params['DMAX']] = 999.9 + + mask = torch.zeros((batch, nres,nres), dtype=xyz.dtype, device=xyz.device) + mask[b,i,j] = 1.0 + return c6d, mask + +def xyz_to_t2d(xyz_t, params=PARAMS): + """convert template cartesian coordinates into 2d distance + and orientation maps + + Parameters + ---------- + xyz_t : pytorch tensor of shape [batch,templ,nres,3,3] + stores Cartesian coordinates of template backbone N,Ca,C atoms + + Returns + ------- + t2d : pytorch tensor of shape [batch,nres,nres,37+6+3] + stores stacked dist,omega,theta,phi 2D maps + """ + B, T, L = xyz_t.shape[:3] + c6d, mask = xyz_to_c6d(xyz_t[:,:,:,:3].view(B*T,L,3,3), params=params) + c6d = c6d.view(B, T, L, L, 4) + mask = mask.view(B, T, L, L, 1) + # + # dist to one-hot encoded + dist = dist_to_onehot(c6d[...,0], params) + orien = torch.cat((torch.sin(c6d[...,1:]), torch.cos(c6d[...,1:])), dim=-1)*mask # (B, T, L, L, 6) + # + mask = ~torch.isnan(c6d[:,:,:,:,0]) # (B, T, L, L) + t2d = torch.cat((dist, orien, mask.unsqueeze(-1)), dim=-1) + t2d[torch.isnan(t2d)] = 0.0 + return t2d + +def xyz_to_chi1(xyz_t): + '''convert template cartesian coordinates into chi1 angles + + Parameters + ---------- + xyz_t: pytorch tensor of shape [batch, templ, nres, 14, 3] + stores Cartesian coordinates of template atoms. For missing atoms, it should be NaN + + Returns + ------- + chi1 : pytorch tensor of shape [batch, templ, nres, 2] + stores cos and sin chi1 angle + ''' + B, T, L = xyz_t.shape[:3] + xyz_t = xyz_t.reshape(B*T, L, 14, 3) + + # chi1 angle: N, CA, CB, CG + chi1 = get_dih(xyz_t[:,:,0], xyz_t[:,:,1], xyz_t[:,:,4], xyz_t[:,:,5]) # (B*T, L) + cos_chi1 = torch.cos(chi1) + sin_chi1 = torch.sin(chi1) + mask_chi1 = ~torch.isnan(chi1) + chi1 = torch.stack((cos_chi1, sin_chi1, mask_chi1), dim=-1) # (B*T, L, 3) + chi1[torch.isnan(chi1)] = 0.0 + chi1 = chi1.reshape(B, T, L, 3) + return chi1 + +def xyz_to_bbtor(xyz, params=PARAMS): + batch = xyz.shape[0] + nres = xyz.shape[1] + + # three anchor atoms + N = xyz[:,:,0] + Ca = xyz[:,:,1] + C = xyz[:,:,2] + + # recreate Cb given N,Ca,C + next_N = torch.roll(N, -1, dims=1) + prev_C = torch.roll(C, 1, dims=1) + phi = get_dih(prev_C, N, Ca, C) + psi = get_dih(N, Ca, C, next_N) + # + phi[:,0] = 0.0 + psi[:,-1] = 0.0 + # + astep = 2.0*np.pi / params['ABINS'] + phi_bin = torch.round((phi+np.pi-astep/2)/astep) + psi_bin = torch.round((psi+np.pi-astep/2)/astep) + return torch.stack([phi_bin, psi_bin], axis=-1).long() + +# ============================================================ +def dist_to_onehot(dist, params=PARAMS): + dist[torch.isnan(dist)] = 999.9 + dstep = (params['DMAX'] - params['DMIN']) / params['DBINS'] + dbins = torch.linspace(params['DMIN']+dstep, params['DMAX'], params['DBINS'],dtype=dist.dtype,device=dist.device) + db = torch.bucketize(dist.contiguous(),dbins).long() + dist = torch.nn.functional.one_hot(db, num_classes=params['DBINS']+1).float() + return dist + +def c6d_to_bins(c6d,params=PARAMS): + """bin 2d distance and orientation maps + """ + + dstep = (params['DMAX'] - params['DMIN']) / params['DBINS'] + astep = 2.0*np.pi / params['ABINS'] + + dbins = torch.linspace(params['DMIN']+dstep, params['DMAX'], params['DBINS'],dtype=c6d.dtype,device=c6d.device) + ab360 = torch.linspace(-np.pi+astep, np.pi, params['ABINS'],dtype=c6d.dtype,device=c6d.device) + ab180 = torch.linspace(astep, np.pi, params['ABINS']//2,dtype=c6d.dtype,device=c6d.device) + + db = torch.bucketize(c6d[...,0].contiguous(),dbins) + ob = torch.bucketize(c6d[...,1].contiguous(),ab360) + tb = torch.bucketize(c6d[...,2].contiguous(),ab360) + pb = torch.bucketize(c6d[...,3].contiguous(),ab180) + + ob[db==params['DBINS']] = params['ABINS'] + tb[db==params['DBINS']] = params['ABINS'] + pb[db==params['DBINS']] = params['ABINS']//2 + + return torch.stack([db,ob,tb,pb],axis=-1).to(torch.uint8) + + +# ============================================================ +def dist_to_bins(dist,params=PARAMS): + """bin 2d distance maps + """ + + dstep = (params['DMAX'] - params['DMIN']) / params['DBINS'] + db = torch.round((dist-params['DMIN']-dstep/2)/dstep) + + db[db<0] = 0 + db[db>params['DBINS']] = params['DBINS'] + + return db.long() + + +# ============================================================ +def c6d_to_bins2(c6d, same_chain, negative=False, params=PARAMS): + """bin 2d distance and orientation maps + """ + + dstep = (params['DMAX'] - params['DMIN']) / params['DBINS'] + astep = 2.0*np.pi / params['ABINS'] + + db = torch.round((c6d[...,0]-params['DMIN']-dstep/2)/dstep) + ob = torch.round((c6d[...,1]+np.pi-astep/2)/astep) + tb = torch.round((c6d[...,2]+np.pi-astep/2)/astep) + pb = torch.round((c6d[...,3]-astep/2)/astep) + + # put all dparams['DBINS']] = params['DBINS'] + ob[db==params['DBINS']] = params['ABINS'] + tb[db==params['DBINS']] = params['ABINS'] + pb[db==params['DBINS']] = params['ABINS']//2 + + if negative: + db = torch.where(same_chain.bool(), db.long(), params['DBINS']) + ob = torch.where(same_chain.bool(), ob.long(), params['ABINS']) + tb = torch.where(same_chain.bool(), tb.long(), params['ABINS']) + pb = torch.where(same_chain.bool(), pb.long(), params['ABINS']//2) + + return torch.stack([db,ob,tb,pb],axis=-1).long() + +def get_init_xyz(xyz_t): + # input: xyz_t (B, T, L, 14, 3) + # ouput: xyz (B, T, L, 14, 3) + B, T, L = xyz_t.shape[:3] + init = INIT_CRDS.to(xyz_t.device).reshape(1,1,1,27,3).repeat(B,T,L,1,1) + if torch.isnan(xyz_t).all(): + return init + + mask = torch.isnan(xyz_t[:,:,:,:3]).any(dim=-1).any(dim=-1) # (B, T, L) + # + center_CA = ((~mask[:,:,:,None]) * torch.nan_to_num(xyz_t[:,:,:,1,:])).sum(dim=2) / ((~mask[:,:,:,None]).sum(dim=2)+1e-4) # (B, T, 3) + xyz_t = xyz_t - center_CA.view(B,T,1,1,3) + # + idx_s = list() + for i_b in range(B): + for i_T in range(T): + if mask[i_b, i_T].all(): + continue + exist_in_templ = torch.where(~mask[i_b, i_T])[0] # (L_sub) + seqmap = (torch.arange(L, device=xyz_t.device)[:,None] - exist_in_templ[None,:]).abs() # (L, L_sub) + seqmap = torch.argmin(seqmap, dim=-1) # (L) + idx = torch.gather(exist_in_templ, -1, seqmap) # (L) + offset_CA = torch.gather(xyz_t[i_b, i_T, :, 1, :], 0, idx.reshape(L,1).expand(-1,3)) + init[i_b,i_T] += offset_CA.reshape(L,1,3) + # + xyz = torch.where(mask.view(B, T, L, 1, 1), init, xyz_t) + return xyz diff --git a/rfdiffusion/model_input_logger.py b/rfdiffusion/model_input_logger.py new file mode 100644 index 0000000000000000000000000000000000000000..bb41433252913c7cc6e7f289d18453fbc665be6e --- /dev/null +++ b/rfdiffusion/model_input_logger.py @@ -0,0 +1,71 @@ +import traceback +import os +from inspect import signature +import pickle +import datetime + +def pickle_function_call_wrapper(func, output_dir='pickled_inputs'): + i = 0 + os.makedirs(output_dir) + # pickle.dump({'args': args, 'kwargs': kwargs}, fh) + def wrapper(*args, **kwargs): + """ + Wrap the original function call to print the arguments before + calling the intended function + """ + nonlocal i + i += 1 + func_sig = signature(func) + # Create the argument binding so we can determine what + # parameters are given what values + argument_binding = func_sig.bind(*args, **kwargs) + argument_map = argument_binding.arguments + + # Perform the print so that it shows the function name + # and arguments as a dictionary + path = os.path.join(output_dir, f'{i:05d}.pkl') + print(f"logging {func.__name__} arguments: {[k for k in argument_map]} to {path}") + argument_map['stack'] = traceback.format_stack() + + for k, v in argument_map.items(): + if hasattr(v, 'detach'): + argument_map[k] = v.cpu().detach() + with open(path, 'wb') as fh: + pickle.dump(argument_map, fh) + + return func(*args, **kwargs) + + return wrapper + +def wrap_it(wrapper, instance, method, **kwargs): + class_method = getattr(instance, method) + wrapped_method = wrapper(class_method, **kwargs) + setattr(instance, method, wrapped_method) + + + +def pickle_function_call(instance, method, subdir): + output_dir = os.path.join(os.getcwd(), 'pickled_inputs', subdir, datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")) + wrap_it(pickle_function_call_wrapper, instance, method, output_dir=output_dir) + return output_dir + +# For testing +if __name__=='__main__': + import glob + class Dog: + def __init__(self, name): + self.name = name + def bark(self, arg, kwarg=None): + print(f'{self.name}:{arg}:{kwarg}') + + dog = Dog('fido') + dog.bark('ruff') + + output_dir = pickle_function_call(dog, 'bark', 'debugging') + + dog.bark('ruff', kwarg='wooof') + + for p in glob.glob(os.path.join(output_dir, '*')): + print(p) + with open(p, 'rb') as fh: + print(pickle.load(fh)) diff --git a/rfdiffusion/potentials/__init__.py b/rfdiffusion/potentials/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/rfdiffusion/potentials/__pycache__/__init__.cpython-311.pyc b/rfdiffusion/potentials/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b66d02b1932e009cdfccae16de7e18021a6fb2e Binary files /dev/null and b/rfdiffusion/potentials/__pycache__/__init__.cpython-311.pyc differ diff --git a/rfdiffusion/potentials/__pycache__/__init__.cpython-39.pyc b/rfdiffusion/potentials/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d000b0a40c8edd2bb140d4ae330d3b73bb94ee19 Binary files /dev/null and b/rfdiffusion/potentials/__pycache__/__init__.cpython-39.pyc differ diff --git a/rfdiffusion/potentials/__pycache__/manager.cpython-311.pyc b/rfdiffusion/potentials/__pycache__/manager.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f56a663ee3681b832cd5cf691ff83ec56ee99127 Binary files /dev/null and b/rfdiffusion/potentials/__pycache__/manager.cpython-311.pyc differ diff --git a/rfdiffusion/potentials/__pycache__/manager.cpython-39.pyc b/rfdiffusion/potentials/__pycache__/manager.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c73d614dec21c7664948c784117c6caf6c66ddb Binary files /dev/null and b/rfdiffusion/potentials/__pycache__/manager.cpython-39.pyc differ diff --git a/rfdiffusion/potentials/__pycache__/potentials.cpython-311.pyc b/rfdiffusion/potentials/__pycache__/potentials.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1421d56dbfe3b2fbc60cd7aa2b0ca7f68269fdc2 Binary files /dev/null and b/rfdiffusion/potentials/__pycache__/potentials.cpython-311.pyc differ diff --git a/rfdiffusion/potentials/__pycache__/potentials.cpython-39.pyc b/rfdiffusion/potentials/__pycache__/potentials.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4a86c456946ddd9bb1a9d7a8253669d21b5749e Binary files /dev/null and b/rfdiffusion/potentials/__pycache__/potentials.cpython-39.pyc differ diff --git a/rfdiffusion/potentials/manager.py b/rfdiffusion/potentials/manager.py new file mode 100644 index 0000000000000000000000000000000000000000..98b0d5eb7dc8f17298edc0a869e1694558019e55 --- /dev/null +++ b/rfdiffusion/potentials/manager.py @@ -0,0 +1,208 @@ +import torch +from rfdiffusion.potentials import potentials as potentials +import numpy as np + + +def make_contact_matrix(nchain, intra_all=False, inter_all=False, contact_string=None): + """ + Calculate a matrix of inter/intra chain contact indicators + + Parameters: + nchain (int, required): How many chains are in this design + + contact_str (str, optional): String denoting how to define contacts, comma delimited between pairs of chains + '!' denotes repulsive, '&' denotes attractive + """ + alphabet = [a for a in 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'] + letter2num = {a:i for i,a in enumerate(alphabet)} + + contacts = np.zeros((nchain,nchain)) + written = np.zeros((nchain,nchain)) + + + # intra_all - everything on the diagonal has contact potential + if intra_all: + contacts[np.arange(nchain),np.arange(nchain)] = 1 + + # inter all - everything off the diagonal has contact potential + if inter_all: + mask2d = np.full_like(contacts,False) + for i in range(len(contacts)): + for j in range(len(contacts)): + if i!=j: + mask2d[i,j] = True + + contacts[mask2d.astype(bool)] = 1 + + + # custom contacts/repulsions from user + if contact_string != None: + contact_list = contact_string.split(',') + for c in contact_list: + assert len(c) == 3 + i,j = letter2num[c[0]],letter2num[c[2]] + + symbol = c[1] + + assert symbol in ['!','&'] + if symbol == '!': + contacts[i,j] = -1 + contacts[j,i] = -1 + else: + contacts[i,j] = 1 + contacts[j,i] = 1 + + return contacts + + +def calc_nchains(symbol, components=1): + """ + Calculates number of chains for given symmetry + """ + S = symbol.lower() + + if S.startswith('c'): + return int(S[1:])*components + elif S.startswith('d'): + return 2*int(S[1:])*components + elif S.startswith('o'): + raise NotImplementedError() + elif S.startswith('t'): + return 12*components + else: + raise RuntimeError('Unknown symmetry symbol ',S) + + +class PotentialManager: + ''' + Class to define a set of potentials from the given config object and to apply all of the specified potentials + during each cycle of the inference loop. + + Author: NRB + ''' + + def __init__(self, + potentials_config, + ppi_config, + diffuser_config, + inference_config, + hotspot_0idx, + binderlen, + ): + + self.potentials_config = potentials_config + self.ppi_config = ppi_config + self.inference_config = inference_config + + self.guide_scale = potentials_config.guide_scale + self.guide_decay = potentials_config.guide_decay + + if potentials_config.guiding_potentials is None: + setting_list = [] + else: + setting_list = [self.parse_potential_string(potstr) for potstr in potentials_config.guiding_potentials] + + + # PPI potentials require knowledge about the binderlen which may be detected at runtime + # This is a mechanism to still allow this info to be used in potentials - NRB + if binderlen > 0: + binderlen_update = { 'binderlen': binderlen } + hotspot_res_update = { 'hotspot_res': hotspot_0idx } + + for setting in setting_list: + if setting['type'] in potentials.require_binderlen: + setting.update(binderlen_update) + + self.potentials_to_apply = self.initialize_all_potentials(setting_list) + self.T = diffuser_config.T + + def is_empty(self): + ''' + Check whether this instance of PotentialManager actually contains any potentials + ''' + + return len(self.potentials_to_apply) == 0 + + def parse_potential_string(self, potstr): + ''' + Parse a single entry in the list of potentials to be run to a dictionary of settings for that potential. + + An example of how this parsing is done: + 'setting1:val1,setting2:val2,setting3:val3' -> {setting1:val1,setting2:val2,setting3:val3} + ''' + + setting_dict = {entry.split(':')[0]:entry.split(':')[1] for entry in potstr.split(',')} + + for key in setting_dict: + if not key == 'type': setting_dict[key] = float(setting_dict[key]) + + return setting_dict + + def initialize_all_potentials(self, setting_list): + ''' + Given a list of potential dictionaries where each dictionary defines the configurations for a single potential, + initialize all potentials and add to the list of potentials to be applies + ''' + + to_apply = [] + + for potential_dict in setting_list: + assert(potential_dict['type'] in potentials.implemented_potentials), f'potential with name: {potential_dict["type"]} is not one of the implemented potentials: {potentials.implemented_potentials.keys()}' + + kwargs = {k: potential_dict[k] for k in potential_dict.keys() - {'type'}} + + # symmetric oligomer contact potential args + if self.inference_config.symmetry: + + num_chains = calc_nchains(symbol=self.inference_config.symmetry, components=1) # hard code 1 for now + contact_kwargs={'nchain':num_chains, + 'intra_all':self.potentials_config.olig_intra_all, + 'inter_all':self.potentials_config.olig_inter_all, + 'contact_string':self.potentials_config.olig_custom_contact } + contact_matrix = make_contact_matrix(**contact_kwargs) + kwargs.update({'contact_matrix':contact_matrix}) + + + to_apply.append(potentials.implemented_potentials[potential_dict['type']](**kwargs)) + + return to_apply + + def compute_all_potentials(self, xyz): + ''' + This is the money call. Take the current sequence and structure information and get the sum of all of the potentials that are being used + ''' + + potential_list = [potential.compute(xyz) for potential in self.potentials_to_apply] + potential_stack = torch.stack(potential_list, dim=0) + + return torch.sum(potential_stack, dim=0) + + def get_guide_scale(self, t): + ''' + Given a timestep and a decay type, get the appropriate scale factor to use for applying guiding potentials + + Inputs: + + t (int, required): The current timestep + + Output: + + scale (int): The scale factor to use for applying guiding potentials + + ''' + + implemented_decay_types = { + 'constant': lambda t: self.guide_scale, + # Linear interpolation with y2: 0, y1: guide_scale, x2: 0, x1: T, x: t + 'linear' : lambda t: t/self.T * self.guide_scale, + 'quadratic' : lambda t: t**2/self.T**2 * self.guide_scale, + 'cubic' : lambda t: t**3/self.T**3 * self.guide_scale + } + + if self.guide_decay not in implemented_decay_types: + sys.exit(f'decay_type must be one of {implemented_decay_types.keys()}. Received decay_type={self.guide_decay}. Exiting.') + + return implemented_decay_types[self.guide_decay](t) + + + diff --git a/rfdiffusion/potentials/potentials.py b/rfdiffusion/potentials/potentials.py new file mode 100644 index 0000000000000000000000000000000000000000..b43a2a6b5425e51f0da8c5e233216412ad32ab8d --- /dev/null +++ b/rfdiffusion/potentials/potentials.py @@ -0,0 +1,475 @@ +import torch +import numpy as np +from rfdiffusion.util import generate_Cbeta + +class Potential: + ''' + Interface class that defines the functions a potential must implement + ''' + + def compute(self, xyz): + ''' + Given the current structure of the model prediction, return the current + potential as a PyTorch tensor with a single entry + + Args: + xyz (torch.tensor, size: [L,27,3]: The current coordinates of the sample + + Returns: + potential (torch.tensor, size: [1]): A potential whose value will be MAXIMIZED + by taking a step along it's gradient + ''' + raise NotImplementedError('Potential compute function was not overwritten') + +class monomer_ROG(Potential): + ''' + Radius of Gyration potential for encouraging monomer compactness + + Written by DJ and refactored into a class by NRB + ''' + + def __init__(self, weight=1, min_dist=15): + + self.weight = weight + self.min_dist = min_dist + + def compute(self, xyz): + Ca = xyz[:,1] # [L,3] + + centroid = torch.mean(Ca, dim=0, keepdim=True) # [1,3] + + dgram = torch.cdist(Ca[None,...].contiguous(), centroid[None,...].contiguous(), p=2) # [1,L,1,3] + + dgram = torch.maximum(self.min_dist * torch.ones_like(dgram.squeeze(0)), dgram.squeeze(0)) # [L,1,3] + + rad_of_gyration = torch.sqrt( torch.sum(torch.square(dgram)) / Ca.shape[0] ) # [1] + + return -1 * self.weight * rad_of_gyration + +class binder_ROG(Potential): + ''' + Radius of Gyration potential for encouraging binder compactness + + Author: NRB + ''' + + def __init__(self, binderlen, weight=1, min_dist=15): + + self.binderlen = binderlen + self.min_dist = min_dist + self.weight = weight + + def compute(self, xyz): + + # Only look at binder residues + Ca = xyz[:self.binderlen,1] # [Lb,3] + + centroid = torch.mean(Ca, dim=0, keepdim=True) # [1,3] + + # cdist needs a batch dimension - NRB + dgram = torch.cdist(Ca[None,...].contiguous(), centroid[None,...].contiguous(), p=2) # [1,Lb,1,3] + + dgram = torch.maximum(self.min_dist * torch.ones_like(dgram.squeeze(0)), dgram.squeeze(0)) # [Lb,1,3] + + rad_of_gyration = torch.sqrt( torch.sum(torch.square(dgram)) / Ca.shape[0] ) # [1] + + return -1 * self.weight * rad_of_gyration + + +class dimer_ROG(Potential): + ''' + Radius of Gyration potential for encouraging compactness of both monomers when designing dimers + + Author: PV + ''' + + def __init__(self, binderlen, weight=1, min_dist=15): + + self.binderlen = binderlen + self.min_dist = min_dist + self.weight = weight + + def compute(self, xyz): + + # Only look at monomer 1 residues + Ca_m1 = xyz[:self.binderlen,1] # [Lb,3] + + # Only look at monomer 2 residues + Ca_m2 = xyz[self.binderlen:,1] # [Lb,3] + + centroid_m1 = torch.mean(Ca_m1, dim=0, keepdim=True) # [1,3] + centroid_m2 = torch.mean(Ca_m1, dim=0, keepdim=True) # [1,3] + + # cdist needs a batch dimension - NRB + #This calculates RoG for Monomer 1 + dgram_m1 = torch.cdist(Ca_m1[None,...].contiguous(), centroid_m1[None,...].contiguous(), p=2) # [1,Lb,1,3] + dgram_m1 = torch.maximum(self.min_dist * torch.ones_like(dgram_m1.squeeze(0)), dgram_m1.squeeze(0)) # [Lb,1,3] + rad_of_gyration_m1 = torch.sqrt( torch.sum(torch.square(dgram_m1)) / Ca_m1.shape[0] ) # [1] + + # cdist needs a batch dimension - NRB + #This calculates RoG for Monomer 2 + dgram_m2 = torch.cdist(Ca_m2[None,...].contiguous(), centroid_m2[None,...].contiguous(), p=2) # [1,Lb,1,3] + dgram_m2 = torch.maximum(self.min_dist * torch.ones_like(dgram_m2.squeeze(0)), dgram_m2.squeeze(0)) # [Lb,1,3] + rad_of_gyration_m2 = torch.sqrt( torch.sum(torch.square(dgram_m2)) / Ca_m2.shape[0] ) # [1] + + #Potential value is the average of both radii of gyration (is avg. the best way to do this?) + return -1 * self.weight * (rad_of_gyration_m1 + rad_of_gyration_m2)/2 + +class binder_ncontacts(Potential): + ''' + Differentiable way to maximise number of contacts within a protein + + Motivation is given here: https://www.plumed.org/doc-v2.7/user-doc/html/_c_o_o_r_d_i_n_a_t_i_o_n.html + + ''' + + def __init__(self, binderlen, weight=1, r_0=8, d_0=4): + + self.binderlen = binderlen + self.r_0 = r_0 + self.weight = weight + self.d_0 = d_0 + + def compute(self, xyz): + + # Only look at binder Ca residues + Ca = xyz[:self.binderlen,1] # [Lb,3] + + #cdist needs a batch dimension - NRB + dgram = torch.cdist(Ca[None,...].contiguous(), Ca[None,...].contiguous(), p=2) # [1,Lb,Lb] + divide_by_r_0 = (dgram - self.d_0) / self.r_0 + numerator = torch.pow(divide_by_r_0,6) + denominator = torch.pow(divide_by_r_0,12) + binder_ncontacts = (1 - numerator) / (1 - denominator) + + print("BINDER CONTACTS:", binder_ncontacts.sum()) + #Potential value is the average of both radii of gyration (is avg. the best way to do this?) + return self.weight * binder_ncontacts.sum() + +class interface_ncontacts(Potential): + + ''' + Differentiable way to maximise number of contacts between binder and target + + Motivation is given here: https://www.plumed.org/doc-v2.7/user-doc/html/_c_o_o_r_d_i_n_a_t_i_o_n.html + + Author: PV + ''' + + + def __init__(self, binderlen, weight=1, r_0=8, d_0=6): + + self.binderlen = binderlen + self.r_0 = r_0 + self.weight = weight + self.d_0 = d_0 + + def compute(self, xyz): + + # Extract binder Ca residues + Ca_b = xyz[:self.binderlen,1] # [Lb,3] + + # Extract target Ca residues + Ca_t = xyz[self.binderlen:,1] # [Lt,3] + + #cdist needs a batch dimension - NRB + dgram = torch.cdist(Ca_b[None,...].contiguous(), Ca_t[None,...].contiguous(), p=2) # [1,Lb,Lt] + divide_by_r_0 = (dgram - self.d_0) / self.r_0 + numerator = torch.pow(divide_by_r_0,6) + denominator = torch.pow(divide_by_r_0,12) + interface_ncontacts = (1 - numerator) / (1 - denominator) + #Potential is the sum of values in the tensor + interface_ncontacts = interface_ncontacts.sum() + + print("INTERFACE CONTACTS:", interface_ncontacts.sum()) + + return self.weight * interface_ncontacts + + +class monomer_contacts(Potential): + ''' + Differentiable way to maximise number of contacts within a protein + + Motivation is given here: https://www.plumed.org/doc-v2.7/user-doc/html/_c_o_o_r_d_i_n_a_t_i_o_n.html + Author: PV + + NOTE: This function sometimes produces NaN's -- added check in reverse diffusion for nan grads + ''' + + def __init__(self, weight=1, r_0=8, d_0=2, eps=1e-6): + + self.r_0 = r_0 + self.weight = weight + self.d_0 = d_0 + self.eps = eps + + def compute(self, xyz): + + Ca = xyz[:,1] # [L,3] + + #cdist needs a batch dimension - NRB + dgram = torch.cdist(Ca[None,...].contiguous(), Ca[None,...].contiguous(), p=2) # [1,Lb,Lb] + divide_by_r_0 = (dgram - self.d_0) / self.r_0 + numerator = torch.pow(divide_by_r_0,6) + denominator = torch.pow(divide_by_r_0,12) + + ncontacts = (1 - numerator) / ((1 - denominator)) + + + #Potential value is the average of both radii of gyration (is avg. the best way to do this?) + return self.weight * ncontacts.sum() + + +class olig_contacts(Potential): + """ + Applies PV's num contacts potential within/between chains in symmetric oligomers + + Author: DJ + """ + + def __init__(self, + contact_matrix, + weight_intra=1, + weight_inter=1, + r_0=8, d_0=2): + """ + Parameters: + chain_lengths (list, required): List of chain lengths, length is (Nchains) + + contact_matrix (torch.tensor/np.array, required): + square matrix of shape (Nchains,Nchains) whose (i,j) enry represents + attractive (1), repulsive (-1), or non-existent (0) contact potentials + between chains in the complex + + weight (int/float, optional): Scaling/weighting factor + """ + self.contact_matrix = contact_matrix + self.weight_intra = weight_intra + self.weight_inter = weight_inter + self.r_0 = r_0 + self.d_0 = d_0 + + # check contact matrix only contains valid entries + assert all([i in [-1,0,1] for i in contact_matrix.flatten()]), 'Contact matrix must contain only 0, 1, or -1 in entries' + # assert the matrix is square and symmetric + shape = contact_matrix.shape + assert len(shape) == 2 + assert shape[0] == shape[1] + for i in range(shape[0]): + for j in range(shape[1]): + assert contact_matrix[i,j] == contact_matrix[j,i] + self.nchain=shape[0] + + + def _get_idx(self,i,L): + """ + Returns the zero-indexed indices of the residues in chain i + """ + assert L%self.nchain == 0 + Lchain = L//self.nchain + return i*Lchain + torch.arange(Lchain) + + + def compute(self, xyz): + """ + Iterate through the contact matrix, compute contact potentials between chains that need it, + and negate contacts for any + """ + L = xyz.shape[0] + + all_contacts = 0 + start = 0 + for i in range(self.nchain): + for j in range(self.nchain): + # only compute for upper triangle, disregard zeros in contact matrix + if (i <= j) and (self.contact_matrix[i,j] != 0): + + # get the indices for these two chains + idx_i = self._get_idx(i,L) + idx_j = self._get_idx(j,L) + + Ca_i = xyz[idx_i,1] # slice out crds for this chain + Ca_j = xyz[idx_j,1] # slice out crds for that chain + dgram = torch.cdist(Ca_i[None,...].contiguous(), Ca_j[None,...].contiguous(), p=2) # [1,Lb,Lb] + + divide_by_r_0 = (dgram - self.d_0) / self.r_0 + numerator = torch.pow(divide_by_r_0,6) + denominator = torch.pow(divide_by_r_0,12) + ncontacts = (1 - numerator) / (1 - denominator) + + # weight, don't double count intra + scalar = (i==j)*self.weight_intra/2 + (i!=j)*self.weight_inter + + # contacts attr/repuls relative weights + all_contacts += ncontacts.sum() * self.contact_matrix[i,j] * scalar + + return all_contacts + +def get_damped_lj(r_min, r_lin,p1=6,p2=12): + + y_at_r_lin = lj(r_lin, r_min, p1, p2) + ydot_at_r_lin = lj_grad(r_lin, r_min,p1,p2) + + def inner(dgram): + return (dgram < r_lin) * (ydot_at_r_lin * (dgram - r_lin) + y_at_r_lin) + (dgram >= r_lin) * lj(dgram, r_min, p1, p2) + return inner + +def lj(dgram, r_min,p1=6, p2=12): + return 4 * ((r_min / (2**(1/p1) * dgram))**p2 - (r_min / (2**(1/p1) * dgram))**p1) + +def lj_grad(dgram, r_min,p1=6,p2=12): + return -p2 * r_min**p1*(r_min**p1-dgram**p1) / (dgram**(p2+1)) + +def mask_expand(mask, n=1): + mask_out = mask.clone() + assert mask.ndim == 1 + for i in torch.where(mask)[0]: + for j in range(i-n, i+n+1): + if j >= 0 and j < len(mask): + mask_out[j] = True + return mask_out + +def contact_energy(dgram, d_0, r_0): + divide_by_r_0 = (dgram - d_0) / r_0 + numerator = torch.pow(divide_by_r_0,6) + denominator = torch.pow(divide_by_r_0,12) + + ncontacts = (1 - numerator) / ((1 - denominator)).float() + return - ncontacts + +def poly_repulse(dgram, r, slope, p=1): + a = slope / (p * r**(p-1)) + + return (dgram < r) * a * torch.abs(r - dgram)**p * slope + +#def only_top_n(dgram + + +class substrate_contacts(Potential): + ''' + Implicitly models a ligand with an attractive-repulsive potential. + ''' + + def __init__(self, weight=1, r_0=8, d_0=2, s=1, eps=1e-6, rep_r_0=5, rep_s=2, rep_r_min=1): + + self.r_0 = r_0 + self.weight = weight + self.d_0 = d_0 + self.eps = eps + + # motif frame coordinates + # NOTE: these probably need to be set after sample_init() call, because the motif sequence position in design must be known + self.motif_frame = None # [4,3] xyz coordinates from 4 atoms of input motif + self.motif_mapping = None # list of tuples giving positions of above atoms in design [(resi, atom_idx)] + self.motif_substrate_atoms = None # xyz coordinates of substrate from input motif + r_min = 2 + self.energies = [] + self.energies.append(lambda dgram: s * contact_energy(torch.min(dgram, dim=-1)[0], d_0, r_0)) + if rep_r_min: + self.energies.append(lambda dgram: poly_repulse(torch.min(dgram, dim=-1)[0], rep_r_0, rep_s, p=1.5)) + else: + self.energies.append(lambda dgram: poly_repulse(dgram, rep_r_0, rep_s, p=1.5)) + + + def compute(self, xyz): + + # First, get random set of atoms + # This operates on self.xyz_motif, which is assigned to this class in the model runner (for horrible plumbing reasons) + self._grab_motif_residues(self.xyz_motif) + + # for checking affine transformation is corect + first_distance = torch.sqrt(torch.sqrt(torch.sum(torch.square(self.motif_substrate_atoms[0] - self.motif_frame[0]), dim=-1))) + + # grab the coordinates of the corresponding atoms in the new frame using mapping + res = torch.tensor([k[0] for k in self.motif_mapping]) + atoms = torch.tensor([k[1] for k in self.motif_mapping]) + new_frame = xyz[self.diffusion_mask][res,atoms,:] + # calculate affine transformation matrix and translation vector b/w new frame and motif frame + A, t = self._recover_affine(self.motif_frame, new_frame) + # apply affine transformation to substrate atoms + substrate_atoms = torch.mm(A, self.motif_substrate_atoms.transpose(0,1)).transpose(0,1) + t + second_distance = torch.sqrt(torch.sqrt(torch.sum(torch.square(new_frame[0] - substrate_atoms[0]), dim=-1))) + assert abs(first_distance - second_distance) < 0.01, "Alignment seems to be bad" + diffusion_mask = mask_expand(self.diffusion_mask, 1) + Ca = xyz[~diffusion_mask, 1] + + #cdist needs a batch dimension - NRB + dgram = torch.cdist(Ca[None,...].contiguous(), substrate_atoms.float()[None], p=2)[0] # [Lb,Lb] + + all_energies = [] + for i, energy_fn in enumerate(self.energies): + energy = energy_fn(dgram) + all_energies.append(energy.sum()) + return - self.weight * sum(all_energies) + + #Potential value is the average of both radii of gyration (is avg. the best way to do this?) + return self.weight * ncontacts.sum() + + def _recover_affine(self,frame1, frame2): + """ + Uses Simplex Affine Matrix (SAM) formula to recover affine transform between two sets of 4 xyz coordinates + See: https://www.researchgate.net/publication/332410209_Beginner%27s_guide_to_mapping_simplexes_affinely + + Args: + frame1 - 4 coordinates from starting frame [4,3] + frame2 - 4 coordinates from ending frame [4,3] + + Outputs: + A - affine transformation matrix from frame1->frame2 + t - affine translation vector from frame1->frame2 + """ + + l = len(frame1) + # construct SAM denominator matrix + B = torch.vstack([frame1.T, torch.ones(l)]) + D = 1.0 / torch.linalg.det(B) # SAM denominator + + M = torch.zeros((3,4), dtype=torch.float64) + for i, R in enumerate(frame2.T): + for j in range(l): + num = torch.vstack([R, B]) + # make SAM numerator matrix + num = torch.cat((num[:j+1],num[j+2:])) # make numerator matrix + # calculate SAM entry + M[i][j] = (-1)**j * D * torch.linalg.det(num) + + A, t = torch.hsplit(M, [l-1]) + t = t.transpose(0,1) + return A, t + + def _grab_motif_residues(self, xyz) -> None: + """ + Grabs 4 atoms in the motif. + Currently random subset of Ca atoms if the motif is >= 4 residues, or else 4 random atoms from a single residue + """ + idx = torch.arange(self.diffusion_mask.shape[0]) + idx = idx[self.diffusion_mask].float() + if torch.sum(self.diffusion_mask) >= 4: + rand_idx = torch.multinomial(idx, 4).long() + # get Ca atoms + self.motif_frame = xyz[rand_idx, 1] + self.motif_mapping = [(i,1) for i in rand_idx] + else: + rand_idx = torch.multinomial(idx, 1).long() + self.motif_frame = xyz[rand_idx[0],:4] + self.motif_mapping = [(rand_idx, i) for i in range(4)] + +# Dictionary of types of potentials indexed by name of potential. Used by PotentialManager. +# If you implement a new potential you must add it to this dictionary for it to be used by +# the PotentialManager +implemented_potentials = { 'monomer_ROG': monomer_ROG, + 'binder_ROG': binder_ROG, + 'dimer_ROG': dimer_ROG, + 'binder_ncontacts': binder_ncontacts, + 'interface_ncontacts': interface_ncontacts, + 'monomer_contacts': monomer_contacts, + 'olig_contacts': olig_contacts, + 'substrate_contacts': substrate_contacts} + +require_binderlen = { 'binder_ROG', + 'binder_distance_ReLU', + 'binder_any_ReLU', + 'dimer_ROG', + 'binder_ncontacts', + 'interface_ncontacts'} + diff --git a/rfdiffusion/scoring.py b/rfdiffusion/scoring.py new file mode 100644 index 0000000000000000000000000000000000000000..21377f66cec15b7b01c23031f9b5b5357cf38e38 --- /dev/null +++ b/rfdiffusion/scoring.py @@ -0,0 +1,300 @@ + +## +## lk and lk term +#(LJ_RADIUS LJ_WDEPTH LK_DGFREE LK_LAMBDA LK_VOLUME) +type2ljlk = { + "CNH2":(1.968297,0.094638,3.077030,3.5000,13.500000), + "COO":(1.916661,0.141799,-3.332648,3.5000,14.653000), + "CH0":(2.011760,0.062642,1.409284,3.5000,8.998000), + "CH1":(2.011760,0.062642,-3.538387,3.5000,10.686000), + "CH2":(2.011760,0.062642,-1.854658,3.5000,18.331000), + "CH3":(2.011760,0.062642,7.292929,3.5000,25.855000), + "aroC":(2.016441,0.068775,1.797950,3.5000,16.704000), + "Ntrp":(1.802452,0.161725,-8.413116,3.5000,9.522100), + "Nhis":(1.802452,0.161725,-9.739606,3.5000,9.317700), + "NtrR":(1.802452,0.161725,-5.158080,3.5000,9.779200), + "NH2O":(1.802452,0.161725,-8.101638,3.5000,15.689000), + "Nlys":(1.802452,0.161725,-20.864641,3.5000,16.514000), + "Narg":(1.802452,0.161725,-8.968351,3.5000,15.717000), + "Npro":(1.802452,0.161725,-0.984585,3.5000,3.718100), + "OH":(1.542743,0.161947,-8.133520,3.5000,10.722000), + "OHY":(1.542743,0.161947,-8.133520,3.5000,10.722000), + "ONH2":(1.548662,0.182924,-6.591644,3.5000,10.102000), + "OOC":(1.492871,0.099873,-9.239832,3.5000,9.995600), + "S":(1.975967,0.455970,-1.707229,3.5000,17.640000), + "SH1":(1.975967,0.455970,3.291643,3.5000,23.240000), + "Nbb":(1.802452,0.161725,-9.969494,3.5000,15.992000), + "CAbb":(2.011760,0.062642,2.533791,3.5000,12.137000), + "CObb":(1.916661,0.141799,3.104248,3.5000,13.221000), + "OCbb":(1.540580,0.142417,-8.006829,3.5000,12.196000), + "HNbb":(0.901681,0.005000,0.0000,3.5000,0.0000), + "Hapo":(1.421272,0.021808,0.0000,3.5000,0.0000), + "Haro":(1.374914,0.015909,0.0000,3.5000,0.0000), + "Hpol":(0.901681,0.005000,0.0000,3.5000,0.0000), + "HS":(0.363887,0.050836,0.0000,3.5000,0.0000), +} + +# hbond donor/acceptors +class HbAtom: + NO = 0 + DO = 1 # donor + AC = 2 # acceptor + DA = 3 # donor & acceptor + HP = 4 # polar H + +type2hb = { + "CNH2":HbAtom.NO, "COO":HbAtom.NO, "CH0":HbAtom.NO, "CH1":HbAtom.NO, + "CH2":HbAtom.NO, "CH3":HbAtom.NO, "aroC":HbAtom.NO, "Ntrp":HbAtom.DO, + "Nhis":HbAtom.AC, "NtrR":HbAtom.DO, "NH2O":HbAtom.DO, "Nlys":HbAtom.DO, + "Narg":HbAtom.DO, "Npro":HbAtom.NO, "OH":HbAtom.DA, "OHY":HbAtom.DA, + "ONH2":HbAtom.AC, "OOC":HbAtom.AC, "S":HbAtom.NO, "SH1":HbAtom.NO, + "Nbb":HbAtom.DO, "CAbb":HbAtom.NO, "CObb":HbAtom.NO, "OCbb":HbAtom.AC, + "HNbb":HbAtom.HP, "Hapo":HbAtom.NO, "Haro":HbAtom.NO, "Hpol":HbAtom.HP, + "HS":HbAtom.HP, # HP in rosetta(?) +} + +## +## hbond term +class HbDonType: + PBA = 0 + IND = 1 + IME = 2 + GDE = 3 + CXA = 4 + AMO = 5 + HXL = 6 + AHX = 7 + NTYPES = 8 + +class HbAccType: + PBA = 0 + CXA = 1 + CXL = 2 + HXL = 3 + AHX = 4 + IME = 5 + NTYPES = 6 + +class HbHybType: + SP2 = 0 + SP3 = 1 + RING = 2 + NTYPES = 3 + +type2dontype = { + "Nbb": HbDonType.PBA, + "Ntrp": HbDonType.IND, + "NtrR": HbDonType.GDE, + "Narg": HbDonType.GDE, + "NH2O": HbDonType.CXA, + "Nlys": HbDonType.AMO, + "OH": HbDonType.HXL, + "OHY": HbDonType.AHX, +} + +type2acctype = { + "OCbb": HbAccType.PBA, + "ONH2": HbAccType.CXA, + "OOC": HbAccType.CXL, + "OH": HbAccType.HXL, + "OHY": HbAccType.AHX, + "Nhis": HbAccType.IME, +} + +type2hybtype = { + "OCbb": HbHybType.SP2, + "ONH2": HbHybType.SP2, + "OOC": HbHybType.SP2, + "OHY": HbHybType.SP3, + "OH": HbHybType.SP3, + "Nhis": HbHybType.RING, +} + +dontype2wt = { + HbDonType.PBA: 1.45, + HbDonType.IND: 1.15, + HbDonType.IME: 1.42, + HbDonType.GDE: 1.11, + HbDonType.CXA: 1.29, + HbDonType.AMO: 1.17, + HbDonType.HXL: 0.99, + HbDonType.AHX: 1.00, +} + +acctype2wt = { + HbAccType.PBA: 1.19, + HbAccType.CXA: 1.21, + HbAccType.CXL: 1.10, + HbAccType.HXL: 1.15, + HbAccType.AHX: 1.15, + HbAccType.IME: 1.17, +} + +class HbPolyType: + ahdist_aASN_dARG = 0 + ahdist_aASN_dASN = 1 + ahdist_aASN_dGLY = 2 + ahdist_aASN_dHIS = 3 + ahdist_aASN_dLYS = 4 + ahdist_aASN_dSER = 5 + ahdist_aASN_dTRP = 6 + ahdist_aASN_dTYR = 7 + ahdist_aASP_dARG = 8 + ahdist_aASP_dASN = 9 + ahdist_aASP_dGLY = 10 + ahdist_aASP_dHIS = 11 + ahdist_aASP_dLYS = 12 + ahdist_aASP_dSER = 13 + ahdist_aASP_dTRP = 14 + ahdist_aASP_dTYR = 15 + ahdist_aGLY_dARG = 16 + ahdist_aGLY_dASN = 17 + ahdist_aGLY_dGLY = 18 + ahdist_aGLY_dHIS = 19 + ahdist_aGLY_dLYS = 20 + ahdist_aGLY_dSER = 21 + ahdist_aGLY_dTRP = 22 + ahdist_aGLY_dTYR = 23 + ahdist_aHIS_dARG = 24 + ahdist_aHIS_dASN = 25 + ahdist_aHIS_dGLY = 26 + ahdist_aHIS_dHIS = 27 + ahdist_aHIS_dLYS = 28 + ahdist_aHIS_dSER = 29 + ahdist_aHIS_dTRP = 30 + ahdist_aHIS_dTYR = 31 + ahdist_aSER_dARG = 32 + ahdist_aSER_dASN = 33 + ahdist_aSER_dGLY = 34 + ahdist_aSER_dHIS = 35 + ahdist_aSER_dLYS = 36 + ahdist_aSER_dSER = 37 + ahdist_aSER_dTRP = 38 + ahdist_aSER_dTYR = 39 + ahdist_aTYR_dARG = 40 + ahdist_aTYR_dASN = 41 + ahdist_aTYR_dGLY = 42 + ahdist_aTYR_dHIS = 43 + ahdist_aTYR_dLYS = 44 + ahdist_aTYR_dSER = 45 + ahdist_aTYR_dTRP = 46 + ahdist_aTYR_dTYR = 47 + cosBAH_off = 48 + cosBAH_7 = 49 + cosBAH_6i = 50 + AHD_1h = 51 + AHD_1i = 52 + AHD_1j = 53 + AHD_1k = 54 + +# map donor:acceptor pairs to polynomials +hbtypepair2poly = { + (HbDonType.PBA,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dGLY,HbPolyType.cosBAH_off,HbPolyType.AHD_1j), + (HbDonType.CXA,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dASN,HbPolyType.cosBAH_off,HbPolyType.AHD_1j), + (HbDonType.IME,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dHIS,HbPolyType.cosBAH_off,HbPolyType.AHD_1j), + (HbDonType.IND,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dTRP,HbPolyType.cosBAH_off,HbPolyType.AHD_1j), + (HbDonType.AMO,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dLYS,HbPolyType.cosBAH_off,HbPolyType.AHD_1h), + (HbDonType.GDE,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dARG,HbPolyType.cosBAH_off,HbPolyType.AHD_1j), + (HbDonType.AHX,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dTYR,HbPolyType.cosBAH_off,HbPolyType.AHD_1k), + (HbDonType.HXL,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dSER,HbPolyType.cosBAH_off,HbPolyType.AHD_1k), + (HbDonType.PBA,HbAccType.CXA): (HbPolyType.ahdist_aASN_dGLY,HbPolyType.cosBAH_off,HbPolyType.AHD_1k), + (HbDonType.CXA,HbAccType.CXA): (HbPolyType.ahdist_aASN_dASN,HbPolyType.cosBAH_off,HbPolyType.AHD_1k), + (HbDonType.IME,HbAccType.CXA): (HbPolyType.ahdist_aASN_dHIS,HbPolyType.cosBAH_off,HbPolyType.AHD_1k), + (HbDonType.IND,HbAccType.CXA): (HbPolyType.ahdist_aASN_dTRP,HbPolyType.cosBAH_off,HbPolyType.AHD_1k), + (HbDonType.AMO,HbAccType.CXA): (HbPolyType.ahdist_aASN_dLYS,HbPolyType.cosBAH_off,HbPolyType.AHD_1h), + (HbDonType.GDE,HbAccType.CXA): (HbPolyType.ahdist_aASN_dARG,HbPolyType.cosBAH_off,HbPolyType.AHD_1k), + (HbDonType.AHX,HbAccType.CXA): (HbPolyType.ahdist_aASN_dTYR,HbPolyType.cosBAH_off,HbPolyType.AHD_1k), + (HbDonType.HXL,HbAccType.CXA): (HbPolyType.ahdist_aASN_dSER,HbPolyType.cosBAH_off,HbPolyType.AHD_1k), + (HbDonType.PBA,HbAccType.CXL): (HbPolyType.ahdist_aASP_dGLY,HbPolyType.cosBAH_off,HbPolyType.AHD_1k), + (HbDonType.CXA,HbAccType.CXL): (HbPolyType.ahdist_aASP_dASN,HbPolyType.cosBAH_off,HbPolyType.AHD_1k), + (HbDonType.IME,HbAccType.CXL): (HbPolyType.ahdist_aASP_dHIS,HbPolyType.cosBAH_off,HbPolyType.AHD_1k), + (HbDonType.IND,HbAccType.CXL): (HbPolyType.ahdist_aASP_dTRP,HbPolyType.cosBAH_off,HbPolyType.AHD_1k), + (HbDonType.AMO,HbAccType.CXL): (HbPolyType.ahdist_aASP_dLYS,HbPolyType.cosBAH_off,HbPolyType.AHD_1h), + (HbDonType.GDE,HbAccType.CXL): (HbPolyType.ahdist_aASP_dARG,HbPolyType.cosBAH_off,HbPolyType.AHD_1k), + (HbDonType.AHX,HbAccType.CXL): (HbPolyType.ahdist_aASP_dTYR,HbPolyType.cosBAH_off,HbPolyType.AHD_1k), + (HbDonType.HXL,HbAccType.CXL): (HbPolyType.ahdist_aASP_dSER,HbPolyType.cosBAH_off,HbPolyType.AHD_1k), + (HbDonType.PBA,HbAccType.IME): (HbPolyType.ahdist_aHIS_dGLY,HbPolyType.cosBAH_7,HbPolyType.AHD_1i), + (HbDonType.CXA,HbAccType.IME): (HbPolyType.ahdist_aHIS_dASN,HbPolyType.cosBAH_7,HbPolyType.AHD_1i), + (HbDonType.IME,HbAccType.IME): (HbPolyType.ahdist_aHIS_dHIS,HbPolyType.cosBAH_7,HbPolyType.AHD_1h), + (HbDonType.IND,HbAccType.IME): (HbPolyType.ahdist_aHIS_dTRP,HbPolyType.cosBAH_7,HbPolyType.AHD_1h), + (HbDonType.AMO,HbAccType.IME): (HbPolyType.ahdist_aHIS_dLYS,HbPolyType.cosBAH_7,HbPolyType.AHD_1i), + (HbDonType.GDE,HbAccType.IME): (HbPolyType.ahdist_aHIS_dARG,HbPolyType.cosBAH_7,HbPolyType.AHD_1h), + (HbDonType.AHX,HbAccType.IME): (HbPolyType.ahdist_aHIS_dTYR,HbPolyType.cosBAH_7,HbPolyType.AHD_1i), + (HbDonType.HXL,HbAccType.IME): (HbPolyType.ahdist_aHIS_dSER,HbPolyType.cosBAH_7,HbPolyType.AHD_1i), + (HbDonType.PBA,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dGLY,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i), + (HbDonType.CXA,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dASN,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i), + (HbDonType.IME,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dHIS,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i), + (HbDonType.IND,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dTRP,HbPolyType.cosBAH_6i,HbPolyType.AHD_1h), + (HbDonType.AMO,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dLYS,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i), + (HbDonType.GDE,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dARG,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i), + (HbDonType.AHX,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dTYR,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i), + (HbDonType.HXL,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dSER,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i), + (HbDonType.PBA,HbAccType.HXL): (HbPolyType.ahdist_aSER_dGLY,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i), + (HbDonType.CXA,HbAccType.HXL): (HbPolyType.ahdist_aSER_dASN,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i), + (HbDonType.IME,HbAccType.HXL): (HbPolyType.ahdist_aSER_dHIS,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i), + (HbDonType.IND,HbAccType.HXL): (HbPolyType.ahdist_aSER_dTRP,HbPolyType.cosBAH_6i,HbPolyType.AHD_1h), + (HbDonType.AMO,HbAccType.HXL): (HbPolyType.ahdist_aSER_dLYS,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i), + (HbDonType.GDE,HbAccType.HXL): (HbPolyType.ahdist_aSER_dARG,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i), + (HbDonType.AHX,HbAccType.HXL): (HbPolyType.ahdist_aSER_dTYR,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i), + (HbDonType.HXL,HbAccType.HXL): (HbPolyType.ahdist_aSER_dSER,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i), +} + + +# polynomials are triplets, (x_min, x_max), (y[xx_max]), (c_9,...,c_0) +hbpolytype2coeffs = { # Parameters imported from rosetta sp2_elec_params @v2017.48-dev59886 + HbPolyType.ahdist_aASN_dARG: ((0.7019094761929999, 2.86820307153,),(1.1, 1.1,),( 0.58376113, -9.29345473, 64.86270904, -260.3946711, 661.43138077, -1098.01378958, 1183.58371466, -790.82929582, 291.33125475, -43.01629727,)), + HbPolyType.ahdist_aASN_dASN: ((0.625841094801, 2.75107708444,),(1.1, 1.1,),( -1.31243015, 18.6745072, -112.63858313, 373.32878091, -734.99145504, 861.38324861, -556.21026097, 143.5626977, 20.03238394, -11.52167705,)), + HbPolyType.ahdist_aASN_dGLY: ((0.7477341047139999, 2.6796350782799996,),(1.1, 1.1,),( -1.61294554, 23.3150793, -144.11313069, 496.13575, -1037.83809166, 1348.76826073, -1065.14368678, 473.89008925, -100.41142701, 7.44453515,)), + HbPolyType.ahdist_aASN_dHIS: ((0.344789524346, 2.8303582266000005,),(1.1, 1.1,),( -0.2657122, 4.1073775, -26.9099632, 97.10486507, -209.96002602, 277.33057268, -218.74766996, 97.42852213, -24.07382402, 3.73962807,)), + HbPolyType.ahdist_aASN_dLYS: ((0.542905671869, 2.45259389314,),(1.1, 1.1,),( 1.38531754, -18.48733797, 106.14444613, -344.70585054, 698.91577956, -917.0879402, 775.32787908, -403.09588787, 113.65054778, -11.66516403,)), + HbPolyType.ahdist_aASN_dSER: ((1.0812774602500002, 2.6832123582599996,),(1.1, 1.1,),( -3.51524353, 47.54032873, -254.40168577, 617.84606386, -255.49935027, -2361.56230539, 6426.85797934, -7760.4403891, 4694.08106855, -1149.83549068,)), + HbPolyType.ahdist_aASN_dTRP: ((0.6689984999999999, 3.0704254,),(1.1, 1.1,),( -0.5284840422, 8.3510150838, -56.4100479414, 212.4884326254, -488.3178610608, 703.7762350506, -628.9936994633999, 331.4294356146, -93.265817571, 11.9691623698,)), + HbPolyType.ahdist_aASN_dTYR: ((1.08950268805, 2.6887046709400004,),(1.1, 1.1,),( -4.4488705, 63.27696281, -371.44187037, 1121.71921621, -1638.11394306, 142.99988401, 3436.65879147, -5496.07011787, 3709.30505237, -962.79669688,)), + HbPolyType.ahdist_aASP_dARG: ((0.8100404642229999, 2.9851230124799994,),(1.1, 1.1,),( -0.66430344, 10.41343145, -70.12656205, 265.12578414, -617.05849171, 911.39378582, -847.25013928, 472.09090981, -141.71513167, 18.57721132,)), + HbPolyType.ahdist_aASP_dASN: ((1.05401125073, 3.11129675908,),(1.1, 1.1,),( 0.02090728, -0.24144928, -0.19578075, 16.80904547, -117.70216251, 407.18551288, -809.95195924, 939.83137947, -593.94527692, 159.57610528,)), + HbPolyType.ahdist_aASP_dGLY: ((0.886260952629, 2.66843608743,),(1.1, 1.1,),( -7.00699267, 107.33021779, -713.45752385, 2694.43092298, -6353.05100287, 9667.94098394, -9461.9261027, 5721.0086877, -1933.97818198, 279.47763789,)), + HbPolyType.ahdist_aASP_dHIS: ((1.03597611139, 2.78208509117,),(1.1, 1.1,),( -1.34823406, 17.08925926, -78.75087193, 106.32795459, 400.18459698, -2041.04320193, 4033.83557387, -4239.60530204, 2324.00877252, -519.38410941,)), + HbPolyType.ahdist_aASP_dLYS: ((0.97789485082, 2.50496946108,),(1.1, 1.1,),( -0.41300315, 6.59243438, -44.44525308, 163.11796012, -351.2307798, 443.2463146, -297.84582856, 62.38600547, 33.77496227, -14.11652182,)), + HbPolyType.ahdist_aASP_dSER: ((0.542905671869, 2.45259389314,),(1.1, 1.1,),( 1.38531754, -18.48733797, 106.14444613, -344.70585054, 698.91577956, -917.0879402, 775.32787908, -403.09588787, 113.65054778, -11.66516403,)), + HbPolyType.ahdist_aASP_dTRP: ((0.419155746414, 3.0486938610500003,),(1.1, 1.1,),( -0.24563471, 3.85598551, -25.75176874, 95.36525025, -214.13175785, 299.76133553, -259.0691378, 132.06975835, -37.15612683, 5.60445773,)), + HbPolyType.ahdist_aASP_dTYR: ((1.01057521468, 2.7207545786900003,),(1.1, 1.1,),( -0.15808672, -10.21398871, 178.80080949, -1238.0583801, 4736.25248274, -11071.96777725, 16239.07550047, -14593.21092621, 7335.66765017, -1575.08145078,)), + HbPolyType.ahdist_aGLY_dARG: ((0.499016667857, 2.9377031027599996,),(1.1, 1.1,),( -0.15923533, 2.5526639, -17.38788803, 65.71046957, -151.13491186, 218.78048387, -199.15882919, 110.56568974, -35.95143745, 6.47580213,)), + HbPolyType.ahdist_aGLY_dASN: ((0.7194388032060001, 2.9303772333599998,),(1.1, 1.1,),( -1.40718342, 23.65929694, -172.97144348, 720.64417348, -1882.85420815, 3194.87197776, -3515.52467458, 2415.75238278, -941.47705161, 159.84784277,)), + HbPolyType.ahdist_aGLY_dGLY: ((1.38403812683, 2.9981039433,),(1.1, 1.1,),( -0.5307601, 6.47949946, -22.39522814, -55.14303544, 708.30945242, -2619.49318162, 5227.8805795, -6043.31211632, 3806.04676175, -1007.66024144,)), + HbPolyType.ahdist_aGLY_dHIS: ((0.47406840932899996, 2.9234200830400003,),(1.1, 1.1,),( -0.12881679, 1.933838, -12.03134888, 39.92691227, -75.41519959, 78.87968016, -37.82769801, -0.13178679, 4.50193019, 0.45408359,)), + HbPolyType.ahdist_aGLY_dLYS: ((0.545347533475, 2.42624380351,),(1.1, 1.1,),( -0.22921901, 2.07015714, -6.2947417, 0.66645697, 45.21805416, -130.26668981, 176.32401031, -126.68226346, 43.96744431, -4.40105281,)), + HbPolyType.ahdist_aGLY_dSER: ((1.2803349239700001, 2.2465996077400003,),(1.1, 1.1,),( 6.72508613, -86.98495585, 454.18518444, -1119.89141452, 715.624663, 3172.36852982, -9455.49113097, 11797.38766934, -7363.28302948, 1885.50119665,)), + HbPolyType.ahdist_aGLY_dTRP: ((0.686512740494, 3.02901351815,),(1.1, 1.1,),( -0.1051487, 1.41597708, -7.42149173, 17.31830704, -6.98293652, -54.76605063, 130.95272289, -132.77575305, 62.75460448, -9.89110842,)), + HbPolyType.ahdist_aGLY_dTYR: ((1.28894687639, 2.26335316892,),(1.1, 1.1,),( 13.84536925, -169.40579865, 893.79467505, -2670.60617561, 5016.46234701, -6293.79378818, 5585.1049063, -3683.50722701, 1709.48661405, -399.5712153,)), + HbPolyType.ahdist_aHIS_dARG: ((0.8967400957230001, 2.96809434226,),(1.1, 1.1,),( 0.43460495, -10.52727665, 103.16979807, -551.42887412, 1793.25378923, -3701.08304991, 4861.05155388, -3922.4285529, 1763.82137881, -335.43441944,)), + HbPolyType.ahdist_aHIS_dASN: ((0.887120931718, 2.59166903153,),(1.1, 1.1,),( -3.50289894, 54.42813924, -368.14395507, 1418.90186454, -3425.60485859, 5360.92334837, -5428.54462336, 3424.68800187, -1221.49631986, 189.27122436,)), + HbPolyType.ahdist_aHIS_dGLY: ((1.01629363411, 2.58523052904,),(1.1, 1.1,),( -1.68095217, 21.31894078, -107.72203494, 251.81021758, -134.07465831, -707.64527046, 1894.6282743, -2156.85951846, 1216.83585872, -275.48078944,)), + HbPolyType.ahdist_aHIS_dHIS: ((0.9773010778919999, 2.72533796329,),(1.1, 1.1,),( -2.33350626, 35.66072412, -233.98966111, 859.13714961, -1925.30958567, 2685.35293578, -2257.48067507, 1021.49796136, -169.36082523, -12.1348055,)), + HbPolyType.ahdist_aHIS_dLYS: ((0.7080936539849999, 2.47191718632,),(1.1, 1.1,),( -1.88479369, 28.38084382, -185.74039957, 690.81875917, -1605.11404391, 2414.83545623, -2355.9723201, 1442.24496229, -506.45880637, 79.47512505,)), + HbPolyType.ahdist_aHIS_dSER: ((0.90846809159, 2.5477956147,),(1.1, 1.1,),( -0.92004641, 15.91841533, -117.83979251, 488.22211296, -1244.13047376, 2017.43704053, -2076.04468019, 1302.42621488, -451.29138643, 67.15812575,)), + HbPolyType.ahdist_aHIS_dTRP: ((0.991999676806, 2.81296584506,),(1.1, 1.1,),( -1.29358587, 19.97152857, -131.89796017, 485.29199356, -1084.0466445, 1497.3352889, -1234.58042682, 535.8048197, -75.58951691, -9.91148332,)), + HbPolyType.ahdist_aHIS_dTYR: ((0.882661836357, 2.5469016429900004,),(1.1, 1.1,),( -6.94700143, 109.07997256, -747.64035726, 2929.83959536, -7220.15788571, 11583.34170519, -12078.443492, 7881.85479715, -2918.19482068, 468.23988622,)), + HbPolyType.ahdist_aSER_dARG: ((1.0204658147399999, 2.8899566041900004,),(1.1, 1.1,),( 0.33887327, -7.54511361, 70.87316645, -371.88263665, 1206.67454443, -2516.82084076, 3379.45432693, -2819.73384601, 1325.33307517, -265.54533008,)), + HbPolyType.ahdist_aSER_dASN: ((1.01393052233, 3.0024434159299997,),(1.1, 1.1,),( 0.37012361, -7.46486204, 64.85775924, -318.6047209, 974.66322243, -1924.37334018, 2451.63840629, -1943.1915675, 867.07870559, -163.83771761,)), + HbPolyType.ahdist_aSER_dGLY: ((1.3856562156299999, 2.74160605537,),(1.1, 1.1,),( -1.32847415, 22.67528654, -172.53450064, 770.79034865, -2233.48829652, 4354.38807288, -5697.35144236, 4803.38686157, -2361.48028857, 518.28202382,)), + HbPolyType.ahdist_aSER_dHIS: ((0.550992321207, 2.68549261999,),(1.1, 1.1,),( -1.98041793, 29.59668639, -190.36751773, 688.43324385, -1534.68894765, 2175.66568976, -1952.07622113, 1066.28943929, -324.23381388, 43.41006168,)), + HbPolyType.ahdist_aSER_dLYS: ((0.8603189393170001, 2.77729502744,),(1.1, 1.1,),( 0.90884741, -17.24690746, 141.78469099, -661.85989315, 1929.7674992, -3636.43392779, 4419.00727923, -3332.43482061, 1410.78913266, -253.53829424,)), + HbPolyType.ahdist_aSER_dSER: ((1.10866545921, 2.61727781204,),(1.1, 1.1,),( -0.38264308, 4.41779675, -10.7016645, -81.91314845, 668.91174735, -2187.50684758, 3983.56103269, -4213.32320546, 2418.41531442, -580.28918569,)), + HbPolyType.ahdist_aSER_dTRP: ((1.4092077245899999, 2.8066121197099996,),(1.1, 1.1,),( 0.73762477, -11.70741276, 73.05154232, -205.00144794, 89.58794368, 1082.94541375, -3343.98293188, 4601.70815729, -3178.53568678, 896.59487831,)), + HbPolyType.ahdist_aSER_dTYR: ((1.10773547919, 2.60403567341,),(1.1, 1.1,),( -1.13249925, 14.66643161, -69.01708791, 93.96846742, 380.56063898, -1984.56675689, 4074.08891127, -4492.76927139, 2613.13168054, -627.71933508,)), + HbPolyType.ahdist_aTYR_dARG: ((1.05581400627, 2.85499888099,),(1.1, 1.1,),( -0.30396592, 5.30288548, -39.75788579, 167.5416547, -435.15958911, 716.52357586, -735.95195083, 439.76284677, -130.00400085, 13.23827556,)), + HbPolyType.ahdist_aTYR_dASN: ((1.0994919065200002, 2.8400869077900004,),(1.1, 1.1,),( 0.33548259, -3.5890451, 8.97769025, 48.1492734, -400.5983616, 1269.89613211, -2238.03101675, 2298.33009115, -1290.42961162, 308.43185147,)), + HbPolyType.ahdist_aTYR_dGLY: ((1.36546155066, 2.7303075916400004,),(1.1, 1.1,),( -1.55312915, 18.62092487, -70.91365499, -41.83066505, 1248.88835245, -4719.81948329, 9186.09528168, -10266.11434548, 6266.21959533, -1622.19652457,)), + HbPolyType.ahdist_aTYR_dHIS: ((0.5955982461899999, 2.6643551317500003,),(1.1, 1.1,),( -0.47442788, 7.16629863, -46.71287553, 171.46128947, -388.17484011, 558.45202337, -506.35587481, 276.46237273, -83.52554392, 12.05709329,)), + HbPolyType.ahdist_aTYR_dLYS: ((0.7978598238760001, 2.7620933782,),(1.1, 1.1,),( -0.20201464, 1.69684984, 0.27677515, -55.05786347, 286.29918332, -725.92372531, 1054.771746, -889.33602341, 401.11342256, -73.02221189,)), + HbPolyType.ahdist_aTYR_dSER: ((0.7083554962559999, 2.7032011990599996,),(1.1, 1.1,),( -0.70764192, 11.67978065, -82.80447482, 329.83401367, -810.58976486, 1269.57613941, -1261.04047117, 761.72890446, -254.37526011, 37.24301861,)), + HbPolyType.ahdist_aTYR_dTRP: ((1.10934023051, 2.8819112108,),(1.1, 1.1,),( -11.58453967, 204.88308091, -1589.77384548, 7100.84791905, -20113.61354433, 37457.83646055, -45850.02969172, 35559.8805122, -15854.78726237, 3098.04931146,)), + HbPolyType.ahdist_aTYR_dTYR: ((1.1105954899400001, 2.60081798685,),(1.1, 1.1,),( -1.63120628, 19.48493187, -81.0332905, 56.80517706, 687.42717782, -2842.77799908, 5385.52231471, -5656.74159307, 3178.83470588, -744.70042777,)), + HbPolyType.AHD_1h: ((1.76555274367, 3.1416,),(1.1, 1.1,),( 0.62725838, -9.98558225, 59.39060071, -120.82930213, -333.26536028, 2603.13082592, -6895.51207142, 9651.25238056, -7127.13394872, 2194.77244026,)), + HbPolyType.AHD_1i: ((1.59914724347, 3.1416,),(1.1, 1.1,),( -0.18888801, 3.48241679, -25.65508662, 89.57085435, -95.91708218, -367.93452341, 1589.6904702, -2662.3582135, 2184.40194483, -723.28383545,)), + HbPolyType.AHD_1j: ((1.1435646388, 3.1416,),(1.1, 1.1,),( 0.47683259, -9.54524724, 83.62557693, -420.55867774, 1337.19354878, -2786.26265686, 3803.178227, -3278.62879901, 1619.04116204, -347.50157909,)), + HbPolyType.AHD_1k: ((1.15651981164, 3.1416,),(1.1, 1.1,),( -0.10757999, 2.0276542, -16.51949978, 75.83866839, -214.18025678, 380.55117567, -415.47847283, 255.66998474, -69.94662165, 3.21313428,)), + HbPolyType.cosBAH_off: ((-1234.0, 1.1,),(1.1, 1.1,),( 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,)), + HbPolyType.cosBAH_6i: ((-0.23538144897100002, 1.1,),(1.1, 1.1,),( -0.822093, -3.75364636, 46.88852157, -129.5440564, 146.69151428, -67.60598792, 2.91683129, 9.26673173, -3.84488178, 0.05706659,)), + HbPolyType.cosBAH_7: ((-0.019373850666900002, 1.1,),(1.1, 1.1,),( 0.0, -27.942923450028, 136.039920253368, -268.06959056747, 275.400462507919, -153.502076215949, 39.741591385461, 0.693861510121, -3.885952320499, 1.024765090788892)), +} \ No newline at end of file diff --git a/rfdiffusion/util.py b/rfdiffusion/util.py new file mode 100644 index 0000000000000000000000000000000000000000..19c30f5f27532f7da592df723111f4bcf99c802f --- /dev/null +++ b/rfdiffusion/util.py @@ -0,0 +1,743 @@ +import scipy.sparse +from rfdiffusion.chemical import * +from rfdiffusion.scoring import * + + +def generate_Cbeta(N, Ca, C): + # recreate Cb given N,Ca,C + b = Ca - N + c = C - Ca + a = torch.cross(b, c, dim=-1) + # These are the values used during training + Cb = -0.58273431*a + 0.56802827*b - 0.54067466*c + Ca + # fd: below matches sidechain generator (=Rosetta params) + # Cb = -0.57910144 * a + 0.5689693 * b - 0.5441217 * c + Ca + + return Cb + + +def th_ang_v(ab, bc, eps: float = 1e-8): + def th_norm(x, eps: float = 1e-8): + return x.square().sum(-1, keepdim=True).add(eps).sqrt() + + def th_N(x, alpha: float = 0): + return x / th_norm(x).add(alpha) + + ab, bc = th_N(ab), th_N(bc) + cos_angle = torch.clamp((ab * bc).sum(-1), -1, 1) + sin_angle = torch.sqrt(1 - cos_angle.square() + eps) + dih = torch.stack((cos_angle, sin_angle), -1) + return dih + + +def th_dih_v(ab, bc, cd): + def th_cross(a, b): + a, b = torch.broadcast_tensors(a, b) + return torch.cross(a, b, dim=-1) + + def th_norm(x, eps: float = 1e-8): + return x.square().sum(-1, keepdim=True).add(eps).sqrt() + + def th_N(x, alpha: float = 0): + return x / th_norm(x).add(alpha) + + ab, bc, cd = th_N(ab), th_N(bc), th_N(cd) + n1 = th_N(th_cross(ab, bc)) + n2 = th_N(th_cross(bc, cd)) + sin_angle = (th_cross(n1, bc) * n2).sum(-1) + cos_angle = (n1 * n2).sum(-1) + dih = torch.stack((cos_angle, sin_angle), -1) + return dih + + +def th_dih(a, b, c, d): + return th_dih_v(a - b, b - c, c - d) + + +# More complicated version splits error in CA-N and CA-C (giving more accurate CB position) +# It returns the rigid transformation from local frame to global frame +def rigid_from_3_points(N, Ca, C, non_ideal=False, eps=1e-8): + # N, Ca, C - [B,L, 3] + # R - [B,L, 3, 3], det(R)=1, inv(R) = R.T, R is a rotation matrix + B, L = N.shape[:2] + + v1 = C - Ca + v2 = N - Ca + e1 = v1 / (torch.norm(v1, dim=-1, keepdim=True) + eps) + u2 = v2 - (torch.einsum("bli, bli -> bl", e1, v2)[..., None] * e1) + e2 = u2 / (torch.norm(u2, dim=-1, keepdim=True) + eps) + e3 = torch.cross(e1, e2, dim=-1) + R = torch.cat( + [e1[..., None], e2[..., None], e3[..., None]], axis=-1 + ) # [B,L,3,3] - rotation matrix + + if non_ideal: + v2 = v2 / (torch.norm(v2, dim=-1, keepdim=True) + eps) + cosref = torch.sum(e1 * v2, dim=-1) # cosine of current N-CA-C bond angle + costgt = cos_ideal_NCAC.item() + cos2del = torch.clamp( + cosref * costgt + + torch.sqrt((1 - cosref * cosref) * (1 - costgt * costgt) + eps), + min=-1.0, + max=1.0, + ) + cosdel = torch.sqrt(0.5 * (1 + cos2del) + eps) + sindel = torch.sign(costgt - cosref) * torch.sqrt(1 - 0.5 * (1 + cos2del) + eps) + Rp = torch.eye(3, device=N.device).repeat(B, L, 1, 1) + Rp[:, :, 0, 0] = cosdel + Rp[:, :, 0, 1] = -sindel + Rp[:, :, 1, 0] = sindel + Rp[:, :, 1, 1] = cosdel + + R = torch.einsum("blij,bljk->blik", R, Rp) + + return R, Ca + + +def get_tor_mask(seq, torsion_indices, mask_in=None): + B, L = seq.shape[:2] + tors_mask = torch.ones((B, L, 10), dtype=torch.bool, device=seq.device) + tors_mask[..., 3:7] = torsion_indices[seq, :, -1] > 0 + tors_mask[:, 0, 1] = False + tors_mask[:, -1, 0] = False + + # mask for additional angles + tors_mask[:, :, 7] = seq != aa2num["GLY"] + tors_mask[:, :, 8] = seq != aa2num["GLY"] + tors_mask[:, :, 9] = torch.logical_and(seq != aa2num["GLY"], seq != aa2num["ALA"]) + tors_mask[:, :, 9] = torch.logical_and(tors_mask[:, :, 9], seq != aa2num["UNK"]) + tors_mask[:, :, 9] = torch.logical_and(tors_mask[:, :, 9], seq != aa2num["MAS"]) + + if mask_in != None: + # mask for missing atoms + # chis + ti0 = torch.gather(mask_in, 2, torsion_indices[seq, :, 0]) + ti1 = torch.gather(mask_in, 2, torsion_indices[seq, :, 1]) + ti2 = torch.gather(mask_in, 2, torsion_indices[seq, :, 2]) + ti3 = torch.gather(mask_in, 2, torsion_indices[seq, :, 3]) + is_valid = torch.stack((ti0, ti1, ti2, ti3), dim=-2).all(dim=-1) + tors_mask[..., 3:7] = torch.logical_and(tors_mask[..., 3:7], is_valid) + tors_mask[:, :, 7] = torch.logical_and( + tors_mask[:, :, 7], mask_in[:, :, 4] + ) # CB exist? + tors_mask[:, :, 8] = torch.logical_and( + tors_mask[:, :, 8], mask_in[:, :, 4] + ) # CB exist? + tors_mask[:, :, 9] = torch.logical_and( + tors_mask[:, :, 9], mask_in[:, :, 5] + ) # XG exist? + + return tors_mask + + +def get_torsions( + xyz_in, seq, torsion_indices, torsion_can_flip, ref_angles, mask_in=None +): + B, L = xyz_in.shape[:2] + + tors_mask = get_tor_mask(seq, torsion_indices, mask_in) + + # torsions to restrain to 0 or 180degree + tors_planar = torch.zeros((B, L, 10), dtype=torch.bool, device=xyz_in.device) + tors_planar[:, :, 5] = seq == aa2num["TYR"] # TYR chi 3 should be planar + + # idealize given xyz coordinates before computing torsion angles + xyz = xyz_in.clone() + Rs, Ts = rigid_from_3_points(xyz[..., 0, :], xyz[..., 1, :], xyz[..., 2, :]) + Nideal = torch.tensor([-0.5272, 1.3593, 0.000], device=xyz_in.device) + Cideal = torch.tensor([1.5233, 0.000, 0.000], device=xyz_in.device) + xyz[..., 0, :] = torch.einsum("brij,j->bri", Rs, Nideal) + Ts + xyz[..., 2, :] = torch.einsum("brij,j->bri", Rs, Cideal) + Ts + + torsions = torch.zeros((B, L, 10, 2), device=xyz.device) + # avoid undefined angles for H generation + torsions[:, 0, 1, 0] = 1.0 + torsions[:, -1, 0, 0] = 1.0 + + # omega + torsions[:, :-1, 0, :] = th_dih( + xyz[:, :-1, 1, :], xyz[:, :-1, 2, :], xyz[:, 1:, 0, :], xyz[:, 1:, 1, :] + ) + # phi + torsions[:, 1:, 1, :] = th_dih( + xyz[:, :-1, 2, :], xyz[:, 1:, 0, :], xyz[:, 1:, 1, :], xyz[:, 1:, 2, :] + ) + # psi + torsions[:, :, 2, :] = -1 * th_dih( + xyz[:, :, 0, :], xyz[:, :, 1, :], xyz[:, :, 2, :], xyz[:, :, 3, :] + ) + + # chis + ti0 = torch.gather(xyz, 2, torsion_indices[seq, :, 0, None].repeat(1, 1, 1, 3)) + ti1 = torch.gather(xyz, 2, torsion_indices[seq, :, 1, None].repeat(1, 1, 1, 3)) + ti2 = torch.gather(xyz, 2, torsion_indices[seq, :, 2, None].repeat(1, 1, 1, 3)) + ti3 = torch.gather(xyz, 2, torsion_indices[seq, :, 3, None].repeat(1, 1, 1, 3)) + torsions[:, :, 3:7, :] = th_dih(ti0, ti1, ti2, ti3) + + # CB bend + NC = 0.5 * (xyz[:, :, 0, :3] + xyz[:, :, 2, :3]) + CA = xyz[:, :, 1, :3] + CB = xyz[:, :, 4, :3] + t = th_ang_v(CB - CA, NC - CA) + t0 = ref_angles[seq][..., 0, :] + torsions[:, :, 7, :] = torch.stack( + (torch.sum(t * t0, dim=-1), t[..., 0] * t0[..., 1] - t[..., 1] * t0[..., 0]), + dim=-1, + ) + + # CB twist + NCCA = NC - CA + NCp = xyz[:, :, 2, :3] - xyz[:, :, 0, :3] + NCpp = ( + NCp + - torch.sum(NCp * NCCA, dim=-1, keepdim=True) + / torch.sum(NCCA * NCCA, dim=-1, keepdim=True) + * NCCA + ) + t = th_ang_v(CB - CA, NCpp) + t0 = ref_angles[seq][..., 1, :] + torsions[:, :, 8, :] = torch.stack( + (torch.sum(t * t0, dim=-1), t[..., 0] * t0[..., 1] - t[..., 1] * t0[..., 0]), + dim=-1, + ) + + # CG bend + CG = xyz[:, :, 5, :3] + t = th_ang_v(CG - CB, CA - CB) + t0 = ref_angles[seq][..., 2, :] + torsions[:, :, 9, :] = torch.stack( + (torch.sum(t * t0, dim=-1), t[..., 0] * t0[..., 1] - t[..., 1] * t0[..., 0]), + dim=-1, + ) + + mask0 = torch.isnan(torsions[..., 0]).nonzero() + mask1 = torch.isnan(torsions[..., 1]).nonzero() + torsions[mask0[:, 0], mask0[:, 1], mask0[:, 2], 0] = 1.0 + torsions[mask1[:, 0], mask1[:, 1], mask1[:, 2], 1] = 0.0 + + # alt chis + torsions_alt = torsions.clone() + torsions_alt[torsion_can_flip[seq, :]] *= -1 + + return torsions, torsions_alt, tors_mask, tors_planar + + +def get_tips(xyz, seq): + B, L = xyz.shape[:2] + + xyz_tips = torch.gather( + xyz, 2, tip_indices.to(xyz.device)[seq][:, :, None, None].expand(-1, -1, -1, 3) + ).reshape(B, L, 3) + mask = ~(torch.isnan(xyz_tips[:, :, 0])) + if torch.isnan(xyz_tips).any(): # replace NaN tip atom with virtual Cb atom + # three anchor atoms + N = xyz[:, :, 0] + Ca = xyz[:, :, 1] + C = xyz[:, :, 2] + + # recreate Cb given N,Ca,C + b = Ca - N + c = C - Ca + a = torch.cross(b, c, dim=-1) + Cb = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + Ca + + xyz_tips = torch.where(torch.isnan(xyz_tips), Cb, xyz_tips) + return xyz_tips, mask + + +# process ideal frames +def make_frame(X, Y): + Xn = X / torch.linalg.norm(X) + Y = Y - torch.dot(Y, Xn) * Xn + Yn = Y / torch.linalg.norm(Y) + Z = torch.cross(Xn, Yn) + Zn = Z / torch.linalg.norm(Z) + + return torch.stack((Xn, Yn, Zn), dim=-1) + + +def cross_product_matrix(u): + B, L = u.shape[:2] + matrix = torch.zeros((B, L, 3, 3), device=u.device) + matrix[:, :, 0, 1] = -u[..., 2] + matrix[:, :, 0, 2] = u[..., 1] + matrix[:, :, 1, 0] = u[..., 2] + matrix[:, :, 1, 2] = -u[..., 0] + matrix[:, :, 2, 0] = -u[..., 1] + matrix[:, :, 2, 1] = u[..., 0] + return matrix + + +# writepdb +def writepdb( + filename, atoms, seq, binderlen=None, idx_pdb=None, bfacts=None, chain_idx=None +): + f = open(filename, "w") + ctr = 1 + scpu = seq.cpu().squeeze() + atomscpu = atoms.cpu().squeeze() + if bfacts is None: + bfacts = torch.zeros(atomscpu.shape[0]) + if idx_pdb is None: + idx_pdb = 1 + torch.arange(atomscpu.shape[0]) + + Bfacts = torch.clamp(bfacts.cpu(), 0, 1) + for i, s in enumerate(scpu): + if chain_idx is None: + if binderlen is not None: + if i < binderlen: + chain = "A" + else: + chain = "B" + elif binderlen is None: + chain = "A" + else: + chain = chain_idx[i] + if len(atomscpu.shape) == 2: + f.write( + "%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n" + % ( + "ATOM", + ctr, + " CA ", + num2aa[s], + chain, + idx_pdb[i], + atomscpu[i, 0], + atomscpu[i, 1], + atomscpu[i, 2], + 1.0, + Bfacts[i], + ) + ) + ctr += 1 + elif atomscpu.shape[1] == 3: + for j, atm_j in enumerate([" N ", " CA ", " C "]): + f.write( + "%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n" + % ( + "ATOM", + ctr, + atm_j, + num2aa[s], + chain, + idx_pdb[i], + atomscpu[i, j, 0], + atomscpu[i, j, 1], + atomscpu[i, j, 2], + 1.0, + Bfacts[i], + ) + ) + ctr += 1 + elif atomscpu.shape[1] == 4: + for j, atm_j in enumerate([" N ", " CA ", " C ", " O "]): + f.write( + "%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n" + % ( + "ATOM", + ctr, + atm_j, + num2aa[s], + chain, + idx_pdb[i], + atomscpu[i, j, 0], + atomscpu[i, j, 1], + atomscpu[i, j, 2], + 1.0, + Bfacts[i], + ) + ) + ctr += 1 + + else: + natoms = atomscpu.shape[1] + if natoms != 14 and natoms != 27: + print("bad size!", atoms.shape) + assert False + atms = aa2long[s] + # his prot hack + if ( + s == 8 + and torch.linalg.norm(atomscpu[i, 9, :] - atomscpu[i, 5, :]) < 1.7 + ): + atms = ( + " N ", + " CA ", + " C ", + " O ", + " CB ", + " CG ", + " NE2", + " CD2", + " CE1", + " ND1", + None, + None, + None, + None, + " H ", + " HA ", + "1HB ", + "2HB ", + " HD2", + " HE1", + " HD1", + None, + None, + None, + None, + None, + None, + ) # his_d + + for j, atm_j in enumerate(atms): + if ( + j < natoms and atm_j is not None + ): # and not torch.isnan(atomscpu[i,j,:]).any()): + f.write( + "%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n" + % ( + "ATOM", + ctr, + atm_j, + num2aa[s], + chain, + idx_pdb[i], + atomscpu[i, j, 0], + atomscpu[i, j, 1], + atomscpu[i, j, 2], + 1.0, + Bfacts[i], + ) + ) + ctr += 1 + + +# resolve tip atom indices +tip_indices = torch.full((22,), 0) +for i in range(22): + tip_atm = aa2tip[i] + atm_long = aa2long[i] + tip_indices[i] = atm_long.index(tip_atm) + +# resolve torsion indices +torsion_indices = torch.full((22, 4, 4), 0) +torsion_can_flip = torch.full((22, 10), False, dtype=torch.bool) +for i in range(22): + i_l, i_a = aa2long[i], aa2longalt[i] + for j in range(4): + if torsions[i][j] is None: + continue + for k in range(4): + a = torsions[i][j][k] + torsion_indices[i, j, k] = i_l.index(a) + if i_l.index(a) != i_a.index(a): + torsion_can_flip[i, 3 + j] = True ##bb tors never flip +# HIS is a special case +torsion_can_flip[8, 4] = False + +# build the mapping from atoms in the full rep (Nx27) to the "alternate" rep +allatom_mask = torch.zeros((22, 27), dtype=torch.bool) +long2alt = torch.zeros((22, 27), dtype=torch.long) +for i in range(22): + i_l, i_lalt = aa2long[i], aa2longalt[i] + for j, a in enumerate(i_l): + if a is None: + long2alt[i, j] = j + else: + long2alt[i, j] = i_lalt.index(a) + allatom_mask[i, j] = True + +# bond graph traversal +num_bonds = torch.zeros((22, 27, 27), dtype=torch.long) +for i in range(22): + num_bonds_i = np.zeros((27, 27)) + for bnamei, bnamej in aabonds[i]: + bi, bj = aa2long[i].index(bnamei), aa2long[i].index(bnamej) + num_bonds_i[bi, bj] = 1 + num_bonds_i = scipy.sparse.csgraph.shortest_path(num_bonds_i, directed=False) + num_bonds_i[num_bonds_i >= 4] = 4 + num_bonds[i, ...] = torch.tensor(num_bonds_i) + + +# LJ/LK scoring parameters +ljlk_parameters = torch.zeros((22, 27, 5), dtype=torch.float) +lj_correction_parameters = torch.zeros( + (22, 27, 4), dtype=bool +) # donor/acceptor/hpol/disulf +for i in range(22): + for j, a in enumerate(aa2type[i]): + if a is not None: + ljlk_parameters[i, j, :] = torch.tensor(type2ljlk[a]) + lj_correction_parameters[i, j, 0] = (type2hb[a] == HbAtom.DO) + ( + type2hb[a] == HbAtom.DA + ) + lj_correction_parameters[i, j, 1] = (type2hb[a] == HbAtom.AC) + ( + type2hb[a] == HbAtom.DA + ) + lj_correction_parameters[i, j, 2] = type2hb[a] == HbAtom.HP + lj_correction_parameters[i, j, 3] = a == "SH1" or a == "HS" + + +# hbond scoring parameters +def donorHs(D, bonds, atoms): + dHs = [] + for i, j in bonds: + if i == D: + idx_j = atoms.index(j) + if idx_j >= 14: # if atom j is a hydrogen + dHs.append(idx_j) + if j == D: + idx_i = atoms.index(i) + if idx_i >= 14: # if atom j is a hydrogen + dHs.append(idx_i) + assert len(dHs) > 0 + return dHs + + +def acceptorBB0(A, hyb, bonds, atoms): + if hyb == HbHybType.SP2: + for i, j in bonds: + if i == A: + B = atoms.index(j) + if B < 14: + break + if j == A: + B = atoms.index(i) + if B < 14: + break + for i, j in bonds: + if i == atoms[B]: + B0 = atoms.index(j) + if B0 < 14: + break + if j == atoms[B]: + B0 = atoms.index(i) + if B0 < 14: + break + elif hyb == HbHybType.SP3 or hyb == HbHybType.RING: + for i, j in bonds: + if i == A: + B = atoms.index(j) + if B < 14: + break + if j == A: + B = atoms.index(i) + if B < 14: + break + for i, j in bonds: + if i == A and j != atoms[B]: + B0 = atoms.index(j) + break + if j == A and i != atoms[B]: + B0 = atoms.index(i) + break + + return B, B0 + + +hbtypes = torch.full( + (22, 27, 3), -1, dtype=torch.long +) # (donortype, acceptortype, acchybtype) +hbbaseatoms = torch.full( + (22, 27, 2), -1, dtype=torch.long +) # (B,B0) for acc; (D,-1) for don +hbpolys = torch.zeros( + (HbDonType.NTYPES, HbAccType.NTYPES, 3, 15) +) # weight,xmin,xmax,ymin,ymax,c9,...,c0 + +for i in range(22): + for j, a in enumerate(aa2type[i]): + if a in type2dontype: + j_hs = donorHs(aa2long[i][j], aabonds[i], aa2long[i]) + for j_h in j_hs: + hbtypes[i, j_h, 0] = type2dontype[a] + hbbaseatoms[i, j_h, 0] = j + if a in type2acctype: + j_b, j_b0 = acceptorBB0( + aa2long[i][j], type2hybtype[a], aabonds[i], aa2long[i] + ) + hbtypes[i, j, 1] = type2acctype[a] + hbtypes[i, j, 2] = type2hybtype[a] + hbbaseatoms[i, j, 0] = j_b + hbbaseatoms[i, j, 1] = j_b0 + +for i in range(HbDonType.NTYPES): + for j in range(HbAccType.NTYPES): + weight = dontype2wt[i] * acctype2wt[j] + + pdist, pbah, pahd = hbtypepair2poly[(i, j)] + xrange, yrange, coeffs = hbpolytype2coeffs[pdist] + hbpolys[i, j, 0, 0] = weight + hbpolys[i, j, 0, 1:3] = torch.tensor(xrange) + hbpolys[i, j, 0, 3:5] = torch.tensor(yrange) + hbpolys[i, j, 0, 5:] = torch.tensor(coeffs) + xrange, yrange, coeffs = hbpolytype2coeffs[pahd] + hbpolys[i, j, 1, 0] = weight + hbpolys[i, j, 1, 1:3] = torch.tensor(xrange) + hbpolys[i, j, 1, 3:5] = torch.tensor(yrange) + hbpolys[i, j, 1, 5:] = torch.tensor(coeffs) + xrange, yrange, coeffs = hbpolytype2coeffs[pbah] + hbpolys[i, j, 2, 0] = weight + hbpolys[i, j, 2, 1:3] = torch.tensor(xrange) + hbpolys[i, j, 2, 3:5] = torch.tensor(yrange) + hbpolys[i, j, 2, 5:] = torch.tensor(coeffs) + +# kinematic parameters +base_indices = torch.full((22, 27), 0, dtype=torch.long) +xyzs_in_base_frame = torch.ones((22, 27, 4)) +RTs_by_torsion = torch.eye(4).repeat(22, 7, 1, 1) +reference_angles = torch.ones((22, 3, 2)) + +for i in range(22): + i_l = aa2long[i] + for name, base, coords in ideal_coords[i]: + idx = i_l.index(name) + base_indices[i, idx] = base + xyzs_in_base_frame[i, idx, :3] = torch.tensor(coords) + + # omega frame + RTs_by_torsion[i, 0, :3, :3] = torch.eye(3) + RTs_by_torsion[i, 0, :3, 3] = torch.zeros(3) + + # phi frame + RTs_by_torsion[i, 1, :3, :3] = make_frame( + xyzs_in_base_frame[i, 0, :3] - xyzs_in_base_frame[i, 1, :3], + torch.tensor([1.0, 0.0, 0.0]), + ) + RTs_by_torsion[i, 1, :3, 3] = xyzs_in_base_frame[i, 0, :3] + + # psi frame + RTs_by_torsion[i, 2, :3, :3] = make_frame( + xyzs_in_base_frame[i, 2, :3] - xyzs_in_base_frame[i, 1, :3], + xyzs_in_base_frame[i, 1, :3] - xyzs_in_base_frame[i, 0, :3], + ) + RTs_by_torsion[i, 2, :3, 3] = xyzs_in_base_frame[i, 2, :3] + + # chi1 frame + if torsions[i][0] is not None: + a0, a1, a2 = torsion_indices[i, 0, 0:3] + RTs_by_torsion[i, 3, :3, :3] = make_frame( + xyzs_in_base_frame[i, a2, :3] - xyzs_in_base_frame[i, a1, :3], + xyzs_in_base_frame[i, a0, :3] - xyzs_in_base_frame[i, a1, :3], + ) + RTs_by_torsion[i, 3, :3, 3] = xyzs_in_base_frame[i, a2, :3] + + # chi2~4 frame + for j in range(1, 4): + if torsions[i][j] is not None: + a2 = torsion_indices[i, j, 2] + if (i == 18 and j == 2) or ( + i == 8 and j == 2 + ): # TYR CZ-OH & HIS CE1-HE1 a special case + a0, a1 = torsion_indices[i, j, 0:2] + RTs_by_torsion[i, 3 + j, :3, :3] = make_frame( + xyzs_in_base_frame[i, a2, :3] - xyzs_in_base_frame[i, a1, :3], + xyzs_in_base_frame[i, a0, :3] - xyzs_in_base_frame[i, a1, :3], + ) + else: + RTs_by_torsion[i, 3 + j, :3, :3] = make_frame( + xyzs_in_base_frame[i, a2, :3], + torch.tensor([-1.0, 0.0, 0.0]), + ) + RTs_by_torsion[i, 3 + j, :3, 3] = xyzs_in_base_frame[i, a2, :3] + + # CB/CG angles + NCr = 0.5 * (xyzs_in_base_frame[i, 0, :3] + xyzs_in_base_frame[i, 2, :3]) + CAr = xyzs_in_base_frame[i, 1, :3] + CBr = xyzs_in_base_frame[i, 4, :3] + CGr = xyzs_in_base_frame[i, 5, :3] + reference_angles[i, 0, :] = th_ang_v(CBr - CAr, NCr - CAr) + NCp = xyzs_in_base_frame[i, 2, :3] - xyzs_in_base_frame[i, 0, :3] + NCpp = NCp - torch.dot(NCp, NCr) / torch.dot(NCr, NCr) * NCr + reference_angles[i, 1, :] = th_ang_v(CBr - CAr, NCpp) + reference_angles[i, 2, :] = th_ang_v(CGr, torch.tensor([-1.0, 0.0, 0.0])) + +N_BACKBONE_ATOMS = 3 +N_HEAVY = 14 + + +def writepdb_multi( + filename, + atoms_stack, + bfacts, + seq_stack, + backbone_only=False, + chain_ids=None, + use_hydrogens=True, +): + """ + Function for writing multiple structural states of the same sequence into a single + pdb file. + """ + + f = open(filename, "w") + + if seq_stack.ndim != 2: + T = atoms_stack.shape[0] + seq_stack = torch.tile(seq_stack, (T, 1)) + seq_stack = seq_stack.cpu() + for atoms, scpu in zip(atoms_stack, seq_stack): + ctr = 1 + atomscpu = atoms.cpu() + Bfacts = torch.clamp(bfacts.cpu(), 0, 1) + for i, s in enumerate(scpu): + atms = aa2long[s] + for j, atm_j in enumerate(atms): + if backbone_only and j >= N_BACKBONE_ATOMS: + break + if not use_hydrogens and j >= N_HEAVY: + break + if (atm_j is None) or (torch.all(torch.isnan(atomscpu[i, j]))): + continue + chain_id = "A" + if chain_ids is not None: + chain_id = chain_ids[i] + f.write( + "%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n" + % ( + "ATOM", + ctr, + atm_j, + num2aa[s], + chain_id, + i + 1, + atomscpu[i, j, 0], + atomscpu[i, j, 1], + atomscpu[i, j, 2], + 1.0, + Bfacts[i], + ) + ) + ctr += 1 + + f.write("ENDMDL\n") + +def calc_rmsd(xyz1, xyz2, eps=1e-6): + """ + Calculates RMSD between two sets of atoms (L, 3) + """ + # center to CA centroid + xyz1 = xyz1 - xyz1.mean(0) + xyz2 = xyz2 - xyz2.mean(0) + + # Computation of the covariance matrix + C = xyz2.T @ xyz1 + + # Compute otimal rotation matrix using SVD + V, S, W = np.linalg.svd(C) + + # get sign to ensure right-handedness + d = np.ones([3,3]) + d[:,-1] = np.sign(np.linalg.det(V)*np.linalg.det(W)) + + # Rotation matrix U + U = (d*V) @ W + + # Rotate xyz2 + xyz2_ = xyz2 @ U + L = xyz2_.shape[0] + rmsd = np.sqrt(np.sum((xyz2_-xyz1)*(xyz2_-xyz1), axis=(0,1)) / L + eps) + + return rmsd, U diff --git a/rfdiffusion/util_module.py b/rfdiffusion/util_module.py new file mode 100644 index 0000000000000000000000000000000000000000..20ba2dc447a358b707507ee398881ace4e6f4ca6 --- /dev/null +++ b/rfdiffusion/util_module.py @@ -0,0 +1,310 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from opt_einsum import contract as einsum +import copy +import dgl +from rfdiffusion.util import base_indices, RTs_by_torsion, xyzs_in_base_frame, rigid_from_3_points + +def init_lecun_normal(module): + def truncated_normal(uniform, mu=0.0, sigma=1.0, a=-2, b=2): + normal = torch.distributions.normal.Normal(0, 1) + + alpha = (a - mu) / sigma + beta = (b - mu) / sigma + + alpha_normal_cdf = normal.cdf(torch.tensor(alpha)) + p = alpha_normal_cdf + (normal.cdf(torch.tensor(beta)) - alpha_normal_cdf) * uniform + + v = torch.clamp(2 * p - 1, -1 + 1e-8, 1 - 1e-8) + x = mu + sigma * np.sqrt(2) * torch.erfinv(v) + x = torch.clamp(x, a, b) + + return x + + def sample_truncated_normal(shape): + stddev = np.sqrt(1.0/shape[-1])/.87962566103423978 # shape[-1] = fan_in + return stddev * truncated_normal(torch.rand(shape)) + + module.weight = torch.nn.Parameter( (sample_truncated_normal(module.weight.shape)) ) + return module + +def init_lecun_normal_param(weight): + def truncated_normal(uniform, mu=0.0, sigma=1.0, a=-2, b=2): + normal = torch.distributions.normal.Normal(0, 1) + + alpha = (a - mu) / sigma + beta = (b - mu) / sigma + + alpha_normal_cdf = normal.cdf(torch.tensor(alpha)) + p = alpha_normal_cdf + (normal.cdf(torch.tensor(beta)) - alpha_normal_cdf) * uniform + + v = torch.clamp(2 * p - 1, -1 + 1e-8, 1 - 1e-8) + x = mu + sigma * np.sqrt(2) * torch.erfinv(v) + x = torch.clamp(x, a, b) + + return x + + def sample_truncated_normal(shape): + stddev = np.sqrt(1.0/shape[-1])/.87962566103423978 # shape[-1] = fan_in + return stddev * truncated_normal(torch.rand(shape)) + + weight = torch.nn.Parameter( (sample_truncated_normal(weight.shape)) ) + return weight + +# for gradient checkpointing +def create_custom_forward(module, **kwargs): + def custom_forward(*inputs): + return module(*inputs, **kwargs) + return custom_forward + +def get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + +class Dropout(nn.Module): + # Dropout entire row or column + def __init__(self, broadcast_dim=None, p_drop=0.15): + super(Dropout, self).__init__() + # give ones with probability of 1-p_drop / zeros with p_drop + self.sampler = torch.distributions.bernoulli.Bernoulli(torch.tensor([1-p_drop])) + self.broadcast_dim=broadcast_dim + self.p_drop=p_drop + def forward(self, x): + if not self.training: # no drophead during evaluation mode + return x + shape = list(x.shape) + if not self.broadcast_dim == None: + shape[self.broadcast_dim] = 1 + mask = self.sampler.sample(shape).to(x.device).view(shape) + + x = mask * x / (1.0 - self.p_drop) + return x + +def rbf(D): + # Distance radial basis function + D_min, D_max, D_count = 0., 20., 36 + D_mu = torch.linspace(D_min, D_max, D_count).to(D.device) + D_mu = D_mu[None,:] + D_sigma = (D_max - D_min) / D_count + D_expand = torch.unsqueeze(D, -1) + RBF = torch.exp(-((D_expand - D_mu) / D_sigma)**2) + return RBF + +def get_seqsep(idx): + ''' + Input: + - idx: residue indices of given sequence (B,L) + Output: + - seqsep: sequence separation feature with sign (B, L, L, 1) + Sergey found that having sign in seqsep features helps a little + ''' + seqsep = idx[:,None,:] - idx[:,:,None] + sign = torch.sign(seqsep) + neigh = torch.abs(seqsep) + neigh[neigh > 1] = 0.0 # if bonded -- 1.0 / else 0.0 + neigh = sign * neigh + return neigh.unsqueeze(-1) + +def make_full_graph(xyz, pair, idx, top_k=64, kmin=9): + ''' + Input: + - xyz: current backbone cooordinates (B, L, 3, 3) + - pair: pair features from Trunk (B, L, L, E) + - idx: residue index from ground truth pdb + Output: + - G: defined graph + ''' + + B, L = xyz.shape[:2] + device = xyz.device + + # seq sep + sep = idx[:,None,:] - idx[:,:,None] + b,i,j = torch.where(sep.abs() > 0) + + src = b*L+i + tgt = b*L+j + G = dgl.graph((src, tgt), num_nodes=B*L).to(device) + G.edata['rel_pos'] = (xyz[b,j,:] - xyz[b,i,:]).detach() # no gradient through basis function + + return G, pair[b,i,j][...,None] + +def make_topk_graph(xyz, pair, idx, top_k=64, kmin=32, eps=1e-6): + ''' + Input: + - xyz: current backbone cooordinates (B, L, 3, 3) + - pair: pair features from Trunk (B, L, L, E) + - idx: residue index from ground truth pdb + Output: + - G: defined graph + ''' + + B, L = xyz.shape[:2] + device = xyz.device + + # distance map from current CA coordinates + D = torch.cdist(xyz, xyz) + torch.eye(L, device=device).unsqueeze(0)*999.9 # (B, L, L) + # seq sep + sep = idx[:,None,:] - idx[:,:,None] + sep = sep.abs() + torch.eye(L, device=device).unsqueeze(0)*999.9 + D = D + sep*eps + + # get top_k neighbors + D_neigh, E_idx = torch.topk(D, min(top_k, L), largest=False) # shape of E_idx: (B, L, top_k) + topk_matrix = torch.zeros((B, L, L), device=device) + topk_matrix.scatter_(2, E_idx, 1.0) + + # put an edge if any of the 3 conditions are met: + # 1) |i-j| <= kmin (connect sequentially adjacent residues) + # 2) top_k neighbors + cond = torch.logical_or(topk_matrix > 0.0, sep < kmin) + b,i,j = torch.where(cond) + + src = b*L+i + tgt = b*L+j + G = dgl.graph((src, tgt), num_nodes=B*L).to(device) + G.edata['rel_pos'] = (xyz[b,j,:] - xyz[b,i,:]).detach() # no gradient through basis function + + return G, pair[b,i,j][...,None] + +def make_rotX(angs, eps=1e-6): + B,L = angs.shape[:2] + NORM = torch.linalg.norm(angs, dim=-1) + eps + + RTs = torch.eye(4, device=angs.device).repeat(B,L,1,1) + + RTs[:,:,1,1] = angs[:,:,0]/NORM + RTs[:,:,1,2] = -angs[:,:,1]/NORM + RTs[:,:,2,1] = angs[:,:,1]/NORM + RTs[:,:,2,2] = angs[:,:,0]/NORM + return RTs + +# rotate about the z axis +def make_rotZ(angs, eps=1e-6): + B,L = angs.shape[:2] + NORM = torch.linalg.norm(angs, dim=-1) + eps + + RTs = torch.eye(4, device=angs.device).repeat(B,L,1,1) + + RTs[:,:,0,0] = angs[:,:,0]/NORM + RTs[:,:,0,1] = -angs[:,:,1]/NORM + RTs[:,:,1,0] = angs[:,:,1]/NORM + RTs[:,:,1,1] = angs[:,:,0]/NORM + return RTs + +# rotate about an arbitrary axis +def make_rot_axis(angs, u, eps=1e-6): + B,L = angs.shape[:2] + NORM = torch.linalg.norm(angs, dim=-1) + eps + + RTs = torch.eye(4, device=angs.device).repeat(B,L,1,1) + + ct = angs[:,:,0]/NORM + st = angs[:,:,1]/NORM + u0 = u[:,:,0] + u1 = u[:,:,1] + u2 = u[:,:,2] + + RTs[:,:,0,0] = ct+u0*u0*(1-ct) + RTs[:,:,0,1] = u0*u1*(1-ct)-u2*st + RTs[:,:,0,2] = u0*u2*(1-ct)+u1*st + RTs[:,:,1,0] = u0*u1*(1-ct)+u2*st + RTs[:,:,1,1] = ct+u1*u1*(1-ct) + RTs[:,:,1,2] = u1*u2*(1-ct)-u0*st + RTs[:,:,2,0] = u0*u2*(1-ct)-u1*st + RTs[:,:,2,1] = u1*u2*(1-ct)+u0*st + RTs[:,:,2,2] = ct+u2*u2*(1-ct) + return RTs + +class ComputeAllAtomCoords(nn.Module): + def __init__(self): + super(ComputeAllAtomCoords, self).__init__() + + self.base_indices = nn.Parameter(base_indices, requires_grad=False) + self.RTs_in_base_frame = nn.Parameter(RTs_by_torsion, requires_grad=False) + self.xyzs_in_base_frame = nn.Parameter(xyzs_in_base_frame, requires_grad=False) + + def forward(self, seq, xyz, alphas, non_ideal=False, use_H=True): + B,L = xyz.shape[:2] + + Rs, Ts = rigid_from_3_points(xyz[...,0,:],xyz[...,1,:],xyz[...,2,:], non_ideal=non_ideal) + + RTF0 = torch.eye(4).repeat(B,L,1,1).to(device=Rs.device) + + # bb + RTF0[:,:,:3,:3] = Rs + RTF0[:,:,:3,3] = Ts + + # omega + RTF1 = torch.einsum( + 'brij,brjk,brkl->bril', + RTF0, self.RTs_in_base_frame[seq,0,:], make_rotX(alphas[:,:,0,:])) + + # phi + RTF2 = torch.einsum( + 'brij,brjk,brkl->bril', + RTF0, self.RTs_in_base_frame[seq,1,:], make_rotX(alphas[:,:,1,:])) + + # psi + RTF3 = torch.einsum( + 'brij,brjk,brkl->bril', + RTF0, self.RTs_in_base_frame[seq,2,:], make_rotX(alphas[:,:,2,:])) + + # CB bend + basexyzs = self.xyzs_in_base_frame[seq] + NCr = 0.5*(basexyzs[:,:,2,:3]+basexyzs[:,:,0,:3]) + CAr = (basexyzs[:,:,1,:3]) + CBr = (basexyzs[:,:,4,:3]) + CBrotaxis1 = (CBr-CAr).cross(NCr-CAr) + CBrotaxis1 /= torch.linalg.norm(CBrotaxis1, dim=-1, keepdim=True)+1e-8 + + # CB twist + NCp = basexyzs[:,:,2,:3] - basexyzs[:,:,0,:3] + NCpp = NCp - torch.sum(NCp*NCr, dim=-1, keepdim=True)/ torch.sum(NCr*NCr, dim=-1, keepdim=True) * NCr + CBrotaxis2 = (CBr-CAr).cross(NCpp) + CBrotaxis2 /= torch.linalg.norm(CBrotaxis2, dim=-1, keepdim=True)+1e-8 + + CBrot1 = make_rot_axis(alphas[:,:,7,:], CBrotaxis1 ) + CBrot2 = make_rot_axis(alphas[:,:,8,:], CBrotaxis2 ) + + RTF8 = torch.einsum( + 'brij,brjk,brkl->bril', + RTF0, CBrot1,CBrot2) + + # chi1 + CG bend + RTF4 = torch.einsum( + 'brij,brjk,brkl,brlm->brim', + RTF8, + self.RTs_in_base_frame[seq,3,:], + make_rotX(alphas[:,:,3,:]), + make_rotZ(alphas[:,:,9,:])) + + # chi2 + RTF5 = torch.einsum( + 'brij,brjk,brkl->bril', + RTF4, self.RTs_in_base_frame[seq,4,:],make_rotX(alphas[:,:,4,:])) + + # chi3 + RTF6 = torch.einsum( + 'brij,brjk,brkl->bril', + RTF5,self.RTs_in_base_frame[seq,5,:],make_rotX(alphas[:,:,5,:])) + + # chi4 + RTF7 = torch.einsum( + 'brij,brjk,brkl->bril', + RTF6,self.RTs_in_base_frame[seq,6,:],make_rotX(alphas[:,:,6,:])) + + RTframes = torch.stack(( + RTF0,RTF1,RTF2,RTF3,RTF4,RTF5,RTF6,RTF7,RTF8 + ),dim=2) + + xyzs = torch.einsum( + 'brtij,brtj->brti', + RTframes.gather(2,self.base_indices[seq][...,None,None].repeat(1,1,1,4,4)), basexyzs + ) + + if use_H: + return RTframes, xyzs[...,:3] + else: + return RTframes, xyzs[...,:14,:3]