GlandVergil commited on
Commit
0793996
1 Parent(s): 8eb0d85

Upload 95 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. rfdiffusion.egg-info/PKG-INFO +7 -0
  2. rfdiffusion.egg-info/SOURCES.txt +34 -0
  3. rfdiffusion.egg-info/dependency_links.txt +1 -0
  4. rfdiffusion.egg-info/requires.txt +2 -0
  5. rfdiffusion.egg-info/top_level.txt +1 -0
  6. rfdiffusion/Attention_module.py +404 -0
  7. rfdiffusion/AuxiliaryPredictor.py +92 -0
  8. rfdiffusion/Embeddings.py +303 -0
  9. rfdiffusion/RoseTTAFoldModel.py +140 -0
  10. rfdiffusion/SE3_network.py +83 -0
  11. rfdiffusion/Track_module.py +474 -0
  12. rfdiffusion/__init__.py +0 -0
  13. rfdiffusion/__pycache__/Attention_module.cpython-310.pyc +0 -0
  14. rfdiffusion/__pycache__/Attention_module.cpython-311.pyc +0 -0
  15. rfdiffusion/__pycache__/Attention_module.cpython-39.pyc +0 -0
  16. rfdiffusion/__pycache__/AuxiliaryPredictor.cpython-310.pyc +0 -0
  17. rfdiffusion/__pycache__/AuxiliaryPredictor.cpython-311.pyc +0 -0
  18. rfdiffusion/__pycache__/AuxiliaryPredictor.cpython-39.pyc +0 -0
  19. rfdiffusion/__pycache__/Embeddings.cpython-310.pyc +0 -0
  20. rfdiffusion/__pycache__/Embeddings.cpython-311.pyc +0 -0
  21. rfdiffusion/__pycache__/Embeddings.cpython-39.pyc +0 -0
  22. rfdiffusion/__pycache__/RoseTTAFoldModel.cpython-310.pyc +0 -0
  23. rfdiffusion/__pycache__/RoseTTAFoldModel.cpython-311.pyc +0 -0
  24. rfdiffusion/__pycache__/RoseTTAFoldModel.cpython-39.pyc +0 -0
  25. rfdiffusion/__pycache__/SE3_network.cpython-310.pyc +0 -0
  26. rfdiffusion/__pycache__/SE3_network.cpython-311.pyc +0 -0
  27. rfdiffusion/__pycache__/SE3_network.cpython-39.pyc +0 -0
  28. rfdiffusion/__pycache__/Track_module.cpython-310.pyc +0 -0
  29. rfdiffusion/__pycache__/Track_module.cpython-311.pyc +0 -0
  30. rfdiffusion/__pycache__/Track_module.cpython-39.pyc +0 -0
  31. rfdiffusion/__pycache__/__init__.cpython-310.pyc +0 -0
  32. rfdiffusion/__pycache__/__init__.cpython-311.pyc +0 -0
  33. rfdiffusion/__pycache__/__init__.cpython-39.pyc +0 -0
  34. rfdiffusion/__pycache__/chemical.cpython-310.pyc +0 -0
  35. rfdiffusion/__pycache__/chemical.cpython-311.pyc +0 -0
  36. rfdiffusion/__pycache__/chemical.cpython-39.pyc +0 -0
  37. rfdiffusion/__pycache__/contigs.cpython-310.pyc +0 -0
  38. rfdiffusion/__pycache__/contigs.cpython-311.pyc +0 -0
  39. rfdiffusion/__pycache__/contigs.cpython-39.pyc +0 -0
  40. rfdiffusion/__pycache__/diffusion.cpython-310.pyc +0 -0
  41. rfdiffusion/__pycache__/diffusion.cpython-311.pyc +0 -0
  42. rfdiffusion/__pycache__/diffusion.cpython-39.pyc +0 -0
  43. rfdiffusion/__pycache__/igso3.cpython-310.pyc +0 -0
  44. rfdiffusion/__pycache__/igso3.cpython-311.pyc +0 -0
  45. rfdiffusion/__pycache__/igso3.cpython-39.pyc +0 -0
  46. rfdiffusion/__pycache__/kinematics.cpython-310.pyc +0 -0
  47. rfdiffusion/__pycache__/kinematics.cpython-311.pyc +0 -0
  48. rfdiffusion/__pycache__/kinematics.cpython-39.pyc +0 -0
  49. rfdiffusion/__pycache__/model_input_logger.cpython-311.pyc +0 -0
  50. rfdiffusion/__pycache__/model_input_logger.cpython-39.pyc +0 -0
rfdiffusion.egg-info/PKG-INFO ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.1
2
+ Name: rfdiffusion
3
+ Version: 1.1.0
4
+ Summary: RFdiffusion is an open source method for protein structure generation.
5
+ Home-page: https://github.com/RosettaCommons/RFdiffusion
6
+ Author: Rosetta Commons
7
+ License-File: LICENSE
rfdiffusion.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LICENSE
2
+ README.md
3
+ setup.py
4
+ rfdiffusion/Attention_module.py
5
+ rfdiffusion/AuxiliaryPredictor.py
6
+ rfdiffusion/Embeddings.py
7
+ rfdiffusion/RoseTTAFoldModel.py
8
+ rfdiffusion/SE3_network.py
9
+ rfdiffusion/Track_module.py
10
+ rfdiffusion/__init__.py
11
+ rfdiffusion/chemical.py
12
+ rfdiffusion/contigs.py
13
+ rfdiffusion/coords6d.py
14
+ rfdiffusion/diffusion.py
15
+ rfdiffusion/igso3.py
16
+ rfdiffusion/kinematics.py
17
+ rfdiffusion/model_input_logger.py
18
+ rfdiffusion/scoring.py
19
+ rfdiffusion/util.py
20
+ rfdiffusion/util_module.py
21
+ rfdiffusion.egg-info/PKG-INFO
22
+ rfdiffusion.egg-info/SOURCES.txt
23
+ rfdiffusion.egg-info/dependency_links.txt
24
+ rfdiffusion.egg-info/requires.txt
25
+ rfdiffusion.egg-info/top_level.txt
26
+ rfdiffusion/inference/__init__.py
27
+ rfdiffusion/inference/model_runners.py
28
+ rfdiffusion/inference/symmetry.py
29
+ rfdiffusion/inference/utils.py
30
+ rfdiffusion/potentials/__init__.py
31
+ rfdiffusion/potentials/manager.py
32
+ rfdiffusion/potentials/potentials.py
33
+ scripts/run_inference.py
34
+ tests/test_diffusion.py
rfdiffusion.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
rfdiffusion.egg-info/requires.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch
2
+ se3-transformer
rfdiffusion.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ rfdiffusion
rfdiffusion/Attention_module.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ from opt_einsum import contract as einsum
6
+ from rfdiffusion.util_module import init_lecun_normal
7
+
8
+ class FeedForwardLayer(nn.Module):
9
+ def __init__(self, d_model, r_ff, p_drop=0.1):
10
+ super(FeedForwardLayer, self).__init__()
11
+ self.norm = nn.LayerNorm(d_model)
12
+ self.linear1 = nn.Linear(d_model, d_model*r_ff)
13
+ self.dropout = nn.Dropout(p_drop)
14
+ self.linear2 = nn.Linear(d_model*r_ff, d_model)
15
+
16
+ self.reset_parameter()
17
+
18
+ def reset_parameter(self):
19
+ # initialize linear layer right before ReLu: He initializer (kaiming normal)
20
+ nn.init.kaiming_normal_(self.linear1.weight, nonlinearity='relu')
21
+ nn.init.zeros_(self.linear1.bias)
22
+
23
+ # initialize linear layer right before residual connection: zero initialize
24
+ nn.init.zeros_(self.linear2.weight)
25
+ nn.init.zeros_(self.linear2.bias)
26
+
27
+ def forward(self, src):
28
+ src = self.norm(src)
29
+ src = self.linear2(self.dropout(F.relu_(self.linear1(src))))
30
+ return src
31
+
32
+ class Attention(nn.Module):
33
+ # calculate multi-head attention
34
+ def __init__(self, d_query, d_key, n_head, d_hidden, d_out):
35
+ super(Attention, self).__init__()
36
+ self.h = n_head
37
+ self.dim = d_hidden
38
+ #
39
+ self.to_q = nn.Linear(d_query, n_head*d_hidden, bias=False)
40
+ self.to_k = nn.Linear(d_key, n_head*d_hidden, bias=False)
41
+ self.to_v = nn.Linear(d_key, n_head*d_hidden, bias=False)
42
+ #
43
+ self.to_out = nn.Linear(n_head*d_hidden, d_out)
44
+ self.scaling = 1/math.sqrt(d_hidden)
45
+ #
46
+ # initialize all parameters properly
47
+ self.reset_parameter()
48
+
49
+ def reset_parameter(self):
50
+ # query/key/value projection: Glorot uniform / Xavier uniform
51
+ nn.init.xavier_uniform_(self.to_q.weight)
52
+ nn.init.xavier_uniform_(self.to_k.weight)
53
+ nn.init.xavier_uniform_(self.to_v.weight)
54
+
55
+ # to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
56
+ nn.init.zeros_(self.to_out.weight)
57
+ nn.init.zeros_(self.to_out.bias)
58
+
59
+ def forward(self, query, key, value):
60
+ B, Q = query.shape[:2]
61
+ B, K = key.shape[:2]
62
+ #
63
+ query = self.to_q(query).reshape(B, Q, self.h, self.dim)
64
+ key = self.to_k(key).reshape(B, K, self.h, self.dim)
65
+ value = self.to_v(value).reshape(B, K, self.h, self.dim)
66
+ #
67
+ query = query * self.scaling
68
+ attn = einsum('bqhd,bkhd->bhqk', query, key)
69
+ attn = F.softmax(attn, dim=-1)
70
+ #
71
+ out = einsum('bhqk,bkhd->bqhd', attn, value)
72
+ out = out.reshape(B, Q, self.h*self.dim)
73
+ #
74
+ out = self.to_out(out)
75
+
76
+ return out
77
+
78
+ class AttentionWithBias(nn.Module):
79
+ def __init__(self, d_in=256, d_bias=128, n_head=8, d_hidden=32):
80
+ super(AttentionWithBias, self).__init__()
81
+ self.norm_in = nn.LayerNorm(d_in)
82
+ self.norm_bias = nn.LayerNorm(d_bias)
83
+ #
84
+ self.to_q = nn.Linear(d_in, n_head*d_hidden, bias=False)
85
+ self.to_k = nn.Linear(d_in, n_head*d_hidden, bias=False)
86
+ self.to_v = nn.Linear(d_in, n_head*d_hidden, bias=False)
87
+ self.to_b = nn.Linear(d_bias, n_head, bias=False)
88
+ self.to_g = nn.Linear(d_in, n_head*d_hidden)
89
+ self.to_out = nn.Linear(n_head*d_hidden, d_in)
90
+
91
+ self.scaling = 1/math.sqrt(d_hidden)
92
+ self.h = n_head
93
+ self.dim = d_hidden
94
+
95
+ self.reset_parameter()
96
+
97
+ def reset_parameter(self):
98
+ # query/key/value projection: Glorot uniform / Xavier uniform
99
+ nn.init.xavier_uniform_(self.to_q.weight)
100
+ nn.init.xavier_uniform_(self.to_k.weight)
101
+ nn.init.xavier_uniform_(self.to_v.weight)
102
+
103
+ # bias: normal distribution
104
+ self.to_b = init_lecun_normal(self.to_b)
105
+
106
+ # gating: zero weights, one biases (mostly open gate at the begining)
107
+ nn.init.zeros_(self.to_g.weight)
108
+ nn.init.ones_(self.to_g.bias)
109
+
110
+ # to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
111
+ nn.init.zeros_(self.to_out.weight)
112
+ nn.init.zeros_(self.to_out.bias)
113
+
114
+ def forward(self, x, bias):
115
+ B, L = x.shape[:2]
116
+ #
117
+ x = self.norm_in(x)
118
+ bias = self.norm_bias(bias)
119
+ #
120
+ query = self.to_q(x).reshape(B, L, self.h, self.dim)
121
+ key = self.to_k(x).reshape(B, L, self.h, self.dim)
122
+ value = self.to_v(x).reshape(B, L, self.h, self.dim)
123
+ bias = self.to_b(bias) # (B, L, L, h)
124
+ gate = torch.sigmoid(self.to_g(x))
125
+ #
126
+ key = key * self.scaling
127
+ attn = einsum('bqhd,bkhd->bqkh', query, key)
128
+ attn = attn + bias
129
+ attn = F.softmax(attn, dim=-2)
130
+ #
131
+ out = einsum('bqkh,bkhd->bqhd', attn, value).reshape(B, L, -1)
132
+ out = gate * out
133
+ #
134
+ out = self.to_out(out)
135
+ return out
136
+
137
+ # MSA Attention (row/column) from AlphaFold architecture
138
+ class SequenceWeight(nn.Module):
139
+ def __init__(self, d_msa, n_head, d_hidden, p_drop=0.1):
140
+ super(SequenceWeight, self).__init__()
141
+ self.h = n_head
142
+ self.dim = d_hidden
143
+ self.scale = 1.0 / math.sqrt(self.dim)
144
+
145
+ self.to_query = nn.Linear(d_msa, n_head*d_hidden)
146
+ self.to_key = nn.Linear(d_msa, n_head*d_hidden)
147
+ self.dropout = nn.Dropout(p_drop)
148
+
149
+ self.reset_parameter()
150
+
151
+ def reset_parameter(self):
152
+ # query/key/value projection: Glorot uniform / Xavier uniform
153
+ nn.init.xavier_uniform_(self.to_query.weight)
154
+ nn.init.xavier_uniform_(self.to_key.weight)
155
+
156
+ def forward(self, msa):
157
+ B, N, L = msa.shape[:3]
158
+
159
+ tar_seq = msa[:,0]
160
+
161
+ q = self.to_query(tar_seq).view(B, 1, L, self.h, self.dim)
162
+ k = self.to_key(msa).view(B, N, L, self.h, self.dim)
163
+
164
+ q = q * self.scale
165
+ attn = einsum('bqihd,bkihd->bkihq', q, k)
166
+ attn = F.softmax(attn, dim=1)
167
+ return self.dropout(attn)
168
+
169
+ class MSARowAttentionWithBias(nn.Module):
170
+ def __init__(self, d_msa=256, d_pair=128, n_head=8, d_hidden=32):
171
+ super(MSARowAttentionWithBias, self).__init__()
172
+ self.norm_msa = nn.LayerNorm(d_msa)
173
+ self.norm_pair = nn.LayerNorm(d_pair)
174
+ #
175
+ self.seq_weight = SequenceWeight(d_msa, n_head, d_hidden, p_drop=0.1)
176
+ self.to_q = nn.Linear(d_msa, n_head*d_hidden, bias=False)
177
+ self.to_k = nn.Linear(d_msa, n_head*d_hidden, bias=False)
178
+ self.to_v = nn.Linear(d_msa, n_head*d_hidden, bias=False)
179
+ self.to_b = nn.Linear(d_pair, n_head, bias=False)
180
+ self.to_g = nn.Linear(d_msa, n_head*d_hidden)
181
+ self.to_out = nn.Linear(n_head*d_hidden, d_msa)
182
+
183
+ self.scaling = 1/math.sqrt(d_hidden)
184
+ self.h = n_head
185
+ self.dim = d_hidden
186
+
187
+ self.reset_parameter()
188
+
189
+ def reset_parameter(self):
190
+ # query/key/value projection: Glorot uniform / Xavier uniform
191
+ nn.init.xavier_uniform_(self.to_q.weight)
192
+ nn.init.xavier_uniform_(self.to_k.weight)
193
+ nn.init.xavier_uniform_(self.to_v.weight)
194
+
195
+ # bias: normal distribution
196
+ self.to_b = init_lecun_normal(self.to_b)
197
+
198
+ # gating: zero weights, one biases (mostly open gate at the begining)
199
+ nn.init.zeros_(self.to_g.weight)
200
+ nn.init.ones_(self.to_g.bias)
201
+
202
+ # to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
203
+ nn.init.zeros_(self.to_out.weight)
204
+ nn.init.zeros_(self.to_out.bias)
205
+
206
+ def forward(self, msa, pair): # TODO: make this as tied-attention
207
+ B, N, L = msa.shape[:3]
208
+ #
209
+ msa = self.norm_msa(msa)
210
+ pair = self.norm_pair(pair)
211
+ #
212
+ seq_weight = self.seq_weight(msa) # (B, N, L, h, 1)
213
+ query = self.to_q(msa).reshape(B, N, L, self.h, self.dim)
214
+ key = self.to_k(msa).reshape(B, N, L, self.h, self.dim)
215
+ value = self.to_v(msa).reshape(B, N, L, self.h, self.dim)
216
+ bias = self.to_b(pair) # (B, L, L, h)
217
+ gate = torch.sigmoid(self.to_g(msa))
218
+ #
219
+ query = query * seq_weight.expand(-1, -1, -1, -1, self.dim)
220
+ key = key * self.scaling
221
+ attn = einsum('bsqhd,bskhd->bqkh', query, key)
222
+ attn = attn + bias
223
+ attn = F.softmax(attn, dim=-2)
224
+ #
225
+ out = einsum('bqkh,bskhd->bsqhd', attn, value).reshape(B, N, L, -1)
226
+ out = gate * out
227
+ #
228
+ out = self.to_out(out)
229
+ return out
230
+
231
+ class MSAColAttention(nn.Module):
232
+ def __init__(self, d_msa=256, n_head=8, d_hidden=32):
233
+ super(MSAColAttention, self).__init__()
234
+ self.norm_msa = nn.LayerNorm(d_msa)
235
+ #
236
+ self.to_q = nn.Linear(d_msa, n_head*d_hidden, bias=False)
237
+ self.to_k = nn.Linear(d_msa, n_head*d_hidden, bias=False)
238
+ self.to_v = nn.Linear(d_msa, n_head*d_hidden, bias=False)
239
+ self.to_g = nn.Linear(d_msa, n_head*d_hidden)
240
+ self.to_out = nn.Linear(n_head*d_hidden, d_msa)
241
+
242
+ self.scaling = 1/math.sqrt(d_hidden)
243
+ self.h = n_head
244
+ self.dim = d_hidden
245
+
246
+ self.reset_parameter()
247
+
248
+ def reset_parameter(self):
249
+ # query/key/value projection: Glorot uniform / Xavier uniform
250
+ nn.init.xavier_uniform_(self.to_q.weight)
251
+ nn.init.xavier_uniform_(self.to_k.weight)
252
+ nn.init.xavier_uniform_(self.to_v.weight)
253
+
254
+ # gating: zero weights, one biases (mostly open gate at the begining)
255
+ nn.init.zeros_(self.to_g.weight)
256
+ nn.init.ones_(self.to_g.bias)
257
+
258
+ # to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
259
+ nn.init.zeros_(self.to_out.weight)
260
+ nn.init.zeros_(self.to_out.bias)
261
+
262
+ def forward(self, msa):
263
+ B, N, L = msa.shape[:3]
264
+ #
265
+ msa = self.norm_msa(msa)
266
+ #
267
+ query = self.to_q(msa).reshape(B, N, L, self.h, self.dim)
268
+ key = self.to_k(msa).reshape(B, N, L, self.h, self.dim)
269
+ value = self.to_v(msa).reshape(B, N, L, self.h, self.dim)
270
+ gate = torch.sigmoid(self.to_g(msa))
271
+ #
272
+ query = query * self.scaling
273
+ attn = einsum('bqihd,bkihd->bihqk', query, key)
274
+ attn = F.softmax(attn, dim=-1)
275
+ #
276
+ out = einsum('bihqk,bkihd->bqihd', attn, value).reshape(B, N, L, -1)
277
+ out = gate * out
278
+ #
279
+ out = self.to_out(out)
280
+ return out
281
+
282
+ class MSAColGlobalAttention(nn.Module):
283
+ def __init__(self, d_msa=64, n_head=8, d_hidden=8):
284
+ super(MSAColGlobalAttention, self).__init__()
285
+ self.norm_msa = nn.LayerNorm(d_msa)
286
+ #
287
+ self.to_q = nn.Linear(d_msa, n_head*d_hidden, bias=False)
288
+ self.to_k = nn.Linear(d_msa, d_hidden, bias=False)
289
+ self.to_v = nn.Linear(d_msa, d_hidden, bias=False)
290
+ self.to_g = nn.Linear(d_msa, n_head*d_hidden)
291
+ self.to_out = nn.Linear(n_head*d_hidden, d_msa)
292
+
293
+ self.scaling = 1/math.sqrt(d_hidden)
294
+ self.h = n_head
295
+ self.dim = d_hidden
296
+
297
+ self.reset_parameter()
298
+
299
+ def reset_parameter(self):
300
+ # query/key/value projection: Glorot uniform / Xavier uniform
301
+ nn.init.xavier_uniform_(self.to_q.weight)
302
+ nn.init.xavier_uniform_(self.to_k.weight)
303
+ nn.init.xavier_uniform_(self.to_v.weight)
304
+
305
+ # gating: zero weights, one biases (mostly open gate at the begining)
306
+ nn.init.zeros_(self.to_g.weight)
307
+ nn.init.ones_(self.to_g.bias)
308
+
309
+ # to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
310
+ nn.init.zeros_(self.to_out.weight)
311
+ nn.init.zeros_(self.to_out.bias)
312
+
313
+ def forward(self, msa):
314
+ B, N, L = msa.shape[:3]
315
+ #
316
+ msa = self.norm_msa(msa)
317
+ #
318
+ query = self.to_q(msa).reshape(B, N, L, self.h, self.dim)
319
+ query = query.mean(dim=1) # (B, L, h, dim)
320
+ key = self.to_k(msa) # (B, N, L, dim)
321
+ value = self.to_v(msa) # (B, N, L, dim)
322
+ gate = torch.sigmoid(self.to_g(msa)) # (B, N, L, h*dim)
323
+ #
324
+ query = query * self.scaling
325
+ attn = einsum('bihd,bkid->bihk', query, key) # (B, L, h, N)
326
+ attn = F.softmax(attn, dim=-1)
327
+ #
328
+ out = einsum('bihk,bkid->bihd', attn, value).reshape(B, 1, L, -1) # (B, 1, L, h*dim)
329
+ out = gate * out # (B, N, L, h*dim)
330
+ #
331
+ out = self.to_out(out)
332
+ return out
333
+
334
+ # Instead of triangle attention, use Tied axail attention with bias from coordinates..?
335
+ class BiasedAxialAttention(nn.Module):
336
+ def __init__(self, d_pair, d_bias, n_head, d_hidden, p_drop=0.1, is_row=True):
337
+ super(BiasedAxialAttention, self).__init__()
338
+ #
339
+ self.is_row = is_row
340
+ self.norm_pair = nn.LayerNorm(d_pair)
341
+ self.norm_bias = nn.LayerNorm(d_bias)
342
+
343
+ self.to_q = nn.Linear(d_pair, n_head*d_hidden, bias=False)
344
+ self.to_k = nn.Linear(d_pair, n_head*d_hidden, bias=False)
345
+ self.to_v = nn.Linear(d_pair, n_head*d_hidden, bias=False)
346
+ self.to_b = nn.Linear(d_bias, n_head, bias=False)
347
+ self.to_g = nn.Linear(d_pair, n_head*d_hidden)
348
+ self.to_out = nn.Linear(n_head*d_hidden, d_pair)
349
+
350
+ self.scaling = 1/math.sqrt(d_hidden)
351
+ self.h = n_head
352
+ self.dim = d_hidden
353
+
354
+ # initialize all parameters properly
355
+ self.reset_parameter()
356
+
357
+ def reset_parameter(self):
358
+ # query/key/value projection: Glorot uniform / Xavier uniform
359
+ nn.init.xavier_uniform_(self.to_q.weight)
360
+ nn.init.xavier_uniform_(self.to_k.weight)
361
+ nn.init.xavier_uniform_(self.to_v.weight)
362
+
363
+ # bias: normal distribution
364
+ self.to_b = init_lecun_normal(self.to_b)
365
+
366
+ # gating: zero weights, one biases (mostly open gate at the begining)
367
+ nn.init.zeros_(self.to_g.weight)
368
+ nn.init.ones_(self.to_g.bias)
369
+
370
+ # to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
371
+ nn.init.zeros_(self.to_out.weight)
372
+ nn.init.zeros_(self.to_out.bias)
373
+
374
+ def forward(self, pair, bias):
375
+ # pair: (B, L, L, d_pair)
376
+ B, L = pair.shape[:2]
377
+
378
+ if self.is_row:
379
+ pair = pair.permute(0,2,1,3)
380
+ bias = bias.permute(0,2,1,3)
381
+
382
+ pair = self.norm_pair(pair)
383
+ bias = self.norm_bias(bias)
384
+
385
+ query = self.to_q(pair).reshape(B, L, L, self.h, self.dim)
386
+ key = self.to_k(pair).reshape(B, L, L, self.h, self.dim)
387
+ value = self.to_v(pair).reshape(B, L, L, self.h, self.dim)
388
+ bias = self.to_b(bias) # (B, L, L, h)
389
+ gate = torch.sigmoid(self.to_g(pair)) # (B, L, L, h*dim)
390
+
391
+ query = query * self.scaling
392
+ key = key / math.sqrt(L) # normalize for tied attention
393
+ attn = einsum('bnihk,bnjhk->bijh', query, key) # tied attention
394
+ attn = attn + bias # apply bias
395
+ attn = F.softmax(attn, dim=-2) # (B, L, L, h)
396
+
397
+ out = einsum('bijh,bkjhd->bikhd', attn, value).reshape(B, L, L, -1)
398
+ out = gate * out
399
+
400
+ out = self.to_out(out)
401
+ if self.is_row:
402
+ out = out.permute(0,2,1,3)
403
+ return out
404
+
rfdiffusion/AuxiliaryPredictor.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class DistanceNetwork(nn.Module):
5
+ def __init__(self, n_feat, p_drop=0.1):
6
+ super(DistanceNetwork, self).__init__()
7
+ #
8
+ self.proj_symm = nn.Linear(n_feat, 37*2)
9
+ self.proj_asymm = nn.Linear(n_feat, 37+19)
10
+
11
+ self.reset_parameter()
12
+
13
+ def reset_parameter(self):
14
+ # initialize linear layer for final logit prediction
15
+ nn.init.zeros_(self.proj_symm.weight)
16
+ nn.init.zeros_(self.proj_asymm.weight)
17
+ nn.init.zeros_(self.proj_symm.bias)
18
+ nn.init.zeros_(self.proj_asymm.bias)
19
+
20
+ def forward(self, x):
21
+ # input: pair info (B, L, L, C)
22
+
23
+ # predict theta, phi (non-symmetric)
24
+ logits_asymm = self.proj_asymm(x)
25
+ logits_theta = logits_asymm[:,:,:,:37].permute(0,3,1,2)
26
+ logits_phi = logits_asymm[:,:,:,37:].permute(0,3,1,2)
27
+
28
+ # predict dist, omega
29
+ logits_symm = self.proj_symm(x)
30
+ logits_symm = logits_symm + logits_symm.permute(0,2,1,3)
31
+ logits_dist = logits_symm[:,:,:,:37].permute(0,3,1,2)
32
+ logits_omega = logits_symm[:,:,:,37:].permute(0,3,1,2)
33
+
34
+ return logits_dist, logits_omega, logits_theta, logits_phi
35
+
36
+ class MaskedTokenNetwork(nn.Module):
37
+ def __init__(self, n_feat):
38
+ super(MaskedTokenNetwork, self).__init__()
39
+ self.proj = nn.Linear(n_feat, 21)
40
+
41
+ self.reset_parameter()
42
+
43
+ def reset_parameter(self):
44
+ nn.init.zeros_(self.proj.weight)
45
+ nn.init.zeros_(self.proj.bias)
46
+
47
+ def forward(self, x):
48
+ B, N, L = x.shape[:3]
49
+ logits = self.proj(x).permute(0,3,1,2).reshape(B, -1, N*L)
50
+
51
+ return logits
52
+
53
+ class LDDTNetwork(nn.Module):
54
+ def __init__(self, n_feat, n_bin_lddt=50):
55
+ super(LDDTNetwork, self).__init__()
56
+ self.proj = nn.Linear(n_feat, n_bin_lddt)
57
+
58
+ self.reset_parameter()
59
+
60
+ def reset_parameter(self):
61
+ nn.init.zeros_(self.proj.weight)
62
+ nn.init.zeros_(self.proj.bias)
63
+
64
+ def forward(self, x):
65
+ logits = self.proj(x) # (B, L, 50)
66
+
67
+ return logits.permute(0,2,1)
68
+
69
+ class ExpResolvedNetwork(nn.Module):
70
+ def __init__(self, d_msa, d_state, p_drop=0.1):
71
+ super(ExpResolvedNetwork, self).__init__()
72
+ self.norm_msa = nn.LayerNorm(d_msa)
73
+ self.norm_state = nn.LayerNorm(d_state)
74
+ self.proj = nn.Linear(d_msa+d_state, 1)
75
+
76
+ self.reset_parameter()
77
+
78
+ def reset_parameter(self):
79
+ nn.init.zeros_(self.proj.weight)
80
+ nn.init.zeros_(self.proj.bias)
81
+
82
+ def forward(self, seq, state):
83
+ B, L = seq.shape[:2]
84
+
85
+ seq = self.norm_msa(seq)
86
+ state = self.norm_state(state)
87
+ feat = torch.cat((seq, state), dim=-1)
88
+ logits = self.proj(feat)
89
+ return logits.reshape(B, L)
90
+
91
+
92
+
rfdiffusion/Embeddings.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from opt_einsum import contract as einsum
5
+ import torch.utils.checkpoint as checkpoint
6
+ from rfdiffusion.util import get_tips
7
+ from rfdiffusion.util_module import Dropout, create_custom_forward, rbf, init_lecun_normal
8
+ from rfdiffusion.Attention_module import Attention, FeedForwardLayer, AttentionWithBias
9
+ from rfdiffusion.Track_module import PairStr2Pair
10
+ import math
11
+
12
+ # Module contains classes and functions to generate initial embeddings
13
+
14
+ class PositionalEncoding2D(nn.Module):
15
+ # Add relative positional encoding to pair features
16
+ def __init__(self, d_model, minpos=-32, maxpos=32, p_drop=0.1):
17
+ super(PositionalEncoding2D, self).__init__()
18
+ self.minpos = minpos
19
+ self.maxpos = maxpos
20
+ self.nbin = abs(minpos)+maxpos+1
21
+ self.emb = nn.Embedding(self.nbin, d_model)
22
+ self.drop = nn.Dropout(p_drop)
23
+
24
+ def forward(self, x, idx):
25
+ bins = torch.arange(self.minpos, self.maxpos, device=x.device)
26
+ seqsep = idx[:,None,:] - idx[:,:,None] # (B, L, L)
27
+ #
28
+ ib = torch.bucketize(seqsep, bins).long() # (B, L, L)
29
+ emb = self.emb(ib) #(B, L, L, d_model)
30
+ x = x + emb # add relative positional encoding
31
+ return self.drop(x)
32
+
33
+ class MSA_emb(nn.Module):
34
+ # Get initial seed MSA embedding
35
+ def __init__(self, d_msa=256, d_pair=128, d_state=32, d_init=22+22+2+2,
36
+ minpos=-32, maxpos=32, p_drop=0.1, input_seq_onehot=False):
37
+ super(MSA_emb, self).__init__()
38
+ self.emb = nn.Linear(d_init, d_msa) # embedding for general MSA
39
+ self.emb_q = nn.Embedding(22, d_msa) # embedding for query sequence -- used for MSA embedding
40
+ self.emb_left = nn.Embedding(22, d_pair) # embedding for query sequence -- used for pair embedding
41
+ self.emb_right = nn.Embedding(22, d_pair) # embedding for query sequence -- used for pair embedding
42
+ self.emb_state = nn.Embedding(22, d_state)
43
+ self.drop = nn.Dropout(p_drop)
44
+ self.pos = PositionalEncoding2D(d_pair, minpos=minpos, maxpos=maxpos, p_drop=p_drop)
45
+
46
+ self.input_seq_onehot=input_seq_onehot
47
+
48
+ self.reset_parameter()
49
+
50
+ def reset_parameter(self):
51
+ self.emb = init_lecun_normal(self.emb)
52
+ self.emb_q = init_lecun_normal(self.emb_q)
53
+ self.emb_left = init_lecun_normal(self.emb_left)
54
+ self.emb_right = init_lecun_normal(self.emb_right)
55
+ self.emb_state = init_lecun_normal(self.emb_state)
56
+
57
+ nn.init.zeros_(self.emb.bias)
58
+
59
+ def forward(self, msa, seq, idx):
60
+ # Inputs:
61
+ # - msa: Input MSA (B, N, L, d_init)
62
+ # - seq: Input Sequence (B, L)
63
+ # - idx: Residue index
64
+ # Outputs:
65
+ # - msa: Initial MSA embedding (B, N, L, d_msa)
66
+ # - pair: Initial Pair embedding (B, L, L, d_pair)
67
+
68
+ N = msa.shape[1] # number of sequenes in MSA
69
+
70
+ # msa embedding
71
+ msa = self.emb(msa) # (B, N, L, d_model) # MSA embedding
72
+
73
+ # Sergey's one hot trick
74
+ tmp = (seq @ self.emb_q.weight).unsqueeze(1) # (B, 1, L, d_model) -- query embedding
75
+
76
+ msa = msa + tmp.expand(-1, N, -1, -1) # adding query embedding to MSA
77
+ msa = self.drop(msa)
78
+
79
+ # pair embedding
80
+ # Sergey's one hot trick
81
+ left = (seq @ self.emb_left.weight)[:,None] # (B, 1, L, d_pair)
82
+ right = (seq @ self.emb_right.weight)[:,:,None] # (B, L, 1, d_pair)
83
+
84
+ pair = left + right # (B, L, L, d_pair)
85
+ pair = self.pos(pair, idx) # add relative position
86
+
87
+ # state embedding
88
+ # Sergey's one hot trick
89
+ state = self.drop(seq @ self.emb_state.weight)
90
+ return msa, pair, state
91
+
92
+ class Extra_emb(nn.Module):
93
+ # Get initial seed MSA embedding
94
+ def __init__(self, d_msa=256, d_init=22+1+2, p_drop=0.1, input_seq_onehot=False):
95
+ super(Extra_emb, self).__init__()
96
+ self.emb = nn.Linear(d_init, d_msa) # embedding for general MSA
97
+ self.emb_q = nn.Embedding(22, d_msa) # embedding for query sequence
98
+ self.drop = nn.Dropout(p_drop)
99
+
100
+ self.input_seq_onehot=input_seq_onehot
101
+
102
+ self.reset_parameter()
103
+
104
+ def reset_parameter(self):
105
+ self.emb = init_lecun_normal(self.emb)
106
+ nn.init.zeros_(self.emb.bias)
107
+
108
+ def forward(self, msa, seq, idx):
109
+ # Inputs:
110
+ # - msa: Input MSA (B, N, L, d_init)
111
+ # - seq: Input Sequence (B, L)
112
+ # - idx: Residue index
113
+ # Outputs:
114
+ # - msa: Initial MSA embedding (B, N, L, d_msa)
115
+ N = msa.shape[1] # number of sequenes in MSA
116
+ msa = self.emb(msa) # (B, N, L, d_model) # MSA embedding
117
+
118
+ # Sergey's one hot trick
119
+ seq = (seq @ self.emb_q.weight).unsqueeze(1) # (B, 1, L, d_model) -- query embedding
120
+ msa = msa + seq.expand(-1, N, -1, -1) # adding query embedding to MSA
121
+ return self.drop(msa)
122
+
123
+ class TemplatePairStack(nn.Module):
124
+ # process template pairwise features
125
+ # use structure-biased attention
126
+ def __init__(self, n_block=2, d_templ=64, n_head=4, d_hidden=16, p_drop=0.25):
127
+ super(TemplatePairStack, self).__init__()
128
+ self.n_block = n_block
129
+ proc_s = [PairStr2Pair(d_pair=d_templ, n_head=n_head, d_hidden=d_hidden, p_drop=p_drop) for i in range(n_block)]
130
+ self.block = nn.ModuleList(proc_s)
131
+ self.norm = nn.LayerNorm(d_templ)
132
+ def forward(self, templ, rbf_feat, use_checkpoint=False):
133
+ B, T, L = templ.shape[:3]
134
+ templ = templ.reshape(B*T, L, L, -1)
135
+
136
+ for i_block in range(self.n_block):
137
+ if use_checkpoint:
138
+ templ = checkpoint.checkpoint(create_custom_forward(self.block[i_block]), templ, rbf_feat)
139
+ else:
140
+ templ = self.block[i_block](templ, rbf_feat)
141
+ return self.norm(templ).reshape(B, T, L, L, -1)
142
+
143
+ class TemplateTorsionStack(nn.Module):
144
+ def __init__(self, n_block=2, d_templ=64, n_head=4, d_hidden=16, p_drop=0.15):
145
+ super(TemplateTorsionStack, self).__init__()
146
+ self.n_block=n_block
147
+ self.proj_pair = nn.Linear(d_templ+36, d_templ)
148
+ proc_s = [AttentionWithBias(d_in=d_templ, d_bias=d_templ,
149
+ n_head=n_head, d_hidden=d_hidden) for i in range(n_block)]
150
+ self.row_attn = nn.ModuleList(proc_s)
151
+ proc_s = [FeedForwardLayer(d_templ, 4, p_drop=p_drop) for i in range(n_block)]
152
+ self.ff = nn.ModuleList(proc_s)
153
+ self.norm = nn.LayerNorm(d_templ)
154
+
155
+ def reset_parameter(self):
156
+ self.proj_pair = init_lecun_normal(self.proj_pair)
157
+ nn.init.zeros_(self.proj_pair.bias)
158
+
159
+ def forward(self, tors, pair, rbf_feat, use_checkpoint=False):
160
+ B, T, L = tors.shape[:3]
161
+ tors = tors.reshape(B*T, L, -1)
162
+ pair = pair.reshape(B*T, L, L, -1)
163
+ pair = torch.cat((pair, rbf_feat), dim=-1)
164
+ pair = self.proj_pair(pair)
165
+
166
+ for i_block in range(self.n_block):
167
+ if use_checkpoint:
168
+ tors = tors + checkpoint.checkpoint(create_custom_forward(self.row_attn[i_block]), tors, pair)
169
+ else:
170
+ tors = tors + self.row_attn[i_block](tors, pair)
171
+ tors = tors + self.ff[i_block](tors)
172
+ return self.norm(tors).reshape(B, T, L, -1)
173
+
174
+ class Templ_emb(nn.Module):
175
+ # Get template embedding
176
+ # Features are
177
+ # t2d:
178
+ # - 37 distogram bins + 6 orientations (43)
179
+ # - Mask (missing/unaligned) (1)
180
+ # t1d:
181
+ # - tiled AA sequence (20 standard aa + gap)
182
+ # - confidence (1)
183
+ # - contacting or note (1). NB this is added for diffusion model. Used only in complex training examples - 1 signifies that a residue in the non-diffused chain\
184
+ # i.e. the context, is in contact with the diffused chain.
185
+ #
186
+ #Added extra t1d dimension for contacting or not
187
+ def __init__(self, d_t1d=21+1+1, d_t2d=43+1, d_tor=30, d_pair=128, d_state=32,
188
+ n_block=2, d_templ=64,
189
+ n_head=4, d_hidden=16, p_drop=0.25):
190
+ super(Templ_emb, self).__init__()
191
+ # process 2D features
192
+ self.emb = nn.Linear(d_t1d*2+d_t2d, d_templ)
193
+ self.templ_stack = TemplatePairStack(n_block=n_block, d_templ=d_templ, n_head=n_head,
194
+ d_hidden=d_hidden, p_drop=p_drop)
195
+
196
+ self.attn = Attention(d_pair, d_templ, n_head, d_hidden, d_pair)
197
+
198
+ # process torsion angles
199
+ self.emb_t1d = nn.Linear(d_t1d+d_tor, d_templ)
200
+ self.proj_t1d = nn.Linear(d_templ, d_templ)
201
+ #self.tor_stack = TemplateTorsionStack(n_block=n_block, d_templ=d_templ, n_head=n_head,
202
+ # d_hidden=d_hidden, p_drop=p_drop)
203
+ self.attn_tor = Attention(d_state, d_templ, n_head, d_hidden, d_state)
204
+
205
+ self.reset_parameter()
206
+
207
+ def reset_parameter(self):
208
+ self.emb = init_lecun_normal(self.emb)
209
+ nn.init.zeros_(self.emb.bias)
210
+
211
+ nn.init.kaiming_normal_(self.emb_t1d.weight, nonlinearity='relu')
212
+ nn.init.zeros_(self.emb_t1d.bias)
213
+
214
+ self.proj_t1d = init_lecun_normal(self.proj_t1d)
215
+ nn.init.zeros_(self.proj_t1d.bias)
216
+
217
+ def forward(self, t1d, t2d, alpha_t, xyz_t, pair, state, use_checkpoint=False):
218
+ # Input
219
+ # - t1d: 1D template info (B, T, L, 23)
220
+ # - t2d: 2D template info (B, T, L, L, 44)
221
+ B, T, L, _ = t1d.shape
222
+
223
+ # Prepare 2D template features
224
+ left = t1d.unsqueeze(3).expand(-1,-1,-1,L,-1)
225
+ right = t1d.unsqueeze(2).expand(-1,-1,L,-1,-1)
226
+ #
227
+ templ = torch.cat((t2d, left, right), -1) # (B, T, L, L, 90)
228
+ templ = self.emb(templ) # Template templures (B, T, L, L, d_templ)
229
+ # process each template features
230
+ xyz_t = xyz_t.reshape(B*T, L, -1, 3)
231
+ rbf_feat = rbf(torch.cdist(xyz_t[:,:,1], xyz_t[:,:,1]))
232
+ templ = self.templ_stack(templ, rbf_feat, use_checkpoint=use_checkpoint) # (B, T, L,L, d_templ)
233
+
234
+ # Prepare 1D template torsion angle features
235
+ t1d = torch.cat((t1d, alpha_t), dim=-1) # (B, T, L, 23+30)
236
+
237
+ # process each template features
238
+ t1d = self.proj_t1d(F.relu_(self.emb_t1d(t1d)))
239
+
240
+ # mixing query state features to template state features
241
+ state = state.reshape(B*L, 1, -1)
242
+ t1d = t1d.permute(0,2,1,3).reshape(B*L, T, -1)
243
+ if use_checkpoint:
244
+ out = checkpoint.checkpoint(create_custom_forward(self.attn_tor), state, t1d, t1d)
245
+ out = out.reshape(B, L, -1)
246
+ else:
247
+ out = self.attn_tor(state, t1d, t1d).reshape(B, L, -1)
248
+ state = state.reshape(B, L, -1)
249
+ state = state + out
250
+
251
+ # mixing query pair features to template information (Template pointwise attention)
252
+ pair = pair.reshape(B*L*L, 1, -1)
253
+ templ = templ.permute(0, 2, 3, 1, 4).reshape(B*L*L, T, -1)
254
+ if use_checkpoint:
255
+ out = checkpoint.checkpoint(create_custom_forward(self.attn), pair, templ, templ)
256
+ out = out.reshape(B, L, L, -1)
257
+ else:
258
+ out = self.attn(pair, templ, templ).reshape(B, L, L, -1)
259
+ #
260
+ pair = pair.reshape(B, L, L, -1)
261
+ pair = pair + out
262
+
263
+ return pair, state
264
+
265
+ class Recycling(nn.Module):
266
+ def __init__(self, d_msa=256, d_pair=128, d_state=32):
267
+ super(Recycling, self).__init__()
268
+ self.proj_dist = nn.Linear(36+d_state*2, d_pair)
269
+ self.norm_state = nn.LayerNorm(d_state)
270
+ self.norm_pair = nn.LayerNorm(d_pair)
271
+ self.norm_msa = nn.LayerNorm(d_msa)
272
+
273
+ self.reset_parameter()
274
+
275
+ def reset_parameter(self):
276
+ self.proj_dist = init_lecun_normal(self.proj_dist)
277
+ nn.init.zeros_(self.proj_dist.bias)
278
+
279
+ def forward(self, seq, msa, pair, xyz, state):
280
+ B, L = pair.shape[:2]
281
+ state = self.norm_state(state)
282
+ #
283
+ left = state.unsqueeze(2).expand(-1,-1,L,-1)
284
+ right = state.unsqueeze(1).expand(-1,L,-1,-1)
285
+
286
+ # three anchor atoms
287
+ N = xyz[:,:,0]
288
+ Ca = xyz[:,:,1]
289
+ C = xyz[:,:,2]
290
+
291
+ # recreate Cb given N,Ca,C
292
+ b = Ca - N
293
+ c = C - Ca
294
+ a = torch.cross(b, c, dim=-1)
295
+ Cb = -0.58273431*a + 0.56802827*b - 0.54067466*c + Ca
296
+
297
+ dist = rbf(torch.cdist(Cb, Cb))
298
+ dist = torch.cat((dist, left, right), dim=-1)
299
+ dist = self.proj_dist(dist)
300
+ pair = dist + self.norm_pair(pair)
301
+ msa = self.norm_msa(msa)
302
+ return msa, pair, state
303
+
rfdiffusion/RoseTTAFoldModel.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from rfdiffusion.Embeddings import MSA_emb, Extra_emb, Templ_emb, Recycling
4
+ from rfdiffusion.Track_module import IterativeSimulator
5
+ from rfdiffusion.AuxiliaryPredictor import DistanceNetwork, MaskedTokenNetwork, ExpResolvedNetwork, LDDTNetwork
6
+ from opt_einsum import contract as einsum
7
+
8
+ class RoseTTAFoldModule(nn.Module):
9
+ def __init__(self,
10
+ n_extra_block,
11
+ n_main_block,
12
+ n_ref_block,
13
+ d_msa,
14
+ d_msa_full,
15
+ d_pair,
16
+ d_templ,
17
+ n_head_msa,
18
+ n_head_pair,
19
+ n_head_templ,
20
+ d_hidden,
21
+ d_hidden_templ,
22
+ p_drop,
23
+ d_t1d,
24
+ d_t2d,
25
+ T, # total timesteps (used in timestep emb
26
+ use_motif_timestep, # Whether to have a distinct emb for motif
27
+ freeze_track_motif, # Whether to freeze updates to motif in track
28
+ SE3_param_full={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32},
29
+ SE3_param_topk={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32},
30
+ input_seq_onehot=False, # For continuous vs. discrete sequence
31
+ ):
32
+
33
+ super(RoseTTAFoldModule, self).__init__()
34
+
35
+ self.freeze_track_motif = freeze_track_motif
36
+
37
+ # Input Embeddings
38
+ d_state = SE3_param_topk['l0_out_features']
39
+ self.latent_emb = MSA_emb(d_msa=d_msa, d_pair=d_pair, d_state=d_state,
40
+ p_drop=p_drop, input_seq_onehot=input_seq_onehot) # Allowed to take onehotseq
41
+ self.full_emb = Extra_emb(d_msa=d_msa_full, d_init=25,
42
+ p_drop=p_drop, input_seq_onehot=input_seq_onehot) # Allowed to take onehotseq
43
+ self.templ_emb = Templ_emb(d_pair=d_pair, d_templ=d_templ, d_state=d_state,
44
+ n_head=n_head_templ,
45
+ d_hidden=d_hidden_templ, p_drop=0.25, d_t1d=d_t1d, d_t2d=d_t2d)
46
+
47
+
48
+ # Update inputs with outputs from previous round
49
+ self.recycle = Recycling(d_msa=d_msa, d_pair=d_pair, d_state=d_state)
50
+ #
51
+ self.simulator = IterativeSimulator(n_extra_block=n_extra_block,
52
+ n_main_block=n_main_block,
53
+ n_ref_block=n_ref_block,
54
+ d_msa=d_msa, d_msa_full=d_msa_full,
55
+ d_pair=d_pair, d_hidden=d_hidden,
56
+ n_head_msa=n_head_msa,
57
+ n_head_pair=n_head_pair,
58
+ SE3_param_full=SE3_param_full,
59
+ SE3_param_topk=SE3_param_topk,
60
+ p_drop=p_drop)
61
+ ##
62
+ self.c6d_pred = DistanceNetwork(d_pair, p_drop=p_drop)
63
+ self.aa_pred = MaskedTokenNetwork(d_msa)
64
+ self.lddt_pred = LDDTNetwork(d_state)
65
+
66
+ self.exp_pred = ExpResolvedNetwork(d_msa, d_state)
67
+
68
+ def forward(self, msa_latent, msa_full, seq, xyz, idx, t,
69
+ t1d=None, t2d=None, xyz_t=None, alpha_t=None,
70
+ msa_prev=None, pair_prev=None, state_prev=None,
71
+ return_raw=False, return_full=False, return_infer=False,
72
+ use_checkpoint=False, motif_mask=None, i_cycle=None, n_cycle=None):
73
+
74
+ B, N, L = msa_latent.shape[:3]
75
+ # Get embeddings
76
+ msa_latent, pair, state = self.latent_emb(msa_latent, seq, idx)
77
+ msa_full = self.full_emb(msa_full, seq, idx)
78
+
79
+ # Do recycling
80
+ if msa_prev == None:
81
+ msa_prev = torch.zeros_like(msa_latent[:,0])
82
+ pair_prev = torch.zeros_like(pair)
83
+ state_prev = torch.zeros_like(state)
84
+ msa_recycle, pair_recycle, state_recycle = self.recycle(seq, msa_prev, pair_prev, xyz, state_prev)
85
+ msa_latent[:,0] = msa_latent[:,0] + msa_recycle.reshape(B,L,-1)
86
+ pair = pair + pair_recycle
87
+ state = state + state_recycle
88
+
89
+
90
+ # Get timestep embedding (if using)
91
+ if hasattr(self, 'timestep_embedder'):
92
+ assert t is not None
93
+ time_emb = self.timestep_embedder(L,t,motif_mask)
94
+ n_tmpl = t1d.shape[1]
95
+ t1d = torch.cat([t1d, time_emb[None,None,...].repeat(1,n_tmpl,1,1)], dim=-1)
96
+
97
+ # add template embedding
98
+ pair, state = self.templ_emb(t1d, t2d, alpha_t, xyz_t, pair, state, use_checkpoint=use_checkpoint)
99
+
100
+ # Predict coordinates from given inputs
101
+ is_frozen_residue = motif_mask if self.freeze_track_motif else torch.zeros_like(motif_mask).bool()
102
+ msa, pair, R, T, alpha_s, state = self.simulator(seq, msa_latent, msa_full, pair, xyz[:,:,:3],
103
+ state, idx, use_checkpoint=use_checkpoint,
104
+ motif_mask=is_frozen_residue)
105
+
106
+ if return_raw:
107
+ # get last structure
108
+ xyz = einsum('bnij,bnaj->bnai', R[-1], xyz[:,:,:3]-xyz[:,:,1].unsqueeze(-2)) + T[-1].unsqueeze(-2)
109
+ return msa[:,0], pair, xyz, state, alpha_s[-1]
110
+
111
+ # predict masked amino acids
112
+ logits_aa = self.aa_pred(msa)
113
+
114
+ # Predict LDDT
115
+ lddt = self.lddt_pred(state)
116
+
117
+ if return_infer:
118
+ # get last structure
119
+ xyz = einsum('bnij,bnaj->bnai', R[-1], xyz[:,:,:3]-xyz[:,:,1].unsqueeze(-2)) + T[-1].unsqueeze(-2)
120
+
121
+ # get scalar plddt
122
+ nbin = lddt.shape[1]
123
+ bin_step = 1.0 / nbin
124
+ lddt_bins = torch.linspace(bin_step, 1.0, nbin, dtype=lddt.dtype, device=lddt.device)
125
+ pred_lddt = nn.Softmax(dim=1)(lddt)
126
+ pred_lddt = torch.sum(lddt_bins[None,:,None]*pred_lddt, dim=1)
127
+
128
+ return msa[:,0], pair, xyz, state, alpha_s[-1], logits_aa.permute(0,2,1), pred_lddt
129
+
130
+ #
131
+ # predict distogram & orientograms
132
+ logits = self.c6d_pred(pair)
133
+
134
+ # predict experimentally resolved or not
135
+ logits_exp = self.exp_pred(msa[:,0], state)
136
+
137
+ # get all intermediate bb structures
138
+ xyz = einsum('rbnij,bnaj->rbnai', R, xyz[:,:,:3]-xyz[:,:,1].unsqueeze(-2)) + T.unsqueeze(-2)
139
+
140
+ return logits, logits_aa, logits_exp, xyz, alpha_s, lddt
rfdiffusion/SE3_network.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ #from equivariant_attention.modules import get_basis_and_r, GSE3Res, GNormBias
5
+ #from equivariant_attention.modules import GConvSE3, GNormSE3
6
+ #from equivariant_attention.fibers import Fiber
7
+
8
+ from rfdiffusion.util_module import init_lecun_normal_param
9
+ from se3_transformer.model import SE3Transformer
10
+ from se3_transformer.model.fiber import Fiber
11
+
12
+ class SE3TransformerWrapper(nn.Module):
13
+ """SE(3) equivariant GCN with attention"""
14
+ def __init__(self, num_layers=2, num_channels=32, num_degrees=3, n_heads=4, div=4,
15
+ l0_in_features=32, l0_out_features=32,
16
+ l1_in_features=3, l1_out_features=2,
17
+ num_edge_features=32):
18
+ super().__init__()
19
+ # Build the network
20
+ self.l1_in = l1_in_features
21
+ #
22
+ fiber_edge = Fiber({0: num_edge_features})
23
+ if l1_out_features > 0:
24
+ if l1_in_features > 0:
25
+ fiber_in = Fiber({0: l0_in_features, 1: l1_in_features})
26
+ fiber_hidden = Fiber.create(num_degrees, num_channels)
27
+ fiber_out = Fiber({0: l0_out_features, 1: l1_out_features})
28
+ else:
29
+ fiber_in = Fiber({0: l0_in_features})
30
+ fiber_hidden = Fiber.create(num_degrees, num_channels)
31
+ fiber_out = Fiber({0: l0_out_features, 1: l1_out_features})
32
+ else:
33
+ if l1_in_features > 0:
34
+ fiber_in = Fiber({0: l0_in_features, 1: l1_in_features})
35
+ fiber_hidden = Fiber.create(num_degrees, num_channels)
36
+ fiber_out = Fiber({0: l0_out_features})
37
+ else:
38
+ fiber_in = Fiber({0: l0_in_features})
39
+ fiber_hidden = Fiber.create(num_degrees, num_channels)
40
+ fiber_out = Fiber({0: l0_out_features})
41
+
42
+ self.se3 = SE3Transformer(num_layers=num_layers,
43
+ fiber_in=fiber_in,
44
+ fiber_hidden=fiber_hidden,
45
+ fiber_out = fiber_out,
46
+ num_heads=n_heads,
47
+ channels_div=div,
48
+ fiber_edge=fiber_edge,
49
+ use_layer_norm=True)
50
+ #use_layer_norm=False)
51
+
52
+ self.reset_parameter()
53
+
54
+ def reset_parameter(self):
55
+
56
+ # make sure linear layer before ReLu are initialized with kaiming_normal_
57
+ for n, p in self.se3.named_parameters():
58
+ if "bias" in n:
59
+ nn.init.zeros_(p)
60
+ elif len(p.shape) == 1:
61
+ continue
62
+ else:
63
+ if "radial_func" not in n:
64
+ p = init_lecun_normal_param(p)
65
+ else:
66
+ if "net.6" in n:
67
+ nn.init.zeros_(p)
68
+ else:
69
+ nn.init.kaiming_normal_(p, nonlinearity='relu')
70
+
71
+ # make last layers to be zero-initialized
72
+ #self.se3.graph_modules[-1].to_kernel_self['0'] = init_lecun_normal_param(self.se3.graph_modules[-1].to_kernel_self['0'])
73
+ #self.se3.graph_modules[-1].to_kernel_self['1'] = init_lecun_normal_param(self.se3.graph_modules[-1].to_kernel_self['1'])
74
+ nn.init.zeros_(self.se3.graph_modules[-1].to_kernel_self['0'])
75
+ nn.init.zeros_(self.se3.graph_modules[-1].to_kernel_self['1'])
76
+
77
+ def forward(self, G, type_0_features, type_1_features=None, edge_features=None):
78
+ if self.l1_in > 0:
79
+ node_features = {'0': type_0_features, '1': type_1_features}
80
+ else:
81
+ node_features = {'0': type_0_features}
82
+ edge_features = {'0': edge_features}
83
+ return self.se3(G, node_features, edge_features)
rfdiffusion/Track_module.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.utils.checkpoint as checkpoint
2
+ from rfdiffusion.util_module import *
3
+ from rfdiffusion.Attention_module import *
4
+ from rfdiffusion.SE3_network import SE3TransformerWrapper
5
+
6
+ # Components for three-track blocks
7
+ # 1. MSA -> MSA update (biased attention. bias from pair & structure)
8
+ # 2. Pair -> Pair update (biased attention. bias from structure)
9
+ # 3. MSA -> Pair update (extract coevolution signal)
10
+ # 4. Str -> Str update (node from MSA, edge from Pair)
11
+
12
+ # Update MSA with biased self-attention. bias from Pair & Str
13
+ class MSAPairStr2MSA(nn.Module):
14
+ def __init__(self, d_msa=256, d_pair=128, n_head=8, d_state=16,
15
+ d_hidden=32, p_drop=0.15, use_global_attn=False):
16
+ super(MSAPairStr2MSA, self).__init__()
17
+ self.norm_pair = nn.LayerNorm(d_pair)
18
+ self.proj_pair = nn.Linear(d_pair+36, d_pair)
19
+ self.norm_state = nn.LayerNorm(d_state)
20
+ self.proj_state = nn.Linear(d_state, d_msa)
21
+ self.drop_row = Dropout(broadcast_dim=1, p_drop=p_drop)
22
+ self.row_attn = MSARowAttentionWithBias(d_msa=d_msa, d_pair=d_pair,
23
+ n_head=n_head, d_hidden=d_hidden)
24
+ if use_global_attn:
25
+ self.col_attn = MSAColGlobalAttention(d_msa=d_msa, n_head=n_head, d_hidden=d_hidden)
26
+ else:
27
+ self.col_attn = MSAColAttention(d_msa=d_msa, n_head=n_head, d_hidden=d_hidden)
28
+ self.ff = FeedForwardLayer(d_msa, 4, p_drop=p_drop)
29
+
30
+ # Do proper initialization
31
+ self.reset_parameter()
32
+
33
+ def reset_parameter(self):
34
+ # initialize weights to normal distrib
35
+ self.proj_pair = init_lecun_normal(self.proj_pair)
36
+ self.proj_state = init_lecun_normal(self.proj_state)
37
+
38
+ # initialize bias to zeros
39
+ nn.init.zeros_(self.proj_pair.bias)
40
+ nn.init.zeros_(self.proj_state.bias)
41
+
42
+ def forward(self, msa, pair, rbf_feat, state):
43
+ '''
44
+ Inputs:
45
+ - msa: MSA feature (B, N, L, d_msa)
46
+ - pair: Pair feature (B, L, L, d_pair)
47
+ - rbf_feat: Ca-Ca distance feature calculated from xyz coordinates (B, L, L, 36)
48
+ - xyz: xyz coordinates (B, L, n_atom, 3)
49
+ - state: updated node features after SE(3)-Transformer layer (B, L, d_state)
50
+ Output:
51
+ - msa: Updated MSA feature (B, N, L, d_msa)
52
+ '''
53
+ B, N, L = msa.shape[:3]
54
+
55
+ # prepare input bias feature by combining pair & coordinate info
56
+ pair = self.norm_pair(pair)
57
+ pair = torch.cat((pair, rbf_feat), dim=-1)
58
+ pair = self.proj_pair(pair) # (B, L, L, d_pair)
59
+ #
60
+ # update query sequence feature (first sequence in the MSA) with feedbacks (state) from SE3
61
+ state = self.norm_state(state)
62
+ state = self.proj_state(state).reshape(B, 1, L, -1)
63
+ msa = msa.index_add(1, torch.tensor([0,], device=state.device), state)
64
+ #
65
+ # Apply row/column attention to msa & transform
66
+ msa = msa + self.drop_row(self.row_attn(msa, pair))
67
+ msa = msa + self.col_attn(msa)
68
+ msa = msa + self.ff(msa)
69
+
70
+ return msa
71
+
72
+ class PairStr2Pair(nn.Module):
73
+ def __init__(self, d_pair=128, n_head=4, d_hidden=32, d_rbf=36, p_drop=0.15):
74
+ super(PairStr2Pair, self).__init__()
75
+
76
+ self.emb_rbf = nn.Linear(d_rbf, d_hidden)
77
+ self.proj_rbf = nn.Linear(d_hidden, d_pair)
78
+
79
+ self.drop_row = Dropout(broadcast_dim=1, p_drop=p_drop)
80
+ self.drop_col = Dropout(broadcast_dim=2, p_drop=p_drop)
81
+
82
+ self.row_attn = BiasedAxialAttention(d_pair, d_pair, n_head, d_hidden, p_drop=p_drop, is_row=True)
83
+ self.col_attn = BiasedAxialAttention(d_pair, d_pair, n_head, d_hidden, p_drop=p_drop, is_row=False)
84
+
85
+ self.ff = FeedForwardLayer(d_pair, 2)
86
+
87
+ self.reset_parameter()
88
+
89
+ def reset_parameter(self):
90
+ nn.init.kaiming_normal_(self.emb_rbf.weight, nonlinearity='relu')
91
+ nn.init.zeros_(self.emb_rbf.bias)
92
+
93
+ self.proj_rbf = init_lecun_normal(self.proj_rbf)
94
+ nn.init.zeros_(self.proj_rbf.bias)
95
+
96
+ def forward(self, pair, rbf_feat):
97
+ B, L = pair.shape[:2]
98
+
99
+ rbf_feat = self.proj_rbf(F.relu_(self.emb_rbf(rbf_feat)))
100
+
101
+ pair = pair + self.drop_row(self.row_attn(pair, rbf_feat))
102
+ pair = pair + self.drop_col(self.col_attn(pair, rbf_feat))
103
+ pair = pair + self.ff(pair)
104
+ return pair
105
+
106
+ class MSA2Pair(nn.Module):
107
+ def __init__(self, d_msa=256, d_pair=128, d_hidden=32, p_drop=0.15):
108
+ super(MSA2Pair, self).__init__()
109
+ self.norm = nn.LayerNorm(d_msa)
110
+ self.proj_left = nn.Linear(d_msa, d_hidden)
111
+ self.proj_right = nn.Linear(d_msa, d_hidden)
112
+ self.proj_out = nn.Linear(d_hidden*d_hidden, d_pair)
113
+
114
+ self.reset_parameter()
115
+
116
+ def reset_parameter(self):
117
+ # normal initialization
118
+ self.proj_left = init_lecun_normal(self.proj_left)
119
+ self.proj_right = init_lecun_normal(self.proj_right)
120
+ nn.init.zeros_(self.proj_left.bias)
121
+ nn.init.zeros_(self.proj_right.bias)
122
+
123
+ # zero initialize output
124
+ nn.init.zeros_(self.proj_out.weight)
125
+ nn.init.zeros_(self.proj_out.bias)
126
+
127
+ def forward(self, msa, pair):
128
+ B, N, L = msa.shape[:3]
129
+ msa = self.norm(msa)
130
+ left = self.proj_left(msa)
131
+ right = self.proj_right(msa)
132
+ right = right / float(N)
133
+ out = einsum('bsli,bsmj->blmij', left, right).reshape(B, L, L, -1)
134
+ out = self.proj_out(out)
135
+
136
+ pair = pair + out
137
+
138
+ return pair
139
+
140
+ class SCPred(nn.Module):
141
+ def __init__(self, d_msa=256, d_state=32, d_hidden=128, p_drop=0.15):
142
+ super(SCPred, self).__init__()
143
+ self.norm_s0 = nn.LayerNorm(d_msa)
144
+ self.norm_si = nn.LayerNorm(d_state)
145
+ self.linear_s0 = nn.Linear(d_msa, d_hidden)
146
+ self.linear_si = nn.Linear(d_state, d_hidden)
147
+
148
+ # ResNet layers
149
+ self.linear_1 = nn.Linear(d_hidden, d_hidden)
150
+ self.linear_2 = nn.Linear(d_hidden, d_hidden)
151
+ self.linear_3 = nn.Linear(d_hidden, d_hidden)
152
+ self.linear_4 = nn.Linear(d_hidden, d_hidden)
153
+
154
+ # Final outputs
155
+ self.linear_out = nn.Linear(d_hidden, 20)
156
+
157
+ self.reset_parameter()
158
+
159
+ def reset_parameter(self):
160
+ # normal initialization
161
+ self.linear_s0 = init_lecun_normal(self.linear_s0)
162
+ self.linear_si = init_lecun_normal(self.linear_si)
163
+ self.linear_out = init_lecun_normal(self.linear_out)
164
+ nn.init.zeros_(self.linear_s0.bias)
165
+ nn.init.zeros_(self.linear_si.bias)
166
+ nn.init.zeros_(self.linear_out.bias)
167
+
168
+ # right before relu activation: He initializer (kaiming normal)
169
+ nn.init.kaiming_normal_(self.linear_1.weight, nonlinearity='relu')
170
+ nn.init.zeros_(self.linear_1.bias)
171
+ nn.init.kaiming_normal_(self.linear_3.weight, nonlinearity='relu')
172
+ nn.init.zeros_(self.linear_3.bias)
173
+
174
+ # right before residual connection: zero initialize
175
+ nn.init.zeros_(self.linear_2.weight)
176
+ nn.init.zeros_(self.linear_2.bias)
177
+ nn.init.zeros_(self.linear_4.weight)
178
+ nn.init.zeros_(self.linear_4.bias)
179
+
180
+ def forward(self, seq, state):
181
+ '''
182
+ Predict side-chain torsion angles along with backbone torsions
183
+ Inputs:
184
+ - seq: hidden embeddings corresponding to query sequence (B, L, d_msa)
185
+ - state: state feature (output l0 feature) from previous SE3 layer (B, L, d_state)
186
+ Outputs:
187
+ - si: predicted torsion angles (phi, psi, omega, chi1~4 with cos/sin, Cb bend, Cb twist, CG) (B, L, 10, 2)
188
+ '''
189
+ B, L = seq.shape[:2]
190
+ seq = self.norm_s0(seq)
191
+ state = self.norm_si(state)
192
+ si = self.linear_s0(seq) + self.linear_si(state)
193
+
194
+ si = si + self.linear_2(F.relu_(self.linear_1(F.relu_(si))))
195
+ si = si + self.linear_4(F.relu_(self.linear_3(F.relu_(si))))
196
+
197
+ si = self.linear_out(F.relu_(si))
198
+ return si.view(B, L, 10, 2)
199
+
200
+
201
+ class Str2Str(nn.Module):
202
+ def __init__(self, d_msa=256, d_pair=128, d_state=16,
203
+ SE3_param={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32}, p_drop=0.1):
204
+ super(Str2Str, self).__init__()
205
+
206
+ # initial node & pair feature process
207
+ self.norm_msa = nn.LayerNorm(d_msa)
208
+ self.norm_pair = nn.LayerNorm(d_pair)
209
+ self.norm_state = nn.LayerNorm(d_state)
210
+
211
+ self.embed_x = nn.Linear(d_msa+d_state, SE3_param['l0_in_features'])
212
+ self.embed_e1 = nn.Linear(d_pair, SE3_param['num_edge_features'])
213
+ self.embed_e2 = nn.Linear(SE3_param['num_edge_features']+36+1, SE3_param['num_edge_features'])
214
+
215
+ self.norm_node = nn.LayerNorm(SE3_param['l0_in_features'])
216
+ self.norm_edge1 = nn.LayerNorm(SE3_param['num_edge_features'])
217
+ self.norm_edge2 = nn.LayerNorm(SE3_param['num_edge_features'])
218
+
219
+ self.se3 = SE3TransformerWrapper(**SE3_param)
220
+ self.sc_predictor = SCPred(d_msa=d_msa, d_state=SE3_param['l0_out_features'],
221
+ p_drop=p_drop)
222
+
223
+ self.reset_parameter()
224
+
225
+ def reset_parameter(self):
226
+ # initialize weights to normal distribution
227
+ self.embed_x = init_lecun_normal(self.embed_x)
228
+ self.embed_e1 = init_lecun_normal(self.embed_e1)
229
+ self.embed_e2 = init_lecun_normal(self.embed_e2)
230
+
231
+ # initialize bias to zeros
232
+ nn.init.zeros_(self.embed_x.bias)
233
+ nn.init.zeros_(self.embed_e1.bias)
234
+ nn.init.zeros_(self.embed_e2.bias)
235
+
236
+ @torch.cuda.amp.autocast(enabled=False)
237
+ def forward(self, msa, pair, R_in, T_in, xyz, state, idx, motif_mask, top_k=64, eps=1e-5):
238
+ B, N, L = msa.shape[:3]
239
+
240
+ if motif_mask is None:
241
+ motif_mask = torch.zeros(L).bool()
242
+
243
+ # process msa & pair features
244
+ node = self.norm_msa(msa[:,0])
245
+ pair = self.norm_pair(pair)
246
+ state = self.norm_state(state)
247
+
248
+ node = torch.cat((node, state), dim=-1)
249
+ node = self.norm_node(self.embed_x(node))
250
+ pair = self.norm_edge1(self.embed_e1(pair))
251
+
252
+ neighbor = get_seqsep(idx)
253
+ rbf_feat = rbf(torch.cdist(xyz[:,:,1], xyz[:,:,1]))
254
+ pair = torch.cat((pair, rbf_feat, neighbor), dim=-1)
255
+ pair = self.norm_edge2(self.embed_e2(pair))
256
+
257
+ # define graph
258
+ if top_k != 0:
259
+ G, edge_feats = make_topk_graph(xyz[:,:,1,:], pair, idx, top_k=top_k)
260
+ else:
261
+ G, edge_feats = make_full_graph(xyz[:,:,1,:], pair, idx, top_k=top_k)
262
+ l1_feats = xyz - xyz[:,:,1,:].unsqueeze(2)
263
+ l1_feats = l1_feats.reshape(B*L, -1, 3)
264
+
265
+ # apply SE(3) Transformer & update coordinates
266
+ shift = self.se3(G, node.reshape(B*L, -1, 1), l1_feats, edge_feats)
267
+
268
+ state = shift['0'].reshape(B, L, -1) # (B, L, C)
269
+
270
+ offset = shift['1'].reshape(B, L, 2, 3)
271
+ offset[:,motif_mask,...] = 0 # NOTE: motif mask is all zeros if not freeezing the motif
272
+
273
+ delTi = offset[:,:,0,:] / 10.0 # translation
274
+ R = offset[:,:,1,:] / 100.0 # rotation
275
+
276
+ Qnorm = torch.sqrt( 1 + torch.sum(R*R, dim=-1) )
277
+ qA, qB, qC, qD = 1/Qnorm, R[:,:,0]/Qnorm, R[:,:,1]/Qnorm, R[:,:,2]/Qnorm
278
+
279
+ delRi = torch.zeros((B,L,3,3), device=xyz.device)
280
+ delRi[:,:,0,0] = qA*qA+qB*qB-qC*qC-qD*qD
281
+ delRi[:,:,0,1] = 2*qB*qC - 2*qA*qD
282
+ delRi[:,:,0,2] = 2*qB*qD + 2*qA*qC
283
+ delRi[:,:,1,0] = 2*qB*qC + 2*qA*qD
284
+ delRi[:,:,1,1] = qA*qA-qB*qB+qC*qC-qD*qD
285
+ delRi[:,:,1,2] = 2*qC*qD - 2*qA*qB
286
+ delRi[:,:,2,0] = 2*qB*qD - 2*qA*qC
287
+ delRi[:,:,2,1] = 2*qC*qD + 2*qA*qB
288
+ delRi[:,:,2,2] = qA*qA-qB*qB-qC*qC+qD*qD
289
+
290
+ Ri = einsum('bnij,bnjk->bnik', delRi, R_in)
291
+ Ti = delTi + T_in #einsum('bnij,bnj->bni', delRi, T_in) + delTi
292
+
293
+ alpha = self.sc_predictor(msa[:,0], state)
294
+ return Ri, Ti, state, alpha
295
+
296
+ class IterBlock(nn.Module):
297
+ def __init__(self, d_msa=256, d_pair=128,
298
+ n_head_msa=8, n_head_pair=4,
299
+ use_global_attn=False,
300
+ d_hidden=32, d_hidden_msa=None, p_drop=0.15,
301
+ SE3_param={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32}):
302
+ super(IterBlock, self).__init__()
303
+ if d_hidden_msa == None:
304
+ d_hidden_msa = d_hidden
305
+
306
+ self.msa2msa = MSAPairStr2MSA(d_msa=d_msa, d_pair=d_pair,
307
+ n_head=n_head_msa,
308
+ d_state=SE3_param['l0_out_features'],
309
+ use_global_attn=use_global_attn,
310
+ d_hidden=d_hidden_msa, p_drop=p_drop)
311
+ self.msa2pair = MSA2Pair(d_msa=d_msa, d_pair=d_pair,
312
+ d_hidden=d_hidden//2, p_drop=p_drop)
313
+ #d_hidden=d_hidden, p_drop=p_drop)
314
+ self.pair2pair = PairStr2Pair(d_pair=d_pair, n_head=n_head_pair,
315
+ d_hidden=d_hidden, p_drop=p_drop)
316
+ self.str2str = Str2Str(d_msa=d_msa, d_pair=d_pair,
317
+ d_state=SE3_param['l0_out_features'],
318
+ SE3_param=SE3_param,
319
+ p_drop=p_drop)
320
+
321
+ def forward(self, msa, pair, R_in, T_in, xyz, state, idx, motif_mask, use_checkpoint=False):
322
+ rbf_feat = rbf(torch.cdist(xyz[:,:,1,:], xyz[:,:,1,:]))
323
+ if use_checkpoint:
324
+ msa = checkpoint.checkpoint(create_custom_forward(self.msa2msa), msa, pair, rbf_feat, state)
325
+ pair = checkpoint.checkpoint(create_custom_forward(self.msa2pair), msa, pair)
326
+ pair = checkpoint.checkpoint(create_custom_forward(self.pair2pair), pair, rbf_feat)
327
+ R, T, state, alpha = checkpoint.checkpoint(create_custom_forward(self.str2str, top_k=0), msa, pair, R_in, T_in, xyz, state, idx, motif_mask)
328
+ else:
329
+ msa = self.msa2msa(msa, pair, rbf_feat, state)
330
+ pair = self.msa2pair(msa, pair)
331
+ pair = self.pair2pair(pair, rbf_feat)
332
+ R, T, state, alpha = self.str2str(msa, pair, R_in, T_in, xyz, state, idx, motif_mask=motif_mask, top_k=0)
333
+
334
+ return msa, pair, R, T, state, alpha
335
+
336
+ class IterativeSimulator(nn.Module):
337
+ def __init__(self, n_extra_block=4, n_main_block=12, n_ref_block=4,
338
+ d_msa=256, d_msa_full=64, d_pair=128, d_hidden=32,
339
+ n_head_msa=8, n_head_pair=4,
340
+ SE3_param_full={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32},
341
+ SE3_param_topk={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32},
342
+ p_drop=0.15):
343
+ super(IterativeSimulator, self).__init__()
344
+ self.n_extra_block = n_extra_block
345
+ self.n_main_block = n_main_block
346
+ self.n_ref_block = n_ref_block
347
+
348
+ self.proj_state = nn.Linear(SE3_param_topk['l0_out_features'], SE3_param_full['l0_out_features'])
349
+ # Update with extra sequences
350
+ if n_extra_block > 0:
351
+ self.extra_block = nn.ModuleList([IterBlock(d_msa=d_msa_full, d_pair=d_pair,
352
+ n_head_msa=n_head_msa,
353
+ n_head_pair=n_head_pair,
354
+ d_hidden_msa=8,
355
+ d_hidden=d_hidden,
356
+ p_drop=p_drop,
357
+ use_global_attn=True,
358
+ SE3_param=SE3_param_full)
359
+ for i in range(n_extra_block)])
360
+
361
+ # Update with seed sequences
362
+ if n_main_block > 0:
363
+ self.main_block = nn.ModuleList([IterBlock(d_msa=d_msa, d_pair=d_pair,
364
+ n_head_msa=n_head_msa,
365
+ n_head_pair=n_head_pair,
366
+ d_hidden=d_hidden,
367
+ p_drop=p_drop,
368
+ use_global_attn=False,
369
+ SE3_param=SE3_param_full)
370
+ for i in range(n_main_block)])
371
+
372
+ self.proj_state2 = nn.Linear(SE3_param_full['l0_out_features'], SE3_param_topk['l0_out_features'])
373
+ # Final SE(3) refinement
374
+ if n_ref_block > 0:
375
+ self.str_refiner = Str2Str(d_msa=d_msa, d_pair=d_pair,
376
+ d_state=SE3_param_topk['l0_out_features'],
377
+ SE3_param=SE3_param_topk,
378
+ p_drop=p_drop)
379
+
380
+ self.reset_parameter()
381
+ def reset_parameter(self):
382
+ self.proj_state = init_lecun_normal(self.proj_state)
383
+ nn.init.zeros_(self.proj_state.bias)
384
+ self.proj_state2 = init_lecun_normal(self.proj_state2)
385
+ nn.init.zeros_(self.proj_state2.bias)
386
+
387
+ def forward(self, seq, msa, msa_full, pair, xyz_in, state, idx, use_checkpoint=False, motif_mask=None):
388
+ """
389
+ input:
390
+ seq: query sequence (B, L)
391
+ msa: seed MSA embeddings (B, N, L, d_msa)
392
+ msa_full: extra MSA embeddings (B, N, L, d_msa_full)
393
+ pair: initial residue pair embeddings (B, L, L, d_pair)
394
+ xyz_in: initial BB coordinates (B, L, n_atom, 3)
395
+ state: initial state features containing mixture of query seq, sidechain, accuracy info (B, L, d_state)
396
+ idx: residue index
397
+ motif_mask: bool tensor, True if motif position that is frozen, else False(L,)
398
+ """
399
+
400
+ B, L = pair.shape[:2]
401
+
402
+ if motif_mask is None:
403
+ motif_mask = torch.zeros(L).bool()
404
+
405
+ R_in = torch.eye(3, device=xyz_in.device).reshape(1,1,3,3).expand(B, L, -1, -1)
406
+ T_in = xyz_in[:,:,1].clone()
407
+ xyz_in = xyz_in - T_in.unsqueeze(-2)
408
+
409
+ state = self.proj_state(state)
410
+
411
+ R_s = list()
412
+ T_s = list()
413
+ alpha_s = list()
414
+ for i_m in range(self.n_extra_block):
415
+ R_in = R_in.detach() # detach rotation (for stability)
416
+ T_in = T_in.detach()
417
+ # Get current BB structure
418
+ xyz = einsum('bnij,bnaj->bnai', R_in, xyz_in) + T_in.unsqueeze(-2)
419
+
420
+ msa_full, pair, R_in, T_in, state, alpha = self.extra_block[i_m](msa_full,
421
+ pair,
422
+ R_in,
423
+ T_in,
424
+ xyz,
425
+ state,
426
+ idx,
427
+ motif_mask=motif_mask,
428
+ use_checkpoint=use_checkpoint)
429
+ R_s.append(R_in)
430
+ T_s.append(T_in)
431
+ alpha_s.append(alpha)
432
+
433
+ for i_m in range(self.n_main_block):
434
+ R_in = R_in.detach()
435
+ T_in = T_in.detach()
436
+ # Get current BB structure
437
+ xyz = einsum('bnij,bnaj->bnai', R_in, xyz_in) + T_in.unsqueeze(-2)
438
+
439
+ msa, pair, R_in, T_in, state, alpha = self.main_block[i_m](msa,
440
+ pair,
441
+ R_in,
442
+ T_in,
443
+ xyz,
444
+ state,
445
+ idx,
446
+ motif_mask=motif_mask,
447
+ use_checkpoint=use_checkpoint)
448
+ R_s.append(R_in)
449
+ T_s.append(T_in)
450
+ alpha_s.append(alpha)
451
+
452
+ state = self.proj_state2(state)
453
+ for i_m in range(self.n_ref_block):
454
+ R_in = R_in.detach()
455
+ T_in = T_in.detach()
456
+ xyz = einsum('bnij,bnaj->bnai', R_in, xyz_in) + T_in.unsqueeze(-2)
457
+ R_in, T_in, state, alpha = self.str_refiner(msa,
458
+ pair,
459
+ R_in,
460
+ T_in,
461
+ xyz,
462
+ state,
463
+ idx,
464
+ top_k=64,
465
+ motif_mask=motif_mask)
466
+ R_s.append(R_in)
467
+ T_s.append(T_in)
468
+ alpha_s.append(alpha)
469
+
470
+ R_s = torch.stack(R_s, dim=0)
471
+ T_s = torch.stack(T_s, dim=0)
472
+ alpha_s = torch.stack(alpha_s, dim=0)
473
+
474
+ return msa, pair, R_s, T_s, alpha_s, state
rfdiffusion/__init__.py ADDED
File without changes
rfdiffusion/__pycache__/Attention_module.cpython-310.pyc ADDED
Binary file (10.7 kB). View file
 
rfdiffusion/__pycache__/Attention_module.cpython-311.pyc ADDED
Binary file (27 kB). View file
 
rfdiffusion/__pycache__/Attention_module.cpython-39.pyc ADDED
Binary file (11.3 kB). View file
 
rfdiffusion/__pycache__/AuxiliaryPredictor.cpython-310.pyc ADDED
Binary file (3.59 kB). View file
 
rfdiffusion/__pycache__/AuxiliaryPredictor.cpython-311.pyc ADDED
Binary file (7.34 kB). View file
 
rfdiffusion/__pycache__/AuxiliaryPredictor.cpython-39.pyc ADDED
Binary file (3.76 kB). View file
 
rfdiffusion/__pycache__/Embeddings.cpython-310.pyc ADDED
Binary file (9.5 kB). View file
 
rfdiffusion/__pycache__/Embeddings.cpython-311.pyc ADDED
Binary file (20 kB). View file
 
rfdiffusion/__pycache__/Embeddings.cpython-39.pyc ADDED
Binary file (9.54 kB). View file
 
rfdiffusion/__pycache__/RoseTTAFoldModel.cpython-310.pyc ADDED
Binary file (3.68 kB). View file
 
rfdiffusion/__pycache__/RoseTTAFoldModel.cpython-311.pyc ADDED
Binary file (7.09 kB). View file
 
rfdiffusion/__pycache__/RoseTTAFoldModel.cpython-39.pyc ADDED
Binary file (3.67 kB). View file
 
rfdiffusion/__pycache__/SE3_network.cpython-310.pyc ADDED
Binary file (2.29 kB). View file
 
rfdiffusion/__pycache__/SE3_network.cpython-311.pyc ADDED
Binary file (4.13 kB). View file
 
rfdiffusion/__pycache__/SE3_network.cpython-39.pyc ADDED
Binary file (2.28 kB). View file
 
rfdiffusion/__pycache__/Track_module.cpython-310.pyc ADDED
Binary file (14.1 kB). View file
 
rfdiffusion/__pycache__/Track_module.cpython-311.pyc ADDED
Binary file (31.4 kB). View file
 
rfdiffusion/__pycache__/Track_module.cpython-39.pyc ADDED
Binary file (14.1 kB). View file
 
rfdiffusion/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (154 Bytes). View file
 
rfdiffusion/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (170 Bytes). View file
 
rfdiffusion/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (152 Bytes). View file
 
rfdiffusion/__pycache__/chemical.cpython-310.pyc ADDED
Binary file (20.5 kB). View file
 
rfdiffusion/__pycache__/chemical.cpython-311.pyc ADDED
Binary file (24.6 kB). View file
 
rfdiffusion/__pycache__/chemical.cpython-39.pyc ADDED
Binary file (20.5 kB). View file
 
rfdiffusion/__pycache__/contigs.cpython-310.pyc ADDED
Binary file (10.2 kB). View file
 
rfdiffusion/__pycache__/contigs.cpython-311.pyc ADDED
Binary file (23.2 kB). View file
 
rfdiffusion/__pycache__/contigs.cpython-39.pyc ADDED
Binary file (10.5 kB). View file
 
rfdiffusion/__pycache__/diffusion.cpython-310.pyc ADDED
Binary file (19.4 kB). View file
 
rfdiffusion/__pycache__/diffusion.cpython-311.pyc ADDED
Binary file (31.4 kB). View file
 
rfdiffusion/__pycache__/diffusion.cpython-39.pyc ADDED
Binary file (19.5 kB). View file
 
rfdiffusion/__pycache__/igso3.cpython-310.pyc ADDED
Binary file (4.72 kB). View file
 
rfdiffusion/__pycache__/igso3.cpython-311.pyc ADDED
Binary file (8 kB). View file
 
rfdiffusion/__pycache__/igso3.cpython-39.pyc ADDED
Binary file (4.76 kB). View file
 
rfdiffusion/__pycache__/kinematics.cpython-310.pyc ADDED
Binary file (9.41 kB). View file
 
rfdiffusion/__pycache__/kinematics.cpython-311.pyc ADDED
Binary file (18.8 kB). View file
 
rfdiffusion/__pycache__/kinematics.cpython-39.pyc ADDED
Binary file (9.42 kB). View file
 
rfdiffusion/__pycache__/model_input_logger.cpython-311.pyc ADDED
Binary file (4.71 kB). View file
 
rfdiffusion/__pycache__/model_input_logger.cpython-39.pyc ADDED
Binary file (2.62 kB). View file