Spaces:
Runtime error
Runtime error
GlandVergil
commited on
Commit
•
0793996
1
Parent(s):
8eb0d85
Upload 95 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- rfdiffusion.egg-info/PKG-INFO +7 -0
- rfdiffusion.egg-info/SOURCES.txt +34 -0
- rfdiffusion.egg-info/dependency_links.txt +1 -0
- rfdiffusion.egg-info/requires.txt +2 -0
- rfdiffusion.egg-info/top_level.txt +1 -0
- rfdiffusion/Attention_module.py +404 -0
- rfdiffusion/AuxiliaryPredictor.py +92 -0
- rfdiffusion/Embeddings.py +303 -0
- rfdiffusion/RoseTTAFoldModel.py +140 -0
- rfdiffusion/SE3_network.py +83 -0
- rfdiffusion/Track_module.py +474 -0
- rfdiffusion/__init__.py +0 -0
- rfdiffusion/__pycache__/Attention_module.cpython-310.pyc +0 -0
- rfdiffusion/__pycache__/Attention_module.cpython-311.pyc +0 -0
- rfdiffusion/__pycache__/Attention_module.cpython-39.pyc +0 -0
- rfdiffusion/__pycache__/AuxiliaryPredictor.cpython-310.pyc +0 -0
- rfdiffusion/__pycache__/AuxiliaryPredictor.cpython-311.pyc +0 -0
- rfdiffusion/__pycache__/AuxiliaryPredictor.cpython-39.pyc +0 -0
- rfdiffusion/__pycache__/Embeddings.cpython-310.pyc +0 -0
- rfdiffusion/__pycache__/Embeddings.cpython-311.pyc +0 -0
- rfdiffusion/__pycache__/Embeddings.cpython-39.pyc +0 -0
- rfdiffusion/__pycache__/RoseTTAFoldModel.cpython-310.pyc +0 -0
- rfdiffusion/__pycache__/RoseTTAFoldModel.cpython-311.pyc +0 -0
- rfdiffusion/__pycache__/RoseTTAFoldModel.cpython-39.pyc +0 -0
- rfdiffusion/__pycache__/SE3_network.cpython-310.pyc +0 -0
- rfdiffusion/__pycache__/SE3_network.cpython-311.pyc +0 -0
- rfdiffusion/__pycache__/SE3_network.cpython-39.pyc +0 -0
- rfdiffusion/__pycache__/Track_module.cpython-310.pyc +0 -0
- rfdiffusion/__pycache__/Track_module.cpython-311.pyc +0 -0
- rfdiffusion/__pycache__/Track_module.cpython-39.pyc +0 -0
- rfdiffusion/__pycache__/__init__.cpython-310.pyc +0 -0
- rfdiffusion/__pycache__/__init__.cpython-311.pyc +0 -0
- rfdiffusion/__pycache__/__init__.cpython-39.pyc +0 -0
- rfdiffusion/__pycache__/chemical.cpython-310.pyc +0 -0
- rfdiffusion/__pycache__/chemical.cpython-311.pyc +0 -0
- rfdiffusion/__pycache__/chemical.cpython-39.pyc +0 -0
- rfdiffusion/__pycache__/contigs.cpython-310.pyc +0 -0
- rfdiffusion/__pycache__/contigs.cpython-311.pyc +0 -0
- rfdiffusion/__pycache__/contigs.cpython-39.pyc +0 -0
- rfdiffusion/__pycache__/diffusion.cpython-310.pyc +0 -0
- rfdiffusion/__pycache__/diffusion.cpython-311.pyc +0 -0
- rfdiffusion/__pycache__/diffusion.cpython-39.pyc +0 -0
- rfdiffusion/__pycache__/igso3.cpython-310.pyc +0 -0
- rfdiffusion/__pycache__/igso3.cpython-311.pyc +0 -0
- rfdiffusion/__pycache__/igso3.cpython-39.pyc +0 -0
- rfdiffusion/__pycache__/kinematics.cpython-310.pyc +0 -0
- rfdiffusion/__pycache__/kinematics.cpython-311.pyc +0 -0
- rfdiffusion/__pycache__/kinematics.cpython-39.pyc +0 -0
- rfdiffusion/__pycache__/model_input_logger.cpython-311.pyc +0 -0
- rfdiffusion/__pycache__/model_input_logger.cpython-39.pyc +0 -0
rfdiffusion.egg-info/PKG-INFO
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Metadata-Version: 2.1
|
2 |
+
Name: rfdiffusion
|
3 |
+
Version: 1.1.0
|
4 |
+
Summary: RFdiffusion is an open source method for protein structure generation.
|
5 |
+
Home-page: https://github.com/RosettaCommons/RFdiffusion
|
6 |
+
Author: Rosetta Commons
|
7 |
+
License-File: LICENSE
|
rfdiffusion.egg-info/SOURCES.txt
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
LICENSE
|
2 |
+
README.md
|
3 |
+
setup.py
|
4 |
+
rfdiffusion/Attention_module.py
|
5 |
+
rfdiffusion/AuxiliaryPredictor.py
|
6 |
+
rfdiffusion/Embeddings.py
|
7 |
+
rfdiffusion/RoseTTAFoldModel.py
|
8 |
+
rfdiffusion/SE3_network.py
|
9 |
+
rfdiffusion/Track_module.py
|
10 |
+
rfdiffusion/__init__.py
|
11 |
+
rfdiffusion/chemical.py
|
12 |
+
rfdiffusion/contigs.py
|
13 |
+
rfdiffusion/coords6d.py
|
14 |
+
rfdiffusion/diffusion.py
|
15 |
+
rfdiffusion/igso3.py
|
16 |
+
rfdiffusion/kinematics.py
|
17 |
+
rfdiffusion/model_input_logger.py
|
18 |
+
rfdiffusion/scoring.py
|
19 |
+
rfdiffusion/util.py
|
20 |
+
rfdiffusion/util_module.py
|
21 |
+
rfdiffusion.egg-info/PKG-INFO
|
22 |
+
rfdiffusion.egg-info/SOURCES.txt
|
23 |
+
rfdiffusion.egg-info/dependency_links.txt
|
24 |
+
rfdiffusion.egg-info/requires.txt
|
25 |
+
rfdiffusion.egg-info/top_level.txt
|
26 |
+
rfdiffusion/inference/__init__.py
|
27 |
+
rfdiffusion/inference/model_runners.py
|
28 |
+
rfdiffusion/inference/symmetry.py
|
29 |
+
rfdiffusion/inference/utils.py
|
30 |
+
rfdiffusion/potentials/__init__.py
|
31 |
+
rfdiffusion/potentials/manager.py
|
32 |
+
rfdiffusion/potentials/potentials.py
|
33 |
+
scripts/run_inference.py
|
34 |
+
tests/test_diffusion.py
|
rfdiffusion.egg-info/dependency_links.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
rfdiffusion.egg-info/requires.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
se3-transformer
|
rfdiffusion.egg-info/top_level.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
rfdiffusion
|
rfdiffusion/Attention_module.py
ADDED
@@ -0,0 +1,404 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import math
|
5 |
+
from opt_einsum import contract as einsum
|
6 |
+
from rfdiffusion.util_module import init_lecun_normal
|
7 |
+
|
8 |
+
class FeedForwardLayer(nn.Module):
|
9 |
+
def __init__(self, d_model, r_ff, p_drop=0.1):
|
10 |
+
super(FeedForwardLayer, self).__init__()
|
11 |
+
self.norm = nn.LayerNorm(d_model)
|
12 |
+
self.linear1 = nn.Linear(d_model, d_model*r_ff)
|
13 |
+
self.dropout = nn.Dropout(p_drop)
|
14 |
+
self.linear2 = nn.Linear(d_model*r_ff, d_model)
|
15 |
+
|
16 |
+
self.reset_parameter()
|
17 |
+
|
18 |
+
def reset_parameter(self):
|
19 |
+
# initialize linear layer right before ReLu: He initializer (kaiming normal)
|
20 |
+
nn.init.kaiming_normal_(self.linear1.weight, nonlinearity='relu')
|
21 |
+
nn.init.zeros_(self.linear1.bias)
|
22 |
+
|
23 |
+
# initialize linear layer right before residual connection: zero initialize
|
24 |
+
nn.init.zeros_(self.linear2.weight)
|
25 |
+
nn.init.zeros_(self.linear2.bias)
|
26 |
+
|
27 |
+
def forward(self, src):
|
28 |
+
src = self.norm(src)
|
29 |
+
src = self.linear2(self.dropout(F.relu_(self.linear1(src))))
|
30 |
+
return src
|
31 |
+
|
32 |
+
class Attention(nn.Module):
|
33 |
+
# calculate multi-head attention
|
34 |
+
def __init__(self, d_query, d_key, n_head, d_hidden, d_out):
|
35 |
+
super(Attention, self).__init__()
|
36 |
+
self.h = n_head
|
37 |
+
self.dim = d_hidden
|
38 |
+
#
|
39 |
+
self.to_q = nn.Linear(d_query, n_head*d_hidden, bias=False)
|
40 |
+
self.to_k = nn.Linear(d_key, n_head*d_hidden, bias=False)
|
41 |
+
self.to_v = nn.Linear(d_key, n_head*d_hidden, bias=False)
|
42 |
+
#
|
43 |
+
self.to_out = nn.Linear(n_head*d_hidden, d_out)
|
44 |
+
self.scaling = 1/math.sqrt(d_hidden)
|
45 |
+
#
|
46 |
+
# initialize all parameters properly
|
47 |
+
self.reset_parameter()
|
48 |
+
|
49 |
+
def reset_parameter(self):
|
50 |
+
# query/key/value projection: Glorot uniform / Xavier uniform
|
51 |
+
nn.init.xavier_uniform_(self.to_q.weight)
|
52 |
+
nn.init.xavier_uniform_(self.to_k.weight)
|
53 |
+
nn.init.xavier_uniform_(self.to_v.weight)
|
54 |
+
|
55 |
+
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
|
56 |
+
nn.init.zeros_(self.to_out.weight)
|
57 |
+
nn.init.zeros_(self.to_out.bias)
|
58 |
+
|
59 |
+
def forward(self, query, key, value):
|
60 |
+
B, Q = query.shape[:2]
|
61 |
+
B, K = key.shape[:2]
|
62 |
+
#
|
63 |
+
query = self.to_q(query).reshape(B, Q, self.h, self.dim)
|
64 |
+
key = self.to_k(key).reshape(B, K, self.h, self.dim)
|
65 |
+
value = self.to_v(value).reshape(B, K, self.h, self.dim)
|
66 |
+
#
|
67 |
+
query = query * self.scaling
|
68 |
+
attn = einsum('bqhd,bkhd->bhqk', query, key)
|
69 |
+
attn = F.softmax(attn, dim=-1)
|
70 |
+
#
|
71 |
+
out = einsum('bhqk,bkhd->bqhd', attn, value)
|
72 |
+
out = out.reshape(B, Q, self.h*self.dim)
|
73 |
+
#
|
74 |
+
out = self.to_out(out)
|
75 |
+
|
76 |
+
return out
|
77 |
+
|
78 |
+
class AttentionWithBias(nn.Module):
|
79 |
+
def __init__(self, d_in=256, d_bias=128, n_head=8, d_hidden=32):
|
80 |
+
super(AttentionWithBias, self).__init__()
|
81 |
+
self.norm_in = nn.LayerNorm(d_in)
|
82 |
+
self.norm_bias = nn.LayerNorm(d_bias)
|
83 |
+
#
|
84 |
+
self.to_q = nn.Linear(d_in, n_head*d_hidden, bias=False)
|
85 |
+
self.to_k = nn.Linear(d_in, n_head*d_hidden, bias=False)
|
86 |
+
self.to_v = nn.Linear(d_in, n_head*d_hidden, bias=False)
|
87 |
+
self.to_b = nn.Linear(d_bias, n_head, bias=False)
|
88 |
+
self.to_g = nn.Linear(d_in, n_head*d_hidden)
|
89 |
+
self.to_out = nn.Linear(n_head*d_hidden, d_in)
|
90 |
+
|
91 |
+
self.scaling = 1/math.sqrt(d_hidden)
|
92 |
+
self.h = n_head
|
93 |
+
self.dim = d_hidden
|
94 |
+
|
95 |
+
self.reset_parameter()
|
96 |
+
|
97 |
+
def reset_parameter(self):
|
98 |
+
# query/key/value projection: Glorot uniform / Xavier uniform
|
99 |
+
nn.init.xavier_uniform_(self.to_q.weight)
|
100 |
+
nn.init.xavier_uniform_(self.to_k.weight)
|
101 |
+
nn.init.xavier_uniform_(self.to_v.weight)
|
102 |
+
|
103 |
+
# bias: normal distribution
|
104 |
+
self.to_b = init_lecun_normal(self.to_b)
|
105 |
+
|
106 |
+
# gating: zero weights, one biases (mostly open gate at the begining)
|
107 |
+
nn.init.zeros_(self.to_g.weight)
|
108 |
+
nn.init.ones_(self.to_g.bias)
|
109 |
+
|
110 |
+
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
|
111 |
+
nn.init.zeros_(self.to_out.weight)
|
112 |
+
nn.init.zeros_(self.to_out.bias)
|
113 |
+
|
114 |
+
def forward(self, x, bias):
|
115 |
+
B, L = x.shape[:2]
|
116 |
+
#
|
117 |
+
x = self.norm_in(x)
|
118 |
+
bias = self.norm_bias(bias)
|
119 |
+
#
|
120 |
+
query = self.to_q(x).reshape(B, L, self.h, self.dim)
|
121 |
+
key = self.to_k(x).reshape(B, L, self.h, self.dim)
|
122 |
+
value = self.to_v(x).reshape(B, L, self.h, self.dim)
|
123 |
+
bias = self.to_b(bias) # (B, L, L, h)
|
124 |
+
gate = torch.sigmoid(self.to_g(x))
|
125 |
+
#
|
126 |
+
key = key * self.scaling
|
127 |
+
attn = einsum('bqhd,bkhd->bqkh', query, key)
|
128 |
+
attn = attn + bias
|
129 |
+
attn = F.softmax(attn, dim=-2)
|
130 |
+
#
|
131 |
+
out = einsum('bqkh,bkhd->bqhd', attn, value).reshape(B, L, -1)
|
132 |
+
out = gate * out
|
133 |
+
#
|
134 |
+
out = self.to_out(out)
|
135 |
+
return out
|
136 |
+
|
137 |
+
# MSA Attention (row/column) from AlphaFold architecture
|
138 |
+
class SequenceWeight(nn.Module):
|
139 |
+
def __init__(self, d_msa, n_head, d_hidden, p_drop=0.1):
|
140 |
+
super(SequenceWeight, self).__init__()
|
141 |
+
self.h = n_head
|
142 |
+
self.dim = d_hidden
|
143 |
+
self.scale = 1.0 / math.sqrt(self.dim)
|
144 |
+
|
145 |
+
self.to_query = nn.Linear(d_msa, n_head*d_hidden)
|
146 |
+
self.to_key = nn.Linear(d_msa, n_head*d_hidden)
|
147 |
+
self.dropout = nn.Dropout(p_drop)
|
148 |
+
|
149 |
+
self.reset_parameter()
|
150 |
+
|
151 |
+
def reset_parameter(self):
|
152 |
+
# query/key/value projection: Glorot uniform / Xavier uniform
|
153 |
+
nn.init.xavier_uniform_(self.to_query.weight)
|
154 |
+
nn.init.xavier_uniform_(self.to_key.weight)
|
155 |
+
|
156 |
+
def forward(self, msa):
|
157 |
+
B, N, L = msa.shape[:3]
|
158 |
+
|
159 |
+
tar_seq = msa[:,0]
|
160 |
+
|
161 |
+
q = self.to_query(tar_seq).view(B, 1, L, self.h, self.dim)
|
162 |
+
k = self.to_key(msa).view(B, N, L, self.h, self.dim)
|
163 |
+
|
164 |
+
q = q * self.scale
|
165 |
+
attn = einsum('bqihd,bkihd->bkihq', q, k)
|
166 |
+
attn = F.softmax(attn, dim=1)
|
167 |
+
return self.dropout(attn)
|
168 |
+
|
169 |
+
class MSARowAttentionWithBias(nn.Module):
|
170 |
+
def __init__(self, d_msa=256, d_pair=128, n_head=8, d_hidden=32):
|
171 |
+
super(MSARowAttentionWithBias, self).__init__()
|
172 |
+
self.norm_msa = nn.LayerNorm(d_msa)
|
173 |
+
self.norm_pair = nn.LayerNorm(d_pair)
|
174 |
+
#
|
175 |
+
self.seq_weight = SequenceWeight(d_msa, n_head, d_hidden, p_drop=0.1)
|
176 |
+
self.to_q = nn.Linear(d_msa, n_head*d_hidden, bias=False)
|
177 |
+
self.to_k = nn.Linear(d_msa, n_head*d_hidden, bias=False)
|
178 |
+
self.to_v = nn.Linear(d_msa, n_head*d_hidden, bias=False)
|
179 |
+
self.to_b = nn.Linear(d_pair, n_head, bias=False)
|
180 |
+
self.to_g = nn.Linear(d_msa, n_head*d_hidden)
|
181 |
+
self.to_out = nn.Linear(n_head*d_hidden, d_msa)
|
182 |
+
|
183 |
+
self.scaling = 1/math.sqrt(d_hidden)
|
184 |
+
self.h = n_head
|
185 |
+
self.dim = d_hidden
|
186 |
+
|
187 |
+
self.reset_parameter()
|
188 |
+
|
189 |
+
def reset_parameter(self):
|
190 |
+
# query/key/value projection: Glorot uniform / Xavier uniform
|
191 |
+
nn.init.xavier_uniform_(self.to_q.weight)
|
192 |
+
nn.init.xavier_uniform_(self.to_k.weight)
|
193 |
+
nn.init.xavier_uniform_(self.to_v.weight)
|
194 |
+
|
195 |
+
# bias: normal distribution
|
196 |
+
self.to_b = init_lecun_normal(self.to_b)
|
197 |
+
|
198 |
+
# gating: zero weights, one biases (mostly open gate at the begining)
|
199 |
+
nn.init.zeros_(self.to_g.weight)
|
200 |
+
nn.init.ones_(self.to_g.bias)
|
201 |
+
|
202 |
+
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
|
203 |
+
nn.init.zeros_(self.to_out.weight)
|
204 |
+
nn.init.zeros_(self.to_out.bias)
|
205 |
+
|
206 |
+
def forward(self, msa, pair): # TODO: make this as tied-attention
|
207 |
+
B, N, L = msa.shape[:3]
|
208 |
+
#
|
209 |
+
msa = self.norm_msa(msa)
|
210 |
+
pair = self.norm_pair(pair)
|
211 |
+
#
|
212 |
+
seq_weight = self.seq_weight(msa) # (B, N, L, h, 1)
|
213 |
+
query = self.to_q(msa).reshape(B, N, L, self.h, self.dim)
|
214 |
+
key = self.to_k(msa).reshape(B, N, L, self.h, self.dim)
|
215 |
+
value = self.to_v(msa).reshape(B, N, L, self.h, self.dim)
|
216 |
+
bias = self.to_b(pair) # (B, L, L, h)
|
217 |
+
gate = torch.sigmoid(self.to_g(msa))
|
218 |
+
#
|
219 |
+
query = query * seq_weight.expand(-1, -1, -1, -1, self.dim)
|
220 |
+
key = key * self.scaling
|
221 |
+
attn = einsum('bsqhd,bskhd->bqkh', query, key)
|
222 |
+
attn = attn + bias
|
223 |
+
attn = F.softmax(attn, dim=-2)
|
224 |
+
#
|
225 |
+
out = einsum('bqkh,bskhd->bsqhd', attn, value).reshape(B, N, L, -1)
|
226 |
+
out = gate * out
|
227 |
+
#
|
228 |
+
out = self.to_out(out)
|
229 |
+
return out
|
230 |
+
|
231 |
+
class MSAColAttention(nn.Module):
|
232 |
+
def __init__(self, d_msa=256, n_head=8, d_hidden=32):
|
233 |
+
super(MSAColAttention, self).__init__()
|
234 |
+
self.norm_msa = nn.LayerNorm(d_msa)
|
235 |
+
#
|
236 |
+
self.to_q = nn.Linear(d_msa, n_head*d_hidden, bias=False)
|
237 |
+
self.to_k = nn.Linear(d_msa, n_head*d_hidden, bias=False)
|
238 |
+
self.to_v = nn.Linear(d_msa, n_head*d_hidden, bias=False)
|
239 |
+
self.to_g = nn.Linear(d_msa, n_head*d_hidden)
|
240 |
+
self.to_out = nn.Linear(n_head*d_hidden, d_msa)
|
241 |
+
|
242 |
+
self.scaling = 1/math.sqrt(d_hidden)
|
243 |
+
self.h = n_head
|
244 |
+
self.dim = d_hidden
|
245 |
+
|
246 |
+
self.reset_parameter()
|
247 |
+
|
248 |
+
def reset_parameter(self):
|
249 |
+
# query/key/value projection: Glorot uniform / Xavier uniform
|
250 |
+
nn.init.xavier_uniform_(self.to_q.weight)
|
251 |
+
nn.init.xavier_uniform_(self.to_k.weight)
|
252 |
+
nn.init.xavier_uniform_(self.to_v.weight)
|
253 |
+
|
254 |
+
# gating: zero weights, one biases (mostly open gate at the begining)
|
255 |
+
nn.init.zeros_(self.to_g.weight)
|
256 |
+
nn.init.ones_(self.to_g.bias)
|
257 |
+
|
258 |
+
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
|
259 |
+
nn.init.zeros_(self.to_out.weight)
|
260 |
+
nn.init.zeros_(self.to_out.bias)
|
261 |
+
|
262 |
+
def forward(self, msa):
|
263 |
+
B, N, L = msa.shape[:3]
|
264 |
+
#
|
265 |
+
msa = self.norm_msa(msa)
|
266 |
+
#
|
267 |
+
query = self.to_q(msa).reshape(B, N, L, self.h, self.dim)
|
268 |
+
key = self.to_k(msa).reshape(B, N, L, self.h, self.dim)
|
269 |
+
value = self.to_v(msa).reshape(B, N, L, self.h, self.dim)
|
270 |
+
gate = torch.sigmoid(self.to_g(msa))
|
271 |
+
#
|
272 |
+
query = query * self.scaling
|
273 |
+
attn = einsum('bqihd,bkihd->bihqk', query, key)
|
274 |
+
attn = F.softmax(attn, dim=-1)
|
275 |
+
#
|
276 |
+
out = einsum('bihqk,bkihd->bqihd', attn, value).reshape(B, N, L, -1)
|
277 |
+
out = gate * out
|
278 |
+
#
|
279 |
+
out = self.to_out(out)
|
280 |
+
return out
|
281 |
+
|
282 |
+
class MSAColGlobalAttention(nn.Module):
|
283 |
+
def __init__(self, d_msa=64, n_head=8, d_hidden=8):
|
284 |
+
super(MSAColGlobalAttention, self).__init__()
|
285 |
+
self.norm_msa = nn.LayerNorm(d_msa)
|
286 |
+
#
|
287 |
+
self.to_q = nn.Linear(d_msa, n_head*d_hidden, bias=False)
|
288 |
+
self.to_k = nn.Linear(d_msa, d_hidden, bias=False)
|
289 |
+
self.to_v = nn.Linear(d_msa, d_hidden, bias=False)
|
290 |
+
self.to_g = nn.Linear(d_msa, n_head*d_hidden)
|
291 |
+
self.to_out = nn.Linear(n_head*d_hidden, d_msa)
|
292 |
+
|
293 |
+
self.scaling = 1/math.sqrt(d_hidden)
|
294 |
+
self.h = n_head
|
295 |
+
self.dim = d_hidden
|
296 |
+
|
297 |
+
self.reset_parameter()
|
298 |
+
|
299 |
+
def reset_parameter(self):
|
300 |
+
# query/key/value projection: Glorot uniform / Xavier uniform
|
301 |
+
nn.init.xavier_uniform_(self.to_q.weight)
|
302 |
+
nn.init.xavier_uniform_(self.to_k.weight)
|
303 |
+
nn.init.xavier_uniform_(self.to_v.weight)
|
304 |
+
|
305 |
+
# gating: zero weights, one biases (mostly open gate at the begining)
|
306 |
+
nn.init.zeros_(self.to_g.weight)
|
307 |
+
nn.init.ones_(self.to_g.bias)
|
308 |
+
|
309 |
+
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
|
310 |
+
nn.init.zeros_(self.to_out.weight)
|
311 |
+
nn.init.zeros_(self.to_out.bias)
|
312 |
+
|
313 |
+
def forward(self, msa):
|
314 |
+
B, N, L = msa.shape[:3]
|
315 |
+
#
|
316 |
+
msa = self.norm_msa(msa)
|
317 |
+
#
|
318 |
+
query = self.to_q(msa).reshape(B, N, L, self.h, self.dim)
|
319 |
+
query = query.mean(dim=1) # (B, L, h, dim)
|
320 |
+
key = self.to_k(msa) # (B, N, L, dim)
|
321 |
+
value = self.to_v(msa) # (B, N, L, dim)
|
322 |
+
gate = torch.sigmoid(self.to_g(msa)) # (B, N, L, h*dim)
|
323 |
+
#
|
324 |
+
query = query * self.scaling
|
325 |
+
attn = einsum('bihd,bkid->bihk', query, key) # (B, L, h, N)
|
326 |
+
attn = F.softmax(attn, dim=-1)
|
327 |
+
#
|
328 |
+
out = einsum('bihk,bkid->bihd', attn, value).reshape(B, 1, L, -1) # (B, 1, L, h*dim)
|
329 |
+
out = gate * out # (B, N, L, h*dim)
|
330 |
+
#
|
331 |
+
out = self.to_out(out)
|
332 |
+
return out
|
333 |
+
|
334 |
+
# Instead of triangle attention, use Tied axail attention with bias from coordinates..?
|
335 |
+
class BiasedAxialAttention(nn.Module):
|
336 |
+
def __init__(self, d_pair, d_bias, n_head, d_hidden, p_drop=0.1, is_row=True):
|
337 |
+
super(BiasedAxialAttention, self).__init__()
|
338 |
+
#
|
339 |
+
self.is_row = is_row
|
340 |
+
self.norm_pair = nn.LayerNorm(d_pair)
|
341 |
+
self.norm_bias = nn.LayerNorm(d_bias)
|
342 |
+
|
343 |
+
self.to_q = nn.Linear(d_pair, n_head*d_hidden, bias=False)
|
344 |
+
self.to_k = nn.Linear(d_pair, n_head*d_hidden, bias=False)
|
345 |
+
self.to_v = nn.Linear(d_pair, n_head*d_hidden, bias=False)
|
346 |
+
self.to_b = nn.Linear(d_bias, n_head, bias=False)
|
347 |
+
self.to_g = nn.Linear(d_pair, n_head*d_hidden)
|
348 |
+
self.to_out = nn.Linear(n_head*d_hidden, d_pair)
|
349 |
+
|
350 |
+
self.scaling = 1/math.sqrt(d_hidden)
|
351 |
+
self.h = n_head
|
352 |
+
self.dim = d_hidden
|
353 |
+
|
354 |
+
# initialize all parameters properly
|
355 |
+
self.reset_parameter()
|
356 |
+
|
357 |
+
def reset_parameter(self):
|
358 |
+
# query/key/value projection: Glorot uniform / Xavier uniform
|
359 |
+
nn.init.xavier_uniform_(self.to_q.weight)
|
360 |
+
nn.init.xavier_uniform_(self.to_k.weight)
|
361 |
+
nn.init.xavier_uniform_(self.to_v.weight)
|
362 |
+
|
363 |
+
# bias: normal distribution
|
364 |
+
self.to_b = init_lecun_normal(self.to_b)
|
365 |
+
|
366 |
+
# gating: zero weights, one biases (mostly open gate at the begining)
|
367 |
+
nn.init.zeros_(self.to_g.weight)
|
368 |
+
nn.init.ones_(self.to_g.bias)
|
369 |
+
|
370 |
+
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
|
371 |
+
nn.init.zeros_(self.to_out.weight)
|
372 |
+
nn.init.zeros_(self.to_out.bias)
|
373 |
+
|
374 |
+
def forward(self, pair, bias):
|
375 |
+
# pair: (B, L, L, d_pair)
|
376 |
+
B, L = pair.shape[:2]
|
377 |
+
|
378 |
+
if self.is_row:
|
379 |
+
pair = pair.permute(0,2,1,3)
|
380 |
+
bias = bias.permute(0,2,1,3)
|
381 |
+
|
382 |
+
pair = self.norm_pair(pair)
|
383 |
+
bias = self.norm_bias(bias)
|
384 |
+
|
385 |
+
query = self.to_q(pair).reshape(B, L, L, self.h, self.dim)
|
386 |
+
key = self.to_k(pair).reshape(B, L, L, self.h, self.dim)
|
387 |
+
value = self.to_v(pair).reshape(B, L, L, self.h, self.dim)
|
388 |
+
bias = self.to_b(bias) # (B, L, L, h)
|
389 |
+
gate = torch.sigmoid(self.to_g(pair)) # (B, L, L, h*dim)
|
390 |
+
|
391 |
+
query = query * self.scaling
|
392 |
+
key = key / math.sqrt(L) # normalize for tied attention
|
393 |
+
attn = einsum('bnihk,bnjhk->bijh', query, key) # tied attention
|
394 |
+
attn = attn + bias # apply bias
|
395 |
+
attn = F.softmax(attn, dim=-2) # (B, L, L, h)
|
396 |
+
|
397 |
+
out = einsum('bijh,bkjhd->bikhd', attn, value).reshape(B, L, L, -1)
|
398 |
+
out = gate * out
|
399 |
+
|
400 |
+
out = self.to_out(out)
|
401 |
+
if self.is_row:
|
402 |
+
out = out.permute(0,2,1,3)
|
403 |
+
return out
|
404 |
+
|
rfdiffusion/AuxiliaryPredictor.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
class DistanceNetwork(nn.Module):
|
5 |
+
def __init__(self, n_feat, p_drop=0.1):
|
6 |
+
super(DistanceNetwork, self).__init__()
|
7 |
+
#
|
8 |
+
self.proj_symm = nn.Linear(n_feat, 37*2)
|
9 |
+
self.proj_asymm = nn.Linear(n_feat, 37+19)
|
10 |
+
|
11 |
+
self.reset_parameter()
|
12 |
+
|
13 |
+
def reset_parameter(self):
|
14 |
+
# initialize linear layer for final logit prediction
|
15 |
+
nn.init.zeros_(self.proj_symm.weight)
|
16 |
+
nn.init.zeros_(self.proj_asymm.weight)
|
17 |
+
nn.init.zeros_(self.proj_symm.bias)
|
18 |
+
nn.init.zeros_(self.proj_asymm.bias)
|
19 |
+
|
20 |
+
def forward(self, x):
|
21 |
+
# input: pair info (B, L, L, C)
|
22 |
+
|
23 |
+
# predict theta, phi (non-symmetric)
|
24 |
+
logits_asymm = self.proj_asymm(x)
|
25 |
+
logits_theta = logits_asymm[:,:,:,:37].permute(0,3,1,2)
|
26 |
+
logits_phi = logits_asymm[:,:,:,37:].permute(0,3,1,2)
|
27 |
+
|
28 |
+
# predict dist, omega
|
29 |
+
logits_symm = self.proj_symm(x)
|
30 |
+
logits_symm = logits_symm + logits_symm.permute(0,2,1,3)
|
31 |
+
logits_dist = logits_symm[:,:,:,:37].permute(0,3,1,2)
|
32 |
+
logits_omega = logits_symm[:,:,:,37:].permute(0,3,1,2)
|
33 |
+
|
34 |
+
return logits_dist, logits_omega, logits_theta, logits_phi
|
35 |
+
|
36 |
+
class MaskedTokenNetwork(nn.Module):
|
37 |
+
def __init__(self, n_feat):
|
38 |
+
super(MaskedTokenNetwork, self).__init__()
|
39 |
+
self.proj = nn.Linear(n_feat, 21)
|
40 |
+
|
41 |
+
self.reset_parameter()
|
42 |
+
|
43 |
+
def reset_parameter(self):
|
44 |
+
nn.init.zeros_(self.proj.weight)
|
45 |
+
nn.init.zeros_(self.proj.bias)
|
46 |
+
|
47 |
+
def forward(self, x):
|
48 |
+
B, N, L = x.shape[:3]
|
49 |
+
logits = self.proj(x).permute(0,3,1,2).reshape(B, -1, N*L)
|
50 |
+
|
51 |
+
return logits
|
52 |
+
|
53 |
+
class LDDTNetwork(nn.Module):
|
54 |
+
def __init__(self, n_feat, n_bin_lddt=50):
|
55 |
+
super(LDDTNetwork, self).__init__()
|
56 |
+
self.proj = nn.Linear(n_feat, n_bin_lddt)
|
57 |
+
|
58 |
+
self.reset_parameter()
|
59 |
+
|
60 |
+
def reset_parameter(self):
|
61 |
+
nn.init.zeros_(self.proj.weight)
|
62 |
+
nn.init.zeros_(self.proj.bias)
|
63 |
+
|
64 |
+
def forward(self, x):
|
65 |
+
logits = self.proj(x) # (B, L, 50)
|
66 |
+
|
67 |
+
return logits.permute(0,2,1)
|
68 |
+
|
69 |
+
class ExpResolvedNetwork(nn.Module):
|
70 |
+
def __init__(self, d_msa, d_state, p_drop=0.1):
|
71 |
+
super(ExpResolvedNetwork, self).__init__()
|
72 |
+
self.norm_msa = nn.LayerNorm(d_msa)
|
73 |
+
self.norm_state = nn.LayerNorm(d_state)
|
74 |
+
self.proj = nn.Linear(d_msa+d_state, 1)
|
75 |
+
|
76 |
+
self.reset_parameter()
|
77 |
+
|
78 |
+
def reset_parameter(self):
|
79 |
+
nn.init.zeros_(self.proj.weight)
|
80 |
+
nn.init.zeros_(self.proj.bias)
|
81 |
+
|
82 |
+
def forward(self, seq, state):
|
83 |
+
B, L = seq.shape[:2]
|
84 |
+
|
85 |
+
seq = self.norm_msa(seq)
|
86 |
+
state = self.norm_state(state)
|
87 |
+
feat = torch.cat((seq, state), dim=-1)
|
88 |
+
logits = self.proj(feat)
|
89 |
+
return logits.reshape(B, L)
|
90 |
+
|
91 |
+
|
92 |
+
|
rfdiffusion/Embeddings.py
ADDED
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from opt_einsum import contract as einsum
|
5 |
+
import torch.utils.checkpoint as checkpoint
|
6 |
+
from rfdiffusion.util import get_tips
|
7 |
+
from rfdiffusion.util_module import Dropout, create_custom_forward, rbf, init_lecun_normal
|
8 |
+
from rfdiffusion.Attention_module import Attention, FeedForwardLayer, AttentionWithBias
|
9 |
+
from rfdiffusion.Track_module import PairStr2Pair
|
10 |
+
import math
|
11 |
+
|
12 |
+
# Module contains classes and functions to generate initial embeddings
|
13 |
+
|
14 |
+
class PositionalEncoding2D(nn.Module):
|
15 |
+
# Add relative positional encoding to pair features
|
16 |
+
def __init__(self, d_model, minpos=-32, maxpos=32, p_drop=0.1):
|
17 |
+
super(PositionalEncoding2D, self).__init__()
|
18 |
+
self.minpos = minpos
|
19 |
+
self.maxpos = maxpos
|
20 |
+
self.nbin = abs(minpos)+maxpos+1
|
21 |
+
self.emb = nn.Embedding(self.nbin, d_model)
|
22 |
+
self.drop = nn.Dropout(p_drop)
|
23 |
+
|
24 |
+
def forward(self, x, idx):
|
25 |
+
bins = torch.arange(self.minpos, self.maxpos, device=x.device)
|
26 |
+
seqsep = idx[:,None,:] - idx[:,:,None] # (B, L, L)
|
27 |
+
#
|
28 |
+
ib = torch.bucketize(seqsep, bins).long() # (B, L, L)
|
29 |
+
emb = self.emb(ib) #(B, L, L, d_model)
|
30 |
+
x = x + emb # add relative positional encoding
|
31 |
+
return self.drop(x)
|
32 |
+
|
33 |
+
class MSA_emb(nn.Module):
|
34 |
+
# Get initial seed MSA embedding
|
35 |
+
def __init__(self, d_msa=256, d_pair=128, d_state=32, d_init=22+22+2+2,
|
36 |
+
minpos=-32, maxpos=32, p_drop=0.1, input_seq_onehot=False):
|
37 |
+
super(MSA_emb, self).__init__()
|
38 |
+
self.emb = nn.Linear(d_init, d_msa) # embedding for general MSA
|
39 |
+
self.emb_q = nn.Embedding(22, d_msa) # embedding for query sequence -- used for MSA embedding
|
40 |
+
self.emb_left = nn.Embedding(22, d_pair) # embedding for query sequence -- used for pair embedding
|
41 |
+
self.emb_right = nn.Embedding(22, d_pair) # embedding for query sequence -- used for pair embedding
|
42 |
+
self.emb_state = nn.Embedding(22, d_state)
|
43 |
+
self.drop = nn.Dropout(p_drop)
|
44 |
+
self.pos = PositionalEncoding2D(d_pair, minpos=minpos, maxpos=maxpos, p_drop=p_drop)
|
45 |
+
|
46 |
+
self.input_seq_onehot=input_seq_onehot
|
47 |
+
|
48 |
+
self.reset_parameter()
|
49 |
+
|
50 |
+
def reset_parameter(self):
|
51 |
+
self.emb = init_lecun_normal(self.emb)
|
52 |
+
self.emb_q = init_lecun_normal(self.emb_q)
|
53 |
+
self.emb_left = init_lecun_normal(self.emb_left)
|
54 |
+
self.emb_right = init_lecun_normal(self.emb_right)
|
55 |
+
self.emb_state = init_lecun_normal(self.emb_state)
|
56 |
+
|
57 |
+
nn.init.zeros_(self.emb.bias)
|
58 |
+
|
59 |
+
def forward(self, msa, seq, idx):
|
60 |
+
# Inputs:
|
61 |
+
# - msa: Input MSA (B, N, L, d_init)
|
62 |
+
# - seq: Input Sequence (B, L)
|
63 |
+
# - idx: Residue index
|
64 |
+
# Outputs:
|
65 |
+
# - msa: Initial MSA embedding (B, N, L, d_msa)
|
66 |
+
# - pair: Initial Pair embedding (B, L, L, d_pair)
|
67 |
+
|
68 |
+
N = msa.shape[1] # number of sequenes in MSA
|
69 |
+
|
70 |
+
# msa embedding
|
71 |
+
msa = self.emb(msa) # (B, N, L, d_model) # MSA embedding
|
72 |
+
|
73 |
+
# Sergey's one hot trick
|
74 |
+
tmp = (seq @ self.emb_q.weight).unsqueeze(1) # (B, 1, L, d_model) -- query embedding
|
75 |
+
|
76 |
+
msa = msa + tmp.expand(-1, N, -1, -1) # adding query embedding to MSA
|
77 |
+
msa = self.drop(msa)
|
78 |
+
|
79 |
+
# pair embedding
|
80 |
+
# Sergey's one hot trick
|
81 |
+
left = (seq @ self.emb_left.weight)[:,None] # (B, 1, L, d_pair)
|
82 |
+
right = (seq @ self.emb_right.weight)[:,:,None] # (B, L, 1, d_pair)
|
83 |
+
|
84 |
+
pair = left + right # (B, L, L, d_pair)
|
85 |
+
pair = self.pos(pair, idx) # add relative position
|
86 |
+
|
87 |
+
# state embedding
|
88 |
+
# Sergey's one hot trick
|
89 |
+
state = self.drop(seq @ self.emb_state.weight)
|
90 |
+
return msa, pair, state
|
91 |
+
|
92 |
+
class Extra_emb(nn.Module):
|
93 |
+
# Get initial seed MSA embedding
|
94 |
+
def __init__(self, d_msa=256, d_init=22+1+2, p_drop=0.1, input_seq_onehot=False):
|
95 |
+
super(Extra_emb, self).__init__()
|
96 |
+
self.emb = nn.Linear(d_init, d_msa) # embedding for general MSA
|
97 |
+
self.emb_q = nn.Embedding(22, d_msa) # embedding for query sequence
|
98 |
+
self.drop = nn.Dropout(p_drop)
|
99 |
+
|
100 |
+
self.input_seq_onehot=input_seq_onehot
|
101 |
+
|
102 |
+
self.reset_parameter()
|
103 |
+
|
104 |
+
def reset_parameter(self):
|
105 |
+
self.emb = init_lecun_normal(self.emb)
|
106 |
+
nn.init.zeros_(self.emb.bias)
|
107 |
+
|
108 |
+
def forward(self, msa, seq, idx):
|
109 |
+
# Inputs:
|
110 |
+
# - msa: Input MSA (B, N, L, d_init)
|
111 |
+
# - seq: Input Sequence (B, L)
|
112 |
+
# - idx: Residue index
|
113 |
+
# Outputs:
|
114 |
+
# - msa: Initial MSA embedding (B, N, L, d_msa)
|
115 |
+
N = msa.shape[1] # number of sequenes in MSA
|
116 |
+
msa = self.emb(msa) # (B, N, L, d_model) # MSA embedding
|
117 |
+
|
118 |
+
# Sergey's one hot trick
|
119 |
+
seq = (seq @ self.emb_q.weight).unsqueeze(1) # (B, 1, L, d_model) -- query embedding
|
120 |
+
msa = msa + seq.expand(-1, N, -1, -1) # adding query embedding to MSA
|
121 |
+
return self.drop(msa)
|
122 |
+
|
123 |
+
class TemplatePairStack(nn.Module):
|
124 |
+
# process template pairwise features
|
125 |
+
# use structure-biased attention
|
126 |
+
def __init__(self, n_block=2, d_templ=64, n_head=4, d_hidden=16, p_drop=0.25):
|
127 |
+
super(TemplatePairStack, self).__init__()
|
128 |
+
self.n_block = n_block
|
129 |
+
proc_s = [PairStr2Pair(d_pair=d_templ, n_head=n_head, d_hidden=d_hidden, p_drop=p_drop) for i in range(n_block)]
|
130 |
+
self.block = nn.ModuleList(proc_s)
|
131 |
+
self.norm = nn.LayerNorm(d_templ)
|
132 |
+
def forward(self, templ, rbf_feat, use_checkpoint=False):
|
133 |
+
B, T, L = templ.shape[:3]
|
134 |
+
templ = templ.reshape(B*T, L, L, -1)
|
135 |
+
|
136 |
+
for i_block in range(self.n_block):
|
137 |
+
if use_checkpoint:
|
138 |
+
templ = checkpoint.checkpoint(create_custom_forward(self.block[i_block]), templ, rbf_feat)
|
139 |
+
else:
|
140 |
+
templ = self.block[i_block](templ, rbf_feat)
|
141 |
+
return self.norm(templ).reshape(B, T, L, L, -1)
|
142 |
+
|
143 |
+
class TemplateTorsionStack(nn.Module):
|
144 |
+
def __init__(self, n_block=2, d_templ=64, n_head=4, d_hidden=16, p_drop=0.15):
|
145 |
+
super(TemplateTorsionStack, self).__init__()
|
146 |
+
self.n_block=n_block
|
147 |
+
self.proj_pair = nn.Linear(d_templ+36, d_templ)
|
148 |
+
proc_s = [AttentionWithBias(d_in=d_templ, d_bias=d_templ,
|
149 |
+
n_head=n_head, d_hidden=d_hidden) for i in range(n_block)]
|
150 |
+
self.row_attn = nn.ModuleList(proc_s)
|
151 |
+
proc_s = [FeedForwardLayer(d_templ, 4, p_drop=p_drop) for i in range(n_block)]
|
152 |
+
self.ff = nn.ModuleList(proc_s)
|
153 |
+
self.norm = nn.LayerNorm(d_templ)
|
154 |
+
|
155 |
+
def reset_parameter(self):
|
156 |
+
self.proj_pair = init_lecun_normal(self.proj_pair)
|
157 |
+
nn.init.zeros_(self.proj_pair.bias)
|
158 |
+
|
159 |
+
def forward(self, tors, pair, rbf_feat, use_checkpoint=False):
|
160 |
+
B, T, L = tors.shape[:3]
|
161 |
+
tors = tors.reshape(B*T, L, -1)
|
162 |
+
pair = pair.reshape(B*T, L, L, -1)
|
163 |
+
pair = torch.cat((pair, rbf_feat), dim=-1)
|
164 |
+
pair = self.proj_pair(pair)
|
165 |
+
|
166 |
+
for i_block in range(self.n_block):
|
167 |
+
if use_checkpoint:
|
168 |
+
tors = tors + checkpoint.checkpoint(create_custom_forward(self.row_attn[i_block]), tors, pair)
|
169 |
+
else:
|
170 |
+
tors = tors + self.row_attn[i_block](tors, pair)
|
171 |
+
tors = tors + self.ff[i_block](tors)
|
172 |
+
return self.norm(tors).reshape(B, T, L, -1)
|
173 |
+
|
174 |
+
class Templ_emb(nn.Module):
|
175 |
+
# Get template embedding
|
176 |
+
# Features are
|
177 |
+
# t2d:
|
178 |
+
# - 37 distogram bins + 6 orientations (43)
|
179 |
+
# - Mask (missing/unaligned) (1)
|
180 |
+
# t1d:
|
181 |
+
# - tiled AA sequence (20 standard aa + gap)
|
182 |
+
# - confidence (1)
|
183 |
+
# - contacting or note (1). NB this is added for diffusion model. Used only in complex training examples - 1 signifies that a residue in the non-diffused chain\
|
184 |
+
# i.e. the context, is in contact with the diffused chain.
|
185 |
+
#
|
186 |
+
#Added extra t1d dimension for contacting or not
|
187 |
+
def __init__(self, d_t1d=21+1+1, d_t2d=43+1, d_tor=30, d_pair=128, d_state=32,
|
188 |
+
n_block=2, d_templ=64,
|
189 |
+
n_head=4, d_hidden=16, p_drop=0.25):
|
190 |
+
super(Templ_emb, self).__init__()
|
191 |
+
# process 2D features
|
192 |
+
self.emb = nn.Linear(d_t1d*2+d_t2d, d_templ)
|
193 |
+
self.templ_stack = TemplatePairStack(n_block=n_block, d_templ=d_templ, n_head=n_head,
|
194 |
+
d_hidden=d_hidden, p_drop=p_drop)
|
195 |
+
|
196 |
+
self.attn = Attention(d_pair, d_templ, n_head, d_hidden, d_pair)
|
197 |
+
|
198 |
+
# process torsion angles
|
199 |
+
self.emb_t1d = nn.Linear(d_t1d+d_tor, d_templ)
|
200 |
+
self.proj_t1d = nn.Linear(d_templ, d_templ)
|
201 |
+
#self.tor_stack = TemplateTorsionStack(n_block=n_block, d_templ=d_templ, n_head=n_head,
|
202 |
+
# d_hidden=d_hidden, p_drop=p_drop)
|
203 |
+
self.attn_tor = Attention(d_state, d_templ, n_head, d_hidden, d_state)
|
204 |
+
|
205 |
+
self.reset_parameter()
|
206 |
+
|
207 |
+
def reset_parameter(self):
|
208 |
+
self.emb = init_lecun_normal(self.emb)
|
209 |
+
nn.init.zeros_(self.emb.bias)
|
210 |
+
|
211 |
+
nn.init.kaiming_normal_(self.emb_t1d.weight, nonlinearity='relu')
|
212 |
+
nn.init.zeros_(self.emb_t1d.bias)
|
213 |
+
|
214 |
+
self.proj_t1d = init_lecun_normal(self.proj_t1d)
|
215 |
+
nn.init.zeros_(self.proj_t1d.bias)
|
216 |
+
|
217 |
+
def forward(self, t1d, t2d, alpha_t, xyz_t, pair, state, use_checkpoint=False):
|
218 |
+
# Input
|
219 |
+
# - t1d: 1D template info (B, T, L, 23)
|
220 |
+
# - t2d: 2D template info (B, T, L, L, 44)
|
221 |
+
B, T, L, _ = t1d.shape
|
222 |
+
|
223 |
+
# Prepare 2D template features
|
224 |
+
left = t1d.unsqueeze(3).expand(-1,-1,-1,L,-1)
|
225 |
+
right = t1d.unsqueeze(2).expand(-1,-1,L,-1,-1)
|
226 |
+
#
|
227 |
+
templ = torch.cat((t2d, left, right), -1) # (B, T, L, L, 90)
|
228 |
+
templ = self.emb(templ) # Template templures (B, T, L, L, d_templ)
|
229 |
+
# process each template features
|
230 |
+
xyz_t = xyz_t.reshape(B*T, L, -1, 3)
|
231 |
+
rbf_feat = rbf(torch.cdist(xyz_t[:,:,1], xyz_t[:,:,1]))
|
232 |
+
templ = self.templ_stack(templ, rbf_feat, use_checkpoint=use_checkpoint) # (B, T, L,L, d_templ)
|
233 |
+
|
234 |
+
# Prepare 1D template torsion angle features
|
235 |
+
t1d = torch.cat((t1d, alpha_t), dim=-1) # (B, T, L, 23+30)
|
236 |
+
|
237 |
+
# process each template features
|
238 |
+
t1d = self.proj_t1d(F.relu_(self.emb_t1d(t1d)))
|
239 |
+
|
240 |
+
# mixing query state features to template state features
|
241 |
+
state = state.reshape(B*L, 1, -1)
|
242 |
+
t1d = t1d.permute(0,2,1,3).reshape(B*L, T, -1)
|
243 |
+
if use_checkpoint:
|
244 |
+
out = checkpoint.checkpoint(create_custom_forward(self.attn_tor), state, t1d, t1d)
|
245 |
+
out = out.reshape(B, L, -1)
|
246 |
+
else:
|
247 |
+
out = self.attn_tor(state, t1d, t1d).reshape(B, L, -1)
|
248 |
+
state = state.reshape(B, L, -1)
|
249 |
+
state = state + out
|
250 |
+
|
251 |
+
# mixing query pair features to template information (Template pointwise attention)
|
252 |
+
pair = pair.reshape(B*L*L, 1, -1)
|
253 |
+
templ = templ.permute(0, 2, 3, 1, 4).reshape(B*L*L, T, -1)
|
254 |
+
if use_checkpoint:
|
255 |
+
out = checkpoint.checkpoint(create_custom_forward(self.attn), pair, templ, templ)
|
256 |
+
out = out.reshape(B, L, L, -1)
|
257 |
+
else:
|
258 |
+
out = self.attn(pair, templ, templ).reshape(B, L, L, -1)
|
259 |
+
#
|
260 |
+
pair = pair.reshape(B, L, L, -1)
|
261 |
+
pair = pair + out
|
262 |
+
|
263 |
+
return pair, state
|
264 |
+
|
265 |
+
class Recycling(nn.Module):
|
266 |
+
def __init__(self, d_msa=256, d_pair=128, d_state=32):
|
267 |
+
super(Recycling, self).__init__()
|
268 |
+
self.proj_dist = nn.Linear(36+d_state*2, d_pair)
|
269 |
+
self.norm_state = nn.LayerNorm(d_state)
|
270 |
+
self.norm_pair = nn.LayerNorm(d_pair)
|
271 |
+
self.norm_msa = nn.LayerNorm(d_msa)
|
272 |
+
|
273 |
+
self.reset_parameter()
|
274 |
+
|
275 |
+
def reset_parameter(self):
|
276 |
+
self.proj_dist = init_lecun_normal(self.proj_dist)
|
277 |
+
nn.init.zeros_(self.proj_dist.bias)
|
278 |
+
|
279 |
+
def forward(self, seq, msa, pair, xyz, state):
|
280 |
+
B, L = pair.shape[:2]
|
281 |
+
state = self.norm_state(state)
|
282 |
+
#
|
283 |
+
left = state.unsqueeze(2).expand(-1,-1,L,-1)
|
284 |
+
right = state.unsqueeze(1).expand(-1,L,-1,-1)
|
285 |
+
|
286 |
+
# three anchor atoms
|
287 |
+
N = xyz[:,:,0]
|
288 |
+
Ca = xyz[:,:,1]
|
289 |
+
C = xyz[:,:,2]
|
290 |
+
|
291 |
+
# recreate Cb given N,Ca,C
|
292 |
+
b = Ca - N
|
293 |
+
c = C - Ca
|
294 |
+
a = torch.cross(b, c, dim=-1)
|
295 |
+
Cb = -0.58273431*a + 0.56802827*b - 0.54067466*c + Ca
|
296 |
+
|
297 |
+
dist = rbf(torch.cdist(Cb, Cb))
|
298 |
+
dist = torch.cat((dist, left, right), dim=-1)
|
299 |
+
dist = self.proj_dist(dist)
|
300 |
+
pair = dist + self.norm_pair(pair)
|
301 |
+
msa = self.norm_msa(msa)
|
302 |
+
return msa, pair, state
|
303 |
+
|
rfdiffusion/RoseTTAFoldModel.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from rfdiffusion.Embeddings import MSA_emb, Extra_emb, Templ_emb, Recycling
|
4 |
+
from rfdiffusion.Track_module import IterativeSimulator
|
5 |
+
from rfdiffusion.AuxiliaryPredictor import DistanceNetwork, MaskedTokenNetwork, ExpResolvedNetwork, LDDTNetwork
|
6 |
+
from opt_einsum import contract as einsum
|
7 |
+
|
8 |
+
class RoseTTAFoldModule(nn.Module):
|
9 |
+
def __init__(self,
|
10 |
+
n_extra_block,
|
11 |
+
n_main_block,
|
12 |
+
n_ref_block,
|
13 |
+
d_msa,
|
14 |
+
d_msa_full,
|
15 |
+
d_pair,
|
16 |
+
d_templ,
|
17 |
+
n_head_msa,
|
18 |
+
n_head_pair,
|
19 |
+
n_head_templ,
|
20 |
+
d_hidden,
|
21 |
+
d_hidden_templ,
|
22 |
+
p_drop,
|
23 |
+
d_t1d,
|
24 |
+
d_t2d,
|
25 |
+
T, # total timesteps (used in timestep emb
|
26 |
+
use_motif_timestep, # Whether to have a distinct emb for motif
|
27 |
+
freeze_track_motif, # Whether to freeze updates to motif in track
|
28 |
+
SE3_param_full={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32},
|
29 |
+
SE3_param_topk={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32},
|
30 |
+
input_seq_onehot=False, # For continuous vs. discrete sequence
|
31 |
+
):
|
32 |
+
|
33 |
+
super(RoseTTAFoldModule, self).__init__()
|
34 |
+
|
35 |
+
self.freeze_track_motif = freeze_track_motif
|
36 |
+
|
37 |
+
# Input Embeddings
|
38 |
+
d_state = SE3_param_topk['l0_out_features']
|
39 |
+
self.latent_emb = MSA_emb(d_msa=d_msa, d_pair=d_pair, d_state=d_state,
|
40 |
+
p_drop=p_drop, input_seq_onehot=input_seq_onehot) # Allowed to take onehotseq
|
41 |
+
self.full_emb = Extra_emb(d_msa=d_msa_full, d_init=25,
|
42 |
+
p_drop=p_drop, input_seq_onehot=input_seq_onehot) # Allowed to take onehotseq
|
43 |
+
self.templ_emb = Templ_emb(d_pair=d_pair, d_templ=d_templ, d_state=d_state,
|
44 |
+
n_head=n_head_templ,
|
45 |
+
d_hidden=d_hidden_templ, p_drop=0.25, d_t1d=d_t1d, d_t2d=d_t2d)
|
46 |
+
|
47 |
+
|
48 |
+
# Update inputs with outputs from previous round
|
49 |
+
self.recycle = Recycling(d_msa=d_msa, d_pair=d_pair, d_state=d_state)
|
50 |
+
#
|
51 |
+
self.simulator = IterativeSimulator(n_extra_block=n_extra_block,
|
52 |
+
n_main_block=n_main_block,
|
53 |
+
n_ref_block=n_ref_block,
|
54 |
+
d_msa=d_msa, d_msa_full=d_msa_full,
|
55 |
+
d_pair=d_pair, d_hidden=d_hidden,
|
56 |
+
n_head_msa=n_head_msa,
|
57 |
+
n_head_pair=n_head_pair,
|
58 |
+
SE3_param_full=SE3_param_full,
|
59 |
+
SE3_param_topk=SE3_param_topk,
|
60 |
+
p_drop=p_drop)
|
61 |
+
##
|
62 |
+
self.c6d_pred = DistanceNetwork(d_pair, p_drop=p_drop)
|
63 |
+
self.aa_pred = MaskedTokenNetwork(d_msa)
|
64 |
+
self.lddt_pred = LDDTNetwork(d_state)
|
65 |
+
|
66 |
+
self.exp_pred = ExpResolvedNetwork(d_msa, d_state)
|
67 |
+
|
68 |
+
def forward(self, msa_latent, msa_full, seq, xyz, idx, t,
|
69 |
+
t1d=None, t2d=None, xyz_t=None, alpha_t=None,
|
70 |
+
msa_prev=None, pair_prev=None, state_prev=None,
|
71 |
+
return_raw=False, return_full=False, return_infer=False,
|
72 |
+
use_checkpoint=False, motif_mask=None, i_cycle=None, n_cycle=None):
|
73 |
+
|
74 |
+
B, N, L = msa_latent.shape[:3]
|
75 |
+
# Get embeddings
|
76 |
+
msa_latent, pair, state = self.latent_emb(msa_latent, seq, idx)
|
77 |
+
msa_full = self.full_emb(msa_full, seq, idx)
|
78 |
+
|
79 |
+
# Do recycling
|
80 |
+
if msa_prev == None:
|
81 |
+
msa_prev = torch.zeros_like(msa_latent[:,0])
|
82 |
+
pair_prev = torch.zeros_like(pair)
|
83 |
+
state_prev = torch.zeros_like(state)
|
84 |
+
msa_recycle, pair_recycle, state_recycle = self.recycle(seq, msa_prev, pair_prev, xyz, state_prev)
|
85 |
+
msa_latent[:,0] = msa_latent[:,0] + msa_recycle.reshape(B,L,-1)
|
86 |
+
pair = pair + pair_recycle
|
87 |
+
state = state + state_recycle
|
88 |
+
|
89 |
+
|
90 |
+
# Get timestep embedding (if using)
|
91 |
+
if hasattr(self, 'timestep_embedder'):
|
92 |
+
assert t is not None
|
93 |
+
time_emb = self.timestep_embedder(L,t,motif_mask)
|
94 |
+
n_tmpl = t1d.shape[1]
|
95 |
+
t1d = torch.cat([t1d, time_emb[None,None,...].repeat(1,n_tmpl,1,1)], dim=-1)
|
96 |
+
|
97 |
+
# add template embedding
|
98 |
+
pair, state = self.templ_emb(t1d, t2d, alpha_t, xyz_t, pair, state, use_checkpoint=use_checkpoint)
|
99 |
+
|
100 |
+
# Predict coordinates from given inputs
|
101 |
+
is_frozen_residue = motif_mask if self.freeze_track_motif else torch.zeros_like(motif_mask).bool()
|
102 |
+
msa, pair, R, T, alpha_s, state = self.simulator(seq, msa_latent, msa_full, pair, xyz[:,:,:3],
|
103 |
+
state, idx, use_checkpoint=use_checkpoint,
|
104 |
+
motif_mask=is_frozen_residue)
|
105 |
+
|
106 |
+
if return_raw:
|
107 |
+
# get last structure
|
108 |
+
xyz = einsum('bnij,bnaj->bnai', R[-1], xyz[:,:,:3]-xyz[:,:,1].unsqueeze(-2)) + T[-1].unsqueeze(-2)
|
109 |
+
return msa[:,0], pair, xyz, state, alpha_s[-1]
|
110 |
+
|
111 |
+
# predict masked amino acids
|
112 |
+
logits_aa = self.aa_pred(msa)
|
113 |
+
|
114 |
+
# Predict LDDT
|
115 |
+
lddt = self.lddt_pred(state)
|
116 |
+
|
117 |
+
if return_infer:
|
118 |
+
# get last structure
|
119 |
+
xyz = einsum('bnij,bnaj->bnai', R[-1], xyz[:,:,:3]-xyz[:,:,1].unsqueeze(-2)) + T[-1].unsqueeze(-2)
|
120 |
+
|
121 |
+
# get scalar plddt
|
122 |
+
nbin = lddt.shape[1]
|
123 |
+
bin_step = 1.0 / nbin
|
124 |
+
lddt_bins = torch.linspace(bin_step, 1.0, nbin, dtype=lddt.dtype, device=lddt.device)
|
125 |
+
pred_lddt = nn.Softmax(dim=1)(lddt)
|
126 |
+
pred_lddt = torch.sum(lddt_bins[None,:,None]*pred_lddt, dim=1)
|
127 |
+
|
128 |
+
return msa[:,0], pair, xyz, state, alpha_s[-1], logits_aa.permute(0,2,1), pred_lddt
|
129 |
+
|
130 |
+
#
|
131 |
+
# predict distogram & orientograms
|
132 |
+
logits = self.c6d_pred(pair)
|
133 |
+
|
134 |
+
# predict experimentally resolved or not
|
135 |
+
logits_exp = self.exp_pred(msa[:,0], state)
|
136 |
+
|
137 |
+
# get all intermediate bb structures
|
138 |
+
xyz = einsum('rbnij,bnaj->rbnai', R, xyz[:,:,:3]-xyz[:,:,1].unsqueeze(-2)) + T.unsqueeze(-2)
|
139 |
+
|
140 |
+
return logits, logits_aa, logits_exp, xyz, alpha_s, lddt
|
rfdiffusion/SE3_network.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
#from equivariant_attention.modules import get_basis_and_r, GSE3Res, GNormBias
|
5 |
+
#from equivariant_attention.modules import GConvSE3, GNormSE3
|
6 |
+
#from equivariant_attention.fibers import Fiber
|
7 |
+
|
8 |
+
from rfdiffusion.util_module import init_lecun_normal_param
|
9 |
+
from se3_transformer.model import SE3Transformer
|
10 |
+
from se3_transformer.model.fiber import Fiber
|
11 |
+
|
12 |
+
class SE3TransformerWrapper(nn.Module):
|
13 |
+
"""SE(3) equivariant GCN with attention"""
|
14 |
+
def __init__(self, num_layers=2, num_channels=32, num_degrees=3, n_heads=4, div=4,
|
15 |
+
l0_in_features=32, l0_out_features=32,
|
16 |
+
l1_in_features=3, l1_out_features=2,
|
17 |
+
num_edge_features=32):
|
18 |
+
super().__init__()
|
19 |
+
# Build the network
|
20 |
+
self.l1_in = l1_in_features
|
21 |
+
#
|
22 |
+
fiber_edge = Fiber({0: num_edge_features})
|
23 |
+
if l1_out_features > 0:
|
24 |
+
if l1_in_features > 0:
|
25 |
+
fiber_in = Fiber({0: l0_in_features, 1: l1_in_features})
|
26 |
+
fiber_hidden = Fiber.create(num_degrees, num_channels)
|
27 |
+
fiber_out = Fiber({0: l0_out_features, 1: l1_out_features})
|
28 |
+
else:
|
29 |
+
fiber_in = Fiber({0: l0_in_features})
|
30 |
+
fiber_hidden = Fiber.create(num_degrees, num_channels)
|
31 |
+
fiber_out = Fiber({0: l0_out_features, 1: l1_out_features})
|
32 |
+
else:
|
33 |
+
if l1_in_features > 0:
|
34 |
+
fiber_in = Fiber({0: l0_in_features, 1: l1_in_features})
|
35 |
+
fiber_hidden = Fiber.create(num_degrees, num_channels)
|
36 |
+
fiber_out = Fiber({0: l0_out_features})
|
37 |
+
else:
|
38 |
+
fiber_in = Fiber({0: l0_in_features})
|
39 |
+
fiber_hidden = Fiber.create(num_degrees, num_channels)
|
40 |
+
fiber_out = Fiber({0: l0_out_features})
|
41 |
+
|
42 |
+
self.se3 = SE3Transformer(num_layers=num_layers,
|
43 |
+
fiber_in=fiber_in,
|
44 |
+
fiber_hidden=fiber_hidden,
|
45 |
+
fiber_out = fiber_out,
|
46 |
+
num_heads=n_heads,
|
47 |
+
channels_div=div,
|
48 |
+
fiber_edge=fiber_edge,
|
49 |
+
use_layer_norm=True)
|
50 |
+
#use_layer_norm=False)
|
51 |
+
|
52 |
+
self.reset_parameter()
|
53 |
+
|
54 |
+
def reset_parameter(self):
|
55 |
+
|
56 |
+
# make sure linear layer before ReLu are initialized with kaiming_normal_
|
57 |
+
for n, p in self.se3.named_parameters():
|
58 |
+
if "bias" in n:
|
59 |
+
nn.init.zeros_(p)
|
60 |
+
elif len(p.shape) == 1:
|
61 |
+
continue
|
62 |
+
else:
|
63 |
+
if "radial_func" not in n:
|
64 |
+
p = init_lecun_normal_param(p)
|
65 |
+
else:
|
66 |
+
if "net.6" in n:
|
67 |
+
nn.init.zeros_(p)
|
68 |
+
else:
|
69 |
+
nn.init.kaiming_normal_(p, nonlinearity='relu')
|
70 |
+
|
71 |
+
# make last layers to be zero-initialized
|
72 |
+
#self.se3.graph_modules[-1].to_kernel_self['0'] = init_lecun_normal_param(self.se3.graph_modules[-1].to_kernel_self['0'])
|
73 |
+
#self.se3.graph_modules[-1].to_kernel_self['1'] = init_lecun_normal_param(self.se3.graph_modules[-1].to_kernel_self['1'])
|
74 |
+
nn.init.zeros_(self.se3.graph_modules[-1].to_kernel_self['0'])
|
75 |
+
nn.init.zeros_(self.se3.graph_modules[-1].to_kernel_self['1'])
|
76 |
+
|
77 |
+
def forward(self, G, type_0_features, type_1_features=None, edge_features=None):
|
78 |
+
if self.l1_in > 0:
|
79 |
+
node_features = {'0': type_0_features, '1': type_1_features}
|
80 |
+
else:
|
81 |
+
node_features = {'0': type_0_features}
|
82 |
+
edge_features = {'0': edge_features}
|
83 |
+
return self.se3(G, node_features, edge_features)
|
rfdiffusion/Track_module.py
ADDED
@@ -0,0 +1,474 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.utils.checkpoint as checkpoint
|
2 |
+
from rfdiffusion.util_module import *
|
3 |
+
from rfdiffusion.Attention_module import *
|
4 |
+
from rfdiffusion.SE3_network import SE3TransformerWrapper
|
5 |
+
|
6 |
+
# Components for three-track blocks
|
7 |
+
# 1. MSA -> MSA update (biased attention. bias from pair & structure)
|
8 |
+
# 2. Pair -> Pair update (biased attention. bias from structure)
|
9 |
+
# 3. MSA -> Pair update (extract coevolution signal)
|
10 |
+
# 4. Str -> Str update (node from MSA, edge from Pair)
|
11 |
+
|
12 |
+
# Update MSA with biased self-attention. bias from Pair & Str
|
13 |
+
class MSAPairStr2MSA(nn.Module):
|
14 |
+
def __init__(self, d_msa=256, d_pair=128, n_head=8, d_state=16,
|
15 |
+
d_hidden=32, p_drop=0.15, use_global_attn=False):
|
16 |
+
super(MSAPairStr2MSA, self).__init__()
|
17 |
+
self.norm_pair = nn.LayerNorm(d_pair)
|
18 |
+
self.proj_pair = nn.Linear(d_pair+36, d_pair)
|
19 |
+
self.norm_state = nn.LayerNorm(d_state)
|
20 |
+
self.proj_state = nn.Linear(d_state, d_msa)
|
21 |
+
self.drop_row = Dropout(broadcast_dim=1, p_drop=p_drop)
|
22 |
+
self.row_attn = MSARowAttentionWithBias(d_msa=d_msa, d_pair=d_pair,
|
23 |
+
n_head=n_head, d_hidden=d_hidden)
|
24 |
+
if use_global_attn:
|
25 |
+
self.col_attn = MSAColGlobalAttention(d_msa=d_msa, n_head=n_head, d_hidden=d_hidden)
|
26 |
+
else:
|
27 |
+
self.col_attn = MSAColAttention(d_msa=d_msa, n_head=n_head, d_hidden=d_hidden)
|
28 |
+
self.ff = FeedForwardLayer(d_msa, 4, p_drop=p_drop)
|
29 |
+
|
30 |
+
# Do proper initialization
|
31 |
+
self.reset_parameter()
|
32 |
+
|
33 |
+
def reset_parameter(self):
|
34 |
+
# initialize weights to normal distrib
|
35 |
+
self.proj_pair = init_lecun_normal(self.proj_pair)
|
36 |
+
self.proj_state = init_lecun_normal(self.proj_state)
|
37 |
+
|
38 |
+
# initialize bias to zeros
|
39 |
+
nn.init.zeros_(self.proj_pair.bias)
|
40 |
+
nn.init.zeros_(self.proj_state.bias)
|
41 |
+
|
42 |
+
def forward(self, msa, pair, rbf_feat, state):
|
43 |
+
'''
|
44 |
+
Inputs:
|
45 |
+
- msa: MSA feature (B, N, L, d_msa)
|
46 |
+
- pair: Pair feature (B, L, L, d_pair)
|
47 |
+
- rbf_feat: Ca-Ca distance feature calculated from xyz coordinates (B, L, L, 36)
|
48 |
+
- xyz: xyz coordinates (B, L, n_atom, 3)
|
49 |
+
- state: updated node features after SE(3)-Transformer layer (B, L, d_state)
|
50 |
+
Output:
|
51 |
+
- msa: Updated MSA feature (B, N, L, d_msa)
|
52 |
+
'''
|
53 |
+
B, N, L = msa.shape[:3]
|
54 |
+
|
55 |
+
# prepare input bias feature by combining pair & coordinate info
|
56 |
+
pair = self.norm_pair(pair)
|
57 |
+
pair = torch.cat((pair, rbf_feat), dim=-1)
|
58 |
+
pair = self.proj_pair(pair) # (B, L, L, d_pair)
|
59 |
+
#
|
60 |
+
# update query sequence feature (first sequence in the MSA) with feedbacks (state) from SE3
|
61 |
+
state = self.norm_state(state)
|
62 |
+
state = self.proj_state(state).reshape(B, 1, L, -1)
|
63 |
+
msa = msa.index_add(1, torch.tensor([0,], device=state.device), state)
|
64 |
+
#
|
65 |
+
# Apply row/column attention to msa & transform
|
66 |
+
msa = msa + self.drop_row(self.row_attn(msa, pair))
|
67 |
+
msa = msa + self.col_attn(msa)
|
68 |
+
msa = msa + self.ff(msa)
|
69 |
+
|
70 |
+
return msa
|
71 |
+
|
72 |
+
class PairStr2Pair(nn.Module):
|
73 |
+
def __init__(self, d_pair=128, n_head=4, d_hidden=32, d_rbf=36, p_drop=0.15):
|
74 |
+
super(PairStr2Pair, self).__init__()
|
75 |
+
|
76 |
+
self.emb_rbf = nn.Linear(d_rbf, d_hidden)
|
77 |
+
self.proj_rbf = nn.Linear(d_hidden, d_pair)
|
78 |
+
|
79 |
+
self.drop_row = Dropout(broadcast_dim=1, p_drop=p_drop)
|
80 |
+
self.drop_col = Dropout(broadcast_dim=2, p_drop=p_drop)
|
81 |
+
|
82 |
+
self.row_attn = BiasedAxialAttention(d_pair, d_pair, n_head, d_hidden, p_drop=p_drop, is_row=True)
|
83 |
+
self.col_attn = BiasedAxialAttention(d_pair, d_pair, n_head, d_hidden, p_drop=p_drop, is_row=False)
|
84 |
+
|
85 |
+
self.ff = FeedForwardLayer(d_pair, 2)
|
86 |
+
|
87 |
+
self.reset_parameter()
|
88 |
+
|
89 |
+
def reset_parameter(self):
|
90 |
+
nn.init.kaiming_normal_(self.emb_rbf.weight, nonlinearity='relu')
|
91 |
+
nn.init.zeros_(self.emb_rbf.bias)
|
92 |
+
|
93 |
+
self.proj_rbf = init_lecun_normal(self.proj_rbf)
|
94 |
+
nn.init.zeros_(self.proj_rbf.bias)
|
95 |
+
|
96 |
+
def forward(self, pair, rbf_feat):
|
97 |
+
B, L = pair.shape[:2]
|
98 |
+
|
99 |
+
rbf_feat = self.proj_rbf(F.relu_(self.emb_rbf(rbf_feat)))
|
100 |
+
|
101 |
+
pair = pair + self.drop_row(self.row_attn(pair, rbf_feat))
|
102 |
+
pair = pair + self.drop_col(self.col_attn(pair, rbf_feat))
|
103 |
+
pair = pair + self.ff(pair)
|
104 |
+
return pair
|
105 |
+
|
106 |
+
class MSA2Pair(nn.Module):
|
107 |
+
def __init__(self, d_msa=256, d_pair=128, d_hidden=32, p_drop=0.15):
|
108 |
+
super(MSA2Pair, self).__init__()
|
109 |
+
self.norm = nn.LayerNorm(d_msa)
|
110 |
+
self.proj_left = nn.Linear(d_msa, d_hidden)
|
111 |
+
self.proj_right = nn.Linear(d_msa, d_hidden)
|
112 |
+
self.proj_out = nn.Linear(d_hidden*d_hidden, d_pair)
|
113 |
+
|
114 |
+
self.reset_parameter()
|
115 |
+
|
116 |
+
def reset_parameter(self):
|
117 |
+
# normal initialization
|
118 |
+
self.proj_left = init_lecun_normal(self.proj_left)
|
119 |
+
self.proj_right = init_lecun_normal(self.proj_right)
|
120 |
+
nn.init.zeros_(self.proj_left.bias)
|
121 |
+
nn.init.zeros_(self.proj_right.bias)
|
122 |
+
|
123 |
+
# zero initialize output
|
124 |
+
nn.init.zeros_(self.proj_out.weight)
|
125 |
+
nn.init.zeros_(self.proj_out.bias)
|
126 |
+
|
127 |
+
def forward(self, msa, pair):
|
128 |
+
B, N, L = msa.shape[:3]
|
129 |
+
msa = self.norm(msa)
|
130 |
+
left = self.proj_left(msa)
|
131 |
+
right = self.proj_right(msa)
|
132 |
+
right = right / float(N)
|
133 |
+
out = einsum('bsli,bsmj->blmij', left, right).reshape(B, L, L, -1)
|
134 |
+
out = self.proj_out(out)
|
135 |
+
|
136 |
+
pair = pair + out
|
137 |
+
|
138 |
+
return pair
|
139 |
+
|
140 |
+
class SCPred(nn.Module):
|
141 |
+
def __init__(self, d_msa=256, d_state=32, d_hidden=128, p_drop=0.15):
|
142 |
+
super(SCPred, self).__init__()
|
143 |
+
self.norm_s0 = nn.LayerNorm(d_msa)
|
144 |
+
self.norm_si = nn.LayerNorm(d_state)
|
145 |
+
self.linear_s0 = nn.Linear(d_msa, d_hidden)
|
146 |
+
self.linear_si = nn.Linear(d_state, d_hidden)
|
147 |
+
|
148 |
+
# ResNet layers
|
149 |
+
self.linear_1 = nn.Linear(d_hidden, d_hidden)
|
150 |
+
self.linear_2 = nn.Linear(d_hidden, d_hidden)
|
151 |
+
self.linear_3 = nn.Linear(d_hidden, d_hidden)
|
152 |
+
self.linear_4 = nn.Linear(d_hidden, d_hidden)
|
153 |
+
|
154 |
+
# Final outputs
|
155 |
+
self.linear_out = nn.Linear(d_hidden, 20)
|
156 |
+
|
157 |
+
self.reset_parameter()
|
158 |
+
|
159 |
+
def reset_parameter(self):
|
160 |
+
# normal initialization
|
161 |
+
self.linear_s0 = init_lecun_normal(self.linear_s0)
|
162 |
+
self.linear_si = init_lecun_normal(self.linear_si)
|
163 |
+
self.linear_out = init_lecun_normal(self.linear_out)
|
164 |
+
nn.init.zeros_(self.linear_s0.bias)
|
165 |
+
nn.init.zeros_(self.linear_si.bias)
|
166 |
+
nn.init.zeros_(self.linear_out.bias)
|
167 |
+
|
168 |
+
# right before relu activation: He initializer (kaiming normal)
|
169 |
+
nn.init.kaiming_normal_(self.linear_1.weight, nonlinearity='relu')
|
170 |
+
nn.init.zeros_(self.linear_1.bias)
|
171 |
+
nn.init.kaiming_normal_(self.linear_3.weight, nonlinearity='relu')
|
172 |
+
nn.init.zeros_(self.linear_3.bias)
|
173 |
+
|
174 |
+
# right before residual connection: zero initialize
|
175 |
+
nn.init.zeros_(self.linear_2.weight)
|
176 |
+
nn.init.zeros_(self.linear_2.bias)
|
177 |
+
nn.init.zeros_(self.linear_4.weight)
|
178 |
+
nn.init.zeros_(self.linear_4.bias)
|
179 |
+
|
180 |
+
def forward(self, seq, state):
|
181 |
+
'''
|
182 |
+
Predict side-chain torsion angles along with backbone torsions
|
183 |
+
Inputs:
|
184 |
+
- seq: hidden embeddings corresponding to query sequence (B, L, d_msa)
|
185 |
+
- state: state feature (output l0 feature) from previous SE3 layer (B, L, d_state)
|
186 |
+
Outputs:
|
187 |
+
- si: predicted torsion angles (phi, psi, omega, chi1~4 with cos/sin, Cb bend, Cb twist, CG) (B, L, 10, 2)
|
188 |
+
'''
|
189 |
+
B, L = seq.shape[:2]
|
190 |
+
seq = self.norm_s0(seq)
|
191 |
+
state = self.norm_si(state)
|
192 |
+
si = self.linear_s0(seq) + self.linear_si(state)
|
193 |
+
|
194 |
+
si = si + self.linear_2(F.relu_(self.linear_1(F.relu_(si))))
|
195 |
+
si = si + self.linear_4(F.relu_(self.linear_3(F.relu_(si))))
|
196 |
+
|
197 |
+
si = self.linear_out(F.relu_(si))
|
198 |
+
return si.view(B, L, 10, 2)
|
199 |
+
|
200 |
+
|
201 |
+
class Str2Str(nn.Module):
|
202 |
+
def __init__(self, d_msa=256, d_pair=128, d_state=16,
|
203 |
+
SE3_param={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32}, p_drop=0.1):
|
204 |
+
super(Str2Str, self).__init__()
|
205 |
+
|
206 |
+
# initial node & pair feature process
|
207 |
+
self.norm_msa = nn.LayerNorm(d_msa)
|
208 |
+
self.norm_pair = nn.LayerNorm(d_pair)
|
209 |
+
self.norm_state = nn.LayerNorm(d_state)
|
210 |
+
|
211 |
+
self.embed_x = nn.Linear(d_msa+d_state, SE3_param['l0_in_features'])
|
212 |
+
self.embed_e1 = nn.Linear(d_pair, SE3_param['num_edge_features'])
|
213 |
+
self.embed_e2 = nn.Linear(SE3_param['num_edge_features']+36+1, SE3_param['num_edge_features'])
|
214 |
+
|
215 |
+
self.norm_node = nn.LayerNorm(SE3_param['l0_in_features'])
|
216 |
+
self.norm_edge1 = nn.LayerNorm(SE3_param['num_edge_features'])
|
217 |
+
self.norm_edge2 = nn.LayerNorm(SE3_param['num_edge_features'])
|
218 |
+
|
219 |
+
self.se3 = SE3TransformerWrapper(**SE3_param)
|
220 |
+
self.sc_predictor = SCPred(d_msa=d_msa, d_state=SE3_param['l0_out_features'],
|
221 |
+
p_drop=p_drop)
|
222 |
+
|
223 |
+
self.reset_parameter()
|
224 |
+
|
225 |
+
def reset_parameter(self):
|
226 |
+
# initialize weights to normal distribution
|
227 |
+
self.embed_x = init_lecun_normal(self.embed_x)
|
228 |
+
self.embed_e1 = init_lecun_normal(self.embed_e1)
|
229 |
+
self.embed_e2 = init_lecun_normal(self.embed_e2)
|
230 |
+
|
231 |
+
# initialize bias to zeros
|
232 |
+
nn.init.zeros_(self.embed_x.bias)
|
233 |
+
nn.init.zeros_(self.embed_e1.bias)
|
234 |
+
nn.init.zeros_(self.embed_e2.bias)
|
235 |
+
|
236 |
+
@torch.cuda.amp.autocast(enabled=False)
|
237 |
+
def forward(self, msa, pair, R_in, T_in, xyz, state, idx, motif_mask, top_k=64, eps=1e-5):
|
238 |
+
B, N, L = msa.shape[:3]
|
239 |
+
|
240 |
+
if motif_mask is None:
|
241 |
+
motif_mask = torch.zeros(L).bool()
|
242 |
+
|
243 |
+
# process msa & pair features
|
244 |
+
node = self.norm_msa(msa[:,0])
|
245 |
+
pair = self.norm_pair(pair)
|
246 |
+
state = self.norm_state(state)
|
247 |
+
|
248 |
+
node = torch.cat((node, state), dim=-1)
|
249 |
+
node = self.norm_node(self.embed_x(node))
|
250 |
+
pair = self.norm_edge1(self.embed_e1(pair))
|
251 |
+
|
252 |
+
neighbor = get_seqsep(idx)
|
253 |
+
rbf_feat = rbf(torch.cdist(xyz[:,:,1], xyz[:,:,1]))
|
254 |
+
pair = torch.cat((pair, rbf_feat, neighbor), dim=-1)
|
255 |
+
pair = self.norm_edge2(self.embed_e2(pair))
|
256 |
+
|
257 |
+
# define graph
|
258 |
+
if top_k != 0:
|
259 |
+
G, edge_feats = make_topk_graph(xyz[:,:,1,:], pair, idx, top_k=top_k)
|
260 |
+
else:
|
261 |
+
G, edge_feats = make_full_graph(xyz[:,:,1,:], pair, idx, top_k=top_k)
|
262 |
+
l1_feats = xyz - xyz[:,:,1,:].unsqueeze(2)
|
263 |
+
l1_feats = l1_feats.reshape(B*L, -1, 3)
|
264 |
+
|
265 |
+
# apply SE(3) Transformer & update coordinates
|
266 |
+
shift = self.se3(G, node.reshape(B*L, -1, 1), l1_feats, edge_feats)
|
267 |
+
|
268 |
+
state = shift['0'].reshape(B, L, -1) # (B, L, C)
|
269 |
+
|
270 |
+
offset = shift['1'].reshape(B, L, 2, 3)
|
271 |
+
offset[:,motif_mask,...] = 0 # NOTE: motif mask is all zeros if not freeezing the motif
|
272 |
+
|
273 |
+
delTi = offset[:,:,0,:] / 10.0 # translation
|
274 |
+
R = offset[:,:,1,:] / 100.0 # rotation
|
275 |
+
|
276 |
+
Qnorm = torch.sqrt( 1 + torch.sum(R*R, dim=-1) )
|
277 |
+
qA, qB, qC, qD = 1/Qnorm, R[:,:,0]/Qnorm, R[:,:,1]/Qnorm, R[:,:,2]/Qnorm
|
278 |
+
|
279 |
+
delRi = torch.zeros((B,L,3,3), device=xyz.device)
|
280 |
+
delRi[:,:,0,0] = qA*qA+qB*qB-qC*qC-qD*qD
|
281 |
+
delRi[:,:,0,1] = 2*qB*qC - 2*qA*qD
|
282 |
+
delRi[:,:,0,2] = 2*qB*qD + 2*qA*qC
|
283 |
+
delRi[:,:,1,0] = 2*qB*qC + 2*qA*qD
|
284 |
+
delRi[:,:,1,1] = qA*qA-qB*qB+qC*qC-qD*qD
|
285 |
+
delRi[:,:,1,2] = 2*qC*qD - 2*qA*qB
|
286 |
+
delRi[:,:,2,0] = 2*qB*qD - 2*qA*qC
|
287 |
+
delRi[:,:,2,1] = 2*qC*qD + 2*qA*qB
|
288 |
+
delRi[:,:,2,2] = qA*qA-qB*qB-qC*qC+qD*qD
|
289 |
+
|
290 |
+
Ri = einsum('bnij,bnjk->bnik', delRi, R_in)
|
291 |
+
Ti = delTi + T_in #einsum('bnij,bnj->bni', delRi, T_in) + delTi
|
292 |
+
|
293 |
+
alpha = self.sc_predictor(msa[:,0], state)
|
294 |
+
return Ri, Ti, state, alpha
|
295 |
+
|
296 |
+
class IterBlock(nn.Module):
|
297 |
+
def __init__(self, d_msa=256, d_pair=128,
|
298 |
+
n_head_msa=8, n_head_pair=4,
|
299 |
+
use_global_attn=False,
|
300 |
+
d_hidden=32, d_hidden_msa=None, p_drop=0.15,
|
301 |
+
SE3_param={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32}):
|
302 |
+
super(IterBlock, self).__init__()
|
303 |
+
if d_hidden_msa == None:
|
304 |
+
d_hidden_msa = d_hidden
|
305 |
+
|
306 |
+
self.msa2msa = MSAPairStr2MSA(d_msa=d_msa, d_pair=d_pair,
|
307 |
+
n_head=n_head_msa,
|
308 |
+
d_state=SE3_param['l0_out_features'],
|
309 |
+
use_global_attn=use_global_attn,
|
310 |
+
d_hidden=d_hidden_msa, p_drop=p_drop)
|
311 |
+
self.msa2pair = MSA2Pair(d_msa=d_msa, d_pair=d_pair,
|
312 |
+
d_hidden=d_hidden//2, p_drop=p_drop)
|
313 |
+
#d_hidden=d_hidden, p_drop=p_drop)
|
314 |
+
self.pair2pair = PairStr2Pair(d_pair=d_pair, n_head=n_head_pair,
|
315 |
+
d_hidden=d_hidden, p_drop=p_drop)
|
316 |
+
self.str2str = Str2Str(d_msa=d_msa, d_pair=d_pair,
|
317 |
+
d_state=SE3_param['l0_out_features'],
|
318 |
+
SE3_param=SE3_param,
|
319 |
+
p_drop=p_drop)
|
320 |
+
|
321 |
+
def forward(self, msa, pair, R_in, T_in, xyz, state, idx, motif_mask, use_checkpoint=False):
|
322 |
+
rbf_feat = rbf(torch.cdist(xyz[:,:,1,:], xyz[:,:,1,:]))
|
323 |
+
if use_checkpoint:
|
324 |
+
msa = checkpoint.checkpoint(create_custom_forward(self.msa2msa), msa, pair, rbf_feat, state)
|
325 |
+
pair = checkpoint.checkpoint(create_custom_forward(self.msa2pair), msa, pair)
|
326 |
+
pair = checkpoint.checkpoint(create_custom_forward(self.pair2pair), pair, rbf_feat)
|
327 |
+
R, T, state, alpha = checkpoint.checkpoint(create_custom_forward(self.str2str, top_k=0), msa, pair, R_in, T_in, xyz, state, idx, motif_mask)
|
328 |
+
else:
|
329 |
+
msa = self.msa2msa(msa, pair, rbf_feat, state)
|
330 |
+
pair = self.msa2pair(msa, pair)
|
331 |
+
pair = self.pair2pair(pair, rbf_feat)
|
332 |
+
R, T, state, alpha = self.str2str(msa, pair, R_in, T_in, xyz, state, idx, motif_mask=motif_mask, top_k=0)
|
333 |
+
|
334 |
+
return msa, pair, R, T, state, alpha
|
335 |
+
|
336 |
+
class IterativeSimulator(nn.Module):
|
337 |
+
def __init__(self, n_extra_block=4, n_main_block=12, n_ref_block=4,
|
338 |
+
d_msa=256, d_msa_full=64, d_pair=128, d_hidden=32,
|
339 |
+
n_head_msa=8, n_head_pair=4,
|
340 |
+
SE3_param_full={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32},
|
341 |
+
SE3_param_topk={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32},
|
342 |
+
p_drop=0.15):
|
343 |
+
super(IterativeSimulator, self).__init__()
|
344 |
+
self.n_extra_block = n_extra_block
|
345 |
+
self.n_main_block = n_main_block
|
346 |
+
self.n_ref_block = n_ref_block
|
347 |
+
|
348 |
+
self.proj_state = nn.Linear(SE3_param_topk['l0_out_features'], SE3_param_full['l0_out_features'])
|
349 |
+
# Update with extra sequences
|
350 |
+
if n_extra_block > 0:
|
351 |
+
self.extra_block = nn.ModuleList([IterBlock(d_msa=d_msa_full, d_pair=d_pair,
|
352 |
+
n_head_msa=n_head_msa,
|
353 |
+
n_head_pair=n_head_pair,
|
354 |
+
d_hidden_msa=8,
|
355 |
+
d_hidden=d_hidden,
|
356 |
+
p_drop=p_drop,
|
357 |
+
use_global_attn=True,
|
358 |
+
SE3_param=SE3_param_full)
|
359 |
+
for i in range(n_extra_block)])
|
360 |
+
|
361 |
+
# Update with seed sequences
|
362 |
+
if n_main_block > 0:
|
363 |
+
self.main_block = nn.ModuleList([IterBlock(d_msa=d_msa, d_pair=d_pair,
|
364 |
+
n_head_msa=n_head_msa,
|
365 |
+
n_head_pair=n_head_pair,
|
366 |
+
d_hidden=d_hidden,
|
367 |
+
p_drop=p_drop,
|
368 |
+
use_global_attn=False,
|
369 |
+
SE3_param=SE3_param_full)
|
370 |
+
for i in range(n_main_block)])
|
371 |
+
|
372 |
+
self.proj_state2 = nn.Linear(SE3_param_full['l0_out_features'], SE3_param_topk['l0_out_features'])
|
373 |
+
# Final SE(3) refinement
|
374 |
+
if n_ref_block > 0:
|
375 |
+
self.str_refiner = Str2Str(d_msa=d_msa, d_pair=d_pair,
|
376 |
+
d_state=SE3_param_topk['l0_out_features'],
|
377 |
+
SE3_param=SE3_param_topk,
|
378 |
+
p_drop=p_drop)
|
379 |
+
|
380 |
+
self.reset_parameter()
|
381 |
+
def reset_parameter(self):
|
382 |
+
self.proj_state = init_lecun_normal(self.proj_state)
|
383 |
+
nn.init.zeros_(self.proj_state.bias)
|
384 |
+
self.proj_state2 = init_lecun_normal(self.proj_state2)
|
385 |
+
nn.init.zeros_(self.proj_state2.bias)
|
386 |
+
|
387 |
+
def forward(self, seq, msa, msa_full, pair, xyz_in, state, idx, use_checkpoint=False, motif_mask=None):
|
388 |
+
"""
|
389 |
+
input:
|
390 |
+
seq: query sequence (B, L)
|
391 |
+
msa: seed MSA embeddings (B, N, L, d_msa)
|
392 |
+
msa_full: extra MSA embeddings (B, N, L, d_msa_full)
|
393 |
+
pair: initial residue pair embeddings (B, L, L, d_pair)
|
394 |
+
xyz_in: initial BB coordinates (B, L, n_atom, 3)
|
395 |
+
state: initial state features containing mixture of query seq, sidechain, accuracy info (B, L, d_state)
|
396 |
+
idx: residue index
|
397 |
+
motif_mask: bool tensor, True if motif position that is frozen, else False(L,)
|
398 |
+
"""
|
399 |
+
|
400 |
+
B, L = pair.shape[:2]
|
401 |
+
|
402 |
+
if motif_mask is None:
|
403 |
+
motif_mask = torch.zeros(L).bool()
|
404 |
+
|
405 |
+
R_in = torch.eye(3, device=xyz_in.device).reshape(1,1,3,3).expand(B, L, -1, -1)
|
406 |
+
T_in = xyz_in[:,:,1].clone()
|
407 |
+
xyz_in = xyz_in - T_in.unsqueeze(-2)
|
408 |
+
|
409 |
+
state = self.proj_state(state)
|
410 |
+
|
411 |
+
R_s = list()
|
412 |
+
T_s = list()
|
413 |
+
alpha_s = list()
|
414 |
+
for i_m in range(self.n_extra_block):
|
415 |
+
R_in = R_in.detach() # detach rotation (for stability)
|
416 |
+
T_in = T_in.detach()
|
417 |
+
# Get current BB structure
|
418 |
+
xyz = einsum('bnij,bnaj->bnai', R_in, xyz_in) + T_in.unsqueeze(-2)
|
419 |
+
|
420 |
+
msa_full, pair, R_in, T_in, state, alpha = self.extra_block[i_m](msa_full,
|
421 |
+
pair,
|
422 |
+
R_in,
|
423 |
+
T_in,
|
424 |
+
xyz,
|
425 |
+
state,
|
426 |
+
idx,
|
427 |
+
motif_mask=motif_mask,
|
428 |
+
use_checkpoint=use_checkpoint)
|
429 |
+
R_s.append(R_in)
|
430 |
+
T_s.append(T_in)
|
431 |
+
alpha_s.append(alpha)
|
432 |
+
|
433 |
+
for i_m in range(self.n_main_block):
|
434 |
+
R_in = R_in.detach()
|
435 |
+
T_in = T_in.detach()
|
436 |
+
# Get current BB structure
|
437 |
+
xyz = einsum('bnij,bnaj->bnai', R_in, xyz_in) + T_in.unsqueeze(-2)
|
438 |
+
|
439 |
+
msa, pair, R_in, T_in, state, alpha = self.main_block[i_m](msa,
|
440 |
+
pair,
|
441 |
+
R_in,
|
442 |
+
T_in,
|
443 |
+
xyz,
|
444 |
+
state,
|
445 |
+
idx,
|
446 |
+
motif_mask=motif_mask,
|
447 |
+
use_checkpoint=use_checkpoint)
|
448 |
+
R_s.append(R_in)
|
449 |
+
T_s.append(T_in)
|
450 |
+
alpha_s.append(alpha)
|
451 |
+
|
452 |
+
state = self.proj_state2(state)
|
453 |
+
for i_m in range(self.n_ref_block):
|
454 |
+
R_in = R_in.detach()
|
455 |
+
T_in = T_in.detach()
|
456 |
+
xyz = einsum('bnij,bnaj->bnai', R_in, xyz_in) + T_in.unsqueeze(-2)
|
457 |
+
R_in, T_in, state, alpha = self.str_refiner(msa,
|
458 |
+
pair,
|
459 |
+
R_in,
|
460 |
+
T_in,
|
461 |
+
xyz,
|
462 |
+
state,
|
463 |
+
idx,
|
464 |
+
top_k=64,
|
465 |
+
motif_mask=motif_mask)
|
466 |
+
R_s.append(R_in)
|
467 |
+
T_s.append(T_in)
|
468 |
+
alpha_s.append(alpha)
|
469 |
+
|
470 |
+
R_s = torch.stack(R_s, dim=0)
|
471 |
+
T_s = torch.stack(T_s, dim=0)
|
472 |
+
alpha_s = torch.stack(alpha_s, dim=0)
|
473 |
+
|
474 |
+
return msa, pair, R_s, T_s, alpha_s, state
|
rfdiffusion/__init__.py
ADDED
File without changes
|
rfdiffusion/__pycache__/Attention_module.cpython-310.pyc
ADDED
Binary file (10.7 kB). View file
|
|
rfdiffusion/__pycache__/Attention_module.cpython-311.pyc
ADDED
Binary file (27 kB). View file
|
|
rfdiffusion/__pycache__/Attention_module.cpython-39.pyc
ADDED
Binary file (11.3 kB). View file
|
|
rfdiffusion/__pycache__/AuxiliaryPredictor.cpython-310.pyc
ADDED
Binary file (3.59 kB). View file
|
|
rfdiffusion/__pycache__/AuxiliaryPredictor.cpython-311.pyc
ADDED
Binary file (7.34 kB). View file
|
|
rfdiffusion/__pycache__/AuxiliaryPredictor.cpython-39.pyc
ADDED
Binary file (3.76 kB). View file
|
|
rfdiffusion/__pycache__/Embeddings.cpython-310.pyc
ADDED
Binary file (9.5 kB). View file
|
|
rfdiffusion/__pycache__/Embeddings.cpython-311.pyc
ADDED
Binary file (20 kB). View file
|
|
rfdiffusion/__pycache__/Embeddings.cpython-39.pyc
ADDED
Binary file (9.54 kB). View file
|
|
rfdiffusion/__pycache__/RoseTTAFoldModel.cpython-310.pyc
ADDED
Binary file (3.68 kB). View file
|
|
rfdiffusion/__pycache__/RoseTTAFoldModel.cpython-311.pyc
ADDED
Binary file (7.09 kB). View file
|
|
rfdiffusion/__pycache__/RoseTTAFoldModel.cpython-39.pyc
ADDED
Binary file (3.67 kB). View file
|
|
rfdiffusion/__pycache__/SE3_network.cpython-310.pyc
ADDED
Binary file (2.29 kB). View file
|
|
rfdiffusion/__pycache__/SE3_network.cpython-311.pyc
ADDED
Binary file (4.13 kB). View file
|
|
rfdiffusion/__pycache__/SE3_network.cpython-39.pyc
ADDED
Binary file (2.28 kB). View file
|
|
rfdiffusion/__pycache__/Track_module.cpython-310.pyc
ADDED
Binary file (14.1 kB). View file
|
|
rfdiffusion/__pycache__/Track_module.cpython-311.pyc
ADDED
Binary file (31.4 kB). View file
|
|
rfdiffusion/__pycache__/Track_module.cpython-39.pyc
ADDED
Binary file (14.1 kB). View file
|
|
rfdiffusion/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (154 Bytes). View file
|
|
rfdiffusion/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (170 Bytes). View file
|
|
rfdiffusion/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (152 Bytes). View file
|
|
rfdiffusion/__pycache__/chemical.cpython-310.pyc
ADDED
Binary file (20.5 kB). View file
|
|
rfdiffusion/__pycache__/chemical.cpython-311.pyc
ADDED
Binary file (24.6 kB). View file
|
|
rfdiffusion/__pycache__/chemical.cpython-39.pyc
ADDED
Binary file (20.5 kB). View file
|
|
rfdiffusion/__pycache__/contigs.cpython-310.pyc
ADDED
Binary file (10.2 kB). View file
|
|
rfdiffusion/__pycache__/contigs.cpython-311.pyc
ADDED
Binary file (23.2 kB). View file
|
|
rfdiffusion/__pycache__/contigs.cpython-39.pyc
ADDED
Binary file (10.5 kB). View file
|
|
rfdiffusion/__pycache__/diffusion.cpython-310.pyc
ADDED
Binary file (19.4 kB). View file
|
|
rfdiffusion/__pycache__/diffusion.cpython-311.pyc
ADDED
Binary file (31.4 kB). View file
|
|
rfdiffusion/__pycache__/diffusion.cpython-39.pyc
ADDED
Binary file (19.5 kB). View file
|
|
rfdiffusion/__pycache__/igso3.cpython-310.pyc
ADDED
Binary file (4.72 kB). View file
|
|
rfdiffusion/__pycache__/igso3.cpython-311.pyc
ADDED
Binary file (8 kB). View file
|
|
rfdiffusion/__pycache__/igso3.cpython-39.pyc
ADDED
Binary file (4.76 kB). View file
|
|
rfdiffusion/__pycache__/kinematics.cpython-310.pyc
ADDED
Binary file (9.41 kB). View file
|
|
rfdiffusion/__pycache__/kinematics.cpython-311.pyc
ADDED
Binary file (18.8 kB). View file
|
|
rfdiffusion/__pycache__/kinematics.cpython-39.pyc
ADDED
Binary file (9.42 kB). View file
|
|
rfdiffusion/__pycache__/model_input_logger.cpython-311.pyc
ADDED
Binary file (4.71 kB). View file
|
|
rfdiffusion/__pycache__/model_input_logger.cpython-39.pyc
ADDED
Binary file (2.62 kB). View file
|
|