GlandVergil commited on
Commit
4d6b5c1
·
verified ·
1 Parent(s): 4184384

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -180
app.py DELETED
@@ -1,180 +0,0 @@
1
- import re
2
- import os, time, pickle
3
- import torch
4
- from omegaconf import OmegaConf
5
- import hydra
6
- import logging
7
- from rfdiffusion.util import writepdb_multi, writepdb
8
- from rfdiffusion.inference import utils as iu
9
- from hydra.core.hydra_config import HydraConfig
10
- import numpy as np
11
- import random
12
- import glob
13
- import gradio as gr
14
- def greet(mtf):
15
- return "Hello " + name + "!!"
16
- def make_deterministic(seed=0):
17
- torch.manual_seed(seed)
18
- np.random.seed(seed)
19
- random.seed(seed)
20
-
21
-
22
- @hydra.main(version_base=None, config_path="../config/inference", config_name="base")
23
- def main(conf: HydraConfig) -> None:
24
- log = logging.getLogger(__name__)
25
- if conf.inference.deterministic:
26
- make_deterministic()
27
-
28
- # Check for available GPU and print result of check
29
- if torch.cuda.is_available():
30
- device_name = torch.cuda.get_device_name(torch.cuda.current_device())
31
- log.info(f"Found GPU with device_name {device_name}. Will run RFdiffusion on {device_name}")
32
- else:
33
- log.info("////////////////////////////////////////////////")
34
- log.info("///// NO GPU DETECTED! Falling back to CPU /////")
35
- log.info("////////////////////////////////////////////////")
36
-
37
- # Initialize sampler and target/contig.
38
- sampler = iu.sampler_selector(conf)
39
-
40
- # Loop over number of designs to sample.
41
- design_startnum = sampler.inf_conf.design_startnum
42
- if sampler.inf_conf.design_startnum == -1:
43
- existing = glob.glob(sampler.inf_conf.output_prefix + "*.pdb")
44
- indices = [-1]
45
- for e in existing:
46
- print(e)
47
- m = re.match(".*_(\d+)\.pdb$", e)
48
- print(m)
49
- if not m:
50
- continue
51
- m = m.groups()[0]
52
- indices.append(int(m))
53
- design_startnum = max(indices) + 1
54
-
55
- for i_des in range(design_startnum, design_startnum + sampler.inf_conf.num_designs):
56
- if conf.inference.deterministic:
57
- make_deterministic(i_des)
58
-
59
- start_time = time.time()
60
- out_prefix = f"{sampler.inf_conf.output_prefix}_{i_des}"
61
- log.info(f"Making design {out_prefix}")
62
- if sampler.inf_conf.cautious and os.path.exists(out_prefix + ".pdb"):
63
- log.info(
64
- f"(cautious mode) Skipping this design because {out_prefix}.pdb already exists."
65
- )
66
- continue
67
-
68
- x_init, seq_init = sampler.sample_init()
69
- denoised_xyz_stack = []
70
- px0_xyz_stack = []
71
- seq_stack = []
72
- plddt_stack = []
73
-
74
- x_t = torch.clone(x_init)
75
- seq_t = torch.clone(seq_init)
76
- # Loop over number of reverse diffusion time steps.
77
- for t in range(int(sampler.t_step_input), sampler.inf_conf.final_step - 1, -1):
78
- px0, x_t, seq_t, plddt = sampler.sample_step(
79
- t=t, x_t=x_t, seq_init=seq_t, final_step=sampler.inf_conf.final_step
80
- )
81
- px0_xyz_stack.append(px0)
82
- denoised_xyz_stack.append(x_t)
83
- seq_stack.append(seq_t)
84
- plddt_stack.append(plddt[0]) # remove singleton leading dimension
85
-
86
- # Flip order for better visualization in pymol
87
- denoised_xyz_stack = torch.stack(denoised_xyz_stack)
88
- denoised_xyz_stack = torch.flip(
89
- denoised_xyz_stack,
90
- [
91
- 0,
92
- ],
93
- )
94
- px0_xyz_stack = torch.stack(px0_xyz_stack)
95
- px0_xyz_stack = torch.flip(
96
- px0_xyz_stack,
97
- [
98
- 0,
99
- ],
100
- )
101
-
102
- # For logging -- don't flip
103
- plddt_stack = torch.stack(plddt_stack)
104
-
105
- # Save outputs
106
- os.makedirs(os.path.dirname(out_prefix), exist_ok=True)
107
- final_seq = seq_stack[-1]
108
-
109
- # Output glycines, except for motif region
110
- final_seq = torch.where(
111
- torch.argmax(seq_init, dim=-1) == 21, 7, torch.argmax(seq_init, dim=-1)
112
- ) # 7 is glycine
113
-
114
- bfacts = torch.ones_like(final_seq.squeeze())
115
- # make bfact=0 for diffused coordinates
116
- bfacts[torch.where(torch.argmax(seq_init, dim=-1) == 21, True, False)] = 0
117
- # pX0 last step
118
- out = f"{out_prefix}.pdb"
119
-
120
- # Now don't output sidechains
121
- writepdb(
122
- out,
123
- denoised_xyz_stack[0, :, :4],
124
- final_seq,
125
- sampler.binderlen,
126
- chain_idx=sampler.chain_idx,
127
- bfacts=bfacts,
128
- )
129
-
130
- # run metadata
131
- trb = dict(
132
- config=OmegaConf.to_container(sampler._conf, resolve=True),
133
- plddt=plddt_stack.cpu().numpy(),
134
- device=torch.cuda.get_device_name(torch.cuda.current_device())
135
- if torch.cuda.is_available()
136
- else "CPU",
137
- time=time.time() - start_time,
138
- )
139
- if hasattr(sampler, "contig_map"):
140
- for key, value in sampler.contig_map.get_mappings().items():
141
- trb[key] = value
142
- with open(f"{out_prefix}.trb", "wb") as f_out:
143
- pickle.dump(trb, f_out)
144
-
145
- if sampler.inf_conf.write_trajectory:
146
- # trajectory pdbs
147
- traj_prefix = (
148
- os.path.dirname(out_prefix) + "/traj/" + os.path.basename(out_prefix)
149
- )
150
- os.makedirs(os.path.dirname(traj_prefix), exist_ok=True)
151
-
152
- out = f"{traj_prefix}_Xt-1_traj.pdb"
153
- writepdb_multi(
154
- out,
155
- denoised_xyz_stack,
156
- bfacts,
157
- final_seq.squeeze(),
158
- use_hydrogens=False,
159
- backbone_only=False,
160
- chain_ids=sampler.chain_idx,
161
- )
162
-
163
- out = f"{traj_prefix}_pX0_traj.pdb"
164
- writepdb_multi(
165
- out,
166
- px0_xyz_stack,
167
- bfacts,
168
- final_seq.squeeze(),
169
- use_hydrogens=False,
170
- backbone_only=False,
171
- chain_ids=sampler.chain_idx,
172
- )
173
-
174
- log.info(f"Finished design in {(time.time()-start_time)/60:.2f} minutes")
175
-
176
-
177
- if __name__ == "__main__":
178
- main()
179
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
180
- iface.launch()