GlandVergil commited on
Commit
4184384
1 Parent(s): 551c2bb

Create app.py

Browse files
Files changed (1) hide show
  1. RFdiffusion/app.py +180 -0
RFdiffusion/app.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import os, time, pickle
3
+ import torch
4
+ from omegaconf import OmegaConf
5
+ import hydra
6
+ import logging
7
+ from rfdiffusion.util import writepdb_multi, writepdb
8
+ from rfdiffusion.inference import utils as iu
9
+ from hydra.core.hydra_config import HydraConfig
10
+ import numpy as np
11
+ import random
12
+ import glob
13
+ import gradio as gr
14
+ def greet(mtf):
15
+ return "Hello " + name + "!!"
16
+ def make_deterministic(seed=0):
17
+ torch.manual_seed(seed)
18
+ np.random.seed(seed)
19
+ random.seed(seed)
20
+
21
+
22
+ @hydra.main(version_base=None, config_path="../config/inference", config_name="base")
23
+ def main(conf: HydraConfig) -> None:
24
+ log = logging.getLogger(__name__)
25
+ if conf.inference.deterministic:
26
+ make_deterministic()
27
+
28
+ # Check for available GPU and print result of check
29
+ if torch.cuda.is_available():
30
+ device_name = torch.cuda.get_device_name(torch.cuda.current_device())
31
+ log.info(f"Found GPU with device_name {device_name}. Will run RFdiffusion on {device_name}")
32
+ else:
33
+ log.info("////////////////////////////////////////////////")
34
+ log.info("///// NO GPU DETECTED! Falling back to CPU /////")
35
+ log.info("////////////////////////////////////////////////")
36
+
37
+ # Initialize sampler and target/contig.
38
+ sampler = iu.sampler_selector(conf)
39
+
40
+ # Loop over number of designs to sample.
41
+ design_startnum = sampler.inf_conf.design_startnum
42
+ if sampler.inf_conf.design_startnum == -1:
43
+ existing = glob.glob(sampler.inf_conf.output_prefix + "*.pdb")
44
+ indices = [-1]
45
+ for e in existing:
46
+ print(e)
47
+ m = re.match(".*_(\d+)\.pdb$", e)
48
+ print(m)
49
+ if not m:
50
+ continue
51
+ m = m.groups()[0]
52
+ indices.append(int(m))
53
+ design_startnum = max(indices) + 1
54
+
55
+ for i_des in range(design_startnum, design_startnum + sampler.inf_conf.num_designs):
56
+ if conf.inference.deterministic:
57
+ make_deterministic(i_des)
58
+
59
+ start_time = time.time()
60
+ out_prefix = f"{sampler.inf_conf.output_prefix}_{i_des}"
61
+ log.info(f"Making design {out_prefix}")
62
+ if sampler.inf_conf.cautious and os.path.exists(out_prefix + ".pdb"):
63
+ log.info(
64
+ f"(cautious mode) Skipping this design because {out_prefix}.pdb already exists."
65
+ )
66
+ continue
67
+
68
+ x_init, seq_init = sampler.sample_init()
69
+ denoised_xyz_stack = []
70
+ px0_xyz_stack = []
71
+ seq_stack = []
72
+ plddt_stack = []
73
+
74
+ x_t = torch.clone(x_init)
75
+ seq_t = torch.clone(seq_init)
76
+ # Loop over number of reverse diffusion time steps.
77
+ for t in range(int(sampler.t_step_input), sampler.inf_conf.final_step - 1, -1):
78
+ px0, x_t, seq_t, plddt = sampler.sample_step(
79
+ t=t, x_t=x_t, seq_init=seq_t, final_step=sampler.inf_conf.final_step
80
+ )
81
+ px0_xyz_stack.append(px0)
82
+ denoised_xyz_stack.append(x_t)
83
+ seq_stack.append(seq_t)
84
+ plddt_stack.append(plddt[0]) # remove singleton leading dimension
85
+
86
+ # Flip order for better visualization in pymol
87
+ denoised_xyz_stack = torch.stack(denoised_xyz_stack)
88
+ denoised_xyz_stack = torch.flip(
89
+ denoised_xyz_stack,
90
+ [
91
+ 0,
92
+ ],
93
+ )
94
+ px0_xyz_stack = torch.stack(px0_xyz_stack)
95
+ px0_xyz_stack = torch.flip(
96
+ px0_xyz_stack,
97
+ [
98
+ 0,
99
+ ],
100
+ )
101
+
102
+ # For logging -- don't flip
103
+ plddt_stack = torch.stack(plddt_stack)
104
+
105
+ # Save outputs
106
+ os.makedirs(os.path.dirname(out_prefix), exist_ok=True)
107
+ final_seq = seq_stack[-1]
108
+
109
+ # Output glycines, except for motif region
110
+ final_seq = torch.where(
111
+ torch.argmax(seq_init, dim=-1) == 21, 7, torch.argmax(seq_init, dim=-1)
112
+ ) # 7 is glycine
113
+
114
+ bfacts = torch.ones_like(final_seq.squeeze())
115
+ # make bfact=0 for diffused coordinates
116
+ bfacts[torch.where(torch.argmax(seq_init, dim=-1) == 21, True, False)] = 0
117
+ # pX0 last step
118
+ out = f"{out_prefix}.pdb"
119
+
120
+ # Now don't output sidechains
121
+ writepdb(
122
+ out,
123
+ denoised_xyz_stack[0, :, :4],
124
+ final_seq,
125
+ sampler.binderlen,
126
+ chain_idx=sampler.chain_idx,
127
+ bfacts=bfacts,
128
+ )
129
+
130
+ # run metadata
131
+ trb = dict(
132
+ config=OmegaConf.to_container(sampler._conf, resolve=True),
133
+ plddt=plddt_stack.cpu().numpy(),
134
+ device=torch.cuda.get_device_name(torch.cuda.current_device())
135
+ if torch.cuda.is_available()
136
+ else "CPU",
137
+ time=time.time() - start_time,
138
+ )
139
+ if hasattr(sampler, "contig_map"):
140
+ for key, value in sampler.contig_map.get_mappings().items():
141
+ trb[key] = value
142
+ with open(f"{out_prefix}.trb", "wb") as f_out:
143
+ pickle.dump(trb, f_out)
144
+
145
+ if sampler.inf_conf.write_trajectory:
146
+ # trajectory pdbs
147
+ traj_prefix = (
148
+ os.path.dirname(out_prefix) + "/traj/" + os.path.basename(out_prefix)
149
+ )
150
+ os.makedirs(os.path.dirname(traj_prefix), exist_ok=True)
151
+
152
+ out = f"{traj_prefix}_Xt-1_traj.pdb"
153
+ writepdb_multi(
154
+ out,
155
+ denoised_xyz_stack,
156
+ bfacts,
157
+ final_seq.squeeze(),
158
+ use_hydrogens=False,
159
+ backbone_only=False,
160
+ chain_ids=sampler.chain_idx,
161
+ )
162
+
163
+ out = f"{traj_prefix}_pX0_traj.pdb"
164
+ writepdb_multi(
165
+ out,
166
+ px0_xyz_stack,
167
+ bfacts,
168
+ final_seq.squeeze(),
169
+ use_hydrogens=False,
170
+ backbone_only=False,
171
+ chain_ids=sampler.chain_idx,
172
+ )
173
+
174
+ log.info(f"Finished design in {(time.time()-start_time)/60:.2f} minutes")
175
+
176
+
177
+ if __name__ == "__main__":
178
+ main()
179
+ iface = gr.Interface(fn=greet, inputs="text", outputs="text")
180
+ iface.launch()