Spaces:
Runtime error
Runtime error
GlandVergil
commited on
Commit
•
b14983e
1
Parent(s):
0c65f51
Upload 693 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +3 -0
- RFdiffusion/.github/CODEOWNERS +2 -0
- RFdiffusion/.gitignore +16 -0
- RFdiffusion/.ipynb_checkpoints/untitled-checkpoint.py +0 -0
- RFdiffusion/.rosetta-ci/.gitignore +3 -0
- RFdiffusion/.rosetta-ci/benchmark.py +410 -0
- RFdiffusion/.rosetta-ci/benchmark.template.ini +40 -0
- RFdiffusion/.rosetta-ci/hpc_drivers/__init__.py +5 -0
- RFdiffusion/.rosetta-ci/hpc_drivers/base.py +210 -0
- RFdiffusion/.rosetta-ci/hpc_drivers/multicore.py +184 -0
- RFdiffusion/.rosetta-ci/hpc_drivers/slurm.py +176 -0
- RFdiffusion/.rosetta-ci/test-sets.yaml +65 -0
- RFdiffusion/.rosetta-ci/tests/__init__.py +765 -0
- RFdiffusion/.rosetta-ci/tests/rfd.py +111 -0
- RFdiffusion/.rosetta-ci/tests/self.md +6 -0
- RFdiffusion/.rosetta-ci/tests/self.py +209 -0
- RFdiffusion/END +7 -0
- RFdiffusion/LICENSE +30 -0
- RFdiffusion/appverifUI.dll +0 -0
- RFdiffusion/config/inference/base.yaml +136 -0
- RFdiffusion/config/inference/symmetry.yaml +26 -0
- RFdiffusion/docker/Dockerfile +50 -0
- RFdiffusion/env/SE3Transformer/.dockerignore +123 -0
- RFdiffusion/env/SE3Transformer/.gitignore +121 -0
- RFdiffusion/env/SE3Transformer/Dockerfile +58 -0
- RFdiffusion/env/SE3Transformer/LICENSE +7 -0
- RFdiffusion/env/SE3Transformer/NOTICE +7 -0
- RFdiffusion/env/SE3Transformer/README.md +580 -0
- RFdiffusion/env/SE3Transformer/build/lib/se3_transformer/__init__.py +0 -0
- RFdiffusion/env/SE3Transformer/build/lib/se3_transformer/data_loading/__init__.py +1 -0
- RFdiffusion/env/SE3Transformer/build/lib/se3_transformer/data_loading/data_module.py +63 -0
- RFdiffusion/env/SE3Transformer/build/lib/se3_transformer/data_loading/qm9.py +173 -0
- RFdiffusion/env/SE3Transformer/build/lib/se3_transformer/model/__init__.py +2 -0
- RFdiffusion/env/SE3Transformer/build/lib/se3_transformer/model/basis.py +178 -0
- RFdiffusion/env/SE3Transformer/build/lib/se3_transformer/model/fiber.py +144 -0
- RFdiffusion/env/SE3Transformer/build/lib/se3_transformer/model/layers/__init__.py +5 -0
- RFdiffusion/env/SE3Transformer/build/lib/se3_transformer/model/layers/attention.py +180 -0
- RFdiffusion/env/SE3Transformer/build/lib/se3_transformer/model/layers/convolution.py +336 -0
- RFdiffusion/env/SE3Transformer/build/lib/se3_transformer/model/layers/linear.py +59 -0
- RFdiffusion/env/SE3Transformer/build/lib/se3_transformer/model/layers/norm.py +83 -0
- RFdiffusion/env/SE3Transformer/build/lib/se3_transformer/model/layers/pooling.py +53 -0
- RFdiffusion/env/SE3Transformer/build/lib/se3_transformer/model/transformer.py +222 -0
- RFdiffusion/env/SE3Transformer/build/lib/se3_transformer/runtime/__init__.py +0 -0
- RFdiffusion/env/SE3Transformer/build/lib/se3_transformer/runtime/arguments.py +70 -0
- RFdiffusion/env/SE3Transformer/build/lib/se3_transformer/runtime/callbacks.py +160 -0
- RFdiffusion/env/SE3Transformer/build/lib/se3_transformer/runtime/gpu_affinity.py +325 -0
- RFdiffusion/env/SE3Transformer/build/lib/se3_transformer/runtime/inference.py +131 -0
- RFdiffusion/env/SE3Transformer/build/lib/se3_transformer/runtime/loggers.py +134 -0
- RFdiffusion/env/SE3Transformer/build/lib/se3_transformer/runtime/metrics.py +83 -0
- RFdiffusion/env/SE3Transformer/build/lib/se3_transformer/runtime/training.py +238 -0
.gitattributes
CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
RFdiffusion/env/SE3Transformer/images/se3-transformer.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
RFdiffusion/img/diffusion_protein_gradient_2.jpg filter=lfs diff=lfs merge=lfs -text
|
38 |
+
RFdiffusion/pyrosetta-2023.14+release.7132bdc754a-cp310-cp310-linux_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
RFdiffusion/.github/CODEOWNERS
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# Benchmark scripts
|
2 |
+
/.rosetta-ci @lyskov
|
RFdiffusion/.gitignore
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.py[cod]
|
2 |
+
rfdiffusion.egg-info
|
3 |
+
|
4 |
+
models/
|
5 |
+
schedules/
|
6 |
+
|
7 |
+
examples/ppi_scaffolds
|
8 |
+
|
9 |
+
tests/.results.json
|
10 |
+
tests/input_pdbs
|
11 |
+
tests/outputs
|
12 |
+
tests/ppi_scaffolds
|
13 |
+
tests/reference_outputs/
|
14 |
+
tests/target_folds
|
15 |
+
tests/tim_barrel_scaffold
|
16 |
+
tests/tests_*
|
RFdiffusion/.ipynb_checkpoints/untitled-checkpoint.py
ADDED
File without changes
|
RFdiffusion/.rosetta-ci/.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
*.pyc
|
2 |
+
results/
|
3 |
+
benchmark.ubuntu.ini
|
RFdiffusion/.rosetta-ci/benchmark.py
ADDED
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# :noTabs=true:
|
4 |
+
|
5 |
+
# (c) Copyright Rosetta Commons Member Institutions.
|
6 |
+
# (c) This file is part of the Rosetta software suite and is made available under license.
|
7 |
+
# (c) The Rosetta software is developed by the contributing members of the Rosetta Commons.
|
8 |
+
# (c) For more information, see http://www.rosettacommons.org. Questions about this can be
|
9 |
+
# (c) addressed to University of Washington CoMotion, email: license@uw.edu.
|
10 |
+
|
11 |
+
## @file benchmark.py
|
12 |
+
## @brief Run arbitrary Rosetta testing script
|
13 |
+
## @author Sergey Lyskov
|
14 |
+
|
15 |
+
from __future__ import print_function
|
16 |
+
|
17 |
+
import os, os.path, sys, shutil, json, platform, re
|
18 |
+
import codecs
|
19 |
+
|
20 |
+
from importlib.machinery import SourceFileLoader
|
21 |
+
|
22 |
+
from configparser import ConfigParser, ExtendedInterpolation
|
23 |
+
import argparse
|
24 |
+
|
25 |
+
from tests import * # execute, Tests states and key names
|
26 |
+
from hpc_drivers import *
|
27 |
+
|
28 |
+
|
29 |
+
# Calculating value of Platform dict
|
30 |
+
Platform = {}
|
31 |
+
if sys.platform.startswith("linux"):
|
32 |
+
Platform['os'] = 'ubuntu' if os.path.isfile('/etc/lsb-release') and 'Ubuntu' in open('/etc/lsb-release').read() else 'linux' # can be linux1, linux2, etc
|
33 |
+
elif sys.platform == "darwin" : Platform['os'] = 'mac'
|
34 |
+
elif sys.platform == "cygwin" : Platform['os'] = 'cygwin'
|
35 |
+
elif sys.platform == "win32" : Platform['os'] = 'windows'
|
36 |
+
else: Platform['os'] = 'unknown'
|
37 |
+
|
38 |
+
#Platform['arch'] = platform.architecture()[0][:2] # PlatformBits
|
39 |
+
Platform['compiler'] = 'gcc' if Platform['os'] == 'linux' else 'clang'
|
40 |
+
|
41 |
+
Platform['python'] = sys.executable
|
42 |
+
|
43 |
+
|
44 |
+
def load_python_source_from_file(module_name, module_path):
|
45 |
+
''' replacment for deprecated imp.load_source
|
46 |
+
'''
|
47 |
+
return SourceFileLoader(module_name, module_path).load_module()
|
48 |
+
|
49 |
+
|
50 |
+
class Setup(object):
|
51 |
+
__slots__ = 'test working_dir platform config compare debug'.split() # version daemon path_to_previous_test
|
52 |
+
def __init__(self, **attrs):
|
53 |
+
#self.daemon = True
|
54 |
+
for k, v in attrs.items():
|
55 |
+
if k in self.__slots__: setattr(self, k, v)
|
56 |
+
|
57 |
+
|
58 |
+
def setup_from_options(options):
|
59 |
+
''' Create Setup object based on user supplied options, config files and auto-detection
|
60 |
+
'''
|
61 |
+
platform = dict(Platform)
|
62 |
+
|
63 |
+
if options.suffix: options.suffix = '.' + options.suffix
|
64 |
+
|
65 |
+
platform['extras'] = options.extras.split(',') if options.extras else []
|
66 |
+
platform['python'] = options.python
|
67 |
+
#platform['options'] = json.loads( options.options ) if options.options else {}
|
68 |
+
|
69 |
+
if options.memory: memory = options.memory
|
70 |
+
elif platform['os'] in ['linux', 'ubuntu']: memory = int( execute('Getting memory info...', 'free -m', terminate_on_failure=False, silent=True, silence_output_on_errors=True, return_='output').split('\n')[1].split()[1]) // 1024
|
71 |
+
elif platform['os'] == 'mac': memory = int( execute('Getting memory info...', 'sysctl -a | grep hw.memsize', terminate_on_failure=False, silent=True, silence_output_on_errors=True, return_='output').split()[1]) // 1024 // 1024 // 1024
|
72 |
+
|
73 |
+
platform['compiler'] = options.compiler
|
74 |
+
|
75 |
+
if os.path.isfile(options.config):
|
76 |
+
with open(options.config) as f:
|
77 |
+
if '%(here)s' in f.read():
|
78 |
+
print(f"\n\n>>> ERROR file `{options.config}` seems to be in outdated format! Please use benchmark.template.ini to update it.")
|
79 |
+
sys.exit(1)
|
80 |
+
|
81 |
+
user_config = ConfigParser(
|
82 |
+
dict(
|
83 |
+
_here_ = os.path.abspath('./'),
|
84 |
+
_user_home_ = os.environ['HOME']
|
85 |
+
),
|
86 |
+
interpolation = ExtendedInterpolation()
|
87 |
+
)
|
88 |
+
|
89 |
+
with open(options.config) as f: user_config.readfp(f)
|
90 |
+
|
91 |
+
else:
|
92 |
+
print(f"\n\n>>> Config file `{options.config}` not found. You may want to manually copy `benchmark.ini.template` to `{options.config}` and edit the settings\n\n")
|
93 |
+
user_config = ConfigParser()
|
94 |
+
user_config.set('main', 'cpu_count', '1')
|
95 |
+
user_config.set('main', 'hpc_driver', 'MultiCore')
|
96 |
+
user_config.set('main', 'branch', 'unknown')
|
97 |
+
user_config.set('main', 'revision', '42')
|
98 |
+
user_config.set('main', 'user_name', 'Jane Roe')
|
99 |
+
user_config.set('main', 'user_email', 'jane.roe@university.edu')
|
100 |
+
user_config.add_section('main')
|
101 |
+
|
102 |
+
if options.jobs: user_config.set('main', 'cpu_count', str(options.jobs) )
|
103 |
+
user_config.set('main', 'memory', str(memory) )
|
104 |
+
|
105 |
+
if options.mount:
|
106 |
+
for m in options.mount:
|
107 |
+
key, _, path = m.partition(':')
|
108 |
+
user_config.set('mount', key, path)
|
109 |
+
|
110 |
+
#config = Config.items('config')
|
111 |
+
#for section in config.sections(): print('Config section: ', section, dict(config.items(section)))
|
112 |
+
#config = { section: dict(Config.items(section)) for section in Config.sections() }
|
113 |
+
|
114 |
+
config = { k : d for k, d in user_config['main'].items() if k not in user_config[user_config.default_section] }
|
115 |
+
config['mounts'] = { k : d for k, d in user_config['mount'].items() if k not in user_config[user_config.default_section] }
|
116 |
+
|
117 |
+
#print(json.dumps(config, sort_keys=True, indent=2)); sys.exit(1)
|
118 |
+
|
119 |
+
#config.update( config.pop('config').items() )
|
120 |
+
|
121 |
+
config = dict(config,
|
122 |
+
cpu_count = user_config.getint('main', 'cpu_count'),
|
123 |
+
memory = memory,
|
124 |
+
revision = user_config.getint('main', 'revision'),
|
125 |
+
emulation=True,
|
126 |
+
) # debug=options.debug,
|
127 |
+
|
128 |
+
if 'results_root' not in config: config['results_root'] = os.path.abspath('./results/')
|
129 |
+
|
130 |
+
if 'prefix' in config:
|
131 |
+
assert os.path.isabs( config['prefix'] ), f'ERROR: `prefix` path must be absolute! Got: {config["prefix"]}'
|
132 |
+
|
133 |
+
else: config['prefix'] = os.path.abspath( config['results_root'] + '/prefix')
|
134 |
+
|
135 |
+
config['merge_head'] = options.merge_head
|
136 |
+
config['merge_base'] = options.merge_base
|
137 |
+
|
138 |
+
if options.skip_compile is not None: config['skip_compile'] = options.skip_compile
|
139 |
+
|
140 |
+
#print(f'Results path: {config["results_root"]}')
|
141 |
+
#print('Config:{}, Platform:{}'.format(json.dumps(config, sort_keys=True, indent=2), Platform))
|
142 |
+
|
143 |
+
if options.compare: print('Comparing tests {} with suffixes: {}'.format(options.args, options.compare) )
|
144 |
+
else: print('Running tests: {}'.format(options.args) )
|
145 |
+
|
146 |
+
if len(options.args) != 1: print('Error: Single test-name-to-run should be supplied!'); sys.exit(1)
|
147 |
+
else:
|
148 |
+
test = options.args[0]
|
149 |
+
if test.startswith('tests/'): test = test.partition('tests/')[2][:-3] # removing dir prefix and .py suffix
|
150 |
+
|
151 |
+
if options.compare:
|
152 |
+
compare = options.compare[0], options.compare[1] # (this test suffix, previous test suffix)
|
153 |
+
working_dir = os.path.abspath( config['results_root'] + f'/{platform["os"]}.{test}' ) # will be a root dir with sub-dirs (options.compare[0], options.compare[1])
|
154 |
+
else:
|
155 |
+
compare = None
|
156 |
+
working_dir = os.path.abspath( config['results_root'] + f'/{platform["os"]}.{test}{options.suffix}' )
|
157 |
+
|
158 |
+
|
159 |
+
if os.path.isdir(working_dir): shutil.rmtree(working_dir); #print('Removing old job dir %s...' % working_dir) # remove old dir if any
|
160 |
+
os.makedirs(working_dir)
|
161 |
+
|
162 |
+
setup = Setup(
|
163 |
+
test = test,
|
164 |
+
working_dir = working_dir,
|
165 |
+
platform = platform,
|
166 |
+
config = config,
|
167 |
+
compare = compare,
|
168 |
+
debug = options.debug,
|
169 |
+
#daemon = False,
|
170 |
+
)
|
171 |
+
|
172 |
+
setup_as_json = json.dumps( { k : getattr(setup, k) for k in setup.__slots__}, sort_keys=True, indent=2)
|
173 |
+
with open(working_dir + '/.setup.json', 'w') as f: f.write(setup_as_json)
|
174 |
+
|
175 |
+
#print(f'Detected hardware platform: {Platform}')
|
176 |
+
print(f'Setup: {setup_as_json}')
|
177 |
+
return setup
|
178 |
+
|
179 |
+
|
180 |
+
def truncate_log(log):
|
181 |
+
_max_log_size_ = 1024*1024*1
|
182 |
+
_max_line_size_ = _max_log_size_ // 2
|
183 |
+
|
184 |
+
if len(log) > _max_log_size_:
|
185 |
+
new = log
|
186 |
+
lines = log.split('\n')
|
187 |
+
|
188 |
+
if len(lines) > 256:
|
189 |
+
new_lines = lines[:32] + ['...truncated...'] + lines[-128:]
|
190 |
+
new = '\n'.join(new_lines)
|
191 |
+
|
192 |
+
if len(new) > _max_log_size_: # special case for Ninja logs that does not use \n
|
193 |
+
lines = re.split(r'[\r\n]*', log) #t.log.split('\r')
|
194 |
+
if len(lines) > 256: new = '\n'.join( lines[:32] + ['...truncated...'] + lines[-128:] )
|
195 |
+
|
196 |
+
if len(new) > _max_log_size_: # going to try to truncate each individual line...
|
197 |
+
print(f'Trying to truncate log line-by-line...')
|
198 |
+
new = '\n'.join( (
|
199 |
+
( line[:_max_line_size_//3] + '...truncated...' + line[-_max_line_size_//3:] ) if line > _max_line_size_ else line
|
200 |
+
for line in new_lines ) )
|
201 |
+
|
202 |
+
if len(new) > _max_log_size_: # fall-back strategy in case all of the above failed...
|
203 |
+
print(f'WARNING: could not truncate log line-by-line, falling back to raw truncate...')
|
204 |
+
new = 'WARNING: could not truncate test log line-by-line, falling back to raw truncate!\n...truncated...\n' + ( '\n'.join(lines) )[-_max_log_size_+256:]
|
205 |
+
|
206 |
+
print( 'Trunacting test output log: {0}MiB --> {1}MiB'.format(len(log)/1024/1024, len(new)/1024/1024) )
|
207 |
+
|
208 |
+
log = new
|
209 |
+
|
210 |
+
return log
|
211 |
+
|
212 |
+
def truncate_results_logs(results):
|
213 |
+
results[_LogKey_] = truncate_log( results[_LogKey_] )
|
214 |
+
if _ResultsKey_ in results and _TestsKey_ in results[_ResultsKey_]:
|
215 |
+
tests = results[_ResultsKey_][_TestsKey_]
|
216 |
+
for test in tests:
|
217 |
+
tests[test][_LogKey_] = truncate_log( tests[test][_LogKey_] )
|
218 |
+
|
219 |
+
|
220 |
+
def find_test_description(test_name, test_script_file_name):
|
221 |
+
''' return content of test-description file if any or None if no description was found
|
222 |
+
'''
|
223 |
+
|
224 |
+
def find_description_file(prefix, test_name):
|
225 |
+
fname = prefix + test_name + '.md'
|
226 |
+
if os.path.isfile(fname): return fname
|
227 |
+
return prefix + 'md'
|
228 |
+
|
229 |
+
description_file_name = find_description_file( test_script_file_name[:-len('command.py')] + 'description.', test_name) if test_script_file_name.endswith('/command.py') else find_description_file(test_script_file_name[:-len('py')], test_name)
|
230 |
+
|
231 |
+
if description_file_name and os.path.isfile(description_file_name):
|
232 |
+
print(f'Found test suite description in file: {description_file_name!r}')
|
233 |
+
with open(description_file_name, encoding='utf-8', errors='backslashreplace') as f: description = f.read()
|
234 |
+
return description
|
235 |
+
|
236 |
+
else: return None
|
237 |
+
|
238 |
+
|
239 |
+
|
240 |
+
def run_test(setup):
|
241 |
+
#print(f'{setup!r}')
|
242 |
+
suite, rest = setup.test.split('.'), []
|
243 |
+
while suite:
|
244 |
+
#print( f'suite: {suite}, test: {rest}' )
|
245 |
+
|
246 |
+
file_name = '/'.join( ['tests'] + suite ) + '.py'
|
247 |
+
if os.path.isfile(file_name): break
|
248 |
+
|
249 |
+
file_name = '/'.join( ['tests'] + suite ) + '/command.py'
|
250 |
+
if os.path.isfile(file_name): break
|
251 |
+
|
252 |
+
rest.insert(0, suite.pop())
|
253 |
+
|
254 |
+
|
255 |
+
test = '.'.join( suite + rest )
|
256 |
+
test_name = '.'.join(rest)
|
257 |
+
|
258 |
+
print( f'Loading test from: {file_name}, suite+test: {test!r}, test: {test_name!r}' )
|
259 |
+
#test_suite = imp.load_source('test_suite', file_name)
|
260 |
+
test_suite = load_python_source_from_file('test_suite', file_name)
|
261 |
+
|
262 |
+
test_description = find_test_description(test_name, file_name)
|
263 |
+
|
264 |
+
if setup.compare:
|
265 |
+
#working_dir_1 = os.path.abspath( config['results_root'] + f'/{Platform["os"]}.{test}.{Options.compare[0]}' )
|
266 |
+
working_dir_1 = setup.working_dir + f'/{setup.compare[0]}'
|
267 |
+
|
268 |
+
working_dir_2 = setup.compare[1] and ( setup.working_dir + f'/{setup.compare[1]}' )
|
269 |
+
res_2_json_file_path = setup.compare[1] and f'{working_dir_2}/.execution.results.json'
|
270 |
+
|
271 |
+
with open(working_dir_1 + '/.execution.results.json') as f: res_1 = json.load(f).get(_ResultsKey_)
|
272 |
+
|
273 |
+
if setup.compare[1] and ( not os.path.isfile(res_2_json_file_path) ):
|
274 |
+
setup.compare[1] = None
|
275 |
+
state_override = _S_failed_
|
276 |
+
else:
|
277 |
+
state_override = None
|
278 |
+
|
279 |
+
if setup.compare[1] == None: res_2, working_dir_2 = None, None
|
280 |
+
else:
|
281 |
+
with open(res_2_json_file_path) as f: res_2 = json.load(f).get(_ResultsKey_)
|
282 |
+
|
283 |
+
res = test_suite.compare(test, res_1, working_dir_1, res_2, working_dir_2)
|
284 |
+
|
285 |
+
if state_override:
|
286 |
+
log_prefix = \
|
287 |
+
f'WARNING: Previous test results does not have `.execution.results.json` file, so comparision with None was performed instead!\n' \
|
288 |
+
f'WARNING: Overriding calcualted test state `{res[_StateKey_]}` → `{_S_failed_}`...\n\n'
|
289 |
+
|
290 |
+
res[_LogKey_] = log_prefix + res[_LogKey_]
|
291 |
+
res[_StateKey_] = _S_failed_
|
292 |
+
|
293 |
+
|
294 |
+
# # Caution! Some of the strings in the result object may be unicode. Be robust to unicode in the log messages.
|
295 |
+
# with codecs.open(setup.working_dir+'/.comparison.log.txt', 'w', encoding='utf-8', errors='replace') as f: f.write( truncate_log( res[_LogKey_] ) )
|
296 |
+
# res[_LogKey_] = truncate_log( res[_LogKey_] )
|
297 |
+
|
298 |
+
# # Caution! Some of the strings in the result object may be unicode. Be robust to unicode in the log messages.
|
299 |
+
with codecs.open(setup.working_dir+'/.comparison.log.txt', 'w', encoding='utf-8', errors='replace') as f: f.write(res[_LogKey_])
|
300 |
+
truncate_results_logs(res)
|
301 |
+
|
302 |
+
print( 'Comparison finished with output:\n{}'.format( res[_LogKey_] ) )
|
303 |
+
|
304 |
+
with open(setup.working_dir+'/.comparison.results.json', 'w') as f: json.dump(res, f, sort_keys=True, indent=2)
|
305 |
+
|
306 |
+
#print( 'Comparison finished with results:\n{}'.format( json.dumps(res, sort_keys=True, indent=2) ) )
|
307 |
+
if 'summary' in res: print('Summary section:\n{}'.format( json.dumps(res['summary'], sort_keys=True, indent=2) ) )
|
308 |
+
|
309 |
+
print( f'Output results of this comparison saved to {working_dir_1}/.comparison.results.json\nComparison log saved into {working_dir_1}/.comparison.log.txt' )
|
310 |
+
|
311 |
+
|
312 |
+
else:
|
313 |
+
working_dir = setup.working_dir #os.path.abspath( setup.config['results_root'] + f'/{platform["os"]}.{test}{options.suffix}' )
|
314 |
+
|
315 |
+
hpc_driver_name = setup.config['hpc_driver']
|
316 |
+
hpc_driver = None if hpc_driver_name in ['', 'none'] else eval(hpc_driver_name + '_HPC_Driver')(working_dir, setup.config, tracer=print, set_daemon_message=lambda x:None)
|
317 |
+
|
318 |
+
api_version = test_suite._api_version_ if hasattr(test_suite, '_api_version_') else ''
|
319 |
+
|
320 |
+
# if api_version < '1.0':
|
321 |
+
# res = test_suite.run(test=test_name, rosetta_dir=os.path.abspath('../..'), working_dir=working_dir, platform=dict(Platform), jobs=Config.cpu_count, verbose=True, debug=Options.debug)
|
322 |
+
# else:
|
323 |
+
|
324 |
+
if api_version == '1.0': res = test_suite.run(test=test_name, repository_root=os.path.abspath('./..'), working_dir=working_dir, platform=dict(setup.platform), config=setup.config, hpc_driver=hpc_driver, verbose=True, debug=setup.debug)
|
325 |
+
else:
|
326 |
+
print(f'Test benchmark api_version={api_version} is not supported!'); sys.exit(1)
|
327 |
+
|
328 |
+
if not isinstance(res, dict): print(f'Test returned result of type {type(res)} while dict-like object was expected, please check that test-script have correct `return` statment! Terminating...'); sys.exit(1)
|
329 |
+
|
330 |
+
# Caution! Some of the strings in the result object may be unicode. Be robust to unicode in the log messages
|
331 |
+
with codecs.open(working_dir+'/.execution.log.txt', 'w', encoding='utf-8', errors='replace') as f: f.write( res[_LogKey_] )
|
332 |
+
|
333 |
+
# res[_LogKey_] = truncate_log( res[_LogKey_] )
|
334 |
+
truncate_results_logs(res)
|
335 |
+
|
336 |
+
if _DescriptionKey_ not in res: res[_DescriptionKey_] = test_description
|
337 |
+
|
338 |
+
if res[_StateKey_] not in _S_Values_: print( 'Warning!!! Test {} failed with unknow result code: {}'.format(test_name, res[_StateKey_]) )
|
339 |
+
else: print( f'Test {test} finished with output:\n{res[_LogKey_]}\n----------------------------------------------------------------\nState: {res[_StateKey_]!r} | ', end='')
|
340 |
+
|
341 |
+
# JSON by default serializes to an ascii-encoded format
|
342 |
+
with open(working_dir+'/.execution.results.json', 'w') as f: json.dump(res, f, sort_keys=True, indent=2)
|
343 |
+
|
344 |
+
print( f'Output and full log of this test saved to:\n{working_dir}/.execution.results.json\n{working_dir}/.execution.log.txt' )
|
345 |
+
|
346 |
+
|
347 |
+
|
348 |
+
|
349 |
+
|
350 |
+
|
351 |
+
def main(args):
|
352 |
+
''' Script to Run arbitrary Rosetta test
|
353 |
+
'''
|
354 |
+
parser = argparse.ArgumentParser(usage="Main testing script to run tests in the tests directory. "
|
355 |
+
"Use the --skip-compile to skip the build phase when testing locally. "
|
356 |
+
"Example Command: /benchmark.py -j2 integration.valgrind")
|
357 |
+
|
358 |
+
parser.add_argument('-j', '--jobs', default=0, type=int, help="Number of processors to use on when building. (default: use value from config file or 1)")
|
359 |
+
|
360 |
+
parser.add_argument('-m', '--memory', default=0, type=int, help="Amount of memory to use (default: use 2Gb per job")
|
361 |
+
|
362 |
+
parser.add_argument('--compiler', default=Platform['compiler'], help="Compiler to use")
|
363 |
+
|
364 |
+
#parser.add_argument('--python', default=('3.9' if Platform['os'] == 'mac' else '3.6'), help="Python interpreter to use")
|
365 |
+
parser.add_argument('--python', default=f'{sys.version_info.major}.{sys.version_info.minor}.s', help="Specify version of Python interpreter to use, for example '3.9'. If '.s' added to end of version string then use the same interpreter that was used to start this script. Default: '?.?.s'")
|
366 |
+
|
367 |
+
parser.add_argument("--extras", default='', help="Specify scons extras separated by ',': like --extras=mpi,static" )
|
368 |
+
|
369 |
+
parser.add_argument("--debug", action="store_true", dest="debug", default=False, help="Run specified test in debug mode (not with debug build!) this mean different things and depend on the test. Could be: skip the build phase, skip some of the test phases and so on. [off by default]" )
|
370 |
+
|
371 |
+
parser.add_argument("--suffix", default='', help="Specify ending suffix for test output dir. This is useful when you want to save test results in different dir for later comparison." )
|
372 |
+
|
373 |
+
parser.add_argument("--compare", nargs=2, help="Do not run the tests but instead compare previous results. Use --compare suffix1 suffix2" )
|
374 |
+
|
375 |
+
parser.add_argument("--config", default='benchmark.{os}.ini'.format(os=Platform['os']), action="store", help="Location of .ini file with additional options configuration. Optional.")
|
376 |
+
|
377 |
+
parser.add_argument("--skip-compile", dest='skip_compile', default=None, action="store_true", help="Skip the compilation phase. Assumes the binaries are already compiled locally.")
|
378 |
+
|
379 |
+
#parser.add_argument("--results-root", default=None, action="store", help="Location of `results` dir default is to use `./results`")
|
380 |
+
|
381 |
+
parser.add_argument("--setup", default=None, help="Specify JSON file with setup information. When this option supplied all other config and commandline options is ignored and auto-detection disable. Test, platform info will be gathered from provided JSON file. This option is designed to be used in daemon mode." )
|
382 |
+
|
383 |
+
parser.add_argument("--merge-head", default='HEAD', help="Specify SHA1/branch-name that will be used for `merge-head` value when simulating PR testing" )
|
384 |
+
|
385 |
+
parser.add_argument("--merge-base", default='origin/master', help="Specify SHA1/branch-name that will be used for `merge-base` value when simulating PR testing" )
|
386 |
+
|
387 |
+
parser.add_argument("--mount", action="append", help="Specify one of the mount points, like: --mount release_root:/some/path. This option could be used multiple times if needed" )
|
388 |
+
|
389 |
+
|
390 |
+
parser.add_argument('args', nargs=argparse.REMAINDER)
|
391 |
+
|
392 |
+
options = parser.parse_args(args=args[1:])
|
393 |
+
|
394 |
+
if any( [a.startswith('-') for a in options.args] ) :
|
395 |
+
print( '\nWARNING WARNING WARNING WARNING\n' )
|
396 |
+
print( '\tInterpreting', ' '.join(["'"+a+"'" for a in options.args if a.startswith('-')]), 'as test name(s), rather than as option(s).' )
|
397 |
+
print( "\tTry moving it before any test name, if that's not what you want." )
|
398 |
+
print( '\nWARNING WARNING WARNING WARNING\n' )
|
399 |
+
|
400 |
+
|
401 |
+
if options.setup:
|
402 |
+
with open(options.setup) as f: setup = Setup( **json.load(f) )
|
403 |
+
|
404 |
+
else:
|
405 |
+
setup = setup_from_options(options)
|
406 |
+
|
407 |
+
run_test(setup)
|
408 |
+
|
409 |
+
|
410 |
+
if __name__ == "__main__": main(sys.argv)
|
RFdiffusion/.rosetta-ci/benchmark.template.ini
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Benchmark script configuration file. Some of the tests require some system specific options to run. Please see benchmark.ini.template for list of available options.
|
3 |
+
#
|
4 |
+
|
5 |
+
[DEFAULT]
|
6 |
+
|
7 |
+
[main] # additional config-options for various tests. All this fields will be pass as keys in 'config' function argument
|
8 |
+
|
9 |
+
# how many jobs daemon can run on host machine (this is not related to HPC jobs)
|
10 |
+
cpu_count = 24
|
11 |
+
|
12 |
+
# how many memory in GB daemon can use on host machine (approximation, float)
|
13 |
+
memory = 64
|
14 |
+
|
15 |
+
# user name and email for user who submitted this test
|
16 |
+
user_name = Jane Roe
|
17 |
+
user_email = jane.roe@university.edu
|
18 |
+
|
19 |
+
# HPC Driver, might have one of the following values: MultiCore, Condor, Slurm or none if no HPC Driver should be configured
|
20 |
+
hpc_driver = MultiCore
|
21 |
+
|
22 |
+
# when running by daemons branch:revision will be set to appropriate values to represent currently checked version of main repository
|
23 |
+
branch = unknown
|
24 |
+
revision = 42
|
25 |
+
|
26 |
+
# path to directory where test results will be stored
|
27 |
+
results_root = ${_here_}/results
|
28 |
+
|
29 |
+
release_root = ./results/_release_
|
30 |
+
|
31 |
+
[slurm]
|
32 |
+
# head-node host name, if specified will be used to submit jobs
|
33 |
+
head_node =
|
34 |
+
|
35 |
+
|
36 |
+
[mount]
|
37 |
+
# list of key:path pairs that will be avalible as config.mounts during test run
|
38 |
+
|
39 |
+
# path to releases, leave empty if release production should not be supported by this daemon
|
40 |
+
release_root = ${_here_}/release
|
RFdiffusion/.rosetta-ci/hpc_drivers/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# :noTabs=true:
|
3 |
+
|
4 |
+
from .multicore import MultiCore_HPC_Driver
|
5 |
+
from .slurm import Slurm_HPC_Driver
|
RFdiffusion/.rosetta-ci/hpc_drivers/base.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# :noTabs=true:
|
3 |
+
|
4 |
+
import os, sys, subprocess, stat
|
5 |
+
import time as time_module
|
6 |
+
import signal as signal_module
|
7 |
+
|
8 |
+
class NT: # named tuple
|
9 |
+
def __init__(self, **entries): self.__dict__.update(entries)
|
10 |
+
def __repr__(self):
|
11 |
+
r = 'NT: |'
|
12 |
+
for i in dir(self):
|
13 |
+
if not i.startswith('__') and not isinstance(getattr(self, i), types.MethodType): r += '{} --> {}, '.format(i, getattr(self, i))
|
14 |
+
return r[:-2]+'|'
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
class HPC_Exception(Exception):
|
19 |
+
def __init__(self, value): self.value = value
|
20 |
+
def __str__(self): return self.value
|
21 |
+
|
22 |
+
|
23 |
+
|
24 |
+
def execute(message, command_line, return_='status', until_successes=False, terminate_on_failure=True, silent=False, silence_output=False, tracer=print):
|
25 |
+
if not silent: tracer(message); tracer(command_line); sys.stdout.flush();
|
26 |
+
while True:
|
27 |
+
|
28 |
+
p = subprocess.Popen(command_line, bufsize=0, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
29 |
+
output, errors = p.communicate()
|
30 |
+
|
31 |
+
output = output + errors
|
32 |
+
|
33 |
+
output = output.decode(encoding="utf-8", errors="replace")
|
34 |
+
|
35 |
+
exit_code = p.returncode
|
36 |
+
|
37 |
+
if exit_code and not (silent or silence_output): tracer(output); sys.stdout.flush();
|
38 |
+
|
39 |
+
if exit_code and until_successes: pass # Thats right - redability COUNT!
|
40 |
+
else: break
|
41 |
+
|
42 |
+
tracer( "Error while executing {}: {}\n".format(message, output) )
|
43 |
+
tracer("Sleeping 60s... then I will retry...")
|
44 |
+
sys.stdout.flush();
|
45 |
+
time.sleep(60)
|
46 |
+
|
47 |
+
if return_ == 'tuple': return(exit_code, output)
|
48 |
+
|
49 |
+
if exit_code and terminate_on_failure:
|
50 |
+
tracer("\nEncounter error while executing: " + command_line)
|
51 |
+
if return_==True: return True
|
52 |
+
else: print("\nEncounter error while executing: " + command_line + '\n' + output); sys.exit(1)
|
53 |
+
|
54 |
+
if return_ == 'output': return output
|
55 |
+
else: return False
|
56 |
+
|
57 |
+
|
58 |
+
def Sleep(time_, message, dict_={}):
|
59 |
+
''' Fancy sleep function '''
|
60 |
+
len_ = 0
|
61 |
+
for i in range(time_, 0, -1):
|
62 |
+
#print "Waiting for a new revision:%s... Sleeping...%d \r" % (sc.revision, i),
|
63 |
+
msg = message.format( **dict(dict_, time_left=i) )
|
64 |
+
print( msg, end='' )
|
65 |
+
len_ = max(len_, len(msg))
|
66 |
+
sys.stdout.flush()
|
67 |
+
time_module.sleep(1)
|
68 |
+
|
69 |
+
print( ' '*len_ + '\r', end='' ) # erazing sleep message
|
70 |
+
|
71 |
+
|
72 |
+
# Abstract class for HPC job submission
|
73 |
+
class HPC_Driver:
|
74 |
+
def __init__(self, working_dir, config, tracer=lambda x:None, set_daemon_message=lambda x:None):
|
75 |
+
self.working_dir = working_dir
|
76 |
+
self.config = config
|
77 |
+
self.cpu_usage = 0.0 # cummulative cpu usage in hours
|
78 |
+
self.tracer = tracer
|
79 |
+
self.set_daemon_message = set_daemon_message
|
80 |
+
|
81 |
+
self.cpu_count = self.config['cpu_count'] if type(config) == dict else self.config.getint('DEFAULT', 'cpu_count')
|
82 |
+
|
83 |
+
self.jobs = [] # list of all jobs currently running by this driver, Job class is driver depended, could be just int or something more complex
|
84 |
+
|
85 |
+
self.install_signal_handler()
|
86 |
+
|
87 |
+
|
88 |
+
def __del__(self):
|
89 |
+
self.remove_signal_handler()
|
90 |
+
|
91 |
+
|
92 |
+
def execute(self, executable, arguments, working_dir, log_dir=None, name='_no_name_', memory=256, time=24, shell_wrapper=False, block=True):
|
93 |
+
''' Execute given command line on HPC cluster, must accumulate cpu hours in self.cpu_usage '''
|
94 |
+
if log_dir==None: log_dir=self.working_dir
|
95 |
+
|
96 |
+
if shell_wrapper:
|
97 |
+
shell_wrapper_sh = os.path.abspath(self.working_dir + '/hpc.{}.shell_wrapper.sh'.format(name))
|
98 |
+
with file(shell_wrapper_sh, 'w') as f: f.write('#!/bin/bash\n{} {}\n'.format(executable, arguments)); os.fchmod(f.fileno(), stat.S_IEXEC | stat.S_IREAD | stat.S_IWRITE)
|
99 |
+
executable, arguments = shell_wrapper_sh, ''
|
100 |
+
|
101 |
+
return self.submit_serial_hpc_job(name=name, executable=executable, arguments=arguments, working_dir=working_dir, log_dir=log_dir, jobs_to_queue=1, memory=memory, time=time, block=block, shell_wrapper=shell_wrapper)
|
102 |
+
|
103 |
+
|
104 |
+
|
105 |
+
@property
|
106 |
+
def number_of_cpu_per_node(self):
|
107 |
+
must_be_implemented_in_inherited_classes
|
108 |
+
|
109 |
+
@property
|
110 |
+
def maximum_number_of_mpi_cpu(self):
|
111 |
+
must_be_implemented_in_inherited_classes
|
112 |
+
|
113 |
+
|
114 |
+
def submit_hpc_job(self, name, executable, arguments, working_dir, jobs_to_queue, log_dir, memory=512, time=12, block=True, shell_wrapper=False):
|
115 |
+
print('submit_hpc_job is DEPRECATED and will be removed in near future, please use submit_serial_hpc_job instead!')
|
116 |
+
must_be_implemented_in_inherited_classes
|
117 |
+
|
118 |
+
|
119 |
+
def submit_serial_hpc_job(self, name, executable, arguments, working_dir, jobs_to_queue, log_dir, memory=512, time=12, block=True, shell_wrapper=False):
|
120 |
+
must_be_implemented_in_inherited_classes
|
121 |
+
|
122 |
+
|
123 |
+
def submit_mpi_hpc_job(self, name, executable, arguments, working_dir, log_dir, memory=512, time=12, block=True, process_coefficient="1", requested_nodes=1, requested_processes_per_node=1):
|
124 |
+
''' submit jobs as MPI job
|
125 |
+
process_coefficient should be string representing fraction of process to launch on each node, for example '3 / 4' will start only 75% of MPI process's on each node
|
126 |
+
'''
|
127 |
+
must_be_implemented_in_inherited_classes
|
128 |
+
|
129 |
+
|
130 |
+
def cancel_all_jobs(self):
|
131 |
+
''' Cancel all HPC jobs known to this driver, use this as signal handler for script termination '''
|
132 |
+
for j in self.jobs: self.cancel_job(j)
|
133 |
+
|
134 |
+
def block_until(self, silent, fn, *args, **kwargs):
|
135 |
+
'''
|
136 |
+
**fn must have the driver as the first argument**
|
137 |
+
example:
|
138 |
+
def fn(driver):
|
139 |
+
jobs = list(driver.jobs)
|
140 |
+
jobs = [job for job in jobs if not driver.complete(job)]
|
141 |
+
if len(jobs) <= 8:
|
142 |
+
return False # stops sleeping
|
143 |
+
return True # continues sleeping
|
144 |
+
|
145 |
+
for x in range(100):
|
146 |
+
hpc_driver.submit_hpc_job(...)
|
147 |
+
hpc_driver.block_until(False, fn)
|
148 |
+
'''
|
149 |
+
while fn(self, *args, **kwargs):
|
150 |
+
sys.stdout.flush()
|
151 |
+
time_module.sleep(60)
|
152 |
+
if not silent:
|
153 |
+
Sleep(1, '"Waiting for HPC job(s) to finish, sleeping {time_left}s\r')
|
154 |
+
|
155 |
+
def wait_until_complete(self, jobs=None, callback=None, silent=False):
|
156 |
+
''' Helper function, wait until given jobs list is finished, if no argument is given waits until all jobs known by driver is finished '''
|
157 |
+
jobs = jobs if jobs else self.jobs
|
158 |
+
|
159 |
+
while jobs:
|
160 |
+
for j in jobs[:]:
|
161 |
+
if self.complete(j): jobs.remove(j)
|
162 |
+
|
163 |
+
if jobs:
|
164 |
+
#total_cpu_queued = sum( [j.jobs_queued for j in jobs] )
|
165 |
+
#total_cpu_running = sum( [j.cpu_running for j in jobs] )
|
166 |
+
#self.set_daemon_message("Waiting for HPC job(s) to finish... [{} process(es) in queue, {} process(es) running]".format(total_cpu_queued, total_cpu_running) )
|
167 |
+
#self.tracer("Waiting for HPC job(s) [{} process(es) in queue, {} process(es) running]... \r".format(total_cpu_queued, total_cpu_running), end='')
|
168 |
+
#print "Waiting for {} HPC jobs to finish... [{} jobs in queue, {} jobs running]... Sleeping 32s... \r".format(total_cpu_queued, cpu_queued+cpu_running, cpu_running),
|
169 |
+
|
170 |
+
self.set_daemon_message("Waiting for HPC {} job(s) to finish...".format( len(jobs) ) )
|
171 |
+
#self.tracer("Waiting for HPC {} job(s) to finish...".format( len(jobs) ) )
|
172 |
+
|
173 |
+
sys.stdout.flush()
|
174 |
+
|
175 |
+
if callback: callback()
|
176 |
+
|
177 |
+
if silent: time_module.sleep(64*1)
|
178 |
+
else: Sleep(64, '"Waiting for HPC {n_jobs} job(s) to finish, sleeping {time_left}s \r', dict(n_jobs=len(jobs)))
|
179 |
+
|
180 |
+
|
181 |
+
|
182 |
+
_signals_ = [signal_module.SIGINT, signal_module.SIGTERM, signal_module.SIGABRT]
|
183 |
+
def install_signal_handler(self):
|
184 |
+
def signal_handler(signal_, frame):
|
185 |
+
self.tracer('Recieved signal:{}... Canceling HPC jobs...'.format(signal_) )
|
186 |
+
self.cancel_all_jobs()
|
187 |
+
self.set_daemon_message( 'Remote daemon got terminated with signal:{}'.format(signal_) )
|
188 |
+
sys.exit(1)
|
189 |
+
|
190 |
+
for s in self._signals_: signal_module.signal(s, signal_handler)
|
191 |
+
|
192 |
+
|
193 |
+
def remove_signal_handler(self): # do we really need this???
|
194 |
+
try:
|
195 |
+
for s in self._signals_: signal_module.signal(s, signal_module.SIG_DFL)
|
196 |
+
#print('remove_signal_handler: done!')
|
197 |
+
|
198 |
+
except TypeError:
|
199 |
+
#print('remove_signal_handler: interpreted terminating, skipping remove_signal_handler...')
|
200 |
+
pass
|
201 |
+
|
202 |
+
|
203 |
+
def cancel_job(self, job_id):
|
204 |
+
must_be_implemented_in_inherited_classes
|
205 |
+
|
206 |
+
|
207 |
+
def complete(self, job_id):
|
208 |
+
''' Return job completion status. Return True if job complered and False otherwise
|
209 |
+
'''
|
210 |
+
must_be_implemented_in_inherited_classes
|
RFdiffusion/.rosetta-ci/hpc_drivers/multicore.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# :noTabs=true:
|
3 |
+
|
4 |
+
import time as time_module
|
5 |
+
import codecs
|
6 |
+
import signal
|
7 |
+
|
8 |
+
import os, sys
|
9 |
+
|
10 |
+
try:
|
11 |
+
from .base import *
|
12 |
+
|
13 |
+
except ImportError: # workaround for B2 back-end's
|
14 |
+
import imp
|
15 |
+
imp.load_source(__name__, '/'.join(__file__.split('/')[:-1]) + '/base.py') # A bit of Python magic here, what we trying to say is this: from base import *, but path to base is calculated from our source location # from base import HPC_Driver, execute, NT
|
16 |
+
|
17 |
+
|
18 |
+
class MultiCore_HPC_Driver(HPC_Driver):
|
19 |
+
|
20 |
+
class JobID:
|
21 |
+
def __init__(self, pids=None):
|
22 |
+
self.pids = pids if pids else []
|
23 |
+
|
24 |
+
|
25 |
+
def __bool__(self): return bool(self.pids)
|
26 |
+
|
27 |
+
|
28 |
+
def __len__(self): return len(self.pids)
|
29 |
+
|
30 |
+
|
31 |
+
def add_pid(self, pid): self.pids.append(pid)
|
32 |
+
|
33 |
+
|
34 |
+
def remove_completed_pids(self):
|
35 |
+
for pid in self.pids[:]:
|
36 |
+
try:
|
37 |
+
r = os.waitpid(pid, os.WNOHANG)
|
38 |
+
if r == (pid, 0): self.pids.remove(pid) # process have ended without error
|
39 |
+
elif r[0] == pid : # process ended but with error, special case we will have to wait for all process to terminate and call system exit.
|
40 |
+
#self.cancel_job()
|
41 |
+
#sys.exit(1)
|
42 |
+
self.pids.remove(pid)
|
43 |
+
print('ERROR: Some of the HPC jobs terminated abnormally! Please see HPC logs for details.')
|
44 |
+
|
45 |
+
except ChildProcessError: self.pids.remove(pid)
|
46 |
+
|
47 |
+
|
48 |
+
def cancel(self):
|
49 |
+
for pid in self.pids:
|
50 |
+
try:
|
51 |
+
os.killpg(os.getpgid(pid), signal.SIGKILL)
|
52 |
+
except ChildProcessError: pass
|
53 |
+
|
54 |
+
self.pids = []
|
55 |
+
|
56 |
+
|
57 |
+
|
58 |
+
def __init__(self, *args, **kwds):
|
59 |
+
HPC_Driver.__init__(self, *args, **kwds)
|
60 |
+
#print(f'MultiCore_HPC_Driver: cpu_count: {self.cpu_count}')
|
61 |
+
|
62 |
+
|
63 |
+
def remove_completed_jobs(self):
|
64 |
+
for job in self.jobs[:]: # Need to make a copy so we don't modify a list we're iterating over
|
65 |
+
job.remove_completed_pids()
|
66 |
+
if not job: self.jobs.remove(job)
|
67 |
+
|
68 |
+
|
69 |
+
@property
|
70 |
+
def process_count(self):
|
71 |
+
''' return number of processes that currently ran by this driver instance
|
72 |
+
'''
|
73 |
+
return sum( map(len, self.jobs) )
|
74 |
+
|
75 |
+
|
76 |
+
def submit_hpc_job(self, name, executable, arguments, working_dir, jobs_to_queue, log_dir, memory=512, time=12, block=True, shell_wrapper=False):
|
77 |
+
print('submit_hpc_job is DEPRECATED and will be removed in near future, please use submit_serial_hpc_job instead!')
|
78 |
+
return self.submit_serial_hpc_job(name, executable, arguments, working_dir, jobs_to_queue, log_dir, memory, time, block, shell_wrapper)
|
79 |
+
|
80 |
+
|
81 |
+
def submit_serial_hpc_job(self, name, executable, arguments, working_dir, jobs_to_queue, log_dir, memory=512, time=12, block=True, shell_wrapper=False):
|
82 |
+
cpu_usage = -time_module.time()/60./60.
|
83 |
+
|
84 |
+
if shell_wrapper:
|
85 |
+
shell_wrapper_sh = os.path.abspath(self.working_dir + f'/hpc.{name}.shell_wrapper.sh')
|
86 |
+
with open(shell_wrapper_sh, 'w') as f: f.write('#!/bin/bash\n{} {}\n'.format(executable, arguments)); os.fchmod(f.fileno(), stat.S_IEXEC | stat.S_IREAD | stat.S_IWRITE)
|
87 |
+
executable, arguments = shell_wrapper_sh, ''
|
88 |
+
|
89 |
+
def mfork():
|
90 |
+
''' Check if number of child process is below cpu_count. And if it is - fork the new pocees and return its pid.
|
91 |
+
'''
|
92 |
+
while self.process_count >= self.cpu_count:
|
93 |
+
self.remove_completed_jobs()
|
94 |
+
if self.process_count >= self.cpu_count: time_module.sleep(.5)
|
95 |
+
|
96 |
+
sys.stdout.flush()
|
97 |
+
pid = os.fork()
|
98 |
+
# appending at caller level insted if pid: self.jobs.append(pid) # We are parent!
|
99 |
+
return pid
|
100 |
+
|
101 |
+
current_job = self.JobID()
|
102 |
+
process = 0
|
103 |
+
for i in range(jobs_to_queue):
|
104 |
+
|
105 |
+
pid = mfork()
|
106 |
+
if not pid: # we are child process
|
107 |
+
command_line = 'cd {} && {} {}'.format(working_dir, executable, arguments.format(process=process) )
|
108 |
+
exit_code, log = execute('Running job {}.{}...'.format(name, i), command_line, tracer=self.tracer, return_='tuple')
|
109 |
+
with codecs.open(log_dir+'/.hpc.{name}.{i:02d}.log'.format(**vars()), 'w', encoding='utf-8', errors='replace') as f:
|
110 |
+
f.write(command_line+'\n'+log)
|
111 |
+
if exit_code:
|
112 |
+
error_report = f'\n\n{command_line}\nERROR: PROCESS {name}.{i:02d} TERMINATED WITH NON-ZERO-EXIT-CODE {exit_code}!\n'
|
113 |
+
f.write(error_report)
|
114 |
+
print(log, error_report)
|
115 |
+
|
116 |
+
sys.exit(0)
|
117 |
+
|
118 |
+
else: # we are parent!
|
119 |
+
current_job.add_pid(pid)
|
120 |
+
# Need to potentially re-add to list, as remove_completed_jobs() might trim it.
|
121 |
+
if current_job not in self.jobs: self.jobs.append(current_job)
|
122 |
+
|
123 |
+
process += 1
|
124 |
+
|
125 |
+
if block:
|
126 |
+
#for p in all_queued_jobs: os.waitpid(p, 0) # waiting for all child process to termintate...
|
127 |
+
|
128 |
+
self.wait_until_complete(current_job)
|
129 |
+
self.remove_completed_jobs()
|
130 |
+
|
131 |
+
cpu_usage += time_module.time()/60./60.
|
132 |
+
self.cpu_usage += cpu_usage * jobs_to_queue # approximation...
|
133 |
+
|
134 |
+
current_job = self.JobID()
|
135 |
+
|
136 |
+
return current_job
|
137 |
+
|
138 |
+
|
139 |
+
@property
|
140 |
+
def number_of_cpu_per_node(self): return self.cpu_count
|
141 |
+
|
142 |
+
|
143 |
+
@property
|
144 |
+
def maximum_number_of_mpi_cpu(self): return self.cpu_count
|
145 |
+
|
146 |
+
|
147 |
+
def submit_mpi_hpc_job(self, name, executable, arguments, working_dir, log_dir, memory=512, time=12, block=True, process_coefficient="1", requested_nodes=1, requested_processes_per_node=1):
|
148 |
+
|
149 |
+
if requested_nodes > 1:
|
150 |
+
print( "WARNING: " + str( requested_nodes ) + " nodes were requested, but we're running locally, so only 1 node will be used." )
|
151 |
+
|
152 |
+
if requested_processes_per_node > self.cpu_count:
|
153 |
+
print( "WARNING: " + str(requested_processes_per_node) + " processes were requested, but I only have " + str(self.cpu_count) + " CPUs. Will launch " + str(self.cpu_count) + " processes." )
|
154 |
+
actual_processes = min( requested_processes_per_node, self.cpu_count )
|
155 |
+
|
156 |
+
cpu_usage = -time_module.time()/60./60.
|
157 |
+
|
158 |
+
arguments = arguments.format(process=0)
|
159 |
+
|
160 |
+
command_line = f'cd {working_dir} && mpirun -np {actual_processes} {executable} {arguments}'
|
161 |
+
log = execute(f'Running job {name}...', command_line, tracer=self.tracer, return_='output')
|
162 |
+
with codecs.open(log_dir+'/.hpc.{name}.log'.format(**vars()), 'w', encoding='utf-8', errors='replace') as f: f.write(command_line+'\n'+log)
|
163 |
+
|
164 |
+
cpu_usage += time_module.time()/60./60.
|
165 |
+
self.cpu_usage += cpu_usage * actual_processes # approximation...
|
166 |
+
|
167 |
+
# return None - we do not return anything from this version of submit which imply returning None which in turn will be treated as job-id for already finished job
|
168 |
+
|
169 |
+
|
170 |
+
def complete(self, job_id):
|
171 |
+
''' Return job completion status. Return True if job completed and False otherwise
|
172 |
+
'''
|
173 |
+
self.remove_completed_jobs()
|
174 |
+
return job_id not in self.jobs
|
175 |
+
|
176 |
+
|
177 |
+
def cancel_job(self, job):
|
178 |
+
job.cancel();
|
179 |
+
if job in self.jobs:
|
180 |
+
self.jobs.remove(job)
|
181 |
+
|
182 |
+
|
183 |
+
def __repr__(self):
|
184 |
+
return 'MultiCore_HPC_Driver<>'
|
RFdiffusion/.rosetta-ci/hpc_drivers/slurm.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# :noTabs=true:
|
3 |
+
|
4 |
+
import os, sys, time, collections, math
|
5 |
+
import stat as stat_module
|
6 |
+
|
7 |
+
|
8 |
+
try:
|
9 |
+
from .base import *
|
10 |
+
|
11 |
+
except ImportError: # workaround for B2 back-end's
|
12 |
+
import imp
|
13 |
+
imp.load_source(__name__, '/'.join(__file__.split('/')[:-1]) + '/base.py') # A bit of Python magic here, what we trying to say is this: from base import *, but path to base is calculated from our source location # from base import HPC_Driver, execute, NT
|
14 |
+
|
15 |
+
|
16 |
+
_T_slurm_array_job_template_ = '''\
|
17 |
+
#!/bin/bash
|
18 |
+
#
|
19 |
+
#SBATCH --job-name={name}
|
20 |
+
#SBATCH --output={log_dir}/.hpc.%x.%a.output
|
21 |
+
#
|
22 |
+
#SBATCH --time={time}:00
|
23 |
+
#SBATCH --mem-per-cpu={memory}M
|
24 |
+
#SBATCH --chdir={working_dir}
|
25 |
+
#
|
26 |
+
#SBATCH --array=1-{jobs_to_queue}
|
27 |
+
|
28 |
+
srun {executable} {arguments}
|
29 |
+
'''
|
30 |
+
|
31 |
+
_T_slurm_mpi_job_template_ = '''\
|
32 |
+
#!/bin/bash
|
33 |
+
#
|
34 |
+
#SBATCH --job-name={name}
|
35 |
+
#SBATCH --output={log_dir}/.hpc.%x.output
|
36 |
+
#
|
37 |
+
#SBATCH --time={time}:00
|
38 |
+
#SBATCH --mem-per-cpu={memory}M
|
39 |
+
#SBATCH --chdir={working_dir}
|
40 |
+
#
|
41 |
+
#SBATCH --ntasks={ntasks}
|
42 |
+
|
43 |
+
mpirun {executable} {arguments}
|
44 |
+
'''
|
45 |
+
|
46 |
+
class Slurm_HPC_Driver(HPC_Driver):
|
47 |
+
def head_node_execute(self, message, command_line, *args, **kwargs):
|
48 |
+
head_node = self.config['slurm'].get('head_node')
|
49 |
+
|
50 |
+
command_line, host = (f"ssh {head_node} cd `pwd` '&& {command_line}'", head_node) if head_node else (command_line, 'localhost')
|
51 |
+
return execute(f'Executiong on {host}: {message}' if message else '', command_line, *args, **kwargs)
|
52 |
+
|
53 |
+
|
54 |
+
# NodeGroup = collections.namedtuple('NodeGroup', 'nodes cores')
|
55 |
+
|
56 |
+
# @property
|
57 |
+
# def mpi_topology(self):
|
58 |
+
# ''' return list of NodeGroup's
|
59 |
+
# '''
|
60 |
+
# pass
|
61 |
+
|
62 |
+
|
63 |
+
# @property
|
64 |
+
# def number_of_cpu_per_node(self): return int( self.config['condor']['mpi_cpu_per_node'] )
|
65 |
+
|
66 |
+
# @property
|
67 |
+
# def maximum_number_of_mpi_cpu(self):
|
68 |
+
# return self.number_of_cpu_per_node * int( self.config['condor']['mpi_maximum_number_of_nodes'] )
|
69 |
+
|
70 |
+
|
71 |
+
# def complete(self, condor_job_id):
|
72 |
+
# ''' Return job completion status. Note that single hpc_job may contatin inner list of individual HPC jobs, True should be return if they all run in to completion.
|
73 |
+
# '''
|
74 |
+
|
75 |
+
# execute('Releasing condor jobs...', 'condor_release $USER', return_='tuple')
|
76 |
+
|
77 |
+
# s = execute('', 'condor_q $USER | grep $USER | grep {}'.format(condor_job_id), return_='output', terminate_on_failure=False).replace(' ', '').replace('\n', '')
|
78 |
+
# if s: return False
|
79 |
+
|
80 |
+
# # #setDaemonStatusAndPing('[Job #%s] Running... %s condor job(s) in queue...' % (self.id, len(s.split('\n') ) ) )
|
81 |
+
# # n_jobs = len(s.split('\n'))
|
82 |
+
# # s, o = execute('', 'condor_userprio -all | grep $USER@', return_='tuple')
|
83 |
+
# # if s == 0:
|
84 |
+
# # jobs_running = o.split()
|
85 |
+
# # jobs_running = 'XX' if len(jobs_running) < 4 else jobs_running[4]
|
86 |
+
# # self.set_daemon_message("Waiting for condor to finish HPC jobs... [{} jobs in HPC-Queue, {} CPU's used]".format(n_jobs, jobs_running) )
|
87 |
+
# # print "{} condor jobs in queue... Sleeping 32s... \r".format(n_jobs),
|
88 |
+
# # sys.stdout.flush()
|
89 |
+
# # time.sleep(32)
|
90 |
+
# else:
|
91 |
+
|
92 |
+
# #self.tracer('Waiting for condor to finish the jobs... DONE')
|
93 |
+
# self.jobs.remove(condor_job_id)
|
94 |
+
# self.cpu_usage += self.get_condor_accumulated_usage()
|
95 |
+
# return True # jobs already finished, we return empty list to prevent double counting of cpu_usage
|
96 |
+
|
97 |
+
|
98 |
+
def complete(self, slurm_job_id):
|
99 |
+
''' Return True if job with given id is complete
|
100 |
+
'''
|
101 |
+
|
102 |
+
s = self.head_node_execute('', f'squeue -j {slurm_job_id} --noheader', return_='output', terminate_on_failure=False, silent=True)
|
103 |
+
if s: return False
|
104 |
+
else:
|
105 |
+
#self.tracer('Waiting for condor to finish the jobs... DONE')
|
106 |
+
self.jobs.remove(slurm_job_id)
|
107 |
+
return True # jobs already finished, we return empty list to prevent double counting of cpu_usage
|
108 |
+
|
109 |
+
|
110 |
+
def cancel_job(self, slurm_job_id):
|
111 |
+
self.head_node_execute(f'Slurm_HPC_Driver.canceling job {slurm_job_id}...', f'scancel {slurm_job_id}', terminate_on_failure=False)
|
112 |
+
|
113 |
+
|
114 |
+
# def submit_hpc_job(self, name, executable, arguments, working_dir, jobs_to_queue, log_dir, memory=512, time=12, block=True, shell_wrapper=False):
|
115 |
+
# print('submit_hpc_job is DEPRECATED and will be removed in near future, please use submit_serial_hpc_job instead!')
|
116 |
+
# return self.submit_serial_hpc_job(name, executable, arguments, working_dir, jobs_to_queue, log_dir, memory, time, block, shell_wrapper)
|
117 |
+
|
118 |
+
|
119 |
+
def submit_serial_hpc_job(self, name, executable, arguments, working_dir, jobs_to_queue, log_dir, memory=512, time=12, block=True, shell_wrapper=False):
|
120 |
+
|
121 |
+
arguments = arguments.format(process='%a') # %a is SLURM array index
|
122 |
+
time = int( math.ceil(time*60) )
|
123 |
+
|
124 |
+
if shell_wrapper:
|
125 |
+
shell_wrapper_sh = os.path.abspath(self.working_dir + f'/hpc.{name}.shell_wrapper.sh')
|
126 |
+
with open(shell_wrapper_sh, 'w') as f: f.write('#!/bin/bash\n{} {}\n'.format(executable, arguments)); os.fchmod(f.fileno(), stat.S_IEXEC | stat.S_IREAD | stat.S_IWRITE)
|
127 |
+
executable, arguments = shell_wrapper_sh, ''
|
128 |
+
|
129 |
+
slurm_file = working_dir + f'/.hpc.{name}.slurm'
|
130 |
+
|
131 |
+
with open(slurm_file, 'w') as f: f.write( _T_slurm_array_job_template_.format( **vars() ) )
|
132 |
+
|
133 |
+
|
134 |
+
slurm_job_id = self.head_node_execute('Submitting SLURM array job...', f'cd {self.working_dir} && sbatch {slurm_file}',
|
135 |
+
tracer=self.tracer, return_='output'
|
136 |
+
).split()[-1] # expecting something like `Submitted batch job 6122` in output
|
137 |
+
|
138 |
+
|
139 |
+
self.jobs.append(slurm_job_id)
|
140 |
+
|
141 |
+
if block:
|
142 |
+
self.wait_until_complete( [slurm_job_id] )
|
143 |
+
return None
|
144 |
+
|
145 |
+
else: return slurm_job_id
|
146 |
+
|
147 |
+
|
148 |
+
|
149 |
+
|
150 |
+
|
151 |
+
def submit_mpi_hpc_job(self, name, executable, arguments, working_dir, log_dir, ntasks, memory=512, time=12, block=True, shell_wrapper=False):
|
152 |
+
''' submit jobs as MPI job
|
153 |
+
'''
|
154 |
+
arguments = arguments.format(process='0')
|
155 |
+
time = int( math.ceil(time*60) )
|
156 |
+
|
157 |
+
if shell_wrapper:
|
158 |
+
shell_wrapper_sh = os.path.abspath(self.working_dir + f'/hpc.{name}.shell_wrapper.sh')
|
159 |
+
with open(shell_wrapper_sh, 'w') as f: f.write('#!/bin/bash\n{} {}\n'.format(executable, arguments)); os.fchmod(f.fileno(), stat.S_IEXEC | stat.S_IREAD | stat.S_IWRITE)
|
160 |
+
executable, arguments = shell_wrapper_sh, ''
|
161 |
+
|
162 |
+
slurm_file = working_dir + f'/.hpc.{name}.slurm'
|
163 |
+
|
164 |
+
with open(slurm_file, 'w') as f: f.write( _T_slurm_mpi_job_template_.format( **vars() ) )
|
165 |
+
|
166 |
+
slurm_job_id = self.head_node_execute('Submitting SLURM mpi job...', f'cd {self.working_dir} && sbatch {slurm_file}',
|
167 |
+
tracer=self.tracer, return_='output'
|
168 |
+
).split()[-1] # expecting something like `Submitted batch job 6122` in output
|
169 |
+
|
170 |
+
self.jobs.append(slurm_job_id)
|
171 |
+
|
172 |
+
if block:
|
173 |
+
self.wait_until_complete( [slurm_job_id] )
|
174 |
+
return None
|
175 |
+
|
176 |
+
else: return slurm_job_id
|
RFdiffusion/.rosetta-ci/test-sets.yaml
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# map platform-string → platform definiton
|
2 |
+
platforms:
|
3 |
+
ubuntu-20.04.gcc:
|
4 |
+
os: ubuntu-20.04
|
5 |
+
compiler: gcc
|
6 |
+
python: '3.9'
|
7 |
+
|
8 |
+
ubuntu-20.04.clang:
|
9 |
+
os: ubuntu-20.04
|
10 |
+
compiler: clang
|
11 |
+
python: '3.9'
|
12 |
+
|
13 |
+
|
14 |
+
# map of test-set-name → tests
|
15 |
+
test-sets:
|
16 |
+
main:
|
17 |
+
- ubuntu-20.04.clang.rfd
|
18 |
+
|
19 |
+
python:
|
20 |
+
- ubuntu-20.04.gcc.self.python
|
21 |
+
- ubuntu-20.04.clang.self.python
|
22 |
+
|
23 |
+
self:
|
24 |
+
- ubuntu-20.04.gcc.self.state
|
25 |
+
- ubuntu-20.04.gcc.self.subtests
|
26 |
+
- ubuntu-20.04.gcc.self.release
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
# map of GitHub-label → [test-set]
|
31 |
+
github-label-test-sets:
|
32 |
+
00 main: [main]
|
33 |
+
10 self: [self]
|
34 |
+
16 python: [python]
|
35 |
+
|
36 |
+
|
37 |
+
# map of submit-page-category → tests
|
38 |
+
# tests that does not get assigned will be automatically displayed in 'other' category
|
39 |
+
category-tests:
|
40 |
+
main:
|
41 |
+
- rfd
|
42 |
+
|
43 |
+
self:
|
44 |
+
- self.state
|
45 |
+
- self.subtests
|
46 |
+
- self.release
|
47 |
+
- self.python
|
48 |
+
|
49 |
+
|
50 |
+
# map branch → test-set to
|
51 |
+
# specify list of tests that should be applied by-default during testing of each new commits to specific branch
|
52 |
+
branch-test-sets:
|
53 |
+
main: [main]
|
54 |
+
benchmark: [main, python]
|
55 |
+
|
56 |
+
|
57 |
+
# map branch → test-sets for pull-request's
|
58 |
+
# specify which test-sets should be scheduled for PR's by-default (ie in addition to GH labels applied)
|
59 |
+
# use empty branch name to specify defult value for (ie any branch not explicitly listed)
|
60 |
+
pull-request-branch-test-sets:
|
61 |
+
# specific test sets for benchmark branch
|
62 |
+
benchmark: ['main', 'python']
|
63 |
+
|
64 |
+
# default, will apply to PR's to any other branch
|
65 |
+
'': ['main']
|
RFdiffusion/.rosetta-ci/tests/__init__.py
ADDED
@@ -0,0 +1,765 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# :noTabs=true:
|
4 |
+
|
5 |
+
# (c) Copyright Rosetta Commons Member Institutions.
|
6 |
+
# (c) This file is part of the Rosetta software suite and is made available under license.
|
7 |
+
# (c) The Rosetta software is developed by the contributing members of the Rosetta Commons.
|
8 |
+
# (c) For more information, see http://www.rosettacommons.org. Questions about this can be
|
9 |
+
# (c) addressed to University of Washington CoMotion, email: license@uw.edu.
|
10 |
+
|
11 |
+
## @file tests/__init__.py
|
12 |
+
## @brief Common constats and types for all test types
|
13 |
+
## @author Sergey Lyskov
|
14 |
+
|
15 |
+
import os, time, sys, shutil, codecs, urllib.request, imp, subprocess, json, hashlib # urllib.error, urllib.parse,
|
16 |
+
import platform as platform_module
|
17 |
+
import types as types_module
|
18 |
+
|
19 |
+
# ⚔ do not change wording below, it have to stay in sync with upstream (up to benchmark-model).
|
20 |
+
# Copied from benchmark-model, standard state code's for tests results.
|
21 |
+
|
22 |
+
__all__ = ['execute',
|
23 |
+
'_S_Values_', '_S_draft_', '_S_queued_', '_S_running_', '_S_passed_', '_S_failed_', '_S_build_failed_', '_S_script_failed_',
|
24 |
+
'_StateKey_', '_ResultsKey_', '_LogKey_', '_DescriptionKey_', '_TestsKey_',
|
25 |
+
'_multi_step_config_', '_multi_step_error_', '_multi_step_result_',
|
26 |
+
'to_bytes',
|
27 |
+
]
|
28 |
+
|
29 |
+
_S_draft_ = 'draft'
|
30 |
+
_S_queued_ = 'queued'
|
31 |
+
_S_running_ = 'running'
|
32 |
+
_S_passed_ = 'passed'
|
33 |
+
_S_failed_ = 'failed'
|
34 |
+
_S_build_failed_ = 'build failed'
|
35 |
+
_S_script_failed_ = 'script failed'
|
36 |
+
_S_queued_for_comparison_ = 'queued for comparison'
|
37 |
+
|
38 |
+
_S_Values_ = [_S_draft_, _S_queued_, _S_running_, _S_passed_, _S_failed_, _S_build_failed_, _S_script_failed_, _S_queued_for_comparison_]
|
39 |
+
|
40 |
+
_IgnoreKey_ = 'ignore'
|
41 |
+
_StateKey_ = 'state'
|
42 |
+
_ResultsKey_ = 'results'
|
43 |
+
_LogKey_ = 'log'
|
44 |
+
_DescriptionKey_ = 'description'
|
45 |
+
_TestsKey_ = 'tests'
|
46 |
+
_SummaryKey_ = 'summary'
|
47 |
+
_FailedKey_ = 'failed'
|
48 |
+
_TotalKey_ = 'total'
|
49 |
+
_PlotsKey_ = 'plots'
|
50 |
+
_FailedTestsKey_ = 'failed_tests'
|
51 |
+
_HtmlKey_ = 'html'
|
52 |
+
|
53 |
+
# file names for multi-step test files
|
54 |
+
_multi_step_config_ = 'config.json'
|
55 |
+
_multi_step_error_ = 'error.json'
|
56 |
+
_multi_step_result_ = 'result.json'
|
57 |
+
|
58 |
+
PyRosetta_unix_memory_requirement_per_cpu = 6 # Memory per sub-process in Gb's
|
59 |
+
PyRosetta_unix_unit_test_memory_requirement_per_cpu = 3.0 # Memory per sub-process in Gb's for running PyRosetta unit tests
|
60 |
+
|
61 |
+
# Commands to run all the scripts needed for setting up Rosetta compiles. (Run from main/source directory)
|
62 |
+
PRE_COMPILE_SETUP_SCRIPTS = [ "./update_options.sh", "./update_submodules.sh", "./update_ResidueType_enum_files.sh", "python version.py" ]
|
63 |
+
|
64 |
+
DEFAULT_PYTHON_VERSION='3.9'
|
65 |
+
|
66 |
+
# Standard funtions and classes below ---------------------------------------------------------------------------------
|
67 |
+
|
68 |
+
class BenchmarkError(Exception):
|
69 |
+
def __init__(self, value): self.value = value
|
70 |
+
def __repr__(self): return self.value
|
71 |
+
def __str__(self): return self.value
|
72 |
+
|
73 |
+
|
74 |
+
class NT: # named tuple
|
75 |
+
def __init__(self, **entries): self.__dict__.update(entries)
|
76 |
+
def __repr__(self):
|
77 |
+
r = 'NT: |'
|
78 |
+
for i in dir(self):
|
79 |
+
print(i)
|
80 |
+
if not i.startswith('__') and i != '_as_dict' and not isinstance(getattr(self, i), types_module.MethodType): r += '%s --> %s, ' % (i, getattr(self, i))
|
81 |
+
return r[:-2]+'|'
|
82 |
+
|
83 |
+
@property
|
84 |
+
def _as_dict(self):
|
85 |
+
return { a: getattr(self, a) for a in dir(self) if not a.startswith('__') and a != '_as_dict' and not isinstance(getattr(self, a), types_module.MethodType)}
|
86 |
+
|
87 |
+
|
88 |
+
def Tracer(verbose=False):
|
89 |
+
return print if verbose else lambda x: None
|
90 |
+
|
91 |
+
|
92 |
+
def to_unicode(b):
|
93 |
+
''' Conver bytes to string and handle the errors. If argument is already in string - do nothing
|
94 |
+
'''
|
95 |
+
#return b if type(b) == unicode else unicode(b, 'utf-8', errors='replace')
|
96 |
+
return b if type(b) == str else str(b, 'utf-8', errors='backslashreplace')
|
97 |
+
|
98 |
+
|
99 |
+
def to_bytes(u):
|
100 |
+
''' Conver string to bytes and handle the errors. If argument is already of type bytes - do nothing
|
101 |
+
'''
|
102 |
+
return u if type(u) == bytes else u.encode('utf-8', errors='backslashreplace')
|
103 |
+
|
104 |
+
|
105 |
+
''' Python-2 version
|
106 |
+
def execute(message, commandline, return_=False, until_successes=False, terminate_on_failure=True, add_message_and_command_line_to_output=False):
|
107 |
+
message, commandline = to_unicode(message), to_unicode(commandline)
|
108 |
+
|
109 |
+
TR = Tracer()
|
110 |
+
TR(message); TR(commandline)
|
111 |
+
while True:
|
112 |
+
(res, output) = commands.getstatusoutput(commandline)
|
113 |
+
# Subprocess results will always be a bytes-string.
|
114 |
+
# Probably ASCII, but may have some Unicode characters.
|
115 |
+
# A UTF-8 decode will probably get decent results 99% of the time
|
116 |
+
# and the replace option will gracefully handle the rest.
|
117 |
+
output = to_unicode(output)
|
118 |
+
|
119 |
+
TR(output)
|
120 |
+
|
121 |
+
if res and until_successes: pass # Thats right - redability COUNT!
|
122 |
+
else: break
|
123 |
+
|
124 |
+
print( "Error while executing %s: %s\n" % (message, output) )
|
125 |
+
print( "Sleeping 60s... then I will retry..." )
|
126 |
+
time.sleep(60)
|
127 |
+
|
128 |
+
if add_message_and_command_line_to_output: output = message + '\nCommand line: ' + commandline + '\n' + output
|
129 |
+
|
130 |
+
if return_ == 'tuple': return(res, output)
|
131 |
+
|
132 |
+
if res and terminate_on_failure:
|
133 |
+
TR("\nEncounter error while executing: " + commandline)
|
134 |
+
if return_==True: return res
|
135 |
+
else:
|
136 |
+
print("\nEncounter error while executing: " + commandline + '\n' + output)
|
137 |
+
raise BenchmarkError("\nEncounter error while executing: " + commandline + '\n' + output)
|
138 |
+
|
139 |
+
if return_ == 'output': return output
|
140 |
+
else: return res
|
141 |
+
'''
|
142 |
+
|
143 |
+
def execute_through_subprocess(command_line):
|
144 |
+
# exit_code, output = subprocess.getstatusoutput(command_line)
|
145 |
+
|
146 |
+
# p = subprocess.Popen(command_line, bufsize=0, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
147 |
+
# output, errors = p.communicate()
|
148 |
+
# output = (output + errors).decode(encoding='utf-8', errors='backslashreplace')
|
149 |
+
# exit_code = p.returncode
|
150 |
+
|
151 |
+
# previous 'main' version based on subprocess module. Main issue that output of segfaults will not be captured since they generated by shell
|
152 |
+
p = subprocess.Popen(command_line, bufsize=0, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
153 |
+
output, errors = p.communicate()
|
154 |
+
# output = output + errors # ← we redirected stderr into same pipe as stdcout so errors is None, - no need to concatenate
|
155 |
+
output = output.decode(encoding='utf-8', errors='backslashreplace')
|
156 |
+
exit_code = p.returncode
|
157 |
+
|
158 |
+
return exit_code, output
|
159 |
+
|
160 |
+
|
161 |
+
def execute_through_pexpect(command_line):
|
162 |
+
import pexpect
|
163 |
+
|
164 |
+
child = pexpect.spawn('/bin/bash', ['-c', command_line])
|
165 |
+
child.expect(pexpect.EOF)
|
166 |
+
output = child.before.decode(encoding='utf-8', errors='backslashreplace')
|
167 |
+
child.close()
|
168 |
+
exit_code = child.signalstatus or child.exitstatus
|
169 |
+
|
170 |
+
return exit_code, output
|
171 |
+
|
172 |
+
|
173 |
+
def execute_through_pty(command_line):
|
174 |
+
import pty, select
|
175 |
+
|
176 |
+
if sys.platform == "darwin":
|
177 |
+
|
178 |
+
master, slave = pty.openpty()
|
179 |
+
p = subprocess.Popen(command_line, shell=True, stdout=slave, stdin=slave,
|
180 |
+
stderr=subprocess.STDOUT, close_fds=True)
|
181 |
+
|
182 |
+
buffer = []
|
183 |
+
while True:
|
184 |
+
try:
|
185 |
+
if select.select([master], [], [], 0.2)[0]: # has something to read
|
186 |
+
data = os.read(master, 1 << 22)
|
187 |
+
if data: buffer.append(data)
|
188 |
+
|
189 |
+
elif (p.poll() is not None) and (not select.select([master], [], [], 0.2)[0] ): break # process is finished and output buffer if fully read
|
190 |
+
|
191 |
+
except OSError: break # OSError will be raised when child process close PTY descriptior
|
192 |
+
|
193 |
+
output = b''.join(buffer).decode(encoding='utf-8', errors='backslashreplace')
|
194 |
+
|
195 |
+
os.close(master)
|
196 |
+
os.close(slave)
|
197 |
+
|
198 |
+
p.wait()
|
199 |
+
exit_code = p.returncode
|
200 |
+
|
201 |
+
'''
|
202 |
+
buffer = []
|
203 |
+
while True:
|
204 |
+
if select.select([master], [], [], 0.2)[0]: # has something to read
|
205 |
+
data = os.read(master, 1 << 22)
|
206 |
+
if data: buffer.append(data)
|
207 |
+
# else: break # # EOF - well, technically process _should_ be finished here...
|
208 |
+
|
209 |
+
# elif time.sleep(1) or (p.poll() is not None): # process is finished (sleep here is intentional to trigger race condition, see solution for this on the next few lines)
|
210 |
+
# assert not select.select([master], [], [], 0.2)[0] # should be nothing left to read...
|
211 |
+
# break
|
212 |
+
|
213 |
+
elif (p.poll() is not None) and (not select.select([master], [], [], 0.2)[0] ): break # process is finished and output buffer if fully read
|
214 |
+
|
215 |
+
assert not select.select([master], [], [], 0.2)[0] # should be nothing left to read...
|
216 |
+
|
217 |
+
os.close(slave)
|
218 |
+
os.close(master)
|
219 |
+
|
220 |
+
output = b''.join(buffer).decode(encoding='utf-8', errors='backslashreplace')
|
221 |
+
exit_code = p.returncode
|
222 |
+
'''
|
223 |
+
|
224 |
+
else:
|
225 |
+
|
226 |
+
master, slave = pty.openpty()
|
227 |
+
p = subprocess.Popen(command_line, shell=True, stdout=slave, stdin=slave,
|
228 |
+
stderr=subprocess.STDOUT, close_fds=True)
|
229 |
+
|
230 |
+
os.close(slave)
|
231 |
+
|
232 |
+
buffer = []
|
233 |
+
while True:
|
234 |
+
try:
|
235 |
+
data = os.read(master, 1 << 22)
|
236 |
+
if data: buffer.append(data)
|
237 |
+
except OSError: break # OSError will be raised when child process close PTY descriptior
|
238 |
+
|
239 |
+
output = b''.join(buffer).decode(encoding='utf-8', errors='backslashreplace')
|
240 |
+
|
241 |
+
os.close(master)
|
242 |
+
|
243 |
+
p.wait()
|
244 |
+
exit_code = p.returncode
|
245 |
+
|
246 |
+
return exit_code, output
|
247 |
+
|
248 |
+
|
249 |
+
|
250 |
+
def execute(message, command_line, return_='status', until_successes=False, terminate_on_failure=True, silent=False, silence_output=False, silence_output_on_errors=False, add_message_and_command_line_to_output=False):
|
251 |
+
if not silent: print(message); print(command_line); sys.stdout.flush();
|
252 |
+
while True:
|
253 |
+
|
254 |
+
#exit_code, output = execute_through_subprocess(command_line)
|
255 |
+
#exit_code, output = execute_through_pexpect(command_line)
|
256 |
+
exit_code, output = execute_through_pty(command_line)
|
257 |
+
|
258 |
+
if (exit_code and not silence_output_on_errors) or not (silent or silence_output): print(output); sys.stdout.flush();
|
259 |
+
|
260 |
+
if exit_code and until_successes: pass # Thats right - redability COUNT!
|
261 |
+
else: break
|
262 |
+
|
263 |
+
print( "Error while executing {}: {}\n".format(message, output) )
|
264 |
+
print("Sleeping 60s... then I will retry...")
|
265 |
+
sys.stdout.flush();
|
266 |
+
time.sleep(60)
|
267 |
+
|
268 |
+
if add_message_and_command_line_to_output: output = message + '\nCommand line: ' + command_line + '\n' + output
|
269 |
+
|
270 |
+
if return_ == 'tuple' or return_ == tuple: return(exit_code, output)
|
271 |
+
|
272 |
+
if exit_code and terminate_on_failure:
|
273 |
+
print("\nEncounter error while executing: " + command_line)
|
274 |
+
if return_==True: return True
|
275 |
+
else:
|
276 |
+
print('\nEncounter error while executing: ' + command_line + '\n' + output);
|
277 |
+
raise BenchmarkError('\nEncounter error while executing: ' + command_line + '\n' + output)
|
278 |
+
|
279 |
+
if return_ == 'output': return output
|
280 |
+
else: return exit_code
|
281 |
+
|
282 |
+
|
283 |
+
def parallel_execute(name, jobs, rosetta_dir, working_dir, cpu_count, time=16):
|
284 |
+
''' Execute command line in parallel on local host
|
285 |
+
time specifies the upper limit for cpu-usage runtime (in minutes) for any one process in the parallel execution.
|
286 |
+
|
287 |
+
jobs should be dict with following structure:
|
288 |
+
{
|
289 |
+
'job-string-id-1’: command_line-1,
|
290 |
+
'job-string-id-2’: command_line-2,
|
291 |
+
...
|
292 |
+
}
|
293 |
+
|
294 |
+
return: dict with jobs-id's as keys and value as dict with 'output' and 'result' keys:
|
295 |
+
{
|
296 |
+
"job-string-id-1": {
|
297 |
+
"output": "stdout + stdderr output of command_line-1",
|
298 |
+
"result": <integer exit code for command_line-1>
|
299 |
+
},
|
300 |
+
"c2": {
|
301 |
+
"output": "stdout + stdderr output of command_line-2",
|
302 |
+
"result": <integer exit code for command_line-2>
|
303 |
+
},
|
304 |
+
...
|
305 |
+
}
|
306 |
+
'''
|
307 |
+
job_file_name = working_dir + '/' + name
|
308 |
+
with open(job_file_name + '.json', 'w') as f: json.dump(jobs, f, sort_keys=True, indent=2) # JSON handles unicode internally
|
309 |
+
if time is not None:
|
310 |
+
allowed_time = int(time*60)
|
311 |
+
ulimit_command = f'ulimit -t {allowed_time} && '
|
312 |
+
else:
|
313 |
+
ulimit_command = ''
|
314 |
+
command = f'cd {working_dir} && ' + ulimit_command + f'{rosetta_dir}/tests/benchmark/util/parallel.py -j{cpu_count} {job_file_name}.json'
|
315 |
+
execute("Running {} in parallel with {} CPU's...".format(name, cpu_count), command )
|
316 |
+
|
317 |
+
with open(job_file_name+'.results.json') as f: return json.load(f)
|
318 |
+
|
319 |
+
|
320 |
+
def calculate_unique_prefix_path(platform, config):
|
321 |
+
''' calculate path for prefix location that is unique for this machine and OS
|
322 |
+
'''
|
323 |
+
hostname = os.uname()[1]
|
324 |
+
return config['prefix'] + '/' + hostname + '/' + platform['os']
|
325 |
+
|
326 |
+
|
327 |
+
def get_python_include_and_lib(python):
|
328 |
+
''' calculate python include dir and lib dir from given python executable path
|
329 |
+
'''
|
330 |
+
#python = os.path.realpath(python)
|
331 |
+
python_bin_dir = python.rpartition('/')[0]
|
332 |
+
python_config = f'{python} {python}-config' if python.endswith('2.7') else f'{python}-config'
|
333 |
+
|
334 |
+
#if not os.path.isfile(python_config): python_config = python_bin_dir + '/python-config'
|
335 |
+
|
336 |
+
info = execute('Getting python configuration info...', f'unset __PYVENV_LAUNCHER__ && cd {python_bin_dir} && PATH=.:$PATH && {python_config} --prefix --includes', return_='output').replace('\r', '').split('\n') # Python-3 only: --abiflags
|
337 |
+
python_prefix = info[0]
|
338 |
+
python_include_dir = info[1].split()[0][len('-I'):]
|
339 |
+
python_lib_dir = python_prefix + '/lib'
|
340 |
+
#python_abi_suffix = info[2]
|
341 |
+
#print(python_include_dir, python_lib_dir)
|
342 |
+
|
343 |
+
return NT(python_include_dir=python_include_dir, python_lib_dir=python_lib_dir)
|
344 |
+
|
345 |
+
|
346 |
+
def local_open_ssl_install(prefix, build_prefix, jobs):
|
347 |
+
''' install OpenSSL at given prefix, return url of source archive
|
348 |
+
'''
|
349 |
+
#with tempfile.TemporaryDirectory('open_ssl_build', dir=prefix) as build_prefix:
|
350 |
+
|
351 |
+
url = 'https://www.openssl.org/source/openssl-1.1.1b.tar.gz'
|
352 |
+
#url = 'https://www.openssl.org/source/openssl-3.0.0.tar.gz'
|
353 |
+
|
354 |
+
|
355 |
+
archive = build_prefix + '/' + url.split('/')[-1]
|
356 |
+
build_dir = archive.rpartition('.tar.gz')[0]
|
357 |
+
if os.path.isdir(build_dir): shutil.rmtree(build_dir)
|
358 |
+
|
359 |
+
with open(archive, 'wb') as f:
|
360 |
+
response = urllib.request.urlopen(url)
|
361 |
+
f.write( response.read() )
|
362 |
+
|
363 |
+
execute('Unpacking {}'.format(archive), 'cd {build_prefix} && tar -xvzf {archive}'.format(**vars()) )
|
364 |
+
|
365 |
+
execute('Configuring...', f'cd {build_dir} && ./config --prefix={prefix}')
|
366 |
+
execute('Building...', f'cd {build_dir} && make -j{jobs}')
|
367 |
+
execute('Installing...', f'cd {build_dir} && make -j{jobs} install')
|
368 |
+
|
369 |
+
return url
|
370 |
+
|
371 |
+
|
372 |
+
def remove_pip_and_easy_install(prefix_root_path):
|
373 |
+
''' remove `pip` and `easy_install` executable from given Python / virtual-environments install
|
374 |
+
'''
|
375 |
+
for f in os.listdir(prefix_root_path + '/bin'): # removing all pip's and easy_install's to make sure that environment is immutable
|
376 |
+
for p in ['pip', 'easy_install']:
|
377 |
+
if f.startswith(p): os.remove(prefix_root_path + '/bin/' + f)
|
378 |
+
|
379 |
+
|
380 |
+
|
381 |
+
def local_python_install(platform, config):
|
382 |
+
''' Perform local install of given Python version and return path-to-python-interpreter, python_include_dir, python_lib_dir
|
383 |
+
If previous install is detected skip installiation.
|
384 |
+
Provided Python install will _persistent_ and _immutable_
|
385 |
+
'''
|
386 |
+
jobs = config['cpu_count']
|
387 |
+
compiler, cpp_compiler = ('clang', 'clang++') if platform['os'] == 'mac' else ('gcc', 'g++') # disregarding platform compiler setting and instead use default compiler for platform
|
388 |
+
|
389 |
+
python_version = platform.get('python', DEFAULT_PYTHON_VERSION)
|
390 |
+
|
391 |
+
if python_version.endswith('.s'):
|
392 |
+
assert python_version == f'{sys.version_info.major}.{sys.version_info.minor}.s'
|
393 |
+
#root = executable.rpartition('/bin/python')[0]
|
394 |
+
h = hashlib.md5(); h.update( (sys.executable + sys.version).encode('utf-8', errors='backslashreplace') ); hash = h.hexdigest()
|
395 |
+
return NT(
|
396 |
+
python = sys.executable,
|
397 |
+
root = None,
|
398 |
+
python_include_dir = None,
|
399 |
+
python_lib_dir = None,
|
400 |
+
version = python_version,
|
401 |
+
url = None,
|
402 |
+
platform = platform,
|
403 |
+
config = config,
|
404 |
+
hash = hash,
|
405 |
+
)
|
406 |
+
|
407 |
+
# deprecated, no longer needed
|
408 |
+
# python_version = {'python2' : '2.7',
|
409 |
+
# 'python2.7' : '2.7',
|
410 |
+
# 'python3' : '3.5',
|
411 |
+
# }.get(python_version, python_version)
|
412 |
+
|
413 |
+
# for security reasons we only allow installs for version listed here with hand-coded URL's
|
414 |
+
python_sources = {
|
415 |
+
'2.7' : 'https://www.python.org/ftp/python/2.7.18/Python-2.7.18.tgz',
|
416 |
+
|
417 |
+
'3.5' : 'https://www.python.org/ftp/python/3.5.9/Python-3.5.9.tgz',
|
418 |
+
'3.6' : 'https://www.python.org/ftp/python/3.6.15/Python-3.6.15.tgz',
|
419 |
+
'3.7' : 'https://www.python.org/ftp/python/3.7.14/Python-3.7.14.tgz',
|
420 |
+
'3.8' : 'https://www.python.org/ftp/python/3.8.14/Python-3.8.14.tgz',
|
421 |
+
'3.9' : 'https://www.python.org/ftp/python/3.9.14/Python-3.9.14.tgz',
|
422 |
+
'3.10' : 'https://www.python.org/ftp/python/3.10.10/Python-3.10.10.tgz',
|
423 |
+
'3.11' : 'https://www.python.org/ftp/python/3.11.2/Python-3.11.2.tgz',
|
424 |
+
}
|
425 |
+
|
426 |
+
# map of env -> ('shell-code-before ./configure', 'extra-arguments-for-configure')
|
427 |
+
extras = {
|
428 |
+
#('mac',) : ('__PYVENV_LAUNCHER__="" MACOSX_DEPLOYMENT_TARGET={}'.format(platform_module.mac_ver()[0]), ''), # __PYVENV_LAUNCHER__ now used by-default for all platform installs
|
429 |
+
('mac',) : ('MACOSX_DEPLOYMENT_TARGET={}'.format(platform_module.mac_ver()[0]), ''),
|
430 |
+
('linux', '2.7') : ('', '--enable-unicode=ucs4'),
|
431 |
+
('ubuntu', '2.7') : ('', '--enable-unicode=ucs4'),
|
432 |
+
}
|
433 |
+
|
434 |
+
#packages = '' if (python_version[0] == '2' or python_version == '3.5' ) and platform['os'] == 'mac' else 'pip setuptools wheel' # 2.7 is now deprecated on Mac so some packages could not be installed
|
435 |
+
packages = 'setuptools'
|
436 |
+
|
437 |
+
url = python_sources[python_version]
|
438 |
+
|
439 |
+
extra = extras.get( (platform['os'],) , ('', '') )
|
440 |
+
extra = extras.get( (platform['os'], python_version) , extra)
|
441 |
+
|
442 |
+
extra = ('unset __PYVENV_LAUNCHER__ && ' + extra[0], extra[1])
|
443 |
+
|
444 |
+
options = '--with-ensurepip' #'--without-ensurepip'
|
445 |
+
signature = f'v1.5.1 url: {url}\noptions: {options}\ncompiler: {compiler}\nextra: {extra}\npackages: {packages}\n'
|
446 |
+
|
447 |
+
h = hashlib.md5(); h.update( signature.encode('utf-8', errors='backslashreplace') ); hash = h.hexdigest()
|
448 |
+
|
449 |
+
root = calculate_unique_prefix_path(platform, config) + '/python-' + python_version + '.' + compiler + '/' + hash
|
450 |
+
|
451 |
+
signature_file_name = root + '/.signature'
|
452 |
+
|
453 |
+
#activate = root + '/bin/activate'
|
454 |
+
executable = root + '/bin/python' + python_version
|
455 |
+
|
456 |
+
# if os.path.isfile(executable) and (not execute('Getting python configuration info...', '{executable}-config --prefix --includes'.format(**vars()), terminate_on_failure=False) ):
|
457 |
+
# print('found executable!')
|
458 |
+
# _, executable_version = execute('Checking Python interpreter version...', '{executable} --version'.format(**vars()), return_='tuple')
|
459 |
+
# executable_version = executable_version.split()[-1]
|
460 |
+
# else: executable_version = ''
|
461 |
+
# print('executable_version: {}'.format(executable_version))
|
462 |
+
#if executable_version != url.rpartition('Python-')[2][:-len('.tgz')]:
|
463 |
+
|
464 |
+
if os.path.isfile(signature_file_name) and open(signature_file_name).read() == signature:
|
465 |
+
#print('Install for Python-{} is detected, skipping installation procedure...'.format(python_version))
|
466 |
+
pass
|
467 |
+
|
468 |
+
else:
|
469 |
+
print( 'Installing Python-{python_version}, using {url} with extra:{extra}...'.format( **vars() ) )
|
470 |
+
|
471 |
+
if os.path.isdir(root): shutil.rmtree(root)
|
472 |
+
|
473 |
+
build_prefix = os.path.abspath(root + '/../build-python-{}'.format(python_version) )
|
474 |
+
|
475 |
+
if not os.path.isdir(root): os.makedirs(root)
|
476 |
+
if not os.path.isdir(build_prefix): os.makedirs(build_prefix)
|
477 |
+
|
478 |
+
platform_is_mac = True if platform['os'] in ['mac', 'm1'] else False
|
479 |
+
platform_is_linux = not platform_is_mac
|
480 |
+
|
481 |
+
#if False and platform['os'] == 'mac' and platform_module.machine() == 'arm64' and tuple( map(int, python_version.split('.') ) ) >= (3, 9):
|
482 |
+
if ( platform['os'] == 'mac' and python_version == '3.6' ) \
|
483 |
+
or ( platform_is_linux and python_version in ['3.10', '3.11'] ):
|
484 |
+
open_ssl_url = local_open_ssl_install(root, build_prefix, jobs)
|
485 |
+
options += f' --with-openssl={root} --with-openssl-rpath=auto'
|
486 |
+
#signature += 'OpenSSL install: ' + open_ssl_url + '\n'
|
487 |
+
|
488 |
+
archive = build_prefix + '/' + url.split('/')[-1]
|
489 |
+
build_dir = archive.rpartition('.tgz')[0]
|
490 |
+
if os.path.isdir(build_dir): shutil.rmtree(build_dir)
|
491 |
+
|
492 |
+
with open(archive, 'wb') as f:
|
493 |
+
#response = urllib2.urlopen(url)
|
494 |
+
response = urllib.request.urlopen(url)
|
495 |
+
f.write( response.read() )
|
496 |
+
|
497 |
+
#execute('Execution environment:', 'env'.format(**vars()) )
|
498 |
+
|
499 |
+
execute('Unpacking {}'.format(archive), 'cd {build_prefix} && tar -xvzf {archive}'.format(**vars()) )
|
500 |
+
|
501 |
+
#execute('Building and installing...', 'cd {} && CC={compiler} CXX={cpp_compiler} {extra[0]} ./configure {extra[1]} --prefix={root} && {extra[0]} make -j{jobs} && {extra[0]} make install'.format(build_dir, **locals()) )
|
502 |
+
execute('Configuring...', 'cd {} && CC={compiler} CXX={cpp_compiler} {extra[0]} ./configure {options} {extra[1]} --prefix={root}'.format(build_dir, **locals()) )
|
503 |
+
execute('Building...', 'cd {} && {extra[0]} make -j{jobs}'.format(build_dir, **locals()) )
|
504 |
+
execute('Installing...', 'cd {} && {extra[0]} make -j{jobs} install'.format(build_dir, **locals()) )
|
505 |
+
|
506 |
+
shutil.rmtree(build_prefix)
|
507 |
+
|
508 |
+
#execute('Updating setuptools...', f'cd {root} && {root}/bin/pip{python_version} install --upgrade setuptools wheel' )
|
509 |
+
|
510 |
+
# if 'certifi' not in packages:
|
511 |
+
# packages += ' certifi'
|
512 |
+
|
513 |
+
if packages: execute( f'Installing packages {packages}...', f'cd {root} && unset __PYVENV_LAUNCHER__ && {root}/bin/pip{python_version} install --upgrade {packages}' )
|
514 |
+
#if packages: execute( f'Installing packages {packages}...', f'cd {root} && unset __PYVENV_LAUNCHER__ && {executable} -m pip install --upgrade {packages}' )
|
515 |
+
|
516 |
+
remove_pip_and_easy_install(root) # removing all pip's and easy_install's to make sure that environment is immutable
|
517 |
+
|
518 |
+
with open(signature_file_name, 'w') as f: f.write(signature)
|
519 |
+
|
520 |
+
print( 'Installing Python-{python_version}, using {url} with extra:{extra}... Done.'.format( **vars() ) )
|
521 |
+
|
522 |
+
il = get_python_include_and_lib(executable)
|
523 |
+
|
524 |
+
return NT(
|
525 |
+
python = executable,
|
526 |
+
root = root,
|
527 |
+
python_include_dir = il.python_include_dir,
|
528 |
+
python_lib_dir = il.python_lib_dir,
|
529 |
+
version = python_version,
|
530 |
+
url = url,
|
531 |
+
platform = platform,
|
532 |
+
config = config,
|
533 |
+
hash = hash,
|
534 |
+
)
|
535 |
+
|
536 |
+
|
537 |
+
|
538 |
+
def setup_python_virtual_environment(working_dir, python_environment, packages=''):
|
539 |
+
''' Deploy Python virtual environment at working_dir
|
540 |
+
'''
|
541 |
+
|
542 |
+
python = python_environment.python
|
543 |
+
|
544 |
+
execute('Setting up Python virtual environment...', 'unset __PYVENV_LAUNCHER__ && {python} -m venv --clear {working_dir}'.format(**vars()) )
|
545 |
+
|
546 |
+
activate = f'unset __PYVENV_LAUNCHER__ && . {working_dir}/bin/activate'
|
547 |
+
|
548 |
+
bin=working_dir+'/bin'
|
549 |
+
|
550 |
+
if packages: execute('Installing packages: {}...'.format(packages), 'unset __PYVENV_LAUNCHER__ && {bin}/python {bin}/pip install --upgrade pip setuptools && {bin}/python {bin}/pip install --progress-bar off {packages}'.format(**vars()) )
|
551 |
+
#if packages: execute('Installing packages: {}...'.format(packages), '{bin}/pip{python_environment.version} install {packages}'.format(**vars()) )
|
552 |
+
|
553 |
+
return NT(activate = activate, python = bin + '/python', root = working_dir, bin = bin)
|
554 |
+
|
555 |
+
|
556 |
+
|
557 |
+
def setup_persistent_python_virtual_environment(python_environment, packages):
|
558 |
+
''' Setup _persistent_ and _immutable_ Python virtual environment which will be saved between test runs
|
559 |
+
'''
|
560 |
+
|
561 |
+
if python_environment.version.startswith('2.'):
|
562 |
+
assert not packages, f'ERROR: setup_persistent_python_virtual_environment does not support Python-2.* with non-empty package list!'
|
563 |
+
return NT(activate = ':', python = python_environment.python, root = python_environment.root, bin = python_environment.root + '/bin')
|
564 |
+
|
565 |
+
else:
|
566 |
+
#if 'certifi' not in packages: packages += ' certifi'
|
567 |
+
|
568 |
+
h = hashlib.md5()
|
569 |
+
h.update(f'v1.0.0 platform: {python_environment.platform} python_source_url: {python_environment.url} python-hash: {python_environment.hash} packages: {packages}'.encode('utf-8', errors='backslashreplace') )
|
570 |
+
hash = h.hexdigest()
|
571 |
+
|
572 |
+
prefix = calculate_unique_prefix_path(python_environment.platform, python_environment.config)
|
573 |
+
|
574 |
+
root = os.path.abspath( prefix + '/python_virtual_environments/' + '/python-' + python_environment.version + '/' + hash )
|
575 |
+
signature_file_name = root + '/.signature'
|
576 |
+
signature = f'setup_persistent_python_virtual_environment v1.0.0\npython: {python_environment.hash}\npackages: {packages}\n'
|
577 |
+
|
578 |
+
activate = f'unset __PYVENV_LAUNCHER__ && . {root}/bin/activate'
|
579 |
+
bin = f'{root}/bin'
|
580 |
+
|
581 |
+
if os.path.isfile(signature_file_name) and open(signature_file_name).read() == signature: pass
|
582 |
+
else:
|
583 |
+
if os.path.isdir(root): shutil.rmtree(root)
|
584 |
+
setup_python_virtual_environment(root, python_environment, packages=packages)
|
585 |
+
remove_pip_and_easy_install(root) # removing all pip's and easy_install's to make sure that environment is immutable
|
586 |
+
with open(signature_file_name, 'w') as f: f.write(signature)
|
587 |
+
|
588 |
+
return NT(activate = activate, python = bin + '/python', root = root, bin = bin, hash = hash)
|
589 |
+
|
590 |
+
|
591 |
+
|
592 |
+
def _get_path_to_conda_root(platform, config):
|
593 |
+
''' Perform local (prefix) install of miniconda and return NT(activate, conda_root_dir, conda)
|
594 |
+
this function is for inner use only, - to setup custom conda environment inside your test use `setup_conda_virtual_environment` defined below
|
595 |
+
'''
|
596 |
+
miniconda_sources = {
|
597 |
+
'mac' : 'https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh',
|
598 |
+
'linux' : 'https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh',
|
599 |
+
'aarch64': 'https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-aarch64.sh',
|
600 |
+
'ubuntu' : 'https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh',
|
601 |
+
'm1' : 'https://repo.anaconda.com/miniconda/Miniconda3-py38_4.10.1-MacOSX-arm64.sh',
|
602 |
+
}
|
603 |
+
|
604 |
+
conda_sources = {
|
605 |
+
'mac' : 'https://repo.continuum.io/archive/Anaconda3-2018.12-MacOSX-x86_64.sh',
|
606 |
+
'linux' : 'https://repo.continuum.io/archive/Anaconda3-2018.12-Linux-x86_64.sh',
|
607 |
+
'ubuntu' : 'https://repo.continuum.io/archive/Anaconda3-2018.12-Linux-x86_64.sh',
|
608 |
+
}
|
609 |
+
|
610 |
+
#platform_os = 'm1' if platform_module.machine() == 'arm64' else platform['os']
|
611 |
+
#url = miniconda_sources[ platform_os ]
|
612 |
+
|
613 |
+
platform_os = platform['os']
|
614 |
+
for o in 'alpine centos ubuntu'.split():
|
615 |
+
if platform_os.startswith(o): platform_os = 'linux'
|
616 |
+
|
617 |
+
url = miniconda_sources[platform_os]
|
618 |
+
|
619 |
+
version = '1'
|
620 |
+
channels = '' # conda-forge
|
621 |
+
|
622 |
+
#packages = ['conda-build gcc libgcc', 'libgcc=5.2.0'] # libgcc installs is workaround for "Anaconda libstdc++.so.6: version `GLIBCXX_3.4.20' not found", see: https://stackoverflow.com/questions/48453497/anaconda-libstdc-so-6-version-glibcxx-3-4-20-not-found
|
623 |
+
#packages = ['conda-build gcc'] # libgcc installs is workaround for "Anaconda libstdc++.so.6: version `GLIBCXX_3.4.20' not found", see: https://stackoverflow.com/questions/48453497/anaconda-libstdc-so-6-version-glibcxx-3-4-20-not-found
|
624 |
+
packages = ['conda-build anaconda-client conda-verify',]
|
625 |
+
|
626 |
+
signature = f'url: {url}\nversion: {version}\channels: {channels}\npackages: {packages}\n'
|
627 |
+
|
628 |
+
root = calculate_unique_prefix_path(platform, config) + '/conda'
|
629 |
+
|
630 |
+
signature_file_name = root + '/.signature'
|
631 |
+
|
632 |
+
# presense of __PYVENV_LAUNCHER__,PYTHONHOME, PYTHONPATH sometimes confuse Python so we have to unset them
|
633 |
+
unset = 'unset __PYVENV_LAUNCHER__ && unset PYTHONHOME && unset PYTHONPATH'
|
634 |
+
activate = unset + ' && . ' + root + '/bin/activate'
|
635 |
+
|
636 |
+
executable = root + '/bin/conda'
|
637 |
+
|
638 |
+
|
639 |
+
if os.path.isfile(signature_file_name) and open(signature_file_name).read() == signature:
|
640 |
+
print( f'Install for MiniConda is detected, skipping installation procedure...' )
|
641 |
+
|
642 |
+
else:
|
643 |
+
print( f'Installing MiniConda, using {url}...' )
|
644 |
+
|
645 |
+
if os.path.isdir(root): shutil.rmtree(root)
|
646 |
+
|
647 |
+
build_prefix = os.path.abspath(root + f'/../build-conda' )
|
648 |
+
|
649 |
+
#if not os.path.isdir(root): os.makedirs(root)
|
650 |
+
if not os.path.isdir(build_prefix): os.makedirs(build_prefix)
|
651 |
+
|
652 |
+
archive = build_prefix + '/' + url.split('/')[-1]
|
653 |
+
|
654 |
+
with open(archive, 'wb') as f:
|
655 |
+
response = urllib.request.urlopen(url)
|
656 |
+
f.write( response.read() )
|
657 |
+
|
658 |
+
execute('Installing conda...', f'cd {build_prefix} && {unset} && bash {archive} -b -p {root}' )
|
659 |
+
|
660 |
+
# conda update --yes --quiet -n base -c defaults conda
|
661 |
+
|
662 |
+
if channels: execute(f'Adding extra channles {channels}...', f'cd {build_prefix} && {activate} && conda config --add channels {channels}' )
|
663 |
+
|
664 |
+
for p in packages: execute(f'Installing conda packages: {p}...', f'cd {build_prefix} && {activate} && conda install --quiet --yes {p}' )
|
665 |
+
|
666 |
+
shutil.rmtree(build_prefix)
|
667 |
+
|
668 |
+
with open(signature_file_name, 'w') as f: f.write(signature)
|
669 |
+
|
670 |
+
print( f'Installing MiniConda, using {url}... Done.' )
|
671 |
+
|
672 |
+
execute(f'Updating conda base...', f'{activate} && conda update --all --yes' )
|
673 |
+
return NT(conda=executable, root=root, activate=activate, url=url)
|
674 |
+
|
675 |
+
|
676 |
+
|
677 |
+
def setup_conda_virtual_environment(working_dir, platform, config, packages=''):
|
678 |
+
''' Deploy Conda virtual environment at working_dir
|
679 |
+
'''
|
680 |
+
conda_root_env = _get_path_to_conda_root(platform, config)
|
681 |
+
activate = conda_root_env.activate
|
682 |
+
|
683 |
+
python_version = platform.get('python', DEFAULT_PYTHON_VERSION)
|
684 |
+
|
685 |
+
prefix = os.path.abspath( working_dir + '/.conda-python-' + python_version )
|
686 |
+
|
687 |
+
command_line = f'conda create --quiet --yes --prefix {prefix} python={python_version}'
|
688 |
+
|
689 |
+
execute( f'Setting up Conda for Python-{python_version} virtual environment...', f'cd {working_dir} && {activate} && ( {command_line} || ( conda clean --yes && {command_line} ) )' )
|
690 |
+
|
691 |
+
activate = f'{activate} && conda activate {prefix}'
|
692 |
+
|
693 |
+
if packages: execute( f'Setting up extra packages {packages}...', f'cd {working_dir} && {activate} && conda install --quiet --yes {packages}' )
|
694 |
+
|
695 |
+
python = prefix + '/bin/python' + python_version
|
696 |
+
|
697 |
+
il = get_python_include_and_lib(python)
|
698 |
+
|
699 |
+
return NT(
|
700 |
+
activate = activate,
|
701 |
+
root = prefix,
|
702 |
+
python = python,
|
703 |
+
python_include_dir = il.python_include_dir,
|
704 |
+
python_lib_dir = il.python_lib_dir,
|
705 |
+
version = python_version,
|
706 |
+
activate_base = conda_root_env.activate,
|
707 |
+
url = prefix, # conda_root_env.url,
|
708 |
+
platform=platform,
|
709 |
+
config=config,
|
710 |
+
)
|
711 |
+
|
712 |
+
|
713 |
+
|
714 |
+
class FileLock():
|
715 |
+
''' Implementation of file-lock object that could be use with Python `with` statement
|
716 |
+
'''
|
717 |
+
|
718 |
+
def __init__(self, file_name):
|
719 |
+
self.locked = False
|
720 |
+
self.file_name = file_name
|
721 |
+
|
722 |
+
|
723 |
+
def __enter__(self):
|
724 |
+
if not self.locked: self.acquire()
|
725 |
+
return self
|
726 |
+
|
727 |
+
|
728 |
+
def __exit__(self, exc_type, exc_value, traceback):
|
729 |
+
if self.locked: self.release()
|
730 |
+
|
731 |
+
|
732 |
+
def __del__(self):
|
733 |
+
self.release()
|
734 |
+
|
735 |
+
|
736 |
+
def acquire(self):
|
737 |
+
while True:
|
738 |
+
try:
|
739 |
+
os.close( os.open(self.file_name, os.O_CREAT | os.O_EXCL, mode=0o600) )
|
740 |
+
self.locked = True
|
741 |
+
break
|
742 |
+
|
743 |
+
except FileExistsError as e:
|
744 |
+
time.sleep(60)
|
745 |
+
|
746 |
+
|
747 |
+
def release(self):
|
748 |
+
if self.locked:
|
749 |
+
os.remove(self.file_name)
|
750 |
+
self.locked = False
|
751 |
+
|
752 |
+
|
753 |
+
|
754 |
+
def convert_submodule_urls_from_ssh_to_https(repository_root):
|
755 |
+
''' switching submodules URL to HTTPS so we can clone without SSH key
|
756 |
+
'''
|
757 |
+
with open(f'{repository_root}/.gitmodules') as f: m = f.read()
|
758 |
+
with open(f'{repository_root}/.gitmodules', 'w') as f:
|
759 |
+
f.write(
|
760 |
+
m
|
761 |
+
.replace('url = git@github.com:', 'url = https://github.com/')
|
762 |
+
.replace('url = ../../../', 'url = https://github.com/RosettaCommons/')
|
763 |
+
.replace('url = ../../', 'url = https://github.com/RosettaCommons/')
|
764 |
+
.replace('url = ../', 'url = https://github.com/RosettaCommons/')
|
765 |
+
)
|
RFdiffusion/.rosetta-ci/tests/rfd.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# :noTabs=true:
|
4 |
+
|
5 |
+
# (c) Copyright Rosetta Commons Member Institutions.
|
6 |
+
# (c) This file is part of the Rosetta software suite and is made available under license.
|
7 |
+
# (c) The Rosetta software is developed by the contributing members of the Rosetta Commons.
|
8 |
+
# (c) For more information, see http://www.rosettacommons.org. Questions about this can be
|
9 |
+
# (c) addressed to University of Washington CoMotion, email: license@uw.edu.
|
10 |
+
|
11 |
+
## @file rfd.py
|
12 |
+
## @brief main test files for RFdiffusion
|
13 |
+
## @author Sergey Lyskov
|
14 |
+
|
15 |
+
|
16 |
+
import imp
|
17 |
+
imp.load_source(__name__, '/'.join(__file__.split('/')[:-1]) + '/__init__.py') # A bit of Python magic here, what we trying to say is this: from __init__ import *, but init is calculated from file location
|
18 |
+
|
19 |
+
_api_version_ = '1.0'
|
20 |
+
|
21 |
+
import os, tempfile, shutil
|
22 |
+
import urllib.request
|
23 |
+
|
24 |
+
|
25 |
+
_models_urls_ = '''
|
26 |
+
http://files.ipd.uw.edu/pub/RFdiffusion/6f5902ac237024bdd0c176cb93063dc4/Base_ckpt.pt
|
27 |
+
http://files.ipd.uw.edu/pub/RFdiffusion/e29311f6f1bf1af907f9ef9f44b8328b/Complex_base_ckpt.pt
|
28 |
+
http://files.ipd.uw.edu/pub/RFdiffusion/60f09a193fb5e5ccdc4980417708dbab/Complex_Fold_base_ckpt.pt
|
29 |
+
http://files.ipd.uw.edu/pub/RFdiffusion/74f51cfb8b440f50d70878e05361d8f0/InpaintSeq_ckpt.pt
|
30 |
+
http://files.ipd.uw.edu/pub/RFdiffusion/76d00716416567174cdb7ca96e208296/InpaintSeq_Fold_ckpt.pt
|
31 |
+
http://files.ipd.uw.edu/pub/RFdiffusion/5532d2e1f3a4738decd58b19d633b3c3/ActiveSite_ckpt.pt
|
32 |
+
http://files.ipd.uw.edu/pub/RFdiffusion/12fc204edeae5b57713c5ad7dcb97d39/Base_epoch8_ckpt.pt
|
33 |
+
http://files.ipd.uw.edu/pub/RFdiffusion/f572d396fae9206628714fb2ce00f72e/Complex_beta_ckpt.pt
|
34 |
+
http://files.ipd.uw.edu/pub/RFdiffusion/1befcb9b28e2f778f53d47f18b7597fa/RF_structure_prediction_weights.pt
|
35 |
+
'''.split()
|
36 |
+
|
37 |
+
|
38 |
+
def run_main_test_suite(repository_root, working_dir, platform, config, debug):
|
39 |
+
full_log = ''
|
40 |
+
|
41 |
+
python_environment = local_python_install(platform, config)
|
42 |
+
|
43 |
+
models_dir = repository_root + '/models'
|
44 |
+
if not os.path.isdir(models_dir): os.makedirs(models_dir)
|
45 |
+
|
46 |
+
for url in _models_urls_:
|
47 |
+
file_name = models_dir + '/' + url.split('/')[-1]
|
48 |
+
tmp_file_name = file_name + '.tmp'
|
49 |
+
if not os.path.isfile(file_name):
|
50 |
+
print(f'downloading {url}...')
|
51 |
+
full_log += f'downloading {url}...\n'
|
52 |
+
urllib.request.urlretrieve(url, tmp_file_name)
|
53 |
+
os.rename(tmp_file_name, file_name)
|
54 |
+
|
55 |
+
execute('unpacking ppi scaffolds...', f'cd {repository_root} && tar -xvf examples/ppi_scaffolds_subset.tar.gz -C examples')
|
56 |
+
|
57 |
+
with tempfile.TemporaryDirectory(dir=working_dir) as tmpdirname:
|
58 |
+
# tmpdirname = working_dir+'/.ve'
|
59 |
+
# if True:
|
60 |
+
|
61 |
+
#ve = setup_persistent_python_virtual_environment(python_environment, packages='numpy torch omegaconf scipy opt_einsum dgl')
|
62 |
+
#ve = setup_python_virtual_environment(working_dir+'/.ve', python_environment, packages='numpy torch omegaconf scipy opt_einsum dgl e3nn icecream pyrsistent wandb pynvml decorator jedi hydra-core')
|
63 |
+
ve = setup_python_virtual_environment(tmpdirname, python_environment, packages='numpy torch omegaconf scipy opt_einsum dgl e3nn icecream pyrsistent wandb pynvml decorator jedi hydra-core')
|
64 |
+
|
65 |
+
execute('Installing local se3-transformer package...', f'cd {repository_root}/env/SE3Transformer && {ve.bin}/pip3 install --editable .')
|
66 |
+
execute('Installing RFdiffusion package...', f'cd {repository_root} && {ve.bin}/pip3 install --editable .')
|
67 |
+
|
68 |
+
#res, output = execute('running unit tests...', f'{ve.activate} && cd {repository_root} && python -m unittest', return_='tuple', add_message_and_command_line_to_output=True)
|
69 |
+
#res, output = execute('running unit tests...', f'cd {repository_root} && {ve.bin}/pytest', return_='tuple')
|
70 |
+
|
71 |
+
|
72 |
+
results_file = f'{repository_root}/tests/.results.json'
|
73 |
+
if os.path.isfile(results_file): os.remove(results_file)
|
74 |
+
|
75 |
+
res, output = execute('running RFdiffusion tests...', f'{ve.activate} && cd {repository_root}/tests && python test_diffusion.py', return_='tuple', add_message_and_command_line_to_output=True)
|
76 |
+
|
77 |
+
if os.path.isfile(results_file):
|
78 |
+
with open(results_file) as f: sub_tests_reults = json.load(f)
|
79 |
+
|
80 |
+
state = _S_passed_
|
81 |
+
for r in sub_tests_reults.values():
|
82 |
+
if r[_StateKey_] == _S_failed_:
|
83 |
+
state = _S_failed_
|
84 |
+
break
|
85 |
+
|
86 |
+
else:
|
87 |
+
sub_tests_reults = {}
|
88 |
+
output += '\n\nEmpty sub-test results, marking test as `failed`...'
|
89 |
+
state = _S_failed_
|
90 |
+
|
91 |
+
shutil.move(f'{repository_root}/tests/outputs', f'{working_dir}/outputs')
|
92 |
+
|
93 |
+
for d in os.listdir(f'{repository_root}/tests'):
|
94 |
+
p = f'{repository_root}/tests/{d}'
|
95 |
+
if d.startswith('tests_') and os.path.isdir(p): shutil.rmtree(p)
|
96 |
+
|
97 |
+
results = {
|
98 |
+
_StateKey_ : state,
|
99 |
+
_LogKey_ : full_log + '\n' + output,
|
100 |
+
_ResultsKey_ : {
|
101 |
+
_TestsKey_ : sub_tests_reults,
|
102 |
+
},
|
103 |
+
}
|
104 |
+
|
105 |
+
return results
|
106 |
+
|
107 |
+
|
108 |
+
|
109 |
+
def run(test, repository_root, working_dir, platform, config, hpc_driver=None, verbose=False, debug=False):
|
110 |
+
if test == '': return run_main_test_suite(repository_root=repository_root, working_dir=working_dir, platform=platform, config=config, debug=debug)
|
111 |
+
else: raise BenchmarkError('Unknow scripts test: {}!'.format(test))
|
RFdiffusion/.rosetta-ci/tests/self.md
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# self test suite
|
2 |
+
These tests are design to help debug interface between testing server and Rosetta testing scripts
|
3 |
+
|
4 |
+
-----
|
5 |
+
### python
|
6 |
+
Test Python platform support and functionality of local and persistent Python virtual environments
|
RFdiffusion/.rosetta-ci/tests/self.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# :noTabs=true:
|
4 |
+
|
5 |
+
# (c) Copyright Rosetta Commons Member Institutions.
|
6 |
+
# (c) This file is part of the Rosetta software suite and is made available under license.
|
7 |
+
# (c) The Rosetta software is developed by the contributing members of the Rosetta Commons.
|
8 |
+
# (c) For more information, see http://www.rosettacommons.org. Questions about this can be
|
9 |
+
# (c) addressed to University of Washington CoMotion, email: license@uw.edu.
|
10 |
+
|
11 |
+
## @file dummy.py
|
12 |
+
## @brief self-test and debug-aids tests
|
13 |
+
## @author Sergey Lyskov
|
14 |
+
|
15 |
+
import os, os.path, shutil, re, string
|
16 |
+
import json
|
17 |
+
|
18 |
+
import random
|
19 |
+
|
20 |
+
import imp
|
21 |
+
imp.load_source(__name__, '/'.join(__file__.split('/')[:-1]) + '/__init__.py') # A bit of Python magic here, what we trying to say is this: from __init__ import *, but init is calculated from file location
|
22 |
+
|
23 |
+
_api_version_ = '1.0'
|
24 |
+
|
25 |
+
|
26 |
+
def run_state_test(repository_root, working_dir, platform, config):
|
27 |
+
revision_id = config['revision']
|
28 |
+
states = (_S_passed_, _S_failed_, _S_build_failed_, _S_script_failed_)
|
29 |
+
state = states[revision_id % len(states)]
|
30 |
+
|
31 |
+
return {_StateKey_ : state, _ResultsKey_ : {}, _LogKey_ : f'run_state_test: setting test state to {state!r}...' }
|
32 |
+
|
33 |
+
|
34 |
+
sub_test_description_template = '''\
|
35 |
+
# subtests_test test suite
|
36 |
+
These sub-test description is generated for 3/4 of sub-tests
|
37 |
+
|
38 |
+
-----
|
39 |
+
### {name}
|
40 |
+
The warm time, had already disappeared like dust. Broken rain, fragment of light shadow, bring more pain to my heart...
|
41 |
+
-----
|
42 |
+
'''
|
43 |
+
|
44 |
+
def run_subtests_test(repository_root, working_dir, platform, config):
|
45 |
+
tests = {}
|
46 |
+
for i in range(16):
|
47 |
+
name = f's-{i:02}'
|
48 |
+
log = ('x'*63 + '\n') * 16 * 256 * i
|
49 |
+
s = i % 3
|
50 |
+
if s == 0: state = _S_passed_
|
51 |
+
elif s == 1: state = _S_failed_
|
52 |
+
else: state = _S_script_failed_
|
53 |
+
|
54 |
+
if i % 4:
|
55 |
+
os.mkdir( f'{working_dir}/{name}' )
|
56 |
+
with open(f'{working_dir}/{name}/description.md', 'w') as f: f.write( sub_test_description_template.format(**vars()) )
|
57 |
+
|
58 |
+
with open( f'{working_dir}/{name}/fantome.txt', 'w') as f: f.write('No one wants to hear the sequel to a fairytale\n')
|
59 |
+
|
60 |
+
tests[name] = { _StateKey_ : state, _LogKey_ : log, }
|
61 |
+
|
62 |
+
test_log = ('*'*63 + '\n') * 16 * 1024 * 16
|
63 |
+
return {_StateKey_ : _S_failed_, _ResultsKey_ : {_TestsKey_: tests}, _LogKey_ : test_log }
|
64 |
+
|
65 |
+
|
66 |
+
def run_regression_test(repository_root, working_dir, platform, config):
|
67 |
+
const = 'const'
|
68 |
+
volatile = 'volatile'
|
69 |
+
new = ''.join( random.sample( string.ascii_letters + string.digits, 8) )
|
70 |
+
oversized = 'oversized'
|
71 |
+
|
72 |
+
sub_tests = [const, volatile, new]
|
73 |
+
|
74 |
+
const_dir = working_dir + '/' + const
|
75 |
+
os.mkdir(const_dir)
|
76 |
+
with open(const_dir + '/const_data', 'w') as f: f.write( '\n'.join( (str(i) for i in range(32) ) ) )
|
77 |
+
|
78 |
+
volatile_dir = working_dir + '/' + volatile
|
79 |
+
os.mkdir(volatile_dir)
|
80 |
+
with open(volatile_dir + '/const_data', 'w') as f: f.write( '\n'.join( (str(i) for i in range(32, 64) ) ) )
|
81 |
+
with open(volatile_dir + '/volatile_data', 'w') as f: f.write( '\n'.join( ( ''.join(random.sample( string.ascii_letters + string.digits, 8) ) for i in range(32) ) ) )
|
82 |
+
|
83 |
+
new_dir = working_dir + '/' + new
|
84 |
+
os.mkdir(new_dir)
|
85 |
+
with open(new_dir + '/data', 'w') as f: f.write( '\n'.join( (str(i) for i in range(64)) ) )
|
86 |
+
|
87 |
+
|
88 |
+
new_dir = working_dir + '/' + oversized
|
89 |
+
os.mkdir(new_dir)
|
90 |
+
with open(new_dir + '/large', 'w') as f: f.write( ('x'*63 + '\n')*16*1024*256 +'extra')
|
91 |
+
|
92 |
+
return {_StateKey_ : _S_queued_for_comparison_, _ResultsKey_ : {}, _LogKey_ : f'sub-tests: {sub_tests!r}' }
|
93 |
+
|
94 |
+
|
95 |
+
|
96 |
+
def run_release_test(repository_root, working_dir, platform, config):
|
97 |
+
release_root = config['mounts'].get('release_root')
|
98 |
+
|
99 |
+
branch = config['branch']
|
100 |
+
revision = config['revision']
|
101 |
+
|
102 |
+
assert release_root, "config['release_root'] must be set!"
|
103 |
+
|
104 |
+
release_path = f'{release_root}/dummy'
|
105 |
+
|
106 |
+
if not os.path.isdir(release_path): os.makedirs(release_path)
|
107 |
+
|
108 |
+
with open(f'{release_path}/{branch}-{revision}.txt', 'w') as f: f.write('dummy release file\n')
|
109 |
+
|
110 |
+
return {_StateKey_ : _S_passed_, _ResultsKey_ : {}, _LogKey_ : f'Config release root set to: {release_root}'}
|
111 |
+
|
112 |
+
|
113 |
+
|
114 |
+
def run_python_test(repository_root, working_dir, platform, config):
|
115 |
+
|
116 |
+
import zlib, ssl
|
117 |
+
|
118 |
+
python_environment = local_python_install(platform, config)
|
119 |
+
|
120 |
+
if platform['python'][0] == '2': pass
|
121 |
+
else:
|
122 |
+
|
123 |
+
if platform['os'] == 'mac' and int( platform['python'].split('.')[1] ) > 6 :
|
124 |
+
# SSL certificate test
|
125 |
+
import urllib.request; urllib.request.urlopen('https://benchmark.graylab.jhu.edu')
|
126 |
+
|
127 |
+
ves = [
|
128 |
+
setup_persistent_python_virtual_environment(python_environment, packages='colr dice xdice pdp11games'),
|
129 |
+
setup_python_virtual_environment(working_dir, python_environment, packages='colr dice xdice pdp11games'),
|
130 |
+
]
|
131 |
+
|
132 |
+
for ve in ves:
|
133 |
+
commands = [
|
134 |
+
'import colr, dice, xdice, pdp11games',
|
135 |
+
]
|
136 |
+
|
137 |
+
if platform['os'] == 'mac' and int( platform['python'].split('.')[1] ) > 6 :
|
138 |
+
# SSL certificate test
|
139 |
+
commands.append('import urllib.request; urllib.request.urlopen("https://benchmark.graylab.jhu.edu/queue")')
|
140 |
+
|
141 |
+
for command in commands:
|
142 |
+
execute('Testing local Python virtual enviroment...', f"{ve.activate} && {ve.python} -c '{command}'")
|
143 |
+
execute('Testing local Python virtual enviroment...', f"{ve.activate} && python -c '{command}'")
|
144 |
+
|
145 |
+
|
146 |
+
|
147 |
+
return {_StateKey_ : _S_passed_, _ResultsKey_ : {}, _LogKey_ : f'Done!'}
|
148 |
+
|
149 |
+
|
150 |
+
|
151 |
+
def compare(test, results, files_path, previous_results, previous_files_path):
|
152 |
+
"""
|
153 |
+
Compare the results of two tests run (new vs. previous) for regression test
|
154 |
+
Take two dict and two paths
|
155 |
+
Must return standard dict with results
|
156 |
+
|
157 |
+
:param test: str
|
158 |
+
:param results: dict
|
159 |
+
:param files_path: str
|
160 |
+
:param previous_results: dict
|
161 |
+
:param previous_files_path: str
|
162 |
+
:rtype: dict
|
163 |
+
"""
|
164 |
+
ignore_files = []
|
165 |
+
|
166 |
+
results = dict(tests={}, summary=dict(total=0, failed=0, failed_tests=[])) # , config={}
|
167 |
+
|
168 |
+
if previous_files_path:
|
169 |
+
for test in os.listdir(files_path):
|
170 |
+
if os.path.isdir(files_path + '/' + test):
|
171 |
+
exclude = ''.join([' --exclude="{}"'.format(f) for f in ignore_files] ) + ' --exclude="*.ignore"'
|
172 |
+
res, brief_diff = execute('Comparing {}...'.format(test), 'diff -rq {exclude} {0}/{test} {1}/{test}'.format(previous_files_path, files_path, test=test, exclude=exclude), return_='tuple')
|
173 |
+
res, full_diff = execute('Comparing {}...'.format(test), 'diff -r {exclude} {0}/{test} {1}/{test}'.format(previous_files_path, files_path, test=test, exclude=exclude), return_='tuple')
|
174 |
+
diff = 'Brief Diff:\n' + brief_diff + ( ('\n\nFull Diff:\n' + full_diff[:1024*1024*1]) if full_diff != brief_diff else '' )
|
175 |
+
|
176 |
+
state = _S_failed_ if res else _S_passed_
|
177 |
+
results['tests'][test] = {_StateKey_: state, _LogKey_: diff if state != _S_passed_ else ''}
|
178 |
+
|
179 |
+
results['summary']['total'] += 1
|
180 |
+
if res: results['summary']['failed'] += 1; results['summary']['failed_tests'].append(test)
|
181 |
+
|
182 |
+
else: # no previous tests case, returning 'passed' for all sub_tests
|
183 |
+
for test in os.listdir(files_path):
|
184 |
+
if os.path.isdir(files_path + '/' + test):
|
185 |
+
results['tests'][test] = {_StateKey_: _S_passed_, _LogKey_: 'First run, no previous results available. Skipping comparison...\n'}
|
186 |
+
results['summary']['total'] += 1
|
187 |
+
|
188 |
+
for test in os.listdir(files_path):
|
189 |
+
if os.path.isdir(files_path + '/' + test):
|
190 |
+
if os.path.isfile(files_path+'/'+test+'/.test_did_not_run.log') or os.path.isfile(files_path+'/'+test+'/.test_got_timeout_kill.log'):
|
191 |
+
results['tests'][test][_StateKey_] = _S_script_failed_
|
192 |
+
results['tests'][test][_LogKey_] += '\nCompare(...): Marking as "Script failed" due to presense of .test_did_not_run.log or .test_got_timeout_kill.log file!\n'
|
193 |
+
if test not in results['summary']['failed_tests']:
|
194 |
+
results['summary']['failed'] += 1
|
195 |
+
results['summary']['failed_tests'].append(test)
|
196 |
+
|
197 |
+
state = _S_failed_ if results['summary']['failed'] else _S_passed_
|
198 |
+
|
199 |
+
return {_StateKey_: state, _LogKey_: 'Comparison dummy log...', _ResultsKey_: results}
|
200 |
+
|
201 |
+
|
202 |
+
def run(test, repository_root, working_dir, platform, config, hpc_driver=None, verbose=False, debug=False):
|
203 |
+
if test == 'state': return run_state_test (repository_root, working_dir, platform, config)
|
204 |
+
elif test == 'regression': return run_regression_test (repository_root, working_dir, platform, config)
|
205 |
+
elif test == 'subtests': return run_subtests_test (repository_root, working_dir, platform, config)
|
206 |
+
elif test == 'release': return run_release_test (repository_root, working_dir, platform, config)
|
207 |
+
elif test == 'python': return run_python_test (repository_root, working_dir, platform, config)
|
208 |
+
|
209 |
+
else: raise BenchmarkError(f'Dummy test script does not support run with test={test!r}!')
|
RFdiffusion/END
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"retCode":"100",
|
3 |
+
"retData":null,
|
4 |
+
"retMsg":"操作成功",
|
5 |
+
"retTime":"2022-11-05 22:20:09",
|
6 |
+
"success":true
|
7 |
+
}
|
RFdiffusion/LICENSE
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
BSD License
|
2 |
+
|
3 |
+
Copyright (c) 2023 University of Washington. Developed at the Institute for
|
4 |
+
Protein Design by Joseph Watson, David Juergens, Nathaniel Bennett, Brian Trippe
|
5 |
+
and Jason Yim
|
6 |
+
|
7 |
+
Redistribution and use in source and binary forms, with or without
|
8 |
+
modification, are permitted provided that the following conditions are met:
|
9 |
+
|
10 |
+
Redistributions of source code must retain the above copyright notice, this
|
11 |
+
list of conditions and the following disclaimer.
|
12 |
+
|
13 |
+
Redistributions in binary form must reproduce the above copyright notice, this
|
14 |
+
list of conditions and the following disclaimer in the documentation and/or
|
15 |
+
other materials provided with the distribution.
|
16 |
+
|
17 |
+
Neither the name of the University of Washington nor the names of its
|
18 |
+
contributors may be used to endorse or promote products derived from this
|
19 |
+
software without specific prior written permission.
|
20 |
+
|
21 |
+
THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY OF WASHINGTON AND CONTRIBUTORS “AS
|
22 |
+
IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
23 |
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
24 |
+
DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY OF WASHINGTON OR CONTRIBUTORS BE
|
25 |
+
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
26 |
+
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
|
27 |
+
GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
|
28 |
+
HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
|
29 |
+
LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT
|
30 |
+
OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
RFdiffusion/appverifUI.dll
ADDED
Binary file (112 kB). View file
|
|
RFdiffusion/config/inference/base.yaml
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Base inference Configuration.
|
2 |
+
|
3 |
+
inference:
|
4 |
+
input_pdb: null
|
5 |
+
num_designs: 10
|
6 |
+
design_startnum: 0
|
7 |
+
ckpt_override_path: null
|
8 |
+
symmetry: null
|
9 |
+
recenter: True
|
10 |
+
radius: 10.0
|
11 |
+
model_only_neighbors: False
|
12 |
+
output_prefix: samples/design
|
13 |
+
write_trajectory: True
|
14 |
+
scaffold_guided: False
|
15 |
+
model_runner: SelfConditioning
|
16 |
+
cautious: True
|
17 |
+
align_motif: True
|
18 |
+
symmetric_self_cond: True
|
19 |
+
final_step: 1
|
20 |
+
deterministic: False
|
21 |
+
trb_save_ckpt_path: null
|
22 |
+
schedule_directory_path: null
|
23 |
+
model_directory_path: null
|
24 |
+
|
25 |
+
contigmap:
|
26 |
+
contigs: null
|
27 |
+
inpaint_seq: null
|
28 |
+
provide_seq: null
|
29 |
+
length: null
|
30 |
+
|
31 |
+
model:
|
32 |
+
n_extra_block: 4
|
33 |
+
n_main_block: 32
|
34 |
+
n_ref_block: 4
|
35 |
+
d_msa: 256
|
36 |
+
d_msa_full: 64
|
37 |
+
d_pair: 128
|
38 |
+
d_templ: 64
|
39 |
+
n_head_msa: 8
|
40 |
+
n_head_pair: 4
|
41 |
+
n_head_templ: 4
|
42 |
+
d_hidden: 32
|
43 |
+
d_hidden_templ: 32
|
44 |
+
p_drop: 0.15
|
45 |
+
SE3_param_full:
|
46 |
+
num_layers: 1
|
47 |
+
num_channels: 32
|
48 |
+
num_degrees: 2
|
49 |
+
n_heads: 4
|
50 |
+
div: 4
|
51 |
+
l0_in_features: 8
|
52 |
+
l0_out_features: 8
|
53 |
+
l1_in_features: 3
|
54 |
+
l1_out_features: 2
|
55 |
+
num_edge_features: 32
|
56 |
+
SE3_param_topk:
|
57 |
+
num_layers: 1
|
58 |
+
num_channels: 32
|
59 |
+
num_degrees: 2
|
60 |
+
n_heads: 4
|
61 |
+
div: 4
|
62 |
+
l0_in_features: 64
|
63 |
+
l0_out_features: 64
|
64 |
+
l1_in_features: 3
|
65 |
+
l1_out_features: 2
|
66 |
+
num_edge_features: 64
|
67 |
+
freeze_track_motif: False
|
68 |
+
use_motif_timestep: False
|
69 |
+
|
70 |
+
diffuser:
|
71 |
+
T: 50
|
72 |
+
b_0: 1e-2
|
73 |
+
b_T: 7e-2
|
74 |
+
schedule_type: linear
|
75 |
+
so3_type: igso3
|
76 |
+
crd_scale: 0.25
|
77 |
+
partial_T: null
|
78 |
+
so3_schedule_type: linear
|
79 |
+
min_b: 1.5
|
80 |
+
max_b: 2.5
|
81 |
+
min_sigma: 0.02
|
82 |
+
max_sigma: 1.5
|
83 |
+
|
84 |
+
denoiser:
|
85 |
+
noise_scale_ca: 1
|
86 |
+
final_noise_scale_ca: 1
|
87 |
+
ca_noise_schedule_type: constant
|
88 |
+
noise_scale_frame: 1
|
89 |
+
final_noise_scale_frame: 1
|
90 |
+
frame_noise_schedule_type: constant
|
91 |
+
|
92 |
+
ppi:
|
93 |
+
hotspot_res: null
|
94 |
+
|
95 |
+
potentials:
|
96 |
+
guiding_potentials: null
|
97 |
+
guide_scale: 10
|
98 |
+
guide_decay: constant
|
99 |
+
olig_inter_all : null
|
100 |
+
olig_intra_all : null
|
101 |
+
olig_custom_contact : null
|
102 |
+
substrate: null
|
103 |
+
|
104 |
+
contig_settings:
|
105 |
+
ref_idx: null
|
106 |
+
hal_idx: null
|
107 |
+
idx_rf: null
|
108 |
+
inpaint_seq_tensor: null
|
109 |
+
|
110 |
+
preprocess:
|
111 |
+
sidechain_input: False
|
112 |
+
motif_sidechain_input: True
|
113 |
+
d_t1d: 22
|
114 |
+
d_t2d: 44
|
115 |
+
prob_self_cond: 0.0
|
116 |
+
str_self_cond: False
|
117 |
+
predict_previous: False
|
118 |
+
|
119 |
+
logging:
|
120 |
+
inputs: False
|
121 |
+
|
122 |
+
scaffoldguided:
|
123 |
+
scaffoldguided: False
|
124 |
+
target_pdb: False
|
125 |
+
target_path: null
|
126 |
+
scaffold_list: null
|
127 |
+
scaffold_dir: null
|
128 |
+
sampled_insertion: 0
|
129 |
+
sampled_N: 0
|
130 |
+
sampled_C: 0
|
131 |
+
ss_mask: 0
|
132 |
+
systematic: False
|
133 |
+
target_ss: null
|
134 |
+
target_adj: null
|
135 |
+
mask_loops: True
|
136 |
+
contig_crop: null
|
RFdiffusion/config/inference/symmetry.yaml
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Config for sampling symmetric assemblies.
|
2 |
+
|
3 |
+
defaults:
|
4 |
+
- base
|
5 |
+
|
6 |
+
inference:
|
7 |
+
# Symmetry to sample
|
8 |
+
# Available symmetries:
|
9 |
+
# - Cyclic symmetry (C_n) # call as c5
|
10 |
+
# - Dihedral symmetry (D_n) # call as d5
|
11 |
+
# - Tetrahedral symmetry # call as tetrahedral
|
12 |
+
# - Octahedral symmetry # call as octahedral
|
13 |
+
# - Icosahedral symmetry # call as icosahedral
|
14 |
+
symmetry: c2
|
15 |
+
|
16 |
+
# Set to true for computational efficiency
|
17 |
+
# to avoid memory overhead of modeling all subunits.
|
18 |
+
model_only_neighbors: False
|
19 |
+
|
20 |
+
# Output directory of samples.
|
21 |
+
output_prefix: samples/c2
|
22 |
+
|
23 |
+
contigmap:
|
24 |
+
# Specify a single integer value to sample unconditionally.
|
25 |
+
# Must be evenly divisible by the number of chains in the symmetry.
|
26 |
+
contigs: ['100']
|
RFdiffusion/docker/Dockerfile
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Usage:
|
2 |
+
# git clone https://github.com/RosettaCommons/RFdiffusion.git
|
3 |
+
# cd RFdiffusion
|
4 |
+
# docker build -f docker/Dockerfile -t rfdiffusion .
|
5 |
+
# mkdir $HOME/inputs $HOME/outputs $HOME/models
|
6 |
+
# bash scripts/download_models.sh $HOME/models
|
7 |
+
# wget -P $HOME/inputs https://files.rcsb.org/view/5TPN.pdb
|
8 |
+
|
9 |
+
# docker run -it --rm --gpus all \
|
10 |
+
# -v $HOME/models:$HOME/models \
|
11 |
+
# -v $HOME/inputs:$HOME/inputs \
|
12 |
+
# -v $HOME/outputs:$HOME/outputs \
|
13 |
+
# rfdiffusion \
|
14 |
+
# inference.output_prefix=$HOME/outputs/motifscaffolding \
|
15 |
+
# inference.model_directory_path=$HOME/models \
|
16 |
+
# inference.input_pdb=$HOME/inputs/5TPN.pdb \
|
17 |
+
# inference.num_designs=3 \
|
18 |
+
# 'contigmap.contigs=[10-40/A163-181/10-40]'
|
19 |
+
|
20 |
+
FROM nvcr.io/nvidia/cuda:11.6.2-cudnn8-runtime-ubuntu20.04
|
21 |
+
|
22 |
+
COPY . /app/RFdiffusion/
|
23 |
+
|
24 |
+
RUN apt-get -q update \
|
25 |
+
&& DEBIAN_FRONTEND=noninteractive apt-get install --no-install-recommends -y \
|
26 |
+
git \
|
27 |
+
python3.9 \
|
28 |
+
python3-pip \
|
29 |
+
&& python3.9 -m pip install -q -U --no-cache-dir pip \
|
30 |
+
&& rm -rf /var/lib/apt/lists/* \
|
31 |
+
&& apt-get autoremove -y \
|
32 |
+
&& apt-get clean \
|
33 |
+
&& pip install -q --no-cache-dir \
|
34 |
+
dgl==1.0.2+cu116 -f https://data.dgl.ai/wheels/cu116/repo.html \
|
35 |
+
torch==1.12.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 \
|
36 |
+
e3nn==0.3.3 \
|
37 |
+
wandb==0.12.0 \
|
38 |
+
pynvml==11.0.0 \
|
39 |
+
git+https://github.com/NVIDIA/dllogger#egg=dllogger \
|
40 |
+
decorator==5.1.0 \
|
41 |
+
hydra-core==1.3.2 \
|
42 |
+
pyrsistent==0.19.3 \
|
43 |
+
/app/RFdiffusion/env/SE3Transformer \
|
44 |
+
&& pip install --no-cache-dir /app/RFdiffusion --no-deps
|
45 |
+
|
46 |
+
WORKDIR /app/RFdiffusion
|
47 |
+
|
48 |
+
ENV DGLBACKEND="pytorch"
|
49 |
+
|
50 |
+
ENTRYPOINT ["python3.9", "scripts/run_inference.py"]
|
RFdiffusion/env/SE3Transformer/.dockerignore
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.Trash-0
|
2 |
+
.git
|
3 |
+
data/
|
4 |
+
.DS_Store
|
5 |
+
*wandb/
|
6 |
+
*.pt
|
7 |
+
*.swp
|
8 |
+
|
9 |
+
# added by FAFU
|
10 |
+
.idea/
|
11 |
+
cache/
|
12 |
+
downloaded/
|
13 |
+
*.lprof
|
14 |
+
|
15 |
+
# Byte-compiled / optimized / DLL files
|
16 |
+
__pycache__/
|
17 |
+
*.py[cod]
|
18 |
+
*$py.class
|
19 |
+
|
20 |
+
# C extensions
|
21 |
+
*.so
|
22 |
+
|
23 |
+
# Distribution / packaging
|
24 |
+
.Python
|
25 |
+
build/
|
26 |
+
develop-eggs/
|
27 |
+
dist/
|
28 |
+
downloads/
|
29 |
+
eggs/
|
30 |
+
.eggs/
|
31 |
+
lib/
|
32 |
+
lib64/
|
33 |
+
parts/
|
34 |
+
sdist/
|
35 |
+
var/
|
36 |
+
wheels/
|
37 |
+
*.egg-info/
|
38 |
+
.installed.cfg
|
39 |
+
*.egg
|
40 |
+
MANIFEST
|
41 |
+
|
42 |
+
# PyInstaller
|
43 |
+
# Usually these files are written by a python script from a template
|
44 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
45 |
+
*.manifest
|
46 |
+
*.spec
|
47 |
+
|
48 |
+
# Installer logs
|
49 |
+
pip-log.txt
|
50 |
+
pip-delete-this-directory.txt
|
51 |
+
|
52 |
+
# Unit test / coverage reports
|
53 |
+
htmlcov/
|
54 |
+
.tox/
|
55 |
+
.coverage
|
56 |
+
.coverage.*
|
57 |
+
.cache
|
58 |
+
nosetests.xml
|
59 |
+
coverage.xml
|
60 |
+
*.cover
|
61 |
+
.hypothesis/
|
62 |
+
.pytest_cache/
|
63 |
+
|
64 |
+
# Translations
|
65 |
+
*.mo
|
66 |
+
*.pot
|
67 |
+
|
68 |
+
# Django stuff:
|
69 |
+
*.log
|
70 |
+
local_settings.py
|
71 |
+
db.sqlite3
|
72 |
+
|
73 |
+
# Flask stuff:
|
74 |
+
instance/
|
75 |
+
.webassets-cache
|
76 |
+
|
77 |
+
# Scrapy stuff:
|
78 |
+
.scrapy
|
79 |
+
|
80 |
+
# Sphinx documentation
|
81 |
+
docs/_build/
|
82 |
+
|
83 |
+
# PyBuilder
|
84 |
+
target/
|
85 |
+
|
86 |
+
# Jupyter Notebook
|
87 |
+
.ipynb_checkpoints
|
88 |
+
|
89 |
+
# pyenv
|
90 |
+
.python-version
|
91 |
+
|
92 |
+
# celery beat schedule file
|
93 |
+
celerybeat-schedule
|
94 |
+
|
95 |
+
# SageMath parsed files
|
96 |
+
*.sage.py
|
97 |
+
|
98 |
+
# Environments
|
99 |
+
.env
|
100 |
+
.venv
|
101 |
+
env/
|
102 |
+
venv/
|
103 |
+
ENV/
|
104 |
+
env.bak/
|
105 |
+
venv.bak/
|
106 |
+
|
107 |
+
# Spyder project settings
|
108 |
+
.spyderproject
|
109 |
+
.spyproject
|
110 |
+
|
111 |
+
# Rope project settings
|
112 |
+
.ropeproject
|
113 |
+
|
114 |
+
# mkdocs documentation
|
115 |
+
/site
|
116 |
+
|
117 |
+
# mypy
|
118 |
+
.mypy_cache/
|
119 |
+
|
120 |
+
**/benchmark
|
121 |
+
**/results
|
122 |
+
*.pkl
|
123 |
+
*.log
|
RFdiffusion/env/SE3Transformer/.gitignore
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data/
|
2 |
+
.DS_Store
|
3 |
+
*wandb/
|
4 |
+
*.pt
|
5 |
+
*.swp
|
6 |
+
|
7 |
+
# added by FAFU
|
8 |
+
.idea/
|
9 |
+
cache/
|
10 |
+
downloaded/
|
11 |
+
*.lprof
|
12 |
+
|
13 |
+
# Byte-compiled / optimized / DLL files
|
14 |
+
__pycache__/
|
15 |
+
*.py[cod]
|
16 |
+
*$py.class
|
17 |
+
|
18 |
+
# C extensions
|
19 |
+
*.so
|
20 |
+
|
21 |
+
# Distribution / packaging
|
22 |
+
.Python
|
23 |
+
build/
|
24 |
+
develop-eggs/
|
25 |
+
dist/
|
26 |
+
downloads/
|
27 |
+
eggs/
|
28 |
+
.eggs/
|
29 |
+
lib/
|
30 |
+
lib64/
|
31 |
+
parts/
|
32 |
+
sdist/
|
33 |
+
var/
|
34 |
+
wheels/
|
35 |
+
*.egg-info/
|
36 |
+
.installed.cfg
|
37 |
+
*.egg
|
38 |
+
MANIFEST
|
39 |
+
|
40 |
+
# PyInstaller
|
41 |
+
# Usually these files are written by a python script from a template
|
42 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
43 |
+
*.manifest
|
44 |
+
*.spec
|
45 |
+
|
46 |
+
# Installer logs
|
47 |
+
pip-log.txt
|
48 |
+
pip-delete-this-directory.txt
|
49 |
+
|
50 |
+
# Unit test / coverage reports
|
51 |
+
htmlcov/
|
52 |
+
.tox/
|
53 |
+
.coverage
|
54 |
+
.coverage.*
|
55 |
+
.cache
|
56 |
+
nosetests.xml
|
57 |
+
coverage.xml
|
58 |
+
*.cover
|
59 |
+
.hypothesis/
|
60 |
+
.pytest_cache/
|
61 |
+
|
62 |
+
# Translations
|
63 |
+
*.mo
|
64 |
+
*.pot
|
65 |
+
|
66 |
+
# Django stuff:
|
67 |
+
*.log
|
68 |
+
local_settings.py
|
69 |
+
db.sqlite3
|
70 |
+
|
71 |
+
# Flask stuff:
|
72 |
+
instance/
|
73 |
+
.webassets-cache
|
74 |
+
|
75 |
+
# Scrapy stuff:
|
76 |
+
.scrapy
|
77 |
+
|
78 |
+
# Sphinx documentation
|
79 |
+
docs/_build/
|
80 |
+
|
81 |
+
# PyBuilder
|
82 |
+
target/
|
83 |
+
|
84 |
+
# Jupyter Notebook
|
85 |
+
.ipynb_checkpoints
|
86 |
+
|
87 |
+
# pyenv
|
88 |
+
.python-version
|
89 |
+
|
90 |
+
# celery beat schedule file
|
91 |
+
celerybeat-schedule
|
92 |
+
|
93 |
+
# SageMath parsed files
|
94 |
+
*.sage.py
|
95 |
+
|
96 |
+
# Environments
|
97 |
+
.env
|
98 |
+
.venv
|
99 |
+
env/
|
100 |
+
venv/
|
101 |
+
ENV/
|
102 |
+
env.bak/
|
103 |
+
venv.bak/
|
104 |
+
|
105 |
+
# Spyder project settings
|
106 |
+
.spyderproject
|
107 |
+
.spyproject
|
108 |
+
|
109 |
+
# Rope project settings
|
110 |
+
.ropeproject
|
111 |
+
|
112 |
+
# mkdocs documentation
|
113 |
+
/site
|
114 |
+
|
115 |
+
# mypy
|
116 |
+
.mypy_cache/
|
117 |
+
|
118 |
+
**/benchmark
|
119 |
+
**/results
|
120 |
+
*.pkl
|
121 |
+
*.log
|
RFdiffusion/env/SE3Transformer/Dockerfile
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
4 |
+
# copy of this software and associated documentation files (the "Software"),
|
5 |
+
# to deal in the Software without restriction, including without limitation
|
6 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
8 |
+
# Software is furnished to do so, subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in
|
11 |
+
# all copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
16 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
18 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
19 |
+
# DEALINGS IN THE SOFTWARE.
|
20 |
+
#
|
21 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
22 |
+
# SPDX-License-Identifier: MIT
|
23 |
+
|
24 |
+
# run docker daemon with --default-runtime=nvidia for GPU detection during build
|
25 |
+
# multistage build for DGL with CUDA and FP16
|
26 |
+
|
27 |
+
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:21.07-py3
|
28 |
+
|
29 |
+
FROM ${FROM_IMAGE_NAME} AS dgl_builder
|
30 |
+
|
31 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
32 |
+
RUN apt-get update \
|
33 |
+
&& apt-get install -y git build-essential python3-dev make cmake \
|
34 |
+
&& rm -rf /var/lib/apt/lists/*
|
35 |
+
WORKDIR /dgl
|
36 |
+
RUN git clone --branch v0.7.0 --recurse-submodules --depth 1 https://github.com/dmlc/dgl.git .
|
37 |
+
RUN sed -i 's/"35 50 60 70"/"60 70 80"/g' cmake/modules/CUDA.cmake
|
38 |
+
WORKDIR build
|
39 |
+
RUN cmake -DUSE_CUDA=ON -DUSE_FP16=ON ..
|
40 |
+
RUN make -j8
|
41 |
+
|
42 |
+
|
43 |
+
FROM ${FROM_IMAGE_NAME}
|
44 |
+
|
45 |
+
RUN rm -rf /workspace/*
|
46 |
+
WORKDIR /workspace/se3-transformer
|
47 |
+
|
48 |
+
# copy built DGL and install it
|
49 |
+
COPY --from=dgl_builder /dgl ./dgl
|
50 |
+
RUN cd dgl/python && python setup.py install && cd ../.. && rm -rf dgl
|
51 |
+
|
52 |
+
ADD requirements.txt .
|
53 |
+
RUN pip install --no-cache-dir --upgrade --pre pip
|
54 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
55 |
+
ADD . .
|
56 |
+
|
57 |
+
ENV DGLBACKEND=pytorch
|
58 |
+
ENV OMP_NUM_THREADS=1
|
RFdiffusion/env/SE3Transformer/LICENSE
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright 2021 NVIDIA CORPORATION & AFFILIATES
|
2 |
+
|
3 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
4 |
+
|
5 |
+
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
6 |
+
|
7 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
RFdiffusion/env/SE3Transformer/NOTICE
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
SE(3)-Transformer PyTorch
|
2 |
+
|
3 |
+
This repository includes software from https://github.com/FabianFuchsML/se3-transformer-public
|
4 |
+
licensed under the MIT License.
|
5 |
+
|
6 |
+
This repository includes software from https://github.com/lucidrains/se3-transformer-pytorch
|
7 |
+
licensed under the MIT License.
|
RFdiffusion/env/SE3Transformer/README.md
ADDED
@@ -0,0 +1,580 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SE(3)-Transformers For PyTorch
|
2 |
+
|
3 |
+
This repository provides a script and recipe to train the SE(3)-Transformer model to achieve state-of-the-art accuracy. The content of this repository is tested and maintained by NVIDIA.
|
4 |
+
|
5 |
+
## Table Of Contents
|
6 |
+
- [Model overview](#model-overview)
|
7 |
+
* [Model architecture](#model-architecture)
|
8 |
+
* [Default configuration](#default-configuration)
|
9 |
+
* [Feature support matrix](#feature-support-matrix)
|
10 |
+
* [Features](#features)
|
11 |
+
* [Mixed precision training](#mixed-precision-training)
|
12 |
+
* [Enabling mixed precision](#enabling-mixed-precision)
|
13 |
+
* [Enabling TF32](#enabling-tf32)
|
14 |
+
* [Glossary](#glossary)
|
15 |
+
- [Setup](#setup)
|
16 |
+
* [Requirements](#requirements)
|
17 |
+
- [Quick Start Guide](#quick-start-guide)
|
18 |
+
- [Advanced](#advanced)
|
19 |
+
* [Scripts and sample code](#scripts-and-sample-code)
|
20 |
+
* [Parameters](#parameters)
|
21 |
+
* [Command-line options](#command-line-options)
|
22 |
+
* [Getting the data](#getting-the-data)
|
23 |
+
* [Dataset guidelines](#dataset-guidelines)
|
24 |
+
* [Multi-dataset](#multi-dataset)
|
25 |
+
* [Training process](#training-process)
|
26 |
+
* [Inference process](#inference-process)
|
27 |
+
- [Performance](#performance)
|
28 |
+
* [Benchmarking](#benchmarking)
|
29 |
+
* [Training performance benchmark](#training-performance-benchmark)
|
30 |
+
* [Inference performance benchmark](#inference-performance-benchmark)
|
31 |
+
* [Results](#results)
|
32 |
+
* [Training accuracy results](#training-accuracy-results)
|
33 |
+
* [Training accuracy: NVIDIA DGX A100 (8x A100 80GB)](#training-accuracy-nvidia-dgx-a100-8x-a100-80gb)
|
34 |
+
* [Training accuracy: NVIDIA DGX-1 (8x V100 16GB)](#training-accuracy-nvidia-dgx-1-8x-v100-16gb)
|
35 |
+
* [Training stability test](#training-stability-test)
|
36 |
+
* [Training performance results](#training-performance-results)
|
37 |
+
* [Training performance: NVIDIA DGX A100 (8x A100 80GB)](#training-performance-nvidia-dgx-a100-8x-a100-80gb)
|
38 |
+
* [Training performance: NVIDIA DGX-1 (8x V100 16GB)](#training-performance-nvidia-dgx-1-8x-v100-16gb)
|
39 |
+
* [Inference performance results](#inference-performance-results)
|
40 |
+
* [Inference performance: NVIDIA DGX A100 (1x A100 80GB)](#inference-performance-nvidia-dgx-a100-1x-a100-80gb)
|
41 |
+
* [Inference performance: NVIDIA DGX-1 (1x V100 16GB)](#inference-performance-nvidia-dgx-1-1x-v100-16gb)
|
42 |
+
- [Release notes](#release-notes)
|
43 |
+
* [Changelog](#changelog)
|
44 |
+
* [Known issues](#known-issues)
|
45 |
+
|
46 |
+
|
47 |
+
|
48 |
+
## Model overview
|
49 |
+
|
50 |
+
|
51 |
+
The **SE(3)-Transformer** is a Graph Neural Network using a variant of [self-attention](https://arxiv.org/abs/1706.03762v5) for 3D points and graphs processing.
|
52 |
+
This model is [equivariant](https://en.wikipedia.org/wiki/Equivariant_map) under [continuous 3D roto-translations](https://en.wikipedia.org/wiki/Euclidean_group), meaning that when the inputs (graphs or sets of points) rotate in 3D space (or more generally experience a [proper rigid transformation](https://en.wikipedia.org/wiki/Rigid_transformation)), the model outputs either stay invariant or transform with the input.
|
53 |
+
A mathematical guarantee of equivariance is important to ensure stable and predictable performance in the presence of nuisance transformations of the data input and when the problem has some inherent symmetries we want to exploit.
|
54 |
+
|
55 |
+
|
56 |
+
The model is based on the following publications:
|
57 |
+
- [SE(3)-Transformers: 3D Roto-Translation Equivariant Attention Networks](https://arxiv.org/abs/2006.10503) (NeurIPS 2020) by Fabian B. Fuchs, Daniel E. Worrall, et al.
|
58 |
+
- [Tensor field networks: Rotation- and translation-equivariant neural networks for 3D point clouds](https://arxiv.org/abs/1802.08219) by Nathaniel Thomas, Tess Smidt, et al.
|
59 |
+
|
60 |
+
A follow-up paper explains how this model can be used iteratively, for example, to predict or refine protein structures:
|
61 |
+
|
62 |
+
- [Iterative SE(3)-Transformers](https://arxiv.org/abs/2102.13419) by Fabian B. Fuchs, Daniel E. Worrall, et al.
|
63 |
+
|
64 |
+
Just like [the official implementation](https://github.com/FabianFuchsML/se3-transformer-public), this implementation uses [PyTorch](https://pytorch.org/) and the [Deep Graph Library (DGL)](https://www.dgl.ai/).
|
65 |
+
|
66 |
+
The main differences between this implementation of SE(3)-Transformers and the official one are the following:
|
67 |
+
|
68 |
+
- Training and inference support for multiple GPUs
|
69 |
+
- Training and inference support for [Mixed Precision](https://arxiv.org/abs/1710.03740)
|
70 |
+
- The [QM9 dataset from DGL](https://docs.dgl.ai/en/latest/api/python/dgl.data.html#qm9edge-dataset) is used and automatically downloaded
|
71 |
+
- Significantly increased throughput
|
72 |
+
- Significantly reduced memory consumption
|
73 |
+
- The use of layer normalization in the fully connected radial profile layers is an option (`--use_layer_norm`), off by default
|
74 |
+
- The use of equivariant normalization between attention layers is an option (`--norm`), off by default
|
75 |
+
- The [spherical harmonics](https://en.wikipedia.org/wiki/Spherical_harmonic) and [Clebsch–Gordan coefficients](https://en.wikipedia.org/wiki/Clebsch%E2%80%93Gordan_coefficients), used to compute bases matrices, are computed with the [e3nn library](https://e3nn.org/)
|
76 |
+
|
77 |
+
|
78 |
+
|
79 |
+
This model enables you to predict quantum chemical properties of small organic molecules in the [QM9 dataset](https://www.nature.com/articles/sdata201422).
|
80 |
+
In this case, the exploited symmetry is that these properties do not depend on the orientation or position of the molecules in space.
|
81 |
+
|
82 |
+
|
83 |
+
This model is trained with mixed precision using Tensor Cores on NVIDIA Volta, NVIDIA Turing, and the NVIDIA Ampere GPU architectures. Therefore, researchers can get results up to 1.5x faster than training without Tensor Cores while experiencing the benefits of mixed precision training. This model is tested against each NGC monthly container release to ensure consistent accuracy and performance over time.
|
84 |
+
|
85 |
+
### Model architecture
|
86 |
+
|
87 |
+
The model consists of stacked layers of equivariant graph self-attention and equivariant normalization.
|
88 |
+
Lastly, a Tensor Field Network convolution is applied to obtain invariant features. Graph pooling (mean or max over the nodes) is applied to these features, and the result is fed to a final MLP to get scalar predictions.
|
89 |
+
|
90 |
+
In this setup, the model is a graph-to-scalar network. The pooling can be removed to obtain a graph-to-graph network, and the final TFN can be modified to output features of any type (invariant scalars, 3D vectors, ...).
|
91 |
+
|
92 |
+
|
93 |
+
![Model high-level architecture](./images/se3-transformer.png)
|
94 |
+
|
95 |
+
|
96 |
+
### Default configuration
|
97 |
+
|
98 |
+
|
99 |
+
SE(3)-Transformers introduce a self-attention layer for graphs that is equivariant to 3D roto-translations. It achieves this by leveraging Tensor Field Networks to build attention weights that are invariant and attention values that are equivariant.
|
100 |
+
Combining the equivariant values with the invariant weights gives rise to an equivariant output. This output is normalized while preserving equivariance thanks to equivariant normalization layers operating on feature norms.
|
101 |
+
|
102 |
+
|
103 |
+
The following features were implemented in this model:
|
104 |
+
|
105 |
+
- Support for edge features of any degree (1D, 3D, 5D, ...), whereas the official implementation only supports scalar invariant edge features (degree 0). Edge features with a degree greater than one are
|
106 |
+
concatenated to node features of the same degree. This is required in order to reproduce published results on point cloud processing.
|
107 |
+
- Data-parallel multi-GPU training (DDP)
|
108 |
+
- Mixed precision training (autocast, gradient scaling)
|
109 |
+
- Gradient accumulation
|
110 |
+
- Model checkpointing
|
111 |
+
|
112 |
+
|
113 |
+
The following performance optimizations were implemented in this model:
|
114 |
+
|
115 |
+
|
116 |
+
**General optimizations**
|
117 |
+
|
118 |
+
- The option is provided to precompute bases at the beginning of the training instead of computing them at the beginning of each forward pass (`--precompute_bases`)
|
119 |
+
- The bases computation is just-in-time (JIT) compiled with `torch.jit.script`
|
120 |
+
- The Clebsch-Gordon coefficients are cached in RAM
|
121 |
+
|
122 |
+
|
123 |
+
**Tensor Field Network optimizations**
|
124 |
+
|
125 |
+
- The last layer of each radial profile network does not add any bias in order to avoid large broadcasting operations
|
126 |
+
- The layout (order of dimensions) of the bases tensors is optimized to avoid copies to contiguous memory in the downstream TFN layers
|
127 |
+
- When Tensor Cores are available, and the output feature dimension of computed bases is odd, then it is padded with zeros to make more effective use of Tensor Cores (AMP and TF32 precisions)
|
128 |
+
- Multiple levels of fusion for TFN convolutions (and radial profiles) are provided and automatically used when conditions are met
|
129 |
+
- A low-memory mode is provided that will trade throughput for less memory use (`--low_memory`)
|
130 |
+
|
131 |
+
**Self-attention optimizations**
|
132 |
+
|
133 |
+
- Attention keys and values are computed by a single partial TFN graph convolution in each attention layer instead of two
|
134 |
+
- Graph operations for different output degrees may be fused together if conditions are met
|
135 |
+
|
136 |
+
|
137 |
+
**Normalization optimizations**
|
138 |
+
|
139 |
+
- The equivariant normalization layer is optimized from multiple layer normalizations to a group normalization on fused norms when certain conditions are met
|
140 |
+
|
141 |
+
|
142 |
+
|
143 |
+
Competitive training results and analysis are provided for the following hyperparameters (identical to the ones in the original publication):
|
144 |
+
- Number of layers: 7
|
145 |
+
- Number of degrees: 4
|
146 |
+
- Number of channels: 32
|
147 |
+
- Number of attention heads: 8
|
148 |
+
- Channels division: 2
|
149 |
+
- Use of equivariant normalization: true
|
150 |
+
- Use of layer normalization: true
|
151 |
+
- Pooling: max
|
152 |
+
|
153 |
+
|
154 |
+
### Feature support matrix
|
155 |
+
|
156 |
+
This model supports the following features::
|
157 |
+
|
158 |
+
| Feature | SE(3)-Transformer
|
159 |
+
|-----------------------|--------------------------
|
160 |
+
|Automatic mixed precision (AMP) | Yes
|
161 |
+
|Distributed data parallel (DDP) | Yes
|
162 |
+
|
163 |
+
#### Features
|
164 |
+
|
165 |
+
|
166 |
+
**Distributed data parallel (DDP)**
|
167 |
+
|
168 |
+
[DistributedDataParallel (DDP)](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel) implements data parallelism at the module level that can run across multiple GPUs or machines.
|
169 |
+
|
170 |
+
**Automatic Mixed Precision (AMP)**
|
171 |
+
|
172 |
+
This implementation uses the native PyTorch AMP implementation of mixed precision training. It allows us to use FP16 training with FP32 master weights by modifying just a few lines of code. A detailed explanation of mixed precision can be found in the next section.
|
173 |
+
|
174 |
+
### Mixed precision training
|
175 |
+
|
176 |
+
Mixed precision is the combined use of different numerical precisions in a computational method. [Mixed precision](https://arxiv.org/abs/1710.03740) training offers significant computational speedup by performing operations in half-precision format while storing minimal information in single-precision to retain as much information as possible in critical parts of the network. Since the introduction of [Tensor Cores](https://developer.nvidia.com/tensor-cores) in NVIDIA Volta, and following with both the NVIDIA Turing and NVIDIA Ampere Architectures, significant training speedups are experienced by switching to mixed precision -- up to 3x overall speedup on the most arithmetically intense model architectures. Using [mixed precision training](https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html) previously required two steps:
|
177 |
+
1. Porting the model to use the FP16 data type where appropriate.
|
178 |
+
2. Adding loss scaling to preserve small gradient values.
|
179 |
+
|
180 |
+
AMP enables mixed precision training on NVIDIA Volta, NVIDIA Turing, and NVIDIA Ampere GPU architectures automatically. The PyTorch framework code makes all necessary model changes internally.
|
181 |
+
|
182 |
+
For information about:
|
183 |
+
- How to train using mixed precision, refer to the [Mixed Precision Training](https://arxiv.org/abs/1710.03740) paper and [Training With Mixed Precision](https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html) documentation.
|
184 |
+
- Techniques used for mixed precision training, refer to the [Mixed-Precision Training of Deep Neural Networks](https://devblogs.nvidia.com/mixed-precision-training-deep-neural-networks/) blog.
|
185 |
+
- APEX tools for mixed precision training, refer to the [NVIDIA Apex: Tools for Easy Mixed-Precision Training in PyTorch](https://devblogs.nvidia.com/apex-pytorch-easy-mixed-precision-training/).
|
186 |
+
|
187 |
+
#### Enabling mixed precision
|
188 |
+
|
189 |
+
Mixed precision is enabled in PyTorch by using the native [Automatic Mixed Precision package](https://pytorch.org/docs/stable/amp.html), which casts variables to half-precision upon retrieval while storing variables in single-precision format. Furthermore, to preserve small gradient magnitudes in backpropagation, a [loss scaling](https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html#lossscaling) step must be included when applying gradients. In PyTorch, loss scaling can be applied automatically using a `GradScaler`.
|
190 |
+
Automatic Mixed Precision makes all the adjustments internally in PyTorch, providing two benefits over manual operations. First, programmers need not modify network model code, reducing development and maintenance effort. Second, using AMP maintains forward and backward compatibility with all the APIs for defining and running PyTorch models.
|
191 |
+
|
192 |
+
To enable mixed precision, you can simply use the `--amp` flag when running the training or inference scripts.
|
193 |
+
|
194 |
+
#### Enabling TF32
|
195 |
+
|
196 |
+
TensorFloat-32 (TF32) is the new math mode in [NVIDIA A100](https://www.nvidia.com/en-us/data-center/a100/) GPUs for handling the matrix math, also called tensor operations. TF32 running on Tensor Cores in A100 GPUs can provide up to 10x speedups compared to single-precision floating-point math (FP32) on NVIDIA Volta GPUs.
|
197 |
+
|
198 |
+
TF32 Tensor Cores can speed up networks using FP32, typically with no loss of accuracy. It is more robust than FP16 for models that require a high dynamic range for weights or activations.
|
199 |
+
|
200 |
+
For more information, refer to the [TensorFloat-32 in the A100 GPU Accelerates AI Training, HPC up to 20x](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/) blog post.
|
201 |
+
|
202 |
+
TF32 is supported in the NVIDIA Ampere GPU architecture and is enabled by default.
|
203 |
+
|
204 |
+
|
205 |
+
|
206 |
+
### Glossary
|
207 |
+
|
208 |
+
**Degree (type)**
|
209 |
+
|
210 |
+
In the model, every feature (input, output and hidden) transforms in an equivariant way in relation to the input graph. When we define a feature, we need to choose, in addition to the number of channels, which transformation rule it obeys.
|
211 |
+
|
212 |
+
The degree or type of a feature is a positive integer that describes how this feature transforms when the input rotates in 3D.
|
213 |
+
|
214 |
+
This is related to [irreducible representations](https://en.wikipedia.org/wiki/Irreducible_representation) of different rotation orders.
|
215 |
+
|
216 |
+
The degree of a feature determines its dimensionality. A type-d feature has a dimensionality of 2d+1.
|
217 |
+
|
218 |
+
Some common examples include:
|
219 |
+
- Degree 0: 1D scalars invariant to rotation
|
220 |
+
- Degree 1: 3D vectors that rotate according to 3D rotation matrices
|
221 |
+
- Degree 2: 5D vectors that rotate according to 5D [Wigner-D matrices](https://en.wikipedia.org/wiki/Wigner_D-matrix). These can represent symmetric traceless 3x3 matrices.
|
222 |
+
|
223 |
+
**Fiber**
|
224 |
+
|
225 |
+
A fiber can be viewed as a representation of a set of features of different types or degrees (positive integers), where each feature type transforms according to its rule.
|
226 |
+
|
227 |
+
In this repository, a fiber can be seen as a dictionary with degrees as keys and numbers of channels as values.
|
228 |
+
|
229 |
+
**Multiplicity**
|
230 |
+
|
231 |
+
The multiplicity of a feature of a given type is the number of channels of this feature.
|
232 |
+
|
233 |
+
**Tensor Field Network**
|
234 |
+
|
235 |
+
A [Tensor Field Network](https://arxiv.org/abs/1802.08219) is a kind of equivariant graph convolution that can combine features of different degrees and produce new ones while preserving equivariance thanks to [tensor products](https://en.wikipedia.org/wiki/Tensor_product).
|
236 |
+
|
237 |
+
**Equivariance**
|
238 |
+
|
239 |
+
[Equivariance](https://en.wikipedia.org/wiki/Equivariant_map) is a property of a function of model stating that applying a symmetry transformation to the input and then computing the function produces the same result as computing the function and then applying the transformation to the output.
|
240 |
+
|
241 |
+
In the case of SE(3)-Transformer, the symmetry group is the group of continuous roto-translations (SE(3)).
|
242 |
+
|
243 |
+
## Setup
|
244 |
+
|
245 |
+
The following section lists the requirements that you need to meet in order to start training the SE(3)-Transformer model.
|
246 |
+
|
247 |
+
### Requirements
|
248 |
+
|
249 |
+
This repository contains a Dockerfile which extends the PyTorch 21.07 NGC container and encapsulates some dependencies. Aside from these dependencies, ensure you have the following components:
|
250 |
+
- [NVIDIA Docker](https://github.com/NVIDIA/nvidia-docker)
|
251 |
+
- PyTorch 21.07+ NGC container
|
252 |
+
- Supported GPUs:
|
253 |
+
- [NVIDIA Volta architecture](https://www.nvidia.com/en-us/data-center/volta-gpu-architecture/)
|
254 |
+
- [NVIDIA Turing architecture](https://www.nvidia.com/en-us/design-visualization/technologies/turing-architecture/)
|
255 |
+
- [NVIDIA Ampere architecture](https://www.nvidia.com/en-us/data-center/nvidia-ampere-gpu-architecture/)
|
256 |
+
|
257 |
+
For more information about how to get started with NGC containers, refer to the following sections from the NVIDIA GPU Cloud Documentation and the Deep Learning Documentation:
|
258 |
+
- [Getting Started Using NVIDIA GPU Cloud](https://docs.nvidia.com/ngc/ngc-getting-started-guide/index.html)
|
259 |
+
- [Accessing And Pulling From The NGC Container Registry](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html#accessing_registry)
|
260 |
+
- [Running PyTorch](https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/running.html#running)
|
261 |
+
|
262 |
+
For those unable to use the PyTorch NGC container to set up the required environment or create your own container, refer to the versioned [NVIDIA Container Support Matrix](https://docs.nvidia.com/deeplearning/frameworks/support-matrix/index.html).
|
263 |
+
|
264 |
+
## Quick Start Guide
|
265 |
+
|
266 |
+
To train your model using mixed or TF32 precision with Tensor Cores or FP32, perform the following steps using the default parameters of the SE(3)-Transformer model on the QM9 dataset. For the specifics concerning training and inference, refer to the [Advanced](#advanced) section.
|
267 |
+
|
268 |
+
1. Clone the repository.
|
269 |
+
```
|
270 |
+
git clone https://github.com/NVIDIA/DeepLearningExamples
|
271 |
+
cd DeepLearningExamples/PyTorch/DrugDiscovery/SE3Transformer
|
272 |
+
```
|
273 |
+
|
274 |
+
2. Build the `se3-transformer` PyTorch NGC container.
|
275 |
+
```
|
276 |
+
docker build -t se3-transformer .
|
277 |
+
```
|
278 |
+
|
279 |
+
3. Start an interactive session in the NGC container to run training/inference.
|
280 |
+
```
|
281 |
+
mkdir -p results
|
282 |
+
docker run -it --runtime=nvidia --shm-size=8g --ulimit memlock=-1 --ulimit stack=67108864 --rm -v ${PWD}/results:/results se3-transformer:latest
|
283 |
+
```
|
284 |
+
|
285 |
+
4. Start training.
|
286 |
+
```
|
287 |
+
bash scripts/train.sh
|
288 |
+
```
|
289 |
+
|
290 |
+
5. Start inference/predictions.
|
291 |
+
```
|
292 |
+
bash scripts/predict.sh
|
293 |
+
```
|
294 |
+
|
295 |
+
|
296 |
+
Now that you have your model trained and evaluated, you can choose to compare your training results with our [Training accuracy results](#training-accuracy-results). You can also choose to benchmark your performance to [Training performance benchmark](#training-performance-results) or [Inference performance benchmark](#inference-performance-results). Following the steps in these sections will ensure that you achieve the same accuracy and performance results as stated in the [Results](#results) section.
|
297 |
+
|
298 |
+
## Advanced
|
299 |
+
|
300 |
+
The following sections provide greater details of the dataset, running training and inference, and the training results.
|
301 |
+
|
302 |
+
### Scripts and sample code
|
303 |
+
|
304 |
+
In the root directory, the most important files are:
|
305 |
+
- `Dockerfile`: container with the basic set of dependencies to run SE(3)-Transformers
|
306 |
+
- `requirements.txt`: set of extra requirements to run SE(3)-Transformers
|
307 |
+
- `se3_transformer/data_loading/qm9.py`: QM9 data loading and preprocessing, as well as bases precomputation
|
308 |
+
- `se3_transformer/model/layers/`: directory containing model architecture layers
|
309 |
+
- `se3_transformer/model/transformer.py`: main Transformer module
|
310 |
+
- `se3_transformer/model/basis.py`: logic for computing bases matrices
|
311 |
+
- `se3_transformer/runtime/training.py`: training script, to be run as a python module
|
312 |
+
- `se3_transformer/runtime/inference.py`: inference script, to be run as a python module
|
313 |
+
- `se3_transformer/runtime/metrics.py`: MAE metric with support for multi-GPU synchronization
|
314 |
+
- `se3_transformer/runtime/loggers.py`: [DLLogger](https://github.com/NVIDIA/dllogger) and [W&B](wandb.ai/) loggers
|
315 |
+
|
316 |
+
|
317 |
+
### Parameters
|
318 |
+
|
319 |
+
The complete list of the available parameters for the `training.py` script contains:
|
320 |
+
|
321 |
+
**General**
|
322 |
+
|
323 |
+
- `--epochs`: Number of training epochs (default: `100` for single-GPU)
|
324 |
+
- `--batch_size`: Batch size (default: `240`)
|
325 |
+
- `--seed`: Set a seed globally (default: `None`)
|
326 |
+
- `--num_workers`: Number of dataloading workers (default: `8`)
|
327 |
+
- `--amp`: Use Automatic Mixed Precision (default `false`)
|
328 |
+
- `--gradient_clip`: Clipping of the gradient norms (default: `None`)
|
329 |
+
- `--accumulate_grad_batches`: Gradient accumulation (default: `1`)
|
330 |
+
- `--ckpt_interval`: Save a checkpoint every N epochs (default: `-1`)
|
331 |
+
- `--eval_interval`: Do an evaluation round every N epochs (default: `1`)
|
332 |
+
- `--silent`: Minimize stdout output (default: `false`)
|
333 |
+
|
334 |
+
**Paths**
|
335 |
+
|
336 |
+
- `--data_dir`: Directory where the data is located or should be downloaded (default: `./data`)
|
337 |
+
- `--log_dir`: Directory where the results logs should be saved (default: `/results`)
|
338 |
+
- `--save_ckpt_path`: File where the checkpoint should be saved (default: `None`)
|
339 |
+
- `--load_ckpt_path`: File of the checkpoint to be loaded (default: `None`)
|
340 |
+
|
341 |
+
**Optimizer**
|
342 |
+
|
343 |
+
- `--optimizer`: Optimizer to use (default: `adam`)
|
344 |
+
- `--learning_rate`: Learning rate to use (default: `0.002` for single-GPU)
|
345 |
+
- `--momentum`: Momentum to use (default: `0.9`)
|
346 |
+
- `--weight_decay`: Weight decay to use (default: `0.1`)
|
347 |
+
|
348 |
+
**QM9 dataset**
|
349 |
+
|
350 |
+
- `--task`: Regression task to train on (default: `homo`)
|
351 |
+
- `--precompute_bases`: Precompute bases at the beginning of the script during dataset initialization, instead of computing them at the beginning of each forward pass (default: `false`)
|
352 |
+
|
353 |
+
**Model architecture**
|
354 |
+
|
355 |
+
- `--num_layers`: Number of stacked Transformer layers (default: `7`)
|
356 |
+
- `--num_heads`: Number of heads in self-attention (default: `8`)
|
357 |
+
- `--channels_div`: Channels division before feeding to attention layer (default: `2`)
|
358 |
+
- `--pooling`: Type of graph pooling (default: `max`)
|
359 |
+
- `--norm`: Apply a normalization layer after each attention block (default: `false`)
|
360 |
+
- `--use_layer_norm`: Apply layer normalization between MLP layers (default: `false`)
|
361 |
+
- `--low_memory`: If true, will use fused ops that are slower but use less memory (expect 25 percent less memory). Only has an effect if AMP is enabled on NVIDIA Volta GPUs or if running on Ampere GPUs (default: `false`)
|
362 |
+
- `--num_degrees`: Number of degrees to use. Hidden features will have types [0, ..., num_degrees - 1] (default: `4`)
|
363 |
+
- `--num_channels`: Number of channels for the hidden features (default: `32`)
|
364 |
+
|
365 |
+
|
366 |
+
### Command-line options
|
367 |
+
|
368 |
+
To show the full list of available options and their descriptions, use the `-h` or `--help` command-line option, for example: `python -m se3_transformer.runtime.training --help`.
|
369 |
+
|
370 |
+
|
371 |
+
### Dataset guidelines
|
372 |
+
|
373 |
+
#### Demo dataset
|
374 |
+
|
375 |
+
The SE(3)-Transformer was trained on the QM9 dataset.
|
376 |
+
|
377 |
+
The QM9 dataset is hosted on DGL servers and downloaded (38MB) automatically when needed. By default, it is stored in the `./data` directory, but this location can be changed with the `--data_dir` argument.
|
378 |
+
|
379 |
+
The dataset is saved as a `qm9_edge.npz` file and converted to DGL graphs at runtime.
|
380 |
+
|
381 |
+
As input features, we use:
|
382 |
+
- Node features (6D):
|
383 |
+
- One-hot-encoded atom type (5D) (atom types: H, C, N, O, F)
|
384 |
+
- Number of protons of each atom (1D)
|
385 |
+
- Edge features: one-hot-encoded bond type (4D) (bond types: single, double, triple, aromatic)
|
386 |
+
- The relative positions between adjacent nodes (atoms)
|
387 |
+
|
388 |
+
#### Custom datasets
|
389 |
+
|
390 |
+
To use this network on a new dataset, you can extend the `DataModule` class present in `se3_transformer/data_loading/data_module.py`.
|
391 |
+
|
392 |
+
Your custom collate function should return a tuple with:
|
393 |
+
|
394 |
+
- A (batched) DGLGraph object
|
395 |
+
- A dictionary of node features ({‘{degree}’: tensor})
|
396 |
+
- A dictionary of edge features ({‘{degree}’: tensor})
|
397 |
+
- (Optional) Precomputed bases as a dictionary
|
398 |
+
- Labels as a tensor
|
399 |
+
|
400 |
+
You can then modify the `training.py` and `inference.py` scripts to use your new data module.
|
401 |
+
|
402 |
+
### Training process
|
403 |
+
|
404 |
+
The training script is `se3_transformer/runtime/training.py`, to be run as a module: `python -m se3_transformer.runtime.training`.
|
405 |
+
|
406 |
+
**Logs**
|
407 |
+
|
408 |
+
By default, the resulting logs are stored in `/results/`. This can be changed with `--log_dir`.
|
409 |
+
|
410 |
+
You can connect your existing Weights & Biases account by setting the `WANDB_API_KEY` environment variable.
|
411 |
+
|
412 |
+
**Checkpoints**
|
413 |
+
|
414 |
+
The argument `--save_ckpt_path` can be set to the path of the file where the checkpoints should be saved.
|
415 |
+
`--ckpt_interval` can also be set to the interval (in the number of epochs) between checkpoints.
|
416 |
+
|
417 |
+
**Evaluation**
|
418 |
+
|
419 |
+
The evaluation metric is the Mean Absolute Error (MAE).
|
420 |
+
|
421 |
+
`--eval_interval` can be set to the interval (in the number of epochs) between evaluation rounds. By default, an evaluation round is performed after each epoch.
|
422 |
+
|
423 |
+
**Automatic Mixed Precision**
|
424 |
+
|
425 |
+
To enable Mixed Precision training, add the `--amp` flag.
|
426 |
+
|
427 |
+
**Multi-GPU and multi-node**
|
428 |
+
|
429 |
+
The training script supports the PyTorch elastic launcher to run on multiple GPUs or nodes. Refer to the [official documentation](https://pytorch.org/docs/1.9.0/elastic/run.html).
|
430 |
+
|
431 |
+
For example, to train on all available GPUs with AMP:
|
432 |
+
|
433 |
+
```
|
434 |
+
python -m torch.distributed.run --nnodes=1 --nproc_per_node=gpu --module se3_transformer.runtime.training --amp
|
435 |
+
```
|
436 |
+
|
437 |
+
|
438 |
+
### Inference process
|
439 |
+
|
440 |
+
Inference can be run by using the `se3_transformer.runtime.inference` python module.
|
441 |
+
|
442 |
+
The inference script is `se3_transformer/runtime/inference.py`, to be run as a module: `python -m se3_transformer.runtime.inference`. It requires a pre-trained model checkpoint (to be passed as `--load_ckpt_path`).
|
443 |
+
|
444 |
+
|
445 |
+
## Performance
|
446 |
+
|
447 |
+
The performance measurements in this document were conducted at the time of publication and may not reflect the performance achieved from NVIDIA’s latest software release. For the most up-to-date performance measurements, go to [NVIDIA Data Center Deep Learning Product Performance](https://developer.nvidia.com/deep-learning-performance-training-inference).
|
448 |
+
|
449 |
+
### Benchmarking
|
450 |
+
|
451 |
+
The following section shows how to run benchmarks measuring the model performance in training and inference modes.
|
452 |
+
|
453 |
+
#### Training performance benchmark
|
454 |
+
|
455 |
+
To benchmark the training performance on a specific batch size, run `bash scripts/benchmarck_train.sh {BATCH_SIZE}` for single GPU, and `bash scripts/benchmarck_train_multi_gpu.sh {BATCH_SIZE}` for multi-GPU.
|
456 |
+
|
457 |
+
#### Inference performance benchmark
|
458 |
+
|
459 |
+
To benchmark the inference performance on a specific batch size, run `bash scripts/benchmarck_inference.sh {BATCH_SIZE}`.
|
460 |
+
|
461 |
+
### Results
|
462 |
+
|
463 |
+
|
464 |
+
The following sections provide details on how we achieved our performance and accuracy in training and inference.
|
465 |
+
|
466 |
+
#### Training accuracy results
|
467 |
+
|
468 |
+
##### Training accuracy: NVIDIA DGX A100 (8x A100 80GB)
|
469 |
+
|
470 |
+
Our results were obtained by running the `scripts/train.sh` training script in the PyTorch 21.07 NGC container on NVIDIA DGX A100 (8x A100 80GB) GPUs.
|
471 |
+
|
472 |
+
| GPUs | Batch size / GPU | Absolute error - TF32 | Absolute error - mixed precision | Time to train - TF32 | Time to train - mixed precision | Time to train speedup (mixed precision to TF32) |
|
473 |
+
|:------------------:|:----------------------:|:--------------------:|:------------------------------------:|:---------------------------------:|:----------------------:|:----------------------------------------------:|
|
474 |
+
| 1 | 240 | 0.03456 | 0.03460 | 1h23min | 1h03min | 1.32x |
|
475 |
+
| 8 | 240 | 0.03417 | 0.03424 | 15min | 12min | 1.25x |
|
476 |
+
|
477 |
+
|
478 |
+
##### Training accuracy: NVIDIA DGX-1 (8x V100 16GB)
|
479 |
+
|
480 |
+
Our results were obtained by running the `scripts/train.sh` training script in the PyTorch 21.07 NGC container on NVIDIA DGX-1 with (8x V100 16GB) GPUs.
|
481 |
+
|
482 |
+
| GPUs | Batch size / GPU | Absolute error - FP32 | Absolute error - mixed precision | Time to train - FP32 | Time to train - mixed precision | Time to train speedup (mixed precision to FP32) |
|
483 |
+
|:------------------:|:----------------------:|:--------------------:|:------------------------------------:|:---------------------------------:|:----------------------:|:----------------------------------------------:|
|
484 |
+
| 1 | 240 | 0.03432 | 0.03439 | 2h25min | 1h33min | 1.56x |
|
485 |
+
| 8 | 240 | 0.03380 | 0.03495 | 29min | 20min | 1.45x |
|
486 |
+
|
487 |
+
|
488 |
+
#### Training performance results
|
489 |
+
|
490 |
+
##### Training performance: NVIDIA DGX A100 (8x A100 80GB)
|
491 |
+
|
492 |
+
Our results were obtained by running the `scripts/benchmark_train.sh` and `scripts/benchmark_train_multi_gpu.sh` benchmarking scripts in the PyTorch 21.07 NGC container on NVIDIA DGX A100 with 8x A100 80GB GPUs. Performance numbers (in molecules per millisecond) were averaged over five entire training epochs after a warmup epoch.
|
493 |
+
|
494 |
+
| GPUs | Batch size / GPU | Throughput - TF32 [mol/ms] | Throughput - mixed precision [mol/ms] | Throughput speedup (mixed precision - TF32) | Weak scaling - TF32 | Weak scaling - mixed precision |
|
495 |
+
|:------------------:|:----------------------:|:--------------------:|:------------------------------------:|:---------------------------------:|:----------------------:|:----------------------------------------------:|
|
496 |
+
| 1 | 240 | 2.21 | 2.92 | 1.32x | | |
|
497 |
+
| 1 | 120 | 1.81 | 2.04 | 1.13x | | |
|
498 |
+
| 8 | 240 | 17.15 | 22.95 | 1.34x | 7.76 | 7.86 |
|
499 |
+
| 8 | 120 | 13.89 | 15.62 | 1.12x | 7.67 | 7.66 |
|
500 |
+
|
501 |
+
|
502 |
+
To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
|
503 |
+
|
504 |
+
|
505 |
+
##### Training performance: NVIDIA DGX-1 (8x V100 16GB)
|
506 |
+
|
507 |
+
Our results were obtained by running the `scripts/benchmark_train.sh` and `scripts/benchmark_train_multi_gpu.sh` benchmarking scripts in the PyTorch 21.07 NGC container on NVIDIA DGX-1 with 8x V100 16GB GPUs. Performance numbers (in molecules per millisecond) were averaged over five entire training epochs after a warmup epoch.
|
508 |
+
|
509 |
+
| GPUs | Batch size / GPU | Throughput - FP32 [mol/ms] | Throughput - mixed precision [mol/ms] | Throughput speedup (FP32 - mixed precision) | Weak scaling - FP32 | Weak scaling - mixed precision |
|
510 |
+
|:------------------:|:----------------------:|:--------------------:|:------------------------------------:|:---------------------------------:|:----------------------:|:----------------------------------------------:|
|
511 |
+
| 1 | 240 | 1.25 | 1.88 | 1.50x | | |
|
512 |
+
| 1 | 120 | 1.03 | 1.41 | 1.37x | | |
|
513 |
+
| 8 | 240 | 9.33 | 14.02 | 1.50x | 7.46 | 7.46 |
|
514 |
+
| 8 | 120 | 7.39 | 9.41 | 1.27x | 7.17 | 6.67 |
|
515 |
+
|
516 |
+
|
517 |
+
To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
|
518 |
+
|
519 |
+
|
520 |
+
#### Inference performance results
|
521 |
+
|
522 |
+
|
523 |
+
##### Inference performance: NVIDIA DGX A100 (1x A100 80GB)
|
524 |
+
|
525 |
+
Our results were obtained by running the `scripts/benchmark_inference.sh` inferencing benchmarking script in the PyTorch 21.07 NGC container on NVIDIA DGX A100 with 1x A100 80GB GPU.
|
526 |
+
|
527 |
+
FP16
|
528 |
+
|
529 |
+
| Batch size | Throughput Avg [mol/ms] | Latency Avg [ms] | Latency 90% [ms] |Latency 95% [ms] |Latency 99% [ms] |
|
530 |
+
|:------------:|:------:|:-----:|:-----:|:-----:|:-----:|
|
531 |
+
| 1600 | 11.60 | 140.94 | 138.29 | 140.12 | 386.40 |
|
532 |
+
| 800 | 10.74 | 75.69 | 75.74 | 76.50 | 79.77 |
|
533 |
+
| 400 | 8.86 | 45.57 | 46.11 | 46.60 | 49.97 |
|
534 |
+
|
535 |
+
TF32
|
536 |
+
|
537 |
+
| Batch size | Throughput Avg [mol/ms] | Latency Avg [ms] | Latency 90% [ms] |Latency 95% [ms] |Latency 99% [ms] |
|
538 |
+
|:------------:|:------:|:-----:|:-----:|:-----:|:-----:|
|
539 |
+
| 1600 | 8.58 | 189.20 | 186.39 | 187.71 | 420.28 |
|
540 |
+
| 800 | 8.28 | 97.56 | 97.20 | 97.73 | 101.13 |
|
541 |
+
| 400 | 7.55 | 53.38 | 53.72 | 54.48 | 56.62 |
|
542 |
+
|
543 |
+
To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
|
544 |
+
|
545 |
+
|
546 |
+
|
547 |
+
##### Inference performance: NVIDIA DGX-1 (1x V100 16GB)
|
548 |
+
|
549 |
+
Our results were obtained by running the `scripts/benchmark_inference.sh` inferencing benchmarking script in the PyTorch 21.07 NGC container on NVIDIA DGX-1 with 1x V100 16GB GPU.
|
550 |
+
|
551 |
+
FP16
|
552 |
+
|
553 |
+
| Batch size | Throughput Avg [mol/ms] | Latency Avg [ms] | Latency 90% [ms] |Latency 95% [ms] |Latency 99% [ms] |
|
554 |
+
|:------------:|:------:|:-----:|:-----:|:-----:|:-----:|
|
555 |
+
| 1600 | 6.42 | 254.54 | 247.97 | 249.29 | 721.15 |
|
556 |
+
| 800 | 6.13 | 132.07 | 131.90 | 132.70 | 140.15 |
|
557 |
+
| 400 | 5.37 | 75.12 | 76.01 | 76.66 | 79.90 |
|
558 |
+
|
559 |
+
FP32
|
560 |
+
|
561 |
+
| Batch size | Throughput Avg [mol/ms] | Latency Avg [ms] | Latency 90% [ms] |Latency 95% [ms] |Latency 99% [ms] |
|
562 |
+
|:------------:|:------:|:-----:|:-----:|:-----:|:-----:|
|
563 |
+
| 1600 | 3.39 | 475.86 | 473.82 | 475.64 | 891.18 |
|
564 |
+
| 800 | 3.36 | 239.17 | 240.64 | 241.65 | 243.70 |
|
565 |
+
| 400 | 3.17 | 126.67 | 128.19 | 128.82 | 130.54 |
|
566 |
+
|
567 |
+
|
568 |
+
To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
|
569 |
+
|
570 |
+
|
571 |
+
## Release notes
|
572 |
+
|
573 |
+
### Changelog
|
574 |
+
|
575 |
+
August 2021
|
576 |
+
- Initial release
|
577 |
+
|
578 |
+
### Known issues
|
579 |
+
|
580 |
+
If you encounter `OSError: [Errno 12] Cannot allocate memory` during the Dataloader iterator creation (more precisely during the `fork()`, this is most likely due to the use of the `--precompute_bases` flag. If you cannot add more RAM or Swap to your machine, it is recommended to turn off bases precomputation by removing the `--precompute_bases` flag or using `--precompute_bases false`.
|
RFdiffusion/env/SE3Transformer/build/lib/se3_transformer/__init__.py
ADDED
File without changes
|
RFdiffusion/env/SE3Transformer/build/lib/se3_transformer/data_loading/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .qm9 import QM9DataModule
|
RFdiffusion/env/SE3Transformer/build/lib/se3_transformer/data_loading/data_module.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
4 |
+
# copy of this software and associated documentation files (the "Software"),
|
5 |
+
# to deal in the Software without restriction, including without limitation
|
6 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
8 |
+
# Software is furnished to do so, subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in
|
11 |
+
# all copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
16 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
18 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
19 |
+
# DEALINGS IN THE SOFTWARE.
|
20 |
+
#
|
21 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
22 |
+
# SPDX-License-Identifier: MIT
|
23 |
+
|
24 |
+
import torch.distributed as dist
|
25 |
+
from abc import ABC
|
26 |
+
from torch.utils.data import DataLoader, DistributedSampler, Dataset
|
27 |
+
|
28 |
+
from se3_transformer.runtime.utils import get_local_rank
|
29 |
+
|
30 |
+
|
31 |
+
def _get_dataloader(dataset: Dataset, shuffle: bool, **kwargs) -> DataLoader:
|
32 |
+
# Classic or distributed dataloader depending on the context
|
33 |
+
sampler = DistributedSampler(dataset, shuffle=shuffle) if dist.is_initialized() else None
|
34 |
+
return DataLoader(dataset, shuffle=(shuffle and sampler is None), sampler=sampler, **kwargs)
|
35 |
+
|
36 |
+
|
37 |
+
class DataModule(ABC):
|
38 |
+
""" Abstract DataModule. Children must define self.ds_{train | val | test}. """
|
39 |
+
|
40 |
+
def __init__(self, **dataloader_kwargs):
|
41 |
+
super().__init__()
|
42 |
+
if get_local_rank() == 0:
|
43 |
+
self.prepare_data()
|
44 |
+
|
45 |
+
# Wait until rank zero has prepared the data (download, preprocessing, ...)
|
46 |
+
if dist.is_initialized():
|
47 |
+
dist.barrier(device_ids=[get_local_rank()])
|
48 |
+
|
49 |
+
self.dataloader_kwargs = {'pin_memory': True, 'persistent_workers': True, **dataloader_kwargs}
|
50 |
+
self.ds_train, self.ds_val, self.ds_test = None, None, None
|
51 |
+
|
52 |
+
def prepare_data(self):
|
53 |
+
""" Method called only once per node. Put here any downloading or preprocessing """
|
54 |
+
pass
|
55 |
+
|
56 |
+
def train_dataloader(self) -> DataLoader:
|
57 |
+
return _get_dataloader(self.ds_train, shuffle=True, **self.dataloader_kwargs)
|
58 |
+
|
59 |
+
def val_dataloader(self) -> DataLoader:
|
60 |
+
return _get_dataloader(self.ds_val, shuffle=False, **self.dataloader_kwargs)
|
61 |
+
|
62 |
+
def test_dataloader(self) -> DataLoader:
|
63 |
+
return _get_dataloader(self.ds_test, shuffle=False, **self.dataloader_kwargs)
|
RFdiffusion/env/SE3Transformer/build/lib/se3_transformer/data_loading/qm9.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
4 |
+
# copy of this software and associated documentation files (the "Software"),
|
5 |
+
# to deal in the Software without restriction, including without limitation
|
6 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
8 |
+
# Software is furnished to do so, subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in
|
11 |
+
# all copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
16 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
18 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
19 |
+
# DEALINGS IN THE SOFTWARE.
|
20 |
+
#
|
21 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
22 |
+
# SPDX-License-Identifier: MIT
|
23 |
+
from typing import Tuple
|
24 |
+
|
25 |
+
import dgl
|
26 |
+
import pathlib
|
27 |
+
import torch
|
28 |
+
from dgl.data import QM9EdgeDataset
|
29 |
+
from dgl import DGLGraph
|
30 |
+
from torch import Tensor
|
31 |
+
from torch.utils.data import random_split, DataLoader, Dataset
|
32 |
+
from tqdm import tqdm
|
33 |
+
|
34 |
+
from se3_transformer.data_loading.data_module import DataModule
|
35 |
+
from se3_transformer.model.basis import get_basis
|
36 |
+
from se3_transformer.runtime.utils import get_local_rank, str2bool, using_tensor_cores
|
37 |
+
|
38 |
+
|
39 |
+
def _get_relative_pos(qm9_graph: DGLGraph) -> Tensor:
|
40 |
+
x = qm9_graph.ndata['pos']
|
41 |
+
src, dst = qm9_graph.edges()
|
42 |
+
rel_pos = x[dst] - x[src]
|
43 |
+
return rel_pos
|
44 |
+
|
45 |
+
|
46 |
+
def _get_split_sizes(full_dataset: Dataset) -> Tuple[int, int, int]:
|
47 |
+
len_full = len(full_dataset)
|
48 |
+
len_train = 100_000
|
49 |
+
len_test = int(0.1 * len_full)
|
50 |
+
len_val = len_full - len_train - len_test
|
51 |
+
return len_train, len_val, len_test
|
52 |
+
|
53 |
+
|
54 |
+
class QM9DataModule(DataModule):
|
55 |
+
"""
|
56 |
+
Datamodule wrapping https://docs.dgl.ai/en/latest/api/python/dgl.data.html#qm9edge-dataset
|
57 |
+
Training set is 100k molecules. Test set is 10% of the dataset. Validation set is the rest.
|
58 |
+
This includes all the molecules from QM9 except the ones that are uncharacterized.
|
59 |
+
"""
|
60 |
+
|
61 |
+
NODE_FEATURE_DIM = 6
|
62 |
+
EDGE_FEATURE_DIM = 4
|
63 |
+
|
64 |
+
def __init__(self,
|
65 |
+
data_dir: pathlib.Path,
|
66 |
+
task: str = 'homo',
|
67 |
+
batch_size: int = 240,
|
68 |
+
num_workers: int = 8,
|
69 |
+
num_degrees: int = 4,
|
70 |
+
amp: bool = False,
|
71 |
+
precompute_bases: bool = False,
|
72 |
+
**kwargs):
|
73 |
+
self.data_dir = data_dir # This needs to be before __init__ so that prepare_data has access to it
|
74 |
+
super().__init__(batch_size=batch_size, num_workers=num_workers, collate_fn=self._collate)
|
75 |
+
self.amp = amp
|
76 |
+
self.task = task
|
77 |
+
self.batch_size = batch_size
|
78 |
+
self.num_degrees = num_degrees
|
79 |
+
|
80 |
+
qm9_kwargs = dict(label_keys=[self.task], verbose=False, raw_dir=str(data_dir))
|
81 |
+
if precompute_bases:
|
82 |
+
bases_kwargs = dict(max_degree=num_degrees - 1, use_pad_trick=using_tensor_cores(amp), amp=amp)
|
83 |
+
full_dataset = CachedBasesQM9EdgeDataset(bases_kwargs=bases_kwargs, batch_size=batch_size,
|
84 |
+
num_workers=num_workers, **qm9_kwargs)
|
85 |
+
else:
|
86 |
+
full_dataset = QM9EdgeDataset(**qm9_kwargs)
|
87 |
+
|
88 |
+
self.ds_train, self.ds_val, self.ds_test = random_split(full_dataset, _get_split_sizes(full_dataset),
|
89 |
+
generator=torch.Generator().manual_seed(0))
|
90 |
+
|
91 |
+
train_targets = full_dataset.targets[self.ds_train.indices, full_dataset.label_keys[0]]
|
92 |
+
self.targets_mean = train_targets.mean()
|
93 |
+
self.targets_std = train_targets.std()
|
94 |
+
|
95 |
+
def prepare_data(self):
|
96 |
+
# Download the QM9 preprocessed data
|
97 |
+
QM9EdgeDataset(verbose=True, raw_dir=str(self.data_dir))
|
98 |
+
|
99 |
+
def _collate(self, samples):
|
100 |
+
graphs, y, *bases = map(list, zip(*samples))
|
101 |
+
batched_graph = dgl.batch(graphs)
|
102 |
+
edge_feats = {'0': batched_graph.edata['edge_attr'][..., None]}
|
103 |
+
batched_graph.edata['rel_pos'] = _get_relative_pos(batched_graph)
|
104 |
+
# get node features
|
105 |
+
node_feats = {'0': batched_graph.ndata['attr'][:, :6, None]}
|
106 |
+
targets = (torch.cat(y) - self.targets_mean) / self.targets_std
|
107 |
+
|
108 |
+
if bases:
|
109 |
+
# collate bases
|
110 |
+
all_bases = {
|
111 |
+
key: torch.cat([b[key] for b in bases[0]], dim=0)
|
112 |
+
for key in bases[0][0].keys()
|
113 |
+
}
|
114 |
+
|
115 |
+
return batched_graph, node_feats, edge_feats, all_bases, targets
|
116 |
+
else:
|
117 |
+
return batched_graph, node_feats, edge_feats, targets
|
118 |
+
|
119 |
+
@staticmethod
|
120 |
+
def add_argparse_args(parent_parser):
|
121 |
+
parser = parent_parser.add_argument_group("QM9 dataset")
|
122 |
+
parser.add_argument('--task', type=str, default='homo', const='homo', nargs='?',
|
123 |
+
choices=['mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv',
|
124 |
+
'U0_atom', 'U_atom', 'H_atom', 'G_atom', 'A', 'B', 'C'],
|
125 |
+
help='Regression task to train on')
|
126 |
+
parser.add_argument('--precompute_bases', type=str2bool, nargs='?', const=True, default=False,
|
127 |
+
help='Precompute bases at the beginning of the script during dataset initialization,'
|
128 |
+
' instead of computing them at the beginning of each forward pass.')
|
129 |
+
return parent_parser
|
130 |
+
|
131 |
+
def __repr__(self):
|
132 |
+
return f'QM9({self.task})'
|
133 |
+
|
134 |
+
|
135 |
+
class CachedBasesQM9EdgeDataset(QM9EdgeDataset):
|
136 |
+
""" Dataset extending the QM9 dataset from DGL with precomputed (cached in RAM) pairwise bases """
|
137 |
+
|
138 |
+
def __init__(self, bases_kwargs: dict, batch_size: int, num_workers: int, *args, **kwargs):
|
139 |
+
"""
|
140 |
+
:param bases_kwargs: Arguments to feed the bases computation function
|
141 |
+
:param batch_size: Batch size to use when iterating over the dataset for computing bases
|
142 |
+
"""
|
143 |
+
self.bases_kwargs = bases_kwargs
|
144 |
+
self.batch_size = batch_size
|
145 |
+
self.bases = None
|
146 |
+
self.num_workers = num_workers
|
147 |
+
super().__init__(*args, **kwargs)
|
148 |
+
|
149 |
+
def load(self):
|
150 |
+
super().load()
|
151 |
+
# Iterate through the dataset and compute bases (pairwise only)
|
152 |
+
# Potential improvement: use multi-GPU and gather
|
153 |
+
dataloader = DataLoader(self, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers,
|
154 |
+
collate_fn=lambda samples: dgl.batch([sample[0] for sample in samples]))
|
155 |
+
bases = []
|
156 |
+
for i, graph in tqdm(enumerate(dataloader), total=len(dataloader), desc='Precomputing QM9 bases',
|
157 |
+
disable=get_local_rank() != 0):
|
158 |
+
rel_pos = _get_relative_pos(graph)
|
159 |
+
# Compute the bases with the GPU but convert the result to CPU to store in RAM
|
160 |
+
bases.append({k: v.cpu() for k, v in get_basis(rel_pos.cuda(), **self.bases_kwargs).items()})
|
161 |
+
self.bases = bases # Assign at the end so that __getitem__ isn't confused
|
162 |
+
|
163 |
+
def __getitem__(self, idx: int):
|
164 |
+
graph, label = super().__getitem__(idx)
|
165 |
+
|
166 |
+
if self.bases:
|
167 |
+
bases_idx = idx // self.batch_size
|
168 |
+
bases_cumsum_idx = self.ne_cumsum[idx] - self.ne_cumsum[bases_idx * self.batch_size]
|
169 |
+
bases_cumsum_next_idx = self.ne_cumsum[idx + 1] - self.ne_cumsum[bases_idx * self.batch_size]
|
170 |
+
return graph, label, {key: basis[bases_cumsum_idx:bases_cumsum_next_idx] for key, basis in
|
171 |
+
self.bases[bases_idx].items()}
|
172 |
+
else:
|
173 |
+
return graph, label
|
RFdiffusion/env/SE3Transformer/build/lib/se3_transformer/model/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .transformer import SE3Transformer, SE3TransformerPooled
|
2 |
+
from .fiber import Fiber
|
RFdiffusion/env/SE3Transformer/build/lib/se3_transformer/model/basis.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
4 |
+
# copy of this software and associated documentation files (the "Software"),
|
5 |
+
# to deal in the Software without restriction, including without limitation
|
6 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
8 |
+
# Software is furnished to do so, subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in
|
11 |
+
# all copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
16 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
18 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
19 |
+
# DEALINGS IN THE SOFTWARE.
|
20 |
+
#
|
21 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
22 |
+
# SPDX-License-Identifier: MIT
|
23 |
+
|
24 |
+
|
25 |
+
from functools import lru_cache
|
26 |
+
from typing import Dict, List
|
27 |
+
|
28 |
+
import e3nn.o3 as o3
|
29 |
+
import torch
|
30 |
+
import torch.nn.functional as F
|
31 |
+
from torch import Tensor
|
32 |
+
from torch.cuda.nvtx import range as nvtx_range
|
33 |
+
|
34 |
+
from se3_transformer.runtime.utils import degree_to_dim
|
35 |
+
|
36 |
+
|
37 |
+
@lru_cache(maxsize=None)
|
38 |
+
def get_clebsch_gordon(J: int, d_in: int, d_out: int, device) -> Tensor:
|
39 |
+
""" Get the (cached) Q^{d_out,d_in}_J matrices from equation (8) """
|
40 |
+
return o3.wigner_3j(J, d_in, d_out, dtype=torch.float64, device=device).permute(2, 1, 0)
|
41 |
+
|
42 |
+
|
43 |
+
@lru_cache(maxsize=None)
|
44 |
+
def get_all_clebsch_gordon(max_degree: int, device) -> List[List[Tensor]]:
|
45 |
+
all_cb = []
|
46 |
+
for d_in in range(max_degree + 1):
|
47 |
+
for d_out in range(max_degree + 1):
|
48 |
+
K_Js = []
|
49 |
+
for J in range(abs(d_in - d_out), d_in + d_out + 1):
|
50 |
+
K_Js.append(get_clebsch_gordon(J, d_in, d_out, device))
|
51 |
+
all_cb.append(K_Js)
|
52 |
+
return all_cb
|
53 |
+
|
54 |
+
|
55 |
+
def get_spherical_harmonics(relative_pos: Tensor, max_degree: int) -> List[Tensor]:
|
56 |
+
all_degrees = list(range(2 * max_degree + 1))
|
57 |
+
with nvtx_range('spherical harmonics'):
|
58 |
+
sh = o3.spherical_harmonics(all_degrees, relative_pos, normalize=True)
|
59 |
+
return torch.split(sh, [degree_to_dim(d) for d in all_degrees], dim=1)
|
60 |
+
|
61 |
+
|
62 |
+
@torch.jit.script
|
63 |
+
def get_basis_script(max_degree: int,
|
64 |
+
use_pad_trick: bool,
|
65 |
+
spherical_harmonics: List[Tensor],
|
66 |
+
clebsch_gordon: List[List[Tensor]],
|
67 |
+
amp: bool) -> Dict[str, Tensor]:
|
68 |
+
"""
|
69 |
+
Compute pairwise bases matrices for degrees up to max_degree
|
70 |
+
:param max_degree: Maximum input or output degree
|
71 |
+
:param use_pad_trick: Pad some of the odd dimensions for a better use of Tensor Cores
|
72 |
+
:param spherical_harmonics: List of computed spherical harmonics
|
73 |
+
:param clebsch_gordon: List of computed CB-coefficients
|
74 |
+
:param amp: When true, return bases in FP16 precision
|
75 |
+
"""
|
76 |
+
basis = {}
|
77 |
+
idx = 0
|
78 |
+
# Double for loop instead of product() because of JIT script
|
79 |
+
for d_in in range(max_degree + 1):
|
80 |
+
for d_out in range(max_degree + 1):
|
81 |
+
key = f'{d_in},{d_out}'
|
82 |
+
K_Js = []
|
83 |
+
for freq_idx, J in enumerate(range(abs(d_in - d_out), d_in + d_out + 1)):
|
84 |
+
Q_J = clebsch_gordon[idx][freq_idx]
|
85 |
+
K_Js.append(torch.einsum('n f, k l f -> n l k', spherical_harmonics[J].float(), Q_J.float()))
|
86 |
+
|
87 |
+
basis[key] = torch.stack(K_Js, 2) # Stack on second dim so order is n l f k
|
88 |
+
if amp:
|
89 |
+
basis[key] = basis[key].half()
|
90 |
+
if use_pad_trick:
|
91 |
+
basis[key] = F.pad(basis[key], (0, 1)) # Pad the k dimension, that can be sliced later
|
92 |
+
|
93 |
+
idx += 1
|
94 |
+
|
95 |
+
return basis
|
96 |
+
|
97 |
+
|
98 |
+
@torch.jit.script
|
99 |
+
def update_basis_with_fused(basis: Dict[str, Tensor],
|
100 |
+
max_degree: int,
|
101 |
+
use_pad_trick: bool,
|
102 |
+
fully_fused: bool) -> Dict[str, Tensor]:
|
103 |
+
""" Update the basis dict with partially and optionally fully fused bases """
|
104 |
+
num_edges = basis['0,0'].shape[0]
|
105 |
+
device = basis['0,0'].device
|
106 |
+
dtype = basis['0,0'].dtype
|
107 |
+
sum_dim = sum([degree_to_dim(d) for d in range(max_degree + 1)])
|
108 |
+
|
109 |
+
# Fused per output degree
|
110 |
+
for d_out in range(max_degree + 1):
|
111 |
+
sum_freq = sum([degree_to_dim(min(d, d_out)) for d in range(max_degree + 1)])
|
112 |
+
basis_fused = torch.zeros(num_edges, sum_dim, sum_freq, degree_to_dim(d_out) + int(use_pad_trick),
|
113 |
+
device=device, dtype=dtype)
|
114 |
+
acc_d, acc_f = 0, 0
|
115 |
+
for d_in in range(max_degree + 1):
|
116 |
+
basis_fused[:, acc_d:acc_d + degree_to_dim(d_in), acc_f:acc_f + degree_to_dim(min(d_out, d_in)),
|
117 |
+
:degree_to_dim(d_out)] = basis[f'{d_in},{d_out}'][:, :, :, :degree_to_dim(d_out)]
|
118 |
+
|
119 |
+
acc_d += degree_to_dim(d_in)
|
120 |
+
acc_f += degree_to_dim(min(d_out, d_in))
|
121 |
+
|
122 |
+
basis[f'out{d_out}_fused'] = basis_fused
|
123 |
+
|
124 |
+
# Fused per input degree
|
125 |
+
for d_in in range(max_degree + 1):
|
126 |
+
sum_freq = sum([degree_to_dim(min(d, d_in)) for d in range(max_degree + 1)])
|
127 |
+
basis_fused = torch.zeros(num_edges, degree_to_dim(d_in), sum_freq, sum_dim,
|
128 |
+
device=device, dtype=dtype)
|
129 |
+
acc_d, acc_f = 0, 0
|
130 |
+
for d_out in range(max_degree + 1):
|
131 |
+
basis_fused[:, :, acc_f:acc_f + degree_to_dim(min(d_out, d_in)), acc_d:acc_d + degree_to_dim(d_out)] \
|
132 |
+
= basis[f'{d_in},{d_out}'][:, :, :, :degree_to_dim(d_out)]
|
133 |
+
|
134 |
+
acc_d += degree_to_dim(d_out)
|
135 |
+
acc_f += degree_to_dim(min(d_out, d_in))
|
136 |
+
|
137 |
+
basis[f'in{d_in}_fused'] = basis_fused
|
138 |
+
|
139 |
+
if fully_fused:
|
140 |
+
# Fully fused
|
141 |
+
# Double sum this way because of JIT script
|
142 |
+
sum_freq = sum([
|
143 |
+
sum([degree_to_dim(min(d_in, d_out)) for d_in in range(max_degree + 1)]) for d_out in range(max_degree + 1)
|
144 |
+
])
|
145 |
+
basis_fused = torch.zeros(num_edges, sum_dim, sum_freq, sum_dim, device=device, dtype=dtype)
|
146 |
+
|
147 |
+
acc_d, acc_f = 0, 0
|
148 |
+
for d_out in range(max_degree + 1):
|
149 |
+
b = basis[f'out{d_out}_fused']
|
150 |
+
basis_fused[:, :, acc_f:acc_f + b.shape[2], acc_d:acc_d + degree_to_dim(d_out)] = b[:, :, :,
|
151 |
+
:degree_to_dim(d_out)]
|
152 |
+
acc_f += b.shape[2]
|
153 |
+
acc_d += degree_to_dim(d_out)
|
154 |
+
|
155 |
+
basis['fully_fused'] = basis_fused
|
156 |
+
|
157 |
+
del basis['0,0'] # We know that the basis for l = k = 0 is filled with a constant
|
158 |
+
return basis
|
159 |
+
|
160 |
+
|
161 |
+
def get_basis(relative_pos: Tensor,
|
162 |
+
max_degree: int = 4,
|
163 |
+
compute_gradients: bool = False,
|
164 |
+
use_pad_trick: bool = False,
|
165 |
+
amp: bool = False) -> Dict[str, Tensor]:
|
166 |
+
with nvtx_range('spherical harmonics'):
|
167 |
+
spherical_harmonics = get_spherical_harmonics(relative_pos, max_degree)
|
168 |
+
with nvtx_range('CB coefficients'):
|
169 |
+
clebsch_gordon = get_all_clebsch_gordon(max_degree, relative_pos.device)
|
170 |
+
|
171 |
+
with torch.autograd.set_grad_enabled(compute_gradients):
|
172 |
+
with nvtx_range('bases'):
|
173 |
+
basis = get_basis_script(max_degree=max_degree,
|
174 |
+
use_pad_trick=use_pad_trick,
|
175 |
+
spherical_harmonics=spherical_harmonics,
|
176 |
+
clebsch_gordon=clebsch_gordon,
|
177 |
+
amp=amp)
|
178 |
+
return basis
|
RFdiffusion/env/SE3Transformer/build/lib/se3_transformer/model/fiber.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
4 |
+
# copy of this software and associated documentation files (the "Software"),
|
5 |
+
# to deal in the Software without restriction, including without limitation
|
6 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
8 |
+
# Software is furnished to do so, subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in
|
11 |
+
# all copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
16 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
18 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
19 |
+
# DEALINGS IN THE SOFTWARE.
|
20 |
+
#
|
21 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
22 |
+
# SPDX-License-Identifier: MIT
|
23 |
+
|
24 |
+
|
25 |
+
from collections import namedtuple
|
26 |
+
from itertools import product
|
27 |
+
from typing import Dict
|
28 |
+
|
29 |
+
import torch
|
30 |
+
from torch import Tensor
|
31 |
+
|
32 |
+
from se3_transformer.runtime.utils import degree_to_dim
|
33 |
+
|
34 |
+
FiberEl = namedtuple('FiberEl', ['degree', 'channels'])
|
35 |
+
|
36 |
+
|
37 |
+
class Fiber(dict):
|
38 |
+
"""
|
39 |
+
Describes the structure of some set of features.
|
40 |
+
Features are split into types (0, 1, 2, 3, ...). A feature of type k has a dimension of 2k+1.
|
41 |
+
Type-0 features: invariant scalars
|
42 |
+
Type-1 features: equivariant 3D vectors
|
43 |
+
Type-2 features: equivariant symmetric traceless matrices
|
44 |
+
...
|
45 |
+
|
46 |
+
As inputs to a SE3 layer, there can be many features of the same types, and many features of different types.
|
47 |
+
The 'multiplicity' or 'number of channels' is the number of features of a given type.
|
48 |
+
This class puts together all the degrees and their multiplicities in order to describe
|
49 |
+
the inputs, outputs or hidden features of SE3 layers.
|
50 |
+
"""
|
51 |
+
|
52 |
+
def __init__(self, structure):
|
53 |
+
if isinstance(structure, dict):
|
54 |
+
structure = [FiberEl(int(d), int(m)) for d, m in sorted(structure.items(), key=lambda x: x[1])]
|
55 |
+
elif not isinstance(structure[0], FiberEl):
|
56 |
+
structure = list(map(lambda t: FiberEl(*t), sorted(structure, key=lambda x: x[1])))
|
57 |
+
self.structure = structure
|
58 |
+
super().__init__({d: m for d, m in self.structure})
|
59 |
+
|
60 |
+
@property
|
61 |
+
def degrees(self):
|
62 |
+
return sorted([t.degree for t in self.structure])
|
63 |
+
|
64 |
+
@property
|
65 |
+
def channels(self):
|
66 |
+
return [self[d] for d in self.degrees]
|
67 |
+
|
68 |
+
@property
|
69 |
+
def num_features(self):
|
70 |
+
""" Size of the resulting tensor if all features were concatenated together """
|
71 |
+
return sum(t.channels * degree_to_dim(t.degree) for t in self.structure)
|
72 |
+
|
73 |
+
@staticmethod
|
74 |
+
def create(num_degrees: int, num_channels: int):
|
75 |
+
""" Create a Fiber with degrees 0..num_degrees-1, all with the same multiplicity """
|
76 |
+
return Fiber([(degree, num_channels) for degree in range(num_degrees)])
|
77 |
+
|
78 |
+
@staticmethod
|
79 |
+
def from_features(feats: Dict[str, Tensor]):
|
80 |
+
""" Infer the Fiber structure from a feature dict """
|
81 |
+
structure = {}
|
82 |
+
for k, v in feats.items():
|
83 |
+
degree = int(k)
|
84 |
+
assert len(v.shape) == 3, 'Feature shape should be (N, C, 2D+1)'
|
85 |
+
assert v.shape[-1] == degree_to_dim(degree)
|
86 |
+
structure[degree] = v.shape[-2]
|
87 |
+
return Fiber(structure)
|
88 |
+
|
89 |
+
def __getitem__(self, degree: int):
|
90 |
+
""" fiber[degree] returns the multiplicity for this degree """
|
91 |
+
return dict(self.structure).get(degree, 0)
|
92 |
+
|
93 |
+
def __iter__(self):
|
94 |
+
""" Iterate over namedtuples (degree, channels) """
|
95 |
+
return iter(self.structure)
|
96 |
+
|
97 |
+
def __mul__(self, other):
|
98 |
+
"""
|
99 |
+
If other in an int, multiplies all the multiplicities by other.
|
100 |
+
If other is a fiber, returns the cartesian product.
|
101 |
+
"""
|
102 |
+
if isinstance(other, Fiber):
|
103 |
+
return product(self.structure, other.structure)
|
104 |
+
elif isinstance(other, int):
|
105 |
+
return Fiber({t.degree: t.channels * other for t in self.structure})
|
106 |
+
|
107 |
+
def __add__(self, other):
|
108 |
+
"""
|
109 |
+
If other in an int, add other to all the multiplicities.
|
110 |
+
If other is a fiber, add the multiplicities of the fibers together.
|
111 |
+
"""
|
112 |
+
if isinstance(other, Fiber):
|
113 |
+
return Fiber({t.degree: t.channels + other[t.degree] for t in self.structure})
|
114 |
+
elif isinstance(other, int):
|
115 |
+
return Fiber({t.degree: t.channels + other for t in self.structure})
|
116 |
+
|
117 |
+
def __repr__(self):
|
118 |
+
return str(self.structure)
|
119 |
+
|
120 |
+
@staticmethod
|
121 |
+
def combine_max(f1, f2):
|
122 |
+
""" Combine two fiber by taking the maximum multiplicity for each degree in both fibers """
|
123 |
+
new_dict = dict(f1.structure)
|
124 |
+
for k, m in f2.structure:
|
125 |
+
new_dict[k] = max(new_dict.get(k, 0), m)
|
126 |
+
|
127 |
+
return Fiber(list(new_dict.items()))
|
128 |
+
|
129 |
+
@staticmethod
|
130 |
+
def combine_selectively(f1, f2):
|
131 |
+
""" Combine two fiber by taking the sum of multiplicities for each degree in the first fiber """
|
132 |
+
# only use orders which occur in fiber f1
|
133 |
+
new_dict = dict(f1.structure)
|
134 |
+
for k in f1.degrees:
|
135 |
+
if k in f2.degrees:
|
136 |
+
new_dict[k] += f2[k]
|
137 |
+
return Fiber(list(new_dict.items()))
|
138 |
+
|
139 |
+
def to_attention_heads(self, tensors: Dict[str, Tensor], num_heads: int):
|
140 |
+
# dict(N, num_channels, 2d+1) -> (N, num_heads, -1)
|
141 |
+
fibers = [tensors[str(degree)].reshape(*tensors[str(degree)].shape[:-2], num_heads, -1) for degree in
|
142 |
+
self.degrees]
|
143 |
+
fibers = torch.cat(fibers, -1)
|
144 |
+
return fibers
|
RFdiffusion/env/SE3Transformer/build/lib/se3_transformer/model/layers/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .linear import LinearSE3
|
2 |
+
from .norm import NormSE3
|
3 |
+
from .pooling import GPooling
|
4 |
+
from .convolution import ConvSE3
|
5 |
+
from .attention import AttentionBlockSE3
|
RFdiffusion/env/SE3Transformer/build/lib/se3_transformer/model/layers/attention.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
4 |
+
# copy of this software and associated documentation files (the "Software"),
|
5 |
+
# to deal in the Software without restriction, including without limitation
|
6 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
8 |
+
# Software is furnished to do so, subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in
|
11 |
+
# all copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
16 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
18 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
19 |
+
# DEALINGS IN THE SOFTWARE.
|
20 |
+
#
|
21 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
22 |
+
# SPDX-License-Identifier: MIT
|
23 |
+
|
24 |
+
import dgl
|
25 |
+
import numpy as np
|
26 |
+
import torch
|
27 |
+
import torch.nn as nn
|
28 |
+
from dgl import DGLGraph
|
29 |
+
from dgl.ops import edge_softmax
|
30 |
+
from torch import Tensor
|
31 |
+
from typing import Dict, Optional, Union
|
32 |
+
|
33 |
+
from se3_transformer.model.fiber import Fiber
|
34 |
+
from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel
|
35 |
+
from se3_transformer.model.layers.linear import LinearSE3
|
36 |
+
from se3_transformer.runtime.utils import degree_to_dim, aggregate_residual, unfuse_features
|
37 |
+
from torch.cuda.nvtx import range as nvtx_range
|
38 |
+
|
39 |
+
|
40 |
+
class AttentionSE3(nn.Module):
|
41 |
+
""" Multi-headed sparse graph self-attention (SE(3)-equivariant) """
|
42 |
+
|
43 |
+
def __init__(
|
44 |
+
self,
|
45 |
+
num_heads: int,
|
46 |
+
key_fiber: Fiber,
|
47 |
+
value_fiber: Fiber
|
48 |
+
):
|
49 |
+
"""
|
50 |
+
:param num_heads: Number of attention heads
|
51 |
+
:param key_fiber: Fiber for the keys (and also for the queries)
|
52 |
+
:param value_fiber: Fiber for the values
|
53 |
+
"""
|
54 |
+
super().__init__()
|
55 |
+
self.num_heads = num_heads
|
56 |
+
self.key_fiber = key_fiber
|
57 |
+
self.value_fiber = value_fiber
|
58 |
+
|
59 |
+
def forward(
|
60 |
+
self,
|
61 |
+
value: Union[Tensor, Dict[str, Tensor]], # edge features (may be fused)
|
62 |
+
key: Union[Tensor, Dict[str, Tensor]], # edge features (may be fused)
|
63 |
+
query: Dict[str, Tensor], # node features
|
64 |
+
graph: DGLGraph
|
65 |
+
):
|
66 |
+
with nvtx_range('AttentionSE3'):
|
67 |
+
with nvtx_range('reshape keys and queries'):
|
68 |
+
if isinstance(key, Tensor):
|
69 |
+
# case where features of all types are fused
|
70 |
+
key = key.reshape(key.shape[0], self.num_heads, -1)
|
71 |
+
# need to reshape queries that way to keep the same layout as keys
|
72 |
+
out = torch.cat([query[str(d)] for d in self.key_fiber.degrees], dim=-1)
|
73 |
+
query = out.reshape(list(query.values())[0].shape[0], self.num_heads, -1)
|
74 |
+
else:
|
75 |
+
# features are not fused, need to fuse and reshape them
|
76 |
+
key = self.key_fiber.to_attention_heads(key, self.num_heads)
|
77 |
+
query = self.key_fiber.to_attention_heads(query, self.num_heads)
|
78 |
+
|
79 |
+
with nvtx_range('attention dot product + softmax'):
|
80 |
+
# Compute attention weights (softmax of inner product between key and query)
|
81 |
+
edge_weights = dgl.ops.e_dot_v(graph, key, query).squeeze(-1)
|
82 |
+
edge_weights /= np.sqrt(self.key_fiber.num_features)
|
83 |
+
edge_weights = edge_softmax(graph, edge_weights)
|
84 |
+
edge_weights = edge_weights[..., None, None]
|
85 |
+
|
86 |
+
with nvtx_range('weighted sum'):
|
87 |
+
if isinstance(value, Tensor):
|
88 |
+
# features of all types are fused
|
89 |
+
v = value.view(value.shape[0], self.num_heads, -1, value.shape[-1])
|
90 |
+
weights = edge_weights * v
|
91 |
+
feat_out = dgl.ops.copy_e_sum(graph, weights)
|
92 |
+
feat_out = feat_out.view(feat_out.shape[0], -1, feat_out.shape[-1]) # merge heads
|
93 |
+
out = unfuse_features(feat_out, self.value_fiber.degrees)
|
94 |
+
else:
|
95 |
+
out = {}
|
96 |
+
for degree, channels in self.value_fiber:
|
97 |
+
v = value[str(degree)].view(-1, self.num_heads, channels // self.num_heads,
|
98 |
+
degree_to_dim(degree))
|
99 |
+
weights = edge_weights * v
|
100 |
+
res = dgl.ops.copy_e_sum(graph, weights)
|
101 |
+
out[str(degree)] = res.view(-1, channels, degree_to_dim(degree)) # merge heads
|
102 |
+
|
103 |
+
return out
|
104 |
+
|
105 |
+
|
106 |
+
class AttentionBlockSE3(nn.Module):
|
107 |
+
""" Multi-headed sparse graph self-attention block with skip connection, linear projection (SE(3)-equivariant) """
|
108 |
+
|
109 |
+
def __init__(
|
110 |
+
self,
|
111 |
+
fiber_in: Fiber,
|
112 |
+
fiber_out: Fiber,
|
113 |
+
fiber_edge: Optional[Fiber] = None,
|
114 |
+
num_heads: int = 4,
|
115 |
+
channels_div: int = 2,
|
116 |
+
use_layer_norm: bool = False,
|
117 |
+
max_degree: bool = 4,
|
118 |
+
fuse_level: ConvSE3FuseLevel = ConvSE3FuseLevel.FULL,
|
119 |
+
**kwargs
|
120 |
+
):
|
121 |
+
"""
|
122 |
+
:param fiber_in: Fiber describing the input features
|
123 |
+
:param fiber_out: Fiber describing the output features
|
124 |
+
:param fiber_edge: Fiber describing the edge features (node distances excluded)
|
125 |
+
:param num_heads: Number of attention heads
|
126 |
+
:param channels_div: Divide the channels by this integer for computing values
|
127 |
+
:param use_layer_norm: Apply layer normalization between MLP layers
|
128 |
+
:param max_degree: Maximum degree used in the bases computation
|
129 |
+
:param fuse_level: Maximum fuse level to use in TFN convolutions
|
130 |
+
"""
|
131 |
+
super().__init__()
|
132 |
+
if fiber_edge is None:
|
133 |
+
fiber_edge = Fiber({})
|
134 |
+
self.fiber_in = fiber_in
|
135 |
+
# value_fiber has same structure as fiber_out but #channels divided by 'channels_div'
|
136 |
+
value_fiber = Fiber([(degree, channels // channels_div) for degree, channels in fiber_out])
|
137 |
+
# key_query_fiber has the same structure as fiber_out, but only degrees which are in in_fiber
|
138 |
+
# (queries are merely projected, hence degrees have to match input)
|
139 |
+
key_query_fiber = Fiber([(fe.degree, fe.channels) for fe in value_fiber if fe.degree in fiber_in.degrees])
|
140 |
+
|
141 |
+
self.to_key_value = ConvSE3(fiber_in, value_fiber + key_query_fiber, pool=False, fiber_edge=fiber_edge,
|
142 |
+
use_layer_norm=use_layer_norm, max_degree=max_degree, fuse_level=fuse_level,
|
143 |
+
allow_fused_output=True)
|
144 |
+
self.to_query = LinearSE3(fiber_in, key_query_fiber)
|
145 |
+
self.attention = AttentionSE3(num_heads, key_query_fiber, value_fiber)
|
146 |
+
self.project = LinearSE3(value_fiber + fiber_in, fiber_out)
|
147 |
+
|
148 |
+
def forward(
|
149 |
+
self,
|
150 |
+
node_features: Dict[str, Tensor],
|
151 |
+
edge_features: Dict[str, Tensor],
|
152 |
+
graph: DGLGraph,
|
153 |
+
basis: Dict[str, Tensor]
|
154 |
+
):
|
155 |
+
with nvtx_range('AttentionBlockSE3'):
|
156 |
+
with nvtx_range('keys / values'):
|
157 |
+
fused_key_value = self.to_key_value(node_features, edge_features, graph, basis)
|
158 |
+
key, value = self._get_key_value_from_fused(fused_key_value)
|
159 |
+
|
160 |
+
with nvtx_range('queries'):
|
161 |
+
query = self.to_query(node_features)
|
162 |
+
|
163 |
+
z = self.attention(value, key, query, graph)
|
164 |
+
z_concat = aggregate_residual(node_features, z, 'cat')
|
165 |
+
return self.project(z_concat)
|
166 |
+
|
167 |
+
def _get_key_value_from_fused(self, fused_key_value):
|
168 |
+
# Extract keys and queries features from fused features
|
169 |
+
if isinstance(fused_key_value, Tensor):
|
170 |
+
# Previous layer was a fully fused convolution
|
171 |
+
value, key = torch.chunk(fused_key_value, chunks=2, dim=-2)
|
172 |
+
else:
|
173 |
+
key, value = {}, {}
|
174 |
+
for degree, feat in fused_key_value.items():
|
175 |
+
if int(degree) in self.fiber_in.degrees:
|
176 |
+
value[degree], key[degree] = torch.chunk(feat, chunks=2, dim=-2)
|
177 |
+
else:
|
178 |
+
value[degree] = feat
|
179 |
+
|
180 |
+
return key, value
|
RFdiffusion/env/SE3Transformer/build/lib/se3_transformer/model/layers/convolution.py
ADDED
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
4 |
+
# copy of this software and associated documentation files (the "Software"),
|
5 |
+
# to deal in the Software without restriction, including without limitation
|
6 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
8 |
+
# Software is furnished to do so, subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in
|
11 |
+
# all copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
16 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
18 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
19 |
+
# DEALINGS IN THE SOFTWARE.
|
20 |
+
#
|
21 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
22 |
+
# SPDX-License-Identifier: MIT
|
23 |
+
|
24 |
+
from enum import Enum
|
25 |
+
from itertools import product
|
26 |
+
from typing import Dict
|
27 |
+
|
28 |
+
import dgl
|
29 |
+
import numpy as np
|
30 |
+
import torch
|
31 |
+
import torch.nn as nn
|
32 |
+
from dgl import DGLGraph
|
33 |
+
from torch import Tensor
|
34 |
+
from torch.cuda.nvtx import range as nvtx_range
|
35 |
+
|
36 |
+
from se3_transformer.model.fiber import Fiber
|
37 |
+
from se3_transformer.runtime.utils import degree_to_dim, unfuse_features
|
38 |
+
|
39 |
+
|
40 |
+
class ConvSE3FuseLevel(Enum):
|
41 |
+
"""
|
42 |
+
Enum to select a maximum level of fusing optimizations that will be applied when certain conditions are met.
|
43 |
+
If a desired level L is picked and the level L cannot be applied to a level, other fused ops < L are considered.
|
44 |
+
A higher level means faster training, but also more memory usage.
|
45 |
+
If you are tight on memory and want to feed large inputs to the network, choose a low value.
|
46 |
+
If you want to train fast, choose a high value.
|
47 |
+
Recommended value is FULL with AMP.
|
48 |
+
|
49 |
+
Fully fused TFN convolutions requirements:
|
50 |
+
- all input channels are the same
|
51 |
+
- all output channels are the same
|
52 |
+
- input degrees span the range [0, ..., max_degree]
|
53 |
+
- output degrees span the range [0, ..., max_degree]
|
54 |
+
|
55 |
+
Partially fused TFN convolutions requirements:
|
56 |
+
* For fusing by output degree:
|
57 |
+
- all input channels are the same
|
58 |
+
- input degrees span the range [0, ..., max_degree]
|
59 |
+
* For fusing by input degree:
|
60 |
+
- all output channels are the same
|
61 |
+
- output degrees span the range [0, ..., max_degree]
|
62 |
+
|
63 |
+
Original TFN pairwise convolutions: no requirements
|
64 |
+
"""
|
65 |
+
|
66 |
+
FULL = 2
|
67 |
+
PARTIAL = 1
|
68 |
+
NONE = 0
|
69 |
+
|
70 |
+
|
71 |
+
class RadialProfile(nn.Module):
|
72 |
+
"""
|
73 |
+
Radial profile function.
|
74 |
+
Outputs weights used to weigh basis matrices in order to get convolution kernels.
|
75 |
+
In TFN notation: $R^{l,k}$
|
76 |
+
In SE(3)-Transformer notation: $\phi^{l,k}$
|
77 |
+
|
78 |
+
Note:
|
79 |
+
In the original papers, this function only depends on relative node distances ||x||.
|
80 |
+
Here, we allow this function to also take as input additional invariant edge features.
|
81 |
+
This does not break equivariance and adds expressive power to the model.
|
82 |
+
|
83 |
+
Diagram:
|
84 |
+
invariant edge features (node distances included) ───> MLP layer (shared across edges) ───> radial weights
|
85 |
+
"""
|
86 |
+
|
87 |
+
def __init__(
|
88 |
+
self,
|
89 |
+
num_freq: int,
|
90 |
+
channels_in: int,
|
91 |
+
channels_out: int,
|
92 |
+
edge_dim: int = 1,
|
93 |
+
mid_dim: int = 32,
|
94 |
+
use_layer_norm: bool = False
|
95 |
+
):
|
96 |
+
"""
|
97 |
+
:param num_freq: Number of frequencies
|
98 |
+
:param channels_in: Number of input channels
|
99 |
+
:param channels_out: Number of output channels
|
100 |
+
:param edge_dim: Number of invariant edge features (input to the radial function)
|
101 |
+
:param mid_dim: Size of the hidden MLP layers
|
102 |
+
:param use_layer_norm: Apply layer normalization between MLP layers
|
103 |
+
"""
|
104 |
+
super().__init__()
|
105 |
+
modules = [
|
106 |
+
nn.Linear(edge_dim, mid_dim),
|
107 |
+
nn.LayerNorm(mid_dim) if use_layer_norm else None,
|
108 |
+
nn.ReLU(),
|
109 |
+
nn.Linear(mid_dim, mid_dim),
|
110 |
+
nn.LayerNorm(mid_dim) if use_layer_norm else None,
|
111 |
+
nn.ReLU(),
|
112 |
+
nn.Linear(mid_dim, num_freq * channels_in * channels_out, bias=False)
|
113 |
+
]
|
114 |
+
|
115 |
+
self.net = nn.Sequential(*[m for m in modules if m is not None])
|
116 |
+
|
117 |
+
def forward(self, features: Tensor) -> Tensor:
|
118 |
+
return self.net(features)
|
119 |
+
|
120 |
+
|
121 |
+
class VersatileConvSE3(nn.Module):
|
122 |
+
"""
|
123 |
+
Building block for TFN convolutions.
|
124 |
+
This single module can be used for fully fused convolutions, partially fused convolutions, or pairwise convolutions.
|
125 |
+
"""
|
126 |
+
|
127 |
+
def __init__(self,
|
128 |
+
freq_sum: int,
|
129 |
+
channels_in: int,
|
130 |
+
channels_out: int,
|
131 |
+
edge_dim: int,
|
132 |
+
use_layer_norm: bool,
|
133 |
+
fuse_level: ConvSE3FuseLevel):
|
134 |
+
super().__init__()
|
135 |
+
self.freq_sum = freq_sum
|
136 |
+
self.channels_out = channels_out
|
137 |
+
self.channels_in = channels_in
|
138 |
+
self.fuse_level = fuse_level
|
139 |
+
self.radial_func = RadialProfile(num_freq=freq_sum,
|
140 |
+
channels_in=channels_in,
|
141 |
+
channels_out=channels_out,
|
142 |
+
edge_dim=edge_dim,
|
143 |
+
use_layer_norm=use_layer_norm)
|
144 |
+
|
145 |
+
def forward(self, features: Tensor, invariant_edge_feats: Tensor, basis: Tensor):
|
146 |
+
with nvtx_range(f'VersatileConvSE3'):
|
147 |
+
num_edges = features.shape[0]
|
148 |
+
in_dim = features.shape[2]
|
149 |
+
with nvtx_range(f'RadialProfile'):
|
150 |
+
radial_weights = self.radial_func(invariant_edge_feats) \
|
151 |
+
.view(-1, self.channels_out, self.channels_in * self.freq_sum)
|
152 |
+
|
153 |
+
if basis is not None:
|
154 |
+
# This block performs the einsum n i l, n o i f, n l f k -> n o k
|
155 |
+
out_dim = basis.shape[-1]
|
156 |
+
if self.fuse_level != ConvSE3FuseLevel.FULL:
|
157 |
+
out_dim += out_dim % 2 - 1 # Account for padded basis
|
158 |
+
basis_view = basis.view(num_edges, in_dim, -1)
|
159 |
+
tmp = (features @ basis_view).view(num_edges, -1, basis.shape[-1])
|
160 |
+
return (radial_weights @ tmp)[:, :, :out_dim]
|
161 |
+
else:
|
162 |
+
# k = l = 0 non-fused case
|
163 |
+
return radial_weights @ features
|
164 |
+
|
165 |
+
|
166 |
+
class ConvSE3(nn.Module):
|
167 |
+
"""
|
168 |
+
SE(3)-equivariant graph convolution (Tensor Field Network convolution).
|
169 |
+
This convolution can map an arbitrary input Fiber to an arbitrary output Fiber, while preserving equivariance.
|
170 |
+
Features of different degrees interact together to produce output features.
|
171 |
+
|
172 |
+
Note 1:
|
173 |
+
The option is given to not pool the output. This means that the convolution sum over neighbors will not be
|
174 |
+
done, and the returned features will be edge features instead of node features.
|
175 |
+
|
176 |
+
Note 2:
|
177 |
+
Unlike the original paper and implementation, this convolution can handle edge feature of degree greater than 0.
|
178 |
+
Input edge features are concatenated with input source node features before the kernel is applied.
|
179 |
+
"""
|
180 |
+
|
181 |
+
def __init__(
|
182 |
+
self,
|
183 |
+
fiber_in: Fiber,
|
184 |
+
fiber_out: Fiber,
|
185 |
+
fiber_edge: Fiber,
|
186 |
+
pool: bool = True,
|
187 |
+
use_layer_norm: bool = False,
|
188 |
+
self_interaction: bool = False,
|
189 |
+
max_degree: int = 4,
|
190 |
+
fuse_level: ConvSE3FuseLevel = ConvSE3FuseLevel.FULL,
|
191 |
+
allow_fused_output: bool = False
|
192 |
+
):
|
193 |
+
"""
|
194 |
+
:param fiber_in: Fiber describing the input features
|
195 |
+
:param fiber_out: Fiber describing the output features
|
196 |
+
:param fiber_edge: Fiber describing the edge features (node distances excluded)
|
197 |
+
:param pool: If True, compute final node features by averaging incoming edge features
|
198 |
+
:param use_layer_norm: Apply layer normalization between MLP layers
|
199 |
+
:param self_interaction: Apply self-interaction of nodes
|
200 |
+
:param max_degree: Maximum degree used in the bases computation
|
201 |
+
:param fuse_level: Maximum fuse level to use in TFN convolutions
|
202 |
+
:param allow_fused_output: Allow the module to output a fused representation of features
|
203 |
+
"""
|
204 |
+
super().__init__()
|
205 |
+
self.pool = pool
|
206 |
+
self.fiber_in = fiber_in
|
207 |
+
self.fiber_out = fiber_out
|
208 |
+
self.self_interaction = self_interaction
|
209 |
+
self.max_degree = max_degree
|
210 |
+
self.allow_fused_output = allow_fused_output
|
211 |
+
|
212 |
+
# channels_in: account for the concatenation of edge features
|
213 |
+
channels_in_set = set([f.channels + fiber_edge[f.degree] * (f.degree > 0) for f in self.fiber_in])
|
214 |
+
channels_out_set = set([f.channels for f in self.fiber_out])
|
215 |
+
unique_channels_in = (len(channels_in_set) == 1)
|
216 |
+
unique_channels_out = (len(channels_out_set) == 1)
|
217 |
+
degrees_up_to_max = list(range(max_degree + 1))
|
218 |
+
common_args = dict(edge_dim=fiber_edge[0] + 1, use_layer_norm=use_layer_norm)
|
219 |
+
|
220 |
+
if fuse_level.value >= ConvSE3FuseLevel.FULL.value and \
|
221 |
+
unique_channels_in and fiber_in.degrees == degrees_up_to_max and \
|
222 |
+
unique_channels_out and fiber_out.degrees == degrees_up_to_max:
|
223 |
+
# Single fused convolution
|
224 |
+
self.used_fuse_level = ConvSE3FuseLevel.FULL
|
225 |
+
|
226 |
+
sum_freq = sum([
|
227 |
+
degree_to_dim(min(d_in, d_out))
|
228 |
+
for d_in, d_out in product(degrees_up_to_max, degrees_up_to_max)
|
229 |
+
])
|
230 |
+
|
231 |
+
self.conv = VersatileConvSE3(sum_freq, list(channels_in_set)[0], list(channels_out_set)[0],
|
232 |
+
fuse_level=self.used_fuse_level, **common_args)
|
233 |
+
|
234 |
+
elif fuse_level.value >= ConvSE3FuseLevel.PARTIAL.value and \
|
235 |
+
unique_channels_in and fiber_in.degrees == degrees_up_to_max:
|
236 |
+
# Convolutions fused per output degree
|
237 |
+
self.used_fuse_level = ConvSE3FuseLevel.PARTIAL
|
238 |
+
self.conv_out = nn.ModuleDict()
|
239 |
+
for d_out, c_out in fiber_out:
|
240 |
+
sum_freq = sum([degree_to_dim(min(d_out, d)) for d in fiber_in.degrees])
|
241 |
+
self.conv_out[str(d_out)] = VersatileConvSE3(sum_freq, list(channels_in_set)[0], c_out,
|
242 |
+
fuse_level=self.used_fuse_level, **common_args)
|
243 |
+
|
244 |
+
elif fuse_level.value >= ConvSE3FuseLevel.PARTIAL.value and \
|
245 |
+
unique_channels_out and fiber_out.degrees == degrees_up_to_max:
|
246 |
+
# Convolutions fused per input degree
|
247 |
+
self.used_fuse_level = ConvSE3FuseLevel.PARTIAL
|
248 |
+
self.conv_in = nn.ModuleDict()
|
249 |
+
for d_in, c_in in fiber_in:
|
250 |
+
sum_freq = sum([degree_to_dim(min(d_in, d)) for d in fiber_out.degrees])
|
251 |
+
self.conv_in[str(d_in)] = VersatileConvSE3(sum_freq, c_in, list(channels_out_set)[0],
|
252 |
+
fuse_level=ConvSE3FuseLevel.FULL, **common_args)
|
253 |
+
#fuse_level=self.used_fuse_level, **common_args)
|
254 |
+
else:
|
255 |
+
# Use pairwise TFN convolutions
|
256 |
+
self.used_fuse_level = ConvSE3FuseLevel.NONE
|
257 |
+
self.conv = nn.ModuleDict()
|
258 |
+
for (degree_in, channels_in), (degree_out, channels_out) in (self.fiber_in * self.fiber_out):
|
259 |
+
dict_key = f'{degree_in},{degree_out}'
|
260 |
+
channels_in_new = channels_in + fiber_edge[degree_in] * (degree_in > 0)
|
261 |
+
sum_freq = degree_to_dim(min(degree_in, degree_out))
|
262 |
+
self.conv[dict_key] = VersatileConvSE3(sum_freq, channels_in_new, channels_out,
|
263 |
+
fuse_level=self.used_fuse_level, **common_args)
|
264 |
+
|
265 |
+
if self_interaction:
|
266 |
+
self.to_kernel_self = nn.ParameterDict()
|
267 |
+
for degree_out, channels_out in fiber_out:
|
268 |
+
if fiber_in[degree_out]:
|
269 |
+
self.to_kernel_self[str(degree_out)] = nn.Parameter(
|
270 |
+
torch.randn(channels_out, fiber_in[degree_out]) / np.sqrt(fiber_in[degree_out]))
|
271 |
+
|
272 |
+
def forward(
|
273 |
+
self,
|
274 |
+
node_feats: Dict[str, Tensor],
|
275 |
+
edge_feats: Dict[str, Tensor],
|
276 |
+
graph: DGLGraph,
|
277 |
+
basis: Dict[str, Tensor]
|
278 |
+
):
|
279 |
+
with nvtx_range(f'ConvSE3'):
|
280 |
+
invariant_edge_feats = edge_feats['0'].squeeze(-1)
|
281 |
+
src, dst = graph.edges()
|
282 |
+
out = {}
|
283 |
+
in_features = []
|
284 |
+
|
285 |
+
# Fetch all input features from edge and node features
|
286 |
+
for degree_in in self.fiber_in.degrees:
|
287 |
+
src_node_features = node_feats[str(degree_in)][src]
|
288 |
+
if degree_in > 0 and str(degree_in) in edge_feats:
|
289 |
+
# Handle edge features of any type by concatenating them to node features
|
290 |
+
src_node_features = torch.cat([src_node_features, edge_feats[str(degree_in)]], dim=1)
|
291 |
+
in_features.append(src_node_features)
|
292 |
+
|
293 |
+
if self.used_fuse_level == ConvSE3FuseLevel.FULL:
|
294 |
+
in_features_fused = torch.cat(in_features, dim=-1)
|
295 |
+
out = self.conv(in_features_fused, invariant_edge_feats, basis['fully_fused'])
|
296 |
+
|
297 |
+
if not self.allow_fused_output or self.self_interaction or self.pool:
|
298 |
+
out = unfuse_features(out, self.fiber_out.degrees)
|
299 |
+
|
300 |
+
elif self.used_fuse_level == ConvSE3FuseLevel.PARTIAL and hasattr(self, 'conv_out'):
|
301 |
+
in_features_fused = torch.cat(in_features, dim=-1)
|
302 |
+
for degree_out in self.fiber_out.degrees:
|
303 |
+
out[str(degree_out)] = self.conv_out[str(degree_out)](in_features_fused, invariant_edge_feats,
|
304 |
+
basis[f'out{degree_out}_fused'])
|
305 |
+
|
306 |
+
elif self.used_fuse_level == ConvSE3FuseLevel.PARTIAL and hasattr(self, 'conv_in'):
|
307 |
+
out = 0
|
308 |
+
for degree_in, feature in zip(self.fiber_in.degrees, in_features):
|
309 |
+
out += self.conv_in[str(degree_in)](feature, invariant_edge_feats,
|
310 |
+
basis[f'in{degree_in}_fused'])
|
311 |
+
if not self.allow_fused_output or self.self_interaction or self.pool:
|
312 |
+
out = unfuse_features(out, self.fiber_out.degrees)
|
313 |
+
else:
|
314 |
+
# Fallback to pairwise TFN convolutions
|
315 |
+
for degree_out in self.fiber_out.degrees:
|
316 |
+
out_feature = 0
|
317 |
+
for degree_in, feature in zip(self.fiber_in.degrees, in_features):
|
318 |
+
dict_key = f'{degree_in},{degree_out}'
|
319 |
+
out_feature = out_feature + self.conv[dict_key](feature, invariant_edge_feats,
|
320 |
+
basis.get(dict_key, None))
|
321 |
+
out[str(degree_out)] = out_feature
|
322 |
+
|
323 |
+
for degree_out in self.fiber_out.degrees:
|
324 |
+
if self.self_interaction and str(degree_out) in self.to_kernel_self:
|
325 |
+
with nvtx_range(f'self interaction'):
|
326 |
+
dst_features = node_feats[str(degree_out)][dst]
|
327 |
+
kernel_self = self.to_kernel_self[str(degree_out)]
|
328 |
+
out[str(degree_out)] += kernel_self @ dst_features
|
329 |
+
|
330 |
+
if self.pool:
|
331 |
+
with nvtx_range(f'pooling'):
|
332 |
+
if isinstance(out, dict):
|
333 |
+
out[str(degree_out)] = dgl.ops.copy_e_sum(graph, out[str(degree_out)])
|
334 |
+
else:
|
335 |
+
out = dgl.ops.copy_e_sum(graph, out)
|
336 |
+
return out
|
RFdiffusion/env/SE3Transformer/build/lib/se3_transformer/model/layers/linear.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
4 |
+
# copy of this software and associated documentation files (the "Software"),
|
5 |
+
# to deal in the Software without restriction, including without limitation
|
6 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
8 |
+
# Software is furnished to do so, subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in
|
11 |
+
# all copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
16 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
18 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
19 |
+
# DEALINGS IN THE SOFTWARE.
|
20 |
+
#
|
21 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
22 |
+
# SPDX-License-Identifier: MIT
|
23 |
+
|
24 |
+
|
25 |
+
from typing import Dict
|
26 |
+
|
27 |
+
import numpy as np
|
28 |
+
import torch
|
29 |
+
import torch.nn as nn
|
30 |
+
from torch import Tensor
|
31 |
+
|
32 |
+
from se3_transformer.model.fiber import Fiber
|
33 |
+
|
34 |
+
|
35 |
+
class LinearSE3(nn.Module):
|
36 |
+
"""
|
37 |
+
Graph Linear SE(3)-equivariant layer, equivalent to a 1x1 convolution.
|
38 |
+
Maps a fiber to a fiber with the same degrees (channels may be different).
|
39 |
+
No interaction between degrees, but interaction between channels.
|
40 |
+
|
41 |
+
type-0 features (C_0 channels) ────> Linear(bias=False) ────> type-0 features (C'_0 channels)
|
42 |
+
type-1 features (C_1 channels) ────> Linear(bias=False) ────> type-1 features (C'_1 channels)
|
43 |
+
:
|
44 |
+
type-k features (C_k channels) ────> Linear(bias=False) ────> type-k features (C'_k channels)
|
45 |
+
"""
|
46 |
+
|
47 |
+
def __init__(self, fiber_in: Fiber, fiber_out: Fiber):
|
48 |
+
super().__init__()
|
49 |
+
self.weights = nn.ParameterDict({
|
50 |
+
str(degree_out): nn.Parameter(
|
51 |
+
torch.randn(channels_out, fiber_in[degree_out]) / np.sqrt(fiber_in[degree_out]))
|
52 |
+
for degree_out, channels_out in fiber_out
|
53 |
+
})
|
54 |
+
|
55 |
+
def forward(self, features: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Tensor]:
|
56 |
+
return {
|
57 |
+
degree: self.weights[degree] @ features[degree]
|
58 |
+
for degree, weight in self.weights.items()
|
59 |
+
}
|
RFdiffusion/env/SE3Transformer/build/lib/se3_transformer/model/layers/norm.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
4 |
+
# copy of this software and associated documentation files (the "Software"),
|
5 |
+
# to deal in the Software without restriction, including without limitation
|
6 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
8 |
+
# Software is furnished to do so, subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in
|
11 |
+
# all copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
16 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
18 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
19 |
+
# DEALINGS IN THE SOFTWARE.
|
20 |
+
#
|
21 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
22 |
+
# SPDX-License-Identifier: MIT
|
23 |
+
|
24 |
+
|
25 |
+
from typing import Dict
|
26 |
+
|
27 |
+
import torch
|
28 |
+
import torch.nn as nn
|
29 |
+
from torch import Tensor
|
30 |
+
from torch.cuda.nvtx import range as nvtx_range
|
31 |
+
|
32 |
+
from se3_transformer.model.fiber import Fiber
|
33 |
+
|
34 |
+
|
35 |
+
class NormSE3(nn.Module):
|
36 |
+
"""
|
37 |
+
Norm-based SE(3)-equivariant nonlinearity.
|
38 |
+
|
39 |
+
┌──> feature_norm ──> LayerNorm() ──> ReLU() ──┐
|
40 |
+
feature_in ──┤ * ──> feature_out
|
41 |
+
└──> feature_phase ────────────────────────────┘
|
42 |
+
"""
|
43 |
+
|
44 |
+
NORM_CLAMP = 2 ** -24 # Minimum positive subnormal for FP16
|
45 |
+
|
46 |
+
def __init__(self, fiber: Fiber, nonlinearity: nn.Module = nn.ReLU()):
|
47 |
+
super().__init__()
|
48 |
+
self.fiber = fiber
|
49 |
+
self.nonlinearity = nonlinearity
|
50 |
+
|
51 |
+
if len(set(fiber.channels)) == 1:
|
52 |
+
# Fuse all the layer normalizations into a group normalization
|
53 |
+
self.group_norm = nn.GroupNorm(num_groups=len(fiber.degrees), num_channels=sum(fiber.channels))
|
54 |
+
else:
|
55 |
+
# Use multiple layer normalizations
|
56 |
+
self.layer_norms = nn.ModuleDict({
|
57 |
+
str(degree): nn.LayerNorm(channels)
|
58 |
+
for degree, channels in fiber
|
59 |
+
})
|
60 |
+
|
61 |
+
def forward(self, features: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Tensor]:
|
62 |
+
with nvtx_range('NormSE3'):
|
63 |
+
output = {}
|
64 |
+
if hasattr(self, 'group_norm'):
|
65 |
+
# Compute per-degree norms of features
|
66 |
+
norms = [features[str(d)].norm(dim=-1, keepdim=True).clamp(min=self.NORM_CLAMP)
|
67 |
+
for d in self.fiber.degrees]
|
68 |
+
fused_norms = torch.cat(norms, dim=-2)
|
69 |
+
|
70 |
+
# Transform the norms only
|
71 |
+
new_norms = self.nonlinearity(self.group_norm(fused_norms.squeeze(-1))).unsqueeze(-1)
|
72 |
+
new_norms = torch.chunk(new_norms, chunks=len(self.fiber.degrees), dim=-2)
|
73 |
+
|
74 |
+
# Scale features to the new norms
|
75 |
+
for norm, new_norm, d in zip(norms, new_norms, self.fiber.degrees):
|
76 |
+
output[str(d)] = features[str(d)] / norm * new_norm
|
77 |
+
else:
|
78 |
+
for degree, feat in features.items():
|
79 |
+
norm = feat.norm(dim=-1, keepdim=True).clamp(min=self.NORM_CLAMP)
|
80 |
+
new_norm = self.nonlinearity(self.layer_norms[degree](norm.squeeze(-1)).unsqueeze(-1))
|
81 |
+
output[degree] = new_norm * feat / norm
|
82 |
+
|
83 |
+
return output
|
RFdiffusion/env/SE3Transformer/build/lib/se3_transformer/model/layers/pooling.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
4 |
+
# copy of this software and associated documentation files (the "Software"),
|
5 |
+
# to deal in the Software without restriction, including without limitation
|
6 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
8 |
+
# Software is furnished to do so, subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in
|
11 |
+
# all copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
16 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
18 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
19 |
+
# DEALINGS IN THE SOFTWARE.
|
20 |
+
#
|
21 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
22 |
+
# SPDX-License-Identifier: MIT
|
23 |
+
|
24 |
+
from typing import Dict, Literal
|
25 |
+
|
26 |
+
import torch.nn as nn
|
27 |
+
from dgl import DGLGraph
|
28 |
+
from dgl.nn.pytorch import AvgPooling, MaxPooling
|
29 |
+
from torch import Tensor
|
30 |
+
|
31 |
+
|
32 |
+
class GPooling(nn.Module):
|
33 |
+
"""
|
34 |
+
Graph max/average pooling on a given feature type.
|
35 |
+
The average can be taken for any feature type, and equivariance will be maintained.
|
36 |
+
The maximum can only be taken for invariant features (type 0).
|
37 |
+
If you want max-pooling for type > 0 features, look into Vector Neurons.
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(self, feat_type: int = 0, pool: Literal['max', 'avg'] = 'max'):
|
41 |
+
"""
|
42 |
+
:param feat_type: Feature type to pool
|
43 |
+
:param pool: Type of pooling: max or avg
|
44 |
+
"""
|
45 |
+
super().__init__()
|
46 |
+
assert pool in ['max', 'avg'], f'Unknown pooling: {pool}'
|
47 |
+
assert feat_type == 0 or pool == 'avg', 'Max pooling on type > 0 features will break equivariance'
|
48 |
+
self.feat_type = feat_type
|
49 |
+
self.pool = MaxPooling() if pool == 'max' else AvgPooling()
|
50 |
+
|
51 |
+
def forward(self, features: Dict[str, Tensor], graph: DGLGraph, **kwargs) -> Tensor:
|
52 |
+
pooled = self.pool(graph, features[str(self.feat_type)])
|
53 |
+
return pooled.squeeze(dim=-1)
|
RFdiffusion/env/SE3Transformer/build/lib/se3_transformer/model/transformer.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
4 |
+
# copy of this software and associated documentation files (the "Software"),
|
5 |
+
# to deal in the Software without restriction, including without limitation
|
6 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
8 |
+
# Software is furnished to do so, subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in
|
11 |
+
# all copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
16 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
18 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
19 |
+
# DEALINGS IN THE SOFTWARE.
|
20 |
+
#
|
21 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
22 |
+
# SPDX-License-Identifier: MIT
|
23 |
+
|
24 |
+
import logging
|
25 |
+
from typing import Optional, Literal, Dict
|
26 |
+
|
27 |
+
import torch
|
28 |
+
import torch.nn as nn
|
29 |
+
from dgl import DGLGraph
|
30 |
+
from torch import Tensor
|
31 |
+
|
32 |
+
from se3_transformer.model.basis import get_basis, update_basis_with_fused
|
33 |
+
from se3_transformer.model.layers.attention import AttentionBlockSE3
|
34 |
+
from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel
|
35 |
+
from se3_transformer.model.layers.norm import NormSE3
|
36 |
+
from se3_transformer.model.layers.pooling import GPooling
|
37 |
+
from se3_transformer.runtime.utils import str2bool
|
38 |
+
from se3_transformer.model.fiber import Fiber
|
39 |
+
|
40 |
+
|
41 |
+
class Sequential(nn.Sequential):
|
42 |
+
""" Sequential module with arbitrary forward args and kwargs. Used to pass graph, basis and edge features. """
|
43 |
+
|
44 |
+
def forward(self, input, *args, **kwargs):
|
45 |
+
for module in self:
|
46 |
+
input = module(input, *args, **kwargs)
|
47 |
+
return input
|
48 |
+
|
49 |
+
|
50 |
+
def get_populated_edge_features(relative_pos: Tensor, edge_features: Optional[Dict[str, Tensor]] = None):
|
51 |
+
""" Add relative positions to existing edge features """
|
52 |
+
edge_features = edge_features.copy() if edge_features else {}
|
53 |
+
r = relative_pos.norm(dim=-1, keepdim=True)
|
54 |
+
if '0' in edge_features:
|
55 |
+
edge_features['0'] = torch.cat([edge_features['0'], r[..., None]], dim=1)
|
56 |
+
else:
|
57 |
+
edge_features['0'] = r[..., None]
|
58 |
+
|
59 |
+
return edge_features
|
60 |
+
|
61 |
+
|
62 |
+
class SE3Transformer(nn.Module):
|
63 |
+
def __init__(self,
|
64 |
+
num_layers: int,
|
65 |
+
fiber_in: Fiber,
|
66 |
+
fiber_hidden: Fiber,
|
67 |
+
fiber_out: Fiber,
|
68 |
+
num_heads: int,
|
69 |
+
channels_div: int,
|
70 |
+
fiber_edge: Fiber = Fiber({}),
|
71 |
+
return_type: Optional[int] = None,
|
72 |
+
pooling: Optional[Literal['avg', 'max']] = None,
|
73 |
+
norm: bool = True,
|
74 |
+
use_layer_norm: bool = True,
|
75 |
+
tensor_cores: bool = False,
|
76 |
+
low_memory: bool = False,
|
77 |
+
**kwargs):
|
78 |
+
"""
|
79 |
+
:param num_layers: Number of attention layers
|
80 |
+
:param fiber_in: Input fiber description
|
81 |
+
:param fiber_hidden: Hidden fiber description
|
82 |
+
:param fiber_out: Output fiber description
|
83 |
+
:param fiber_edge: Input edge fiber description
|
84 |
+
:param num_heads: Number of attention heads
|
85 |
+
:param channels_div: Channels division before feeding to attention layer
|
86 |
+
:param return_type: Return only features of this type
|
87 |
+
:param pooling: 'avg' or 'max' graph pooling before MLP layers
|
88 |
+
:param norm: Apply a normalization layer after each attention block
|
89 |
+
:param use_layer_norm: Apply layer normalization between MLP layers
|
90 |
+
:param tensor_cores: True if using Tensor Cores (affects the use of fully fused convs, and padded bases)
|
91 |
+
:param low_memory: If True, will use slower ops that use less memory
|
92 |
+
"""
|
93 |
+
super().__init__()
|
94 |
+
self.num_layers = num_layers
|
95 |
+
self.fiber_edge = fiber_edge
|
96 |
+
self.num_heads = num_heads
|
97 |
+
self.channels_div = channels_div
|
98 |
+
self.return_type = return_type
|
99 |
+
self.pooling = pooling
|
100 |
+
self.max_degree = max(*fiber_in.degrees, *fiber_hidden.degrees, *fiber_out.degrees)
|
101 |
+
self.tensor_cores = tensor_cores
|
102 |
+
self.low_memory = low_memory
|
103 |
+
|
104 |
+
if low_memory and not tensor_cores:
|
105 |
+
logging.warning('Low memory mode will have no effect with no Tensor Cores')
|
106 |
+
|
107 |
+
# Fully fused convolutions when using Tensor Cores (and not low memory mode)
|
108 |
+
fuse_level = ConvSE3FuseLevel.FULL if tensor_cores and not low_memory else ConvSE3FuseLevel.PARTIAL
|
109 |
+
|
110 |
+
graph_modules = []
|
111 |
+
for i in range(num_layers):
|
112 |
+
graph_modules.append(AttentionBlockSE3(fiber_in=fiber_in,
|
113 |
+
fiber_out=fiber_hidden,
|
114 |
+
fiber_edge=fiber_edge,
|
115 |
+
num_heads=num_heads,
|
116 |
+
channels_div=channels_div,
|
117 |
+
use_layer_norm=use_layer_norm,
|
118 |
+
max_degree=self.max_degree,
|
119 |
+
fuse_level=fuse_level))
|
120 |
+
if norm:
|
121 |
+
graph_modules.append(NormSE3(fiber_hidden))
|
122 |
+
fiber_in = fiber_hidden
|
123 |
+
|
124 |
+
graph_modules.append(ConvSE3(fiber_in=fiber_in,
|
125 |
+
fiber_out=fiber_out,
|
126 |
+
fiber_edge=fiber_edge,
|
127 |
+
self_interaction=True,
|
128 |
+
use_layer_norm=use_layer_norm,
|
129 |
+
max_degree=self.max_degree))
|
130 |
+
self.graph_modules = Sequential(*graph_modules)
|
131 |
+
|
132 |
+
if pooling is not None:
|
133 |
+
assert return_type is not None, 'return_type must be specified when pooling'
|
134 |
+
self.pooling_module = GPooling(pool=pooling, feat_type=return_type)
|
135 |
+
|
136 |
+
def forward(self, graph: DGLGraph, node_feats: Dict[str, Tensor],
|
137 |
+
edge_feats: Optional[Dict[str, Tensor]] = None,
|
138 |
+
basis: Optional[Dict[str, Tensor]] = None):
|
139 |
+
# Compute bases in case they weren't precomputed as part of the data loading
|
140 |
+
basis = basis or get_basis(graph.edata['rel_pos'], max_degree=self.max_degree, compute_gradients=False,
|
141 |
+
use_pad_trick=self.tensor_cores and not self.low_memory,
|
142 |
+
amp=torch.is_autocast_enabled())
|
143 |
+
|
144 |
+
# Add fused bases (per output degree, per input degree, and fully fused) to the dict
|
145 |
+
basis = update_basis_with_fused(basis, self.max_degree, use_pad_trick=self.tensor_cores and not self.low_memory,
|
146 |
+
fully_fused=self.tensor_cores and not self.low_memory)
|
147 |
+
|
148 |
+
edge_feats = get_populated_edge_features(graph.edata['rel_pos'], edge_feats)
|
149 |
+
|
150 |
+
node_feats = self.graph_modules(node_feats, edge_feats, graph=graph, basis=basis)
|
151 |
+
|
152 |
+
if self.pooling is not None:
|
153 |
+
return self.pooling_module(node_feats, graph=graph)
|
154 |
+
|
155 |
+
if self.return_type is not None:
|
156 |
+
return node_feats[str(self.return_type)]
|
157 |
+
|
158 |
+
return node_feats
|
159 |
+
|
160 |
+
@staticmethod
|
161 |
+
def add_argparse_args(parser):
|
162 |
+
parser.add_argument('--num_layers', type=int, default=7,
|
163 |
+
help='Number of stacked Transformer layers')
|
164 |
+
parser.add_argument('--num_heads', type=int, default=8,
|
165 |
+
help='Number of heads in self-attention')
|
166 |
+
parser.add_argument('--channels_div', type=int, default=2,
|
167 |
+
help='Channels division before feeding to attention layer')
|
168 |
+
parser.add_argument('--pooling', type=str, default=None, const=None, nargs='?', choices=['max', 'avg'],
|
169 |
+
help='Type of graph pooling')
|
170 |
+
parser.add_argument('--norm', type=str2bool, nargs='?', const=True, default=False,
|
171 |
+
help='Apply a normalization layer after each attention block')
|
172 |
+
parser.add_argument('--use_layer_norm', type=str2bool, nargs='?', const=True, default=False,
|
173 |
+
help='Apply layer normalization between MLP layers')
|
174 |
+
parser.add_argument('--low_memory', type=str2bool, nargs='?', const=True, default=False,
|
175 |
+
help='If true, will use fused ops that are slower but that use less memory '
|
176 |
+
'(expect 25 percent less memory). '
|
177 |
+
'Only has an effect if AMP is enabled on Volta GPUs, or if running on Ampere GPUs')
|
178 |
+
|
179 |
+
return parser
|
180 |
+
|
181 |
+
|
182 |
+
class SE3TransformerPooled(nn.Module):
|
183 |
+
def __init__(self,
|
184 |
+
fiber_in: Fiber,
|
185 |
+
fiber_out: Fiber,
|
186 |
+
fiber_edge: Fiber,
|
187 |
+
num_degrees: int,
|
188 |
+
num_channels: int,
|
189 |
+
output_dim: int,
|
190 |
+
**kwargs):
|
191 |
+
super().__init__()
|
192 |
+
kwargs['pooling'] = kwargs['pooling'] or 'max'
|
193 |
+
self.transformer = SE3Transformer(
|
194 |
+
fiber_in=fiber_in,
|
195 |
+
fiber_hidden=Fiber.create(num_degrees, num_channels),
|
196 |
+
fiber_out=fiber_out,
|
197 |
+
fiber_edge=fiber_edge,
|
198 |
+
return_type=0,
|
199 |
+
**kwargs
|
200 |
+
)
|
201 |
+
|
202 |
+
n_out_features = fiber_out.num_features
|
203 |
+
self.mlp = nn.Sequential(
|
204 |
+
nn.Linear(n_out_features, n_out_features),
|
205 |
+
nn.ReLU(),
|
206 |
+
nn.Linear(n_out_features, output_dim)
|
207 |
+
)
|
208 |
+
|
209 |
+
def forward(self, graph, node_feats, edge_feats, basis=None):
|
210 |
+
feats = self.transformer(graph, node_feats, edge_feats, basis).squeeze(-1)
|
211 |
+
y = self.mlp(feats).squeeze(-1)
|
212 |
+
return y
|
213 |
+
|
214 |
+
@staticmethod
|
215 |
+
def add_argparse_args(parent_parser):
|
216 |
+
parser = parent_parser.add_argument_group("Model architecture")
|
217 |
+
SE3Transformer.add_argparse_args(parser)
|
218 |
+
parser.add_argument('--num_degrees',
|
219 |
+
help='Number of degrees to use. Hidden features will have types [0, ..., num_degrees - 1]',
|
220 |
+
type=int, default=4)
|
221 |
+
parser.add_argument('--num_channels', help='Number of channels for the hidden features', type=int, default=32)
|
222 |
+
return parent_parser
|
RFdiffusion/env/SE3Transformer/build/lib/se3_transformer/runtime/__init__.py
ADDED
File without changes
|
RFdiffusion/env/SE3Transformer/build/lib/se3_transformer/runtime/arguments.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
4 |
+
# copy of this software and associated documentation files (the "Software"),
|
5 |
+
# to deal in the Software without restriction, including without limitation
|
6 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
8 |
+
# Software is furnished to do so, subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in
|
11 |
+
# all copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
16 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
18 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
19 |
+
# DEALINGS IN THE SOFTWARE.
|
20 |
+
#
|
21 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
22 |
+
# SPDX-License-Identifier: MIT
|
23 |
+
|
24 |
+
import argparse
|
25 |
+
import pathlib
|
26 |
+
|
27 |
+
from se3_transformer.data_loading import QM9DataModule
|
28 |
+
from se3_transformer.model import SE3TransformerPooled
|
29 |
+
from se3_transformer.runtime.utils import str2bool
|
30 |
+
|
31 |
+
PARSER = argparse.ArgumentParser(description='SE(3)-Transformer')
|
32 |
+
|
33 |
+
paths = PARSER.add_argument_group('Paths')
|
34 |
+
paths.add_argument('--data_dir', type=pathlib.Path, default=pathlib.Path('./data'),
|
35 |
+
help='Directory where the data is located or should be downloaded')
|
36 |
+
paths.add_argument('--log_dir', type=pathlib.Path, default=pathlib.Path('/results'),
|
37 |
+
help='Directory where the results logs should be saved')
|
38 |
+
paths.add_argument('--dllogger_name', type=str, default='dllogger_results.json',
|
39 |
+
help='Name for the resulting DLLogger JSON file')
|
40 |
+
paths.add_argument('--save_ckpt_path', type=pathlib.Path, default=None,
|
41 |
+
help='File where the checkpoint should be saved')
|
42 |
+
paths.add_argument('--load_ckpt_path', type=pathlib.Path, default=None,
|
43 |
+
help='File of the checkpoint to be loaded')
|
44 |
+
|
45 |
+
optimizer = PARSER.add_argument_group('Optimizer')
|
46 |
+
optimizer.add_argument('--optimizer', choices=['adam', 'sgd', 'lamb'], default='adam')
|
47 |
+
optimizer.add_argument('--learning_rate', '--lr', dest='learning_rate', type=float, default=0.002)
|
48 |
+
optimizer.add_argument('--min_learning_rate', '--min_lr', dest='min_learning_rate', type=float, default=None)
|
49 |
+
optimizer.add_argument('--momentum', type=float, default=0.9)
|
50 |
+
optimizer.add_argument('--weight_decay', type=float, default=0.1)
|
51 |
+
|
52 |
+
PARSER.add_argument('--epochs', type=int, default=100, help='Number of training epochs')
|
53 |
+
PARSER.add_argument('--batch_size', type=int, default=240, help='Batch size')
|
54 |
+
PARSER.add_argument('--seed', type=int, default=None, help='Set a seed globally')
|
55 |
+
PARSER.add_argument('--num_workers', type=int, default=8, help='Number of dataloading workers')
|
56 |
+
|
57 |
+
PARSER.add_argument('--amp', type=str2bool, nargs='?', const=True, default=False, help='Use Automatic Mixed Precision')
|
58 |
+
PARSER.add_argument('--gradient_clip', type=float, default=None, help='Clipping of the gradient norms')
|
59 |
+
PARSER.add_argument('--accumulate_grad_batches', type=int, default=1, help='Gradient accumulation')
|
60 |
+
PARSER.add_argument('--ckpt_interval', type=int, default=-1, help='Save a checkpoint every N epochs')
|
61 |
+
PARSER.add_argument('--eval_interval', dest='eval_interval', type=int, default=1,
|
62 |
+
help='Do an evaluation round every N epochs')
|
63 |
+
PARSER.add_argument('--silent', type=str2bool, nargs='?', const=True, default=False,
|
64 |
+
help='Minimize stdout output')
|
65 |
+
|
66 |
+
PARSER.add_argument('--benchmark', type=str2bool, nargs='?', const=True, default=False,
|
67 |
+
help='Benchmark mode')
|
68 |
+
|
69 |
+
QM9DataModule.add_argparse_args(PARSER)
|
70 |
+
SE3TransformerPooled.add_argparse_args(PARSER)
|
RFdiffusion/env/SE3Transformer/build/lib/se3_transformer/runtime/callbacks.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
4 |
+
# copy of this software and associated documentation files (the "Software"),
|
5 |
+
# to deal in the Software without restriction, including without limitation
|
6 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
8 |
+
# Software is furnished to do so, subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in
|
11 |
+
# all copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
16 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
18 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
19 |
+
# DEALINGS IN THE SOFTWARE.
|
20 |
+
#
|
21 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
22 |
+
# SPDX-License-Identifier: MIT
|
23 |
+
|
24 |
+
import logging
|
25 |
+
import time
|
26 |
+
from abc import ABC, abstractmethod
|
27 |
+
from typing import Optional
|
28 |
+
|
29 |
+
import numpy as np
|
30 |
+
import torch
|
31 |
+
|
32 |
+
from se3_transformer.runtime.loggers import Logger
|
33 |
+
from se3_transformer.runtime.metrics import MeanAbsoluteError
|
34 |
+
|
35 |
+
|
36 |
+
class BaseCallback(ABC):
|
37 |
+
def on_fit_start(self, optimizer, args):
|
38 |
+
pass
|
39 |
+
|
40 |
+
def on_fit_end(self):
|
41 |
+
pass
|
42 |
+
|
43 |
+
def on_epoch_end(self):
|
44 |
+
pass
|
45 |
+
|
46 |
+
def on_batch_start(self):
|
47 |
+
pass
|
48 |
+
|
49 |
+
def on_validation_step(self, input, target, pred):
|
50 |
+
pass
|
51 |
+
|
52 |
+
def on_validation_end(self, epoch=None):
|
53 |
+
pass
|
54 |
+
|
55 |
+
def on_checkpoint_load(self, checkpoint):
|
56 |
+
pass
|
57 |
+
|
58 |
+
def on_checkpoint_save(self, checkpoint):
|
59 |
+
pass
|
60 |
+
|
61 |
+
|
62 |
+
class LRSchedulerCallback(BaseCallback):
|
63 |
+
def __init__(self, logger: Optional[Logger] = None):
|
64 |
+
self.logger = logger
|
65 |
+
self.scheduler = None
|
66 |
+
|
67 |
+
@abstractmethod
|
68 |
+
def get_scheduler(self, optimizer, args):
|
69 |
+
pass
|
70 |
+
|
71 |
+
def on_fit_start(self, optimizer, args):
|
72 |
+
self.scheduler = self.get_scheduler(optimizer, args)
|
73 |
+
|
74 |
+
def on_checkpoint_load(self, checkpoint):
|
75 |
+
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
76 |
+
|
77 |
+
def on_checkpoint_save(self, checkpoint):
|
78 |
+
checkpoint['scheduler_state_dict'] = self.scheduler.state_dict()
|
79 |
+
|
80 |
+
def on_epoch_end(self):
|
81 |
+
if self.logger is not None:
|
82 |
+
self.logger.log_metrics({'learning rate': self.scheduler.get_last_lr()[0]}, step=self.scheduler.last_epoch)
|
83 |
+
self.scheduler.step()
|
84 |
+
|
85 |
+
|
86 |
+
class QM9MetricCallback(BaseCallback):
|
87 |
+
""" Logs the rescaled mean absolute error for QM9 regression tasks """
|
88 |
+
|
89 |
+
def __init__(self, logger, targets_std, prefix=''):
|
90 |
+
self.mae = MeanAbsoluteError()
|
91 |
+
self.logger = logger
|
92 |
+
self.targets_std = targets_std
|
93 |
+
self.prefix = prefix
|
94 |
+
self.best_mae = float('inf')
|
95 |
+
|
96 |
+
def on_validation_step(self, input, target, pred):
|
97 |
+
self.mae(pred.detach(), target.detach())
|
98 |
+
|
99 |
+
def on_validation_end(self, epoch=None):
|
100 |
+
mae = self.mae.compute() * self.targets_std
|
101 |
+
logging.info(f'{self.prefix} MAE: {mae}')
|
102 |
+
self.logger.log_metrics({f'{self.prefix} MAE': mae}, epoch)
|
103 |
+
self.best_mae = min(self.best_mae, mae)
|
104 |
+
|
105 |
+
def on_fit_end(self):
|
106 |
+
if self.best_mae != float('inf'):
|
107 |
+
self.logger.log_metrics({f'{self.prefix} best MAE': self.best_mae})
|
108 |
+
|
109 |
+
|
110 |
+
class QM9LRSchedulerCallback(LRSchedulerCallback):
|
111 |
+
def __init__(self, logger, epochs):
|
112 |
+
super().__init__(logger)
|
113 |
+
self.epochs = epochs
|
114 |
+
|
115 |
+
def get_scheduler(self, optimizer, args):
|
116 |
+
min_lr = args.min_learning_rate if args.min_learning_rate else args.learning_rate / 10.0
|
117 |
+
return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, self.epochs, eta_min=min_lr)
|
118 |
+
|
119 |
+
|
120 |
+
class PerformanceCallback(BaseCallback):
|
121 |
+
def __init__(self, logger, batch_size: int, warmup_epochs: int = 1, mode: str = 'train'):
|
122 |
+
self.batch_size = batch_size
|
123 |
+
self.warmup_epochs = warmup_epochs
|
124 |
+
self.epoch = 0
|
125 |
+
self.timestamps = []
|
126 |
+
self.mode = mode
|
127 |
+
self.logger = logger
|
128 |
+
|
129 |
+
def on_batch_start(self):
|
130 |
+
if self.epoch >= self.warmup_epochs:
|
131 |
+
self.timestamps.append(time.time() * 1000.0)
|
132 |
+
|
133 |
+
def _log_perf(self):
|
134 |
+
stats = self.process_performance_stats()
|
135 |
+
for k, v in stats.items():
|
136 |
+
logging.info(f'performance {k}: {v}')
|
137 |
+
|
138 |
+
self.logger.log_metrics(stats)
|
139 |
+
|
140 |
+
def on_epoch_end(self):
|
141 |
+
self.epoch += 1
|
142 |
+
|
143 |
+
def on_fit_end(self):
|
144 |
+
if self.epoch > self.warmup_epochs:
|
145 |
+
self._log_perf()
|
146 |
+
self.timestamps = []
|
147 |
+
|
148 |
+
def process_performance_stats(self):
|
149 |
+
timestamps = np.asarray(self.timestamps)
|
150 |
+
deltas = np.diff(timestamps)
|
151 |
+
throughput = (self.batch_size / deltas).mean()
|
152 |
+
stats = {
|
153 |
+
f"throughput_{self.mode}": throughput,
|
154 |
+
f"latency_{self.mode}_mean": deltas.mean(),
|
155 |
+
f"total_time_{self.mode}": timestamps[-1] - timestamps[0],
|
156 |
+
}
|
157 |
+
for level in [90, 95, 99]:
|
158 |
+
stats.update({f"latency_{self.mode}_{level}": np.percentile(deltas, level)})
|
159 |
+
|
160 |
+
return stats
|
RFdiffusion/env/SE3Transformer/build/lib/se3_transformer/runtime/gpu_affinity.py
ADDED
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
4 |
+
# copy of this software and associated documentation files (the "Software"),
|
5 |
+
# to deal in the Software without restriction, including without limitation
|
6 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
8 |
+
# Software is furnished to do so, subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in
|
11 |
+
# all copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
16 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
18 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
19 |
+
# DEALINGS IN THE SOFTWARE.
|
20 |
+
#
|
21 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
22 |
+
# SPDX-License-Identifier: MIT
|
23 |
+
|
24 |
+
import collections
|
25 |
+
import itertools
|
26 |
+
import math
|
27 |
+
import os
|
28 |
+
import pathlib
|
29 |
+
import re
|
30 |
+
|
31 |
+
import pynvml
|
32 |
+
|
33 |
+
|
34 |
+
class Device:
|
35 |
+
# assumes nvml returns list of 64 bit ints
|
36 |
+
_nvml_affinity_elements = math.ceil(os.cpu_count() / 64)
|
37 |
+
|
38 |
+
def __init__(self, device_idx):
|
39 |
+
super().__init__()
|
40 |
+
self.handle = pynvml.nvmlDeviceGetHandleByIndex(device_idx)
|
41 |
+
|
42 |
+
def get_name(self):
|
43 |
+
return pynvml.nvmlDeviceGetName(self.handle)
|
44 |
+
|
45 |
+
def get_uuid(self):
|
46 |
+
return pynvml.nvmlDeviceGetUUID(self.handle)
|
47 |
+
|
48 |
+
def get_cpu_affinity(self):
|
49 |
+
affinity_string = ""
|
50 |
+
for j in pynvml.nvmlDeviceGetCpuAffinity(self.handle, Device._nvml_affinity_elements):
|
51 |
+
# assume nvml returns list of 64 bit ints
|
52 |
+
affinity_string = "{:064b}".format(j) + affinity_string
|
53 |
+
|
54 |
+
affinity_list = [int(x) for x in affinity_string]
|
55 |
+
affinity_list.reverse() # so core 0 is in 0th element of list
|
56 |
+
|
57 |
+
ret = [i for i, e in enumerate(affinity_list) if e != 0]
|
58 |
+
return ret
|
59 |
+
|
60 |
+
|
61 |
+
def get_thread_siblings_list():
|
62 |
+
"""
|
63 |
+
Returns a list of 2-element integer tuples representing pairs of
|
64 |
+
hyperthreading cores.
|
65 |
+
"""
|
66 |
+
path = "/sys/devices/system/cpu/cpu*/topology/thread_siblings_list"
|
67 |
+
thread_siblings_list = []
|
68 |
+
pattern = re.compile(r"(\d+)\D(\d+)")
|
69 |
+
for fname in pathlib.Path(path[0]).glob(path[1:]):
|
70 |
+
with open(fname) as f:
|
71 |
+
content = f.read().strip()
|
72 |
+
res = pattern.findall(content)
|
73 |
+
if res:
|
74 |
+
pair = tuple(map(int, res[0]))
|
75 |
+
thread_siblings_list.append(pair)
|
76 |
+
return thread_siblings_list
|
77 |
+
|
78 |
+
|
79 |
+
def check_socket_affinities(socket_affinities):
|
80 |
+
# sets of cores should be either identical or disjoint
|
81 |
+
for i, j in itertools.product(socket_affinities, socket_affinities):
|
82 |
+
if not set(i) == set(j) and not set(i).isdisjoint(set(j)):
|
83 |
+
raise RuntimeError(f"Sets of cores should be either identical or disjoint, " f"but got {i} and {j}.")
|
84 |
+
|
85 |
+
|
86 |
+
def get_socket_affinities(nproc_per_node, exclude_unavailable_cores=True):
|
87 |
+
devices = [Device(i) for i in range(nproc_per_node)]
|
88 |
+
socket_affinities = [dev.get_cpu_affinity() for dev in devices]
|
89 |
+
|
90 |
+
if exclude_unavailable_cores:
|
91 |
+
available_cores = os.sched_getaffinity(0)
|
92 |
+
socket_affinities = [list(set(affinity) & available_cores) for affinity in socket_affinities]
|
93 |
+
|
94 |
+
check_socket_affinities(socket_affinities)
|
95 |
+
|
96 |
+
return socket_affinities
|
97 |
+
|
98 |
+
|
99 |
+
def set_socket_affinity(gpu_id):
|
100 |
+
"""
|
101 |
+
The process is assigned with all available logical CPU cores from the CPU
|
102 |
+
socket connected to the GPU with a given id.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
gpu_id: index of a GPU
|
106 |
+
"""
|
107 |
+
dev = Device(gpu_id)
|
108 |
+
affinity = dev.get_cpu_affinity()
|
109 |
+
os.sched_setaffinity(0, affinity)
|
110 |
+
|
111 |
+
|
112 |
+
def set_single_affinity(gpu_id):
|
113 |
+
"""
|
114 |
+
The process is assigned with the first available logical CPU core from the
|
115 |
+
list of all CPU cores from the CPU socket connected to the GPU with a given
|
116 |
+
id.
|
117 |
+
|
118 |
+
Args:
|
119 |
+
gpu_id: index of a GPU
|
120 |
+
"""
|
121 |
+
dev = Device(gpu_id)
|
122 |
+
affinity = dev.get_cpu_affinity()
|
123 |
+
|
124 |
+
# exclude unavailable cores
|
125 |
+
available_cores = os.sched_getaffinity(0)
|
126 |
+
affinity = list(set(affinity) & available_cores)
|
127 |
+
os.sched_setaffinity(0, affinity[:1])
|
128 |
+
|
129 |
+
|
130 |
+
def set_single_unique_affinity(gpu_id, nproc_per_node):
|
131 |
+
"""
|
132 |
+
The process is assigned with a single unique available physical CPU core
|
133 |
+
from the list of all CPU cores from the CPU socket connected to the GPU with
|
134 |
+
a given id.
|
135 |
+
|
136 |
+
Args:
|
137 |
+
gpu_id: index of a GPU
|
138 |
+
"""
|
139 |
+
socket_affinities = get_socket_affinities(nproc_per_node)
|
140 |
+
|
141 |
+
siblings_list = get_thread_siblings_list()
|
142 |
+
siblings_dict = dict(siblings_list)
|
143 |
+
|
144 |
+
# remove siblings
|
145 |
+
for idx, socket_affinity in enumerate(socket_affinities):
|
146 |
+
socket_affinities[idx] = list(set(socket_affinity) - set(siblings_dict.values()))
|
147 |
+
|
148 |
+
affinities = []
|
149 |
+
assigned = []
|
150 |
+
|
151 |
+
for socket_affinity in socket_affinities:
|
152 |
+
for core in socket_affinity:
|
153 |
+
if core not in assigned:
|
154 |
+
affinities.append([core])
|
155 |
+
assigned.append(core)
|
156 |
+
break
|
157 |
+
os.sched_setaffinity(0, affinities[gpu_id])
|
158 |
+
|
159 |
+
|
160 |
+
def set_socket_unique_affinity(gpu_id, nproc_per_node, mode, balanced=True):
|
161 |
+
"""
|
162 |
+
The process is assigned with an unique subset of available physical CPU
|
163 |
+
cores from the CPU socket connected to a GPU with a given id.
|
164 |
+
Assignment automatically includes hyperthreading siblings (if siblings are
|
165 |
+
available).
|
166 |
+
|
167 |
+
Args:
|
168 |
+
gpu_id: index of a GPU
|
169 |
+
nproc_per_node: total number of processes per node
|
170 |
+
mode: mode
|
171 |
+
balanced: assign an equal number of physical cores to each process
|
172 |
+
"""
|
173 |
+
socket_affinities = get_socket_affinities(nproc_per_node)
|
174 |
+
|
175 |
+
siblings_list = get_thread_siblings_list()
|
176 |
+
siblings_dict = dict(siblings_list)
|
177 |
+
|
178 |
+
# remove hyperthreading siblings
|
179 |
+
for idx, socket_affinity in enumerate(socket_affinities):
|
180 |
+
socket_affinities[idx] = list(set(socket_affinity) - set(siblings_dict.values()))
|
181 |
+
|
182 |
+
socket_affinities_to_device_ids = collections.defaultdict(list)
|
183 |
+
|
184 |
+
for idx, socket_affinity in enumerate(socket_affinities):
|
185 |
+
socket_affinities_to_device_ids[tuple(socket_affinity)].append(idx)
|
186 |
+
|
187 |
+
# compute minimal number of physical cores per GPU across all GPUs and
|
188 |
+
# sockets, code assigns this number of cores per GPU if balanced == True
|
189 |
+
min_physical_cores_per_gpu = min(
|
190 |
+
[len(cores) // len(gpus) for cores, gpus in socket_affinities_to_device_ids.items()]
|
191 |
+
)
|
192 |
+
|
193 |
+
for socket_affinity, device_ids in socket_affinities_to_device_ids.items():
|
194 |
+
devices_per_group = len(device_ids)
|
195 |
+
if balanced:
|
196 |
+
cores_per_device = min_physical_cores_per_gpu
|
197 |
+
socket_affinity = socket_affinity[: devices_per_group * min_physical_cores_per_gpu]
|
198 |
+
else:
|
199 |
+
cores_per_device = len(socket_affinity) // devices_per_group
|
200 |
+
|
201 |
+
for group_id, device_id in enumerate(device_ids):
|
202 |
+
if device_id == gpu_id:
|
203 |
+
|
204 |
+
# In theory there should be no difference in performance between
|
205 |
+
# 'interleaved' and 'continuous' pattern on Intel-based DGX-1,
|
206 |
+
# but 'continuous' should be better for DGX A100 because on AMD
|
207 |
+
# Rome 4 consecutive cores are sharing L3 cache.
|
208 |
+
# TODO: code doesn't attempt to automatically detect layout of
|
209 |
+
# L3 cache, also external environment may already exclude some
|
210 |
+
# cores, this code makes no attempt to detect it and to align
|
211 |
+
# mapping to multiples of 4.
|
212 |
+
|
213 |
+
if mode == "interleaved":
|
214 |
+
affinity = list(socket_affinity[group_id::devices_per_group])
|
215 |
+
elif mode == "continuous":
|
216 |
+
affinity = list(socket_affinity[group_id * cores_per_device: (group_id + 1) * cores_per_device])
|
217 |
+
else:
|
218 |
+
raise RuntimeError("Unknown set_socket_unique_affinity mode")
|
219 |
+
|
220 |
+
# unconditionally reintroduce hyperthreading siblings, this step
|
221 |
+
# may result in a different numbers of logical cores assigned to
|
222 |
+
# each GPU even if balanced == True (if hyperthreading siblings
|
223 |
+
# aren't available for a subset of cores due to some external
|
224 |
+
# constraints, siblings are re-added unconditionally, in the
|
225 |
+
# worst case unavailable logical core will be ignored by
|
226 |
+
# os.sched_setaffinity().
|
227 |
+
affinity += [siblings_dict[aff] for aff in affinity if aff in siblings_dict]
|
228 |
+
os.sched_setaffinity(0, affinity)
|
229 |
+
|
230 |
+
|
231 |
+
def set_affinity(gpu_id, nproc_per_node, mode="socket_unique_continuous", balanced=True):
|
232 |
+
"""
|
233 |
+
The process is assigned with a proper CPU affinity which matches hardware
|
234 |
+
architecture on a given platform. Usually it improves and stabilizes
|
235 |
+
performance of deep learning training workloads.
|
236 |
+
|
237 |
+
This function assumes that the workload is running in multi-process
|
238 |
+
single-device mode (there are multiple training processes and each process
|
239 |
+
is running on a single GPU), which is typical for multi-GPU training
|
240 |
+
workloads using `torch.nn.parallel.DistributedDataParallel`.
|
241 |
+
|
242 |
+
Available affinity modes:
|
243 |
+
* 'socket' - the process is assigned with all available logical CPU cores
|
244 |
+
from the CPU socket connected to the GPU with a given id.
|
245 |
+
* 'single' - the process is assigned with the first available logical CPU
|
246 |
+
core from the list of all CPU cores from the CPU socket connected to the GPU
|
247 |
+
with a given id (multiple GPUs could be assigned with the same CPU core).
|
248 |
+
* 'single_unique' - the process is assigned with a single unique available
|
249 |
+
physical CPU core from the list of all CPU cores from the CPU socket
|
250 |
+
connected to the GPU with a given id.
|
251 |
+
* 'socket_unique_interleaved' - the process is assigned with an unique
|
252 |
+
subset of available physical CPU cores from the CPU socket connected to a
|
253 |
+
GPU with a given id, hyperthreading siblings are included automatically,
|
254 |
+
cores are assigned with interleaved indexing pattern
|
255 |
+
* 'socket_unique_continuous' - (the default) the process is assigned with an
|
256 |
+
unique subset of available physical CPU cores from the CPU socket connected
|
257 |
+
to a GPU with a given id, hyperthreading siblings are included
|
258 |
+
automatically, cores are assigned with continuous indexing pattern
|
259 |
+
|
260 |
+
'socket_unique_continuous' is the recommended mode for deep learning
|
261 |
+
training workloads on NVIDIA DGX machines.
|
262 |
+
|
263 |
+
Args:
|
264 |
+
gpu_id: integer index of a GPU
|
265 |
+
nproc_per_node: number of processes per node
|
266 |
+
mode: affinity mode
|
267 |
+
balanced: assign an equal number of physical cores to each process,
|
268 |
+
affects only 'socket_unique_interleaved' and
|
269 |
+
'socket_unique_continuous' affinity modes
|
270 |
+
|
271 |
+
Returns a set of logical CPU cores on which the process is eligible to run.
|
272 |
+
|
273 |
+
Example:
|
274 |
+
|
275 |
+
import argparse
|
276 |
+
import os
|
277 |
+
|
278 |
+
import gpu_affinity
|
279 |
+
import torch
|
280 |
+
|
281 |
+
|
282 |
+
def main():
|
283 |
+
parser = argparse.ArgumentParser()
|
284 |
+
parser.add_argument(
|
285 |
+
'--local_rank',
|
286 |
+
type=int,
|
287 |
+
default=os.getenv('LOCAL_RANK', 0),
|
288 |
+
)
|
289 |
+
args = parser.parse_args()
|
290 |
+
|
291 |
+
nproc_per_node = torch.cuda.device_count()
|
292 |
+
|
293 |
+
affinity = gpu_affinity.set_affinity(args.local_rank, nproc_per_node)
|
294 |
+
print(f'{args.local_rank}: core affinity: {affinity}')
|
295 |
+
|
296 |
+
|
297 |
+
if __name__ == "__main__":
|
298 |
+
main()
|
299 |
+
|
300 |
+
Launch the example with:
|
301 |
+
python -m torch.distributed.launch --nproc_per_node <#GPUs> example.py
|
302 |
+
|
303 |
+
|
304 |
+
WARNING: On DGX A100 only a half of CPU cores have direct access to GPUs.
|
305 |
+
This function restricts execution only to the CPU cores directly connected
|
306 |
+
to GPUs, so on DGX A100 it will limit the code to half of CPU cores and half
|
307 |
+
of CPU memory bandwidth (which may be fine for many DL models).
|
308 |
+
"""
|
309 |
+
pynvml.nvmlInit()
|
310 |
+
|
311 |
+
if mode == "socket":
|
312 |
+
set_socket_affinity(gpu_id)
|
313 |
+
elif mode == "single":
|
314 |
+
set_single_affinity(gpu_id)
|
315 |
+
elif mode == "single_unique":
|
316 |
+
set_single_unique_affinity(gpu_id, nproc_per_node)
|
317 |
+
elif mode == "socket_unique_interleaved":
|
318 |
+
set_socket_unique_affinity(gpu_id, nproc_per_node, "interleaved", balanced)
|
319 |
+
elif mode == "socket_unique_continuous":
|
320 |
+
set_socket_unique_affinity(gpu_id, nproc_per_node, "continuous", balanced)
|
321 |
+
else:
|
322 |
+
raise RuntimeError("Unknown affinity mode")
|
323 |
+
|
324 |
+
affinity = os.sched_getaffinity(0)
|
325 |
+
return affinity
|
RFdiffusion/env/SE3Transformer/build/lib/se3_transformer/runtime/inference.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
4 |
+
# copy of this software and associated documentation files (the "Software"),
|
5 |
+
# to deal in the Software without restriction, including without limitation
|
6 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
8 |
+
# Software is furnished to do so, subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in
|
11 |
+
# all copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
16 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
18 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
19 |
+
# DEALINGS IN THE SOFTWARE.
|
20 |
+
#
|
21 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
22 |
+
# SPDX-License-Identifier: MIT
|
23 |
+
|
24 |
+
from typing import List
|
25 |
+
|
26 |
+
import torch
|
27 |
+
import torch.nn as nn
|
28 |
+
from torch.nn.parallel import DistributedDataParallel
|
29 |
+
from torch.utils.data import DataLoader
|
30 |
+
from tqdm import tqdm
|
31 |
+
|
32 |
+
from se3_transformer.runtime import gpu_affinity
|
33 |
+
from se3_transformer.runtime.arguments import PARSER
|
34 |
+
from se3_transformer.runtime.callbacks import BaseCallback
|
35 |
+
from se3_transformer.runtime.loggers import DLLogger
|
36 |
+
from se3_transformer.runtime.utils import to_cuda, get_local_rank
|
37 |
+
|
38 |
+
|
39 |
+
@torch.inference_mode()
|
40 |
+
def evaluate(model: nn.Module,
|
41 |
+
dataloader: DataLoader,
|
42 |
+
callbacks: List[BaseCallback],
|
43 |
+
args):
|
44 |
+
model.eval()
|
45 |
+
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), unit='batch', desc=f'Evaluation',
|
46 |
+
leave=False, disable=(args.silent or get_local_rank() != 0)):
|
47 |
+
*input, target = to_cuda(batch)
|
48 |
+
|
49 |
+
for callback in callbacks:
|
50 |
+
callback.on_batch_start()
|
51 |
+
|
52 |
+
with torch.cuda.amp.autocast(enabled=args.amp):
|
53 |
+
pred = model(*input)
|
54 |
+
|
55 |
+
for callback in callbacks:
|
56 |
+
callback.on_validation_step(input, target, pred)
|
57 |
+
|
58 |
+
|
59 |
+
if __name__ == '__main__':
|
60 |
+
from se3_transformer.runtime.callbacks import QM9MetricCallback, PerformanceCallback
|
61 |
+
from se3_transformer.runtime.utils import init_distributed, seed_everything
|
62 |
+
from se3_transformer.model import SE3TransformerPooled, Fiber
|
63 |
+
from se3_transformer.data_loading import QM9DataModule
|
64 |
+
import torch.distributed as dist
|
65 |
+
import logging
|
66 |
+
import sys
|
67 |
+
|
68 |
+
is_distributed = init_distributed()
|
69 |
+
local_rank = get_local_rank()
|
70 |
+
args = PARSER.parse_args()
|
71 |
+
|
72 |
+
logging.getLogger().setLevel(logging.CRITICAL if local_rank != 0 or args.silent else logging.INFO)
|
73 |
+
|
74 |
+
logging.info('====== SE(3)-Transformer ======')
|
75 |
+
logging.info('| Inference on the test set |')
|
76 |
+
logging.info('===============================')
|
77 |
+
|
78 |
+
if not args.benchmark and args.load_ckpt_path is None:
|
79 |
+
logging.error('No load_ckpt_path provided, you need to provide a saved model to evaluate')
|
80 |
+
sys.exit(1)
|
81 |
+
|
82 |
+
if args.benchmark:
|
83 |
+
logging.info('Running benchmark mode with one warmup pass')
|
84 |
+
|
85 |
+
if args.seed is not None:
|
86 |
+
seed_everything(args.seed)
|
87 |
+
|
88 |
+
major_cc, minor_cc = torch.cuda.get_device_capability()
|
89 |
+
|
90 |
+
logger = DLLogger(args.log_dir, filename=args.dllogger_name)
|
91 |
+
datamodule = QM9DataModule(**vars(args))
|
92 |
+
model = SE3TransformerPooled(
|
93 |
+
fiber_in=Fiber({0: datamodule.NODE_FEATURE_DIM}),
|
94 |
+
fiber_out=Fiber({0: args.num_degrees * args.num_channels}),
|
95 |
+
fiber_edge=Fiber({0: datamodule.EDGE_FEATURE_DIM}),
|
96 |
+
output_dim=1,
|
97 |
+
tensor_cores=(args.amp and major_cc >= 7) or major_cc >= 8, # use Tensor Cores more effectively
|
98 |
+
**vars(args)
|
99 |
+
)
|
100 |
+
callbacks = [QM9MetricCallback(logger, targets_std=datamodule.targets_std, prefix='test')]
|
101 |
+
|
102 |
+
model.to(device=torch.cuda.current_device())
|
103 |
+
if args.load_ckpt_path is not None:
|
104 |
+
checkpoint = torch.load(str(args.load_ckpt_path), map_location={'cuda:0': f'cuda:{local_rank}'})
|
105 |
+
model.load_state_dict(checkpoint['state_dict'])
|
106 |
+
|
107 |
+
if is_distributed:
|
108 |
+
nproc_per_node = torch.cuda.device_count()
|
109 |
+
affinity = gpu_affinity.set_affinity(local_rank, nproc_per_node)
|
110 |
+
model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)
|
111 |
+
|
112 |
+
test_dataloader = datamodule.test_dataloader() if not args.benchmark else datamodule.train_dataloader()
|
113 |
+
evaluate(model,
|
114 |
+
test_dataloader,
|
115 |
+
callbacks,
|
116 |
+
args)
|
117 |
+
|
118 |
+
for callback in callbacks:
|
119 |
+
callback.on_validation_end()
|
120 |
+
|
121 |
+
if args.benchmark:
|
122 |
+
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
123 |
+
callbacks = [PerformanceCallback(logger, args.batch_size * world_size, warmup_epochs=1, mode='inference')]
|
124 |
+
for _ in range(6):
|
125 |
+
evaluate(model,
|
126 |
+
test_dataloader,
|
127 |
+
callbacks,
|
128 |
+
args)
|
129 |
+
callbacks[0].on_epoch_end()
|
130 |
+
|
131 |
+
callbacks[0].on_fit_end()
|
RFdiffusion/env/SE3Transformer/build/lib/se3_transformer/runtime/loggers.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
4 |
+
# copy of this software and associated documentation files (the "Software"),
|
5 |
+
# to deal in the Software without restriction, including without limitation
|
6 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
8 |
+
# Software is furnished to do so, subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in
|
11 |
+
# all copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
16 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
18 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
19 |
+
# DEALINGS IN THE SOFTWARE.
|
20 |
+
#
|
21 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
22 |
+
# SPDX-License-Identifier: MIT
|
23 |
+
|
24 |
+
import pathlib
|
25 |
+
from abc import ABC, abstractmethod
|
26 |
+
from enum import Enum
|
27 |
+
from typing import Dict, Any, Callable, Optional
|
28 |
+
|
29 |
+
import dllogger
|
30 |
+
import torch.distributed as dist
|
31 |
+
import wandb
|
32 |
+
from dllogger import Verbosity
|
33 |
+
|
34 |
+
from se3_transformer.runtime.utils import rank_zero_only
|
35 |
+
|
36 |
+
|
37 |
+
class Logger(ABC):
|
38 |
+
@rank_zero_only
|
39 |
+
@abstractmethod
|
40 |
+
def log_hyperparams(self, params):
|
41 |
+
pass
|
42 |
+
|
43 |
+
@rank_zero_only
|
44 |
+
@abstractmethod
|
45 |
+
def log_metrics(self, metrics, step=None):
|
46 |
+
pass
|
47 |
+
|
48 |
+
@staticmethod
|
49 |
+
def _sanitize_params(params):
|
50 |
+
def _sanitize(val):
|
51 |
+
if isinstance(val, Callable):
|
52 |
+
try:
|
53 |
+
_val = val()
|
54 |
+
if isinstance(_val, Callable):
|
55 |
+
return val.__name__
|
56 |
+
return _val
|
57 |
+
except Exception:
|
58 |
+
return getattr(val, "__name__", None)
|
59 |
+
elif isinstance(val, pathlib.Path) or isinstance(val, Enum):
|
60 |
+
return str(val)
|
61 |
+
return val
|
62 |
+
|
63 |
+
return {key: _sanitize(val) for key, val in params.items()}
|
64 |
+
|
65 |
+
|
66 |
+
class LoggerCollection(Logger):
|
67 |
+
def __init__(self, loggers):
|
68 |
+
super().__init__()
|
69 |
+
self.loggers = loggers
|
70 |
+
|
71 |
+
def __getitem__(self, index):
|
72 |
+
return [logger for logger in self.loggers][index]
|
73 |
+
|
74 |
+
@rank_zero_only
|
75 |
+
def log_metrics(self, metrics, step=None):
|
76 |
+
for logger in self.loggers:
|
77 |
+
logger.log_metrics(metrics, step)
|
78 |
+
|
79 |
+
@rank_zero_only
|
80 |
+
def log_hyperparams(self, params):
|
81 |
+
for logger in self.loggers:
|
82 |
+
logger.log_hyperparams(params)
|
83 |
+
|
84 |
+
|
85 |
+
class DLLogger(Logger):
|
86 |
+
def __init__(self, save_dir: pathlib.Path, filename: str):
|
87 |
+
super().__init__()
|
88 |
+
if not dist.is_initialized() or dist.get_rank() == 0:
|
89 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
90 |
+
dllogger.init(
|
91 |
+
backends=[dllogger.JSONStreamBackend(Verbosity.DEFAULT, str(save_dir / filename))])
|
92 |
+
|
93 |
+
@rank_zero_only
|
94 |
+
def log_hyperparams(self, params):
|
95 |
+
params = self._sanitize_params(params)
|
96 |
+
dllogger.log(step="PARAMETER", data=params)
|
97 |
+
|
98 |
+
@rank_zero_only
|
99 |
+
def log_metrics(self, metrics, step=None):
|
100 |
+
if step is None:
|
101 |
+
step = tuple()
|
102 |
+
|
103 |
+
dllogger.log(step=step, data=metrics)
|
104 |
+
|
105 |
+
|
106 |
+
class WandbLogger(Logger):
|
107 |
+
def __init__(
|
108 |
+
self,
|
109 |
+
name: str,
|
110 |
+
save_dir: pathlib.Path,
|
111 |
+
id: Optional[str] = None,
|
112 |
+
project: Optional[str] = None
|
113 |
+
):
|
114 |
+
super().__init__()
|
115 |
+
if not dist.is_initialized() or dist.get_rank() == 0:
|
116 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
117 |
+
self.experiment = wandb.init(name=name,
|
118 |
+
project=project,
|
119 |
+
id=id,
|
120 |
+
dir=str(save_dir),
|
121 |
+
resume='allow',
|
122 |
+
anonymous='must')
|
123 |
+
|
124 |
+
@rank_zero_only
|
125 |
+
def log_hyperparams(self, params: Dict[str, Any]) -> None:
|
126 |
+
params = self._sanitize_params(params)
|
127 |
+
self.experiment.config.update(params, allow_val_change=True)
|
128 |
+
|
129 |
+
@rank_zero_only
|
130 |
+
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
|
131 |
+
if step is not None:
|
132 |
+
self.experiment.log({**metrics, 'epoch': step})
|
133 |
+
else:
|
134 |
+
self.experiment.log(metrics)
|
RFdiffusion/env/SE3Transformer/build/lib/se3_transformer/runtime/metrics.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
4 |
+
# copy of this software and associated documentation files (the "Software"),
|
5 |
+
# to deal in the Software without restriction, including without limitation
|
6 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
8 |
+
# Software is furnished to do so, subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in
|
11 |
+
# all copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
16 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
18 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
19 |
+
# DEALINGS IN THE SOFTWARE.
|
20 |
+
#
|
21 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
22 |
+
# SPDX-License-Identifier: MIT
|
23 |
+
|
24 |
+
from abc import ABC, abstractmethod
|
25 |
+
|
26 |
+
import torch
|
27 |
+
import torch.distributed as dist
|
28 |
+
from torch import Tensor
|
29 |
+
|
30 |
+
|
31 |
+
class Metric(ABC):
|
32 |
+
""" Metric class with synchronization capabilities similar to TorchMetrics """
|
33 |
+
|
34 |
+
def __init__(self):
|
35 |
+
self.states = {}
|
36 |
+
|
37 |
+
def add_state(self, name: str, default: Tensor):
|
38 |
+
assert name not in self.states
|
39 |
+
self.states[name] = default.clone()
|
40 |
+
setattr(self, name, default)
|
41 |
+
|
42 |
+
def synchronize(self):
|
43 |
+
if dist.is_initialized():
|
44 |
+
for state in self.states:
|
45 |
+
dist.all_reduce(getattr(self, state), op=dist.ReduceOp.SUM, group=dist.group.WORLD)
|
46 |
+
|
47 |
+
def __call__(self, *args, **kwargs):
|
48 |
+
self.update(*args, **kwargs)
|
49 |
+
|
50 |
+
def reset(self):
|
51 |
+
for name, default in self.states.items():
|
52 |
+
setattr(self, name, default.clone())
|
53 |
+
|
54 |
+
def compute(self):
|
55 |
+
self.synchronize()
|
56 |
+
value = self._compute().item()
|
57 |
+
self.reset()
|
58 |
+
return value
|
59 |
+
|
60 |
+
@abstractmethod
|
61 |
+
def _compute(self):
|
62 |
+
pass
|
63 |
+
|
64 |
+
@abstractmethod
|
65 |
+
def update(self, preds: Tensor, targets: Tensor):
|
66 |
+
pass
|
67 |
+
|
68 |
+
|
69 |
+
class MeanAbsoluteError(Metric):
|
70 |
+
def __init__(self):
|
71 |
+
super().__init__()
|
72 |
+
self.add_state('error', torch.tensor(0, dtype=torch.float32, device='cuda'))
|
73 |
+
self.add_state('total', torch.tensor(0, dtype=torch.int32, device='cuda'))
|
74 |
+
|
75 |
+
def update(self, preds: Tensor, targets: Tensor):
|
76 |
+
preds = preds.detach()
|
77 |
+
n = preds.shape[0]
|
78 |
+
error = torch.abs(preds.view(n, -1) - targets.view(n, -1)).sum()
|
79 |
+
self.total += n
|
80 |
+
self.error += error
|
81 |
+
|
82 |
+
def _compute(self):
|
83 |
+
return self.error / self.total
|
RFdiffusion/env/SE3Transformer/build/lib/se3_transformer/runtime/training.py
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
4 |
+
# copy of this software and associated documentation files (the "Software"),
|
5 |
+
# to deal in the Software without restriction, including without limitation
|
6 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
8 |
+
# Software is furnished to do so, subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in
|
11 |
+
# all copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
16 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
18 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
19 |
+
# DEALINGS IN THE SOFTWARE.
|
20 |
+
#
|
21 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
22 |
+
# SPDX-License-Identifier: MIT
|
23 |
+
|
24 |
+
import logging
|
25 |
+
import pathlib
|
26 |
+
from typing import List
|
27 |
+
|
28 |
+
import numpy as np
|
29 |
+
import torch
|
30 |
+
import torch.distributed as dist
|
31 |
+
import torch.nn as nn
|
32 |
+
from apex.optimizers import FusedAdam, FusedLAMB
|
33 |
+
from torch.nn.modules.loss import _Loss
|
34 |
+
from torch.nn.parallel import DistributedDataParallel
|
35 |
+
from torch.optim import Optimizer
|
36 |
+
from torch.utils.data import DataLoader, DistributedSampler
|
37 |
+
from tqdm import tqdm
|
38 |
+
|
39 |
+
from se3_transformer.data_loading import QM9DataModule
|
40 |
+
from se3_transformer.model import SE3TransformerPooled
|
41 |
+
from se3_transformer.model.fiber import Fiber
|
42 |
+
from se3_transformer.runtime import gpu_affinity
|
43 |
+
from se3_transformer.runtime.arguments import PARSER
|
44 |
+
from se3_transformer.runtime.callbacks import QM9MetricCallback, QM9LRSchedulerCallback, BaseCallback, \
|
45 |
+
PerformanceCallback
|
46 |
+
from se3_transformer.runtime.inference import evaluate
|
47 |
+
from se3_transformer.runtime.loggers import LoggerCollection, DLLogger, WandbLogger, Logger
|
48 |
+
from se3_transformer.runtime.utils import to_cuda, get_local_rank, init_distributed, seed_everything, \
|
49 |
+
using_tensor_cores, increase_l2_fetch_granularity
|
50 |
+
|
51 |
+
|
52 |
+
def save_state(model: nn.Module, optimizer: Optimizer, epoch: int, path: pathlib.Path, callbacks: List[BaseCallback]):
|
53 |
+
""" Saves model, optimizer and epoch states to path (only once per node) """
|
54 |
+
if get_local_rank() == 0:
|
55 |
+
state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict()
|
56 |
+
checkpoint = {
|
57 |
+
'state_dict': state_dict,
|
58 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
59 |
+
'epoch': epoch
|
60 |
+
}
|
61 |
+
for callback in callbacks:
|
62 |
+
callback.on_checkpoint_save(checkpoint)
|
63 |
+
|
64 |
+
torch.save(checkpoint, str(path))
|
65 |
+
logging.info(f'Saved checkpoint to {str(path)}')
|
66 |
+
|
67 |
+
|
68 |
+
def load_state(model: nn.Module, optimizer: Optimizer, path: pathlib.Path, callbacks: List[BaseCallback]):
|
69 |
+
""" Loads model, optimizer and epoch states from path """
|
70 |
+
checkpoint = torch.load(str(path), map_location={'cuda:0': f'cuda:{get_local_rank()}'})
|
71 |
+
if isinstance(model, DistributedDataParallel):
|
72 |
+
model.module.load_state_dict(checkpoint['state_dict'])
|
73 |
+
else:
|
74 |
+
model.load_state_dict(checkpoint['state_dict'])
|
75 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
76 |
+
|
77 |
+
for callback in callbacks:
|
78 |
+
callback.on_checkpoint_load(checkpoint)
|
79 |
+
|
80 |
+
logging.info(f'Loaded checkpoint from {str(path)}')
|
81 |
+
return checkpoint['epoch']
|
82 |
+
|
83 |
+
|
84 |
+
def train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args):
|
85 |
+
losses = []
|
86 |
+
for i, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader), unit='batch',
|
87 |
+
desc=f'Epoch {epoch_idx}', disable=(args.silent or local_rank != 0)):
|
88 |
+
*inputs, target = to_cuda(batch)
|
89 |
+
|
90 |
+
for callback in callbacks:
|
91 |
+
callback.on_batch_start()
|
92 |
+
|
93 |
+
with torch.cuda.amp.autocast(enabled=args.amp):
|
94 |
+
pred = model(*inputs)
|
95 |
+
loss = loss_fn(pred, target) / args.accumulate_grad_batches
|
96 |
+
|
97 |
+
grad_scaler.scale(loss).backward()
|
98 |
+
|
99 |
+
# gradient accumulation
|
100 |
+
if (i + 1) % args.accumulate_grad_batches == 0 or (i + 1) == len(train_dataloader):
|
101 |
+
if args.gradient_clip:
|
102 |
+
grad_scaler.unscale_(optimizer)
|
103 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), args.gradient_clip)
|
104 |
+
|
105 |
+
grad_scaler.step(optimizer)
|
106 |
+
grad_scaler.update()
|
107 |
+
optimizer.zero_grad()
|
108 |
+
|
109 |
+
losses.append(loss.item())
|
110 |
+
|
111 |
+
return np.mean(losses)
|
112 |
+
|
113 |
+
|
114 |
+
def train(model: nn.Module,
|
115 |
+
loss_fn: _Loss,
|
116 |
+
train_dataloader: DataLoader,
|
117 |
+
val_dataloader: DataLoader,
|
118 |
+
callbacks: List[BaseCallback],
|
119 |
+
logger: Logger,
|
120 |
+
args):
|
121 |
+
device = torch.cuda.current_device()
|
122 |
+
model.to(device=device)
|
123 |
+
local_rank = get_local_rank()
|
124 |
+
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
125 |
+
|
126 |
+
if dist.is_initialized():
|
127 |
+
model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)
|
128 |
+
|
129 |
+
model.train()
|
130 |
+
grad_scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
|
131 |
+
if args.optimizer == 'adam':
|
132 |
+
optimizer = FusedAdam(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999),
|
133 |
+
weight_decay=args.weight_decay)
|
134 |
+
elif args.optimizer == 'lamb':
|
135 |
+
optimizer = FusedLAMB(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999),
|
136 |
+
weight_decay=args.weight_decay)
|
137 |
+
else:
|
138 |
+
optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum,
|
139 |
+
weight_decay=args.weight_decay)
|
140 |
+
|
141 |
+
epoch_start = load_state(model, optimizer, args.load_ckpt_path, callbacks) if args.load_ckpt_path else 0
|
142 |
+
|
143 |
+
for callback in callbacks:
|
144 |
+
callback.on_fit_start(optimizer, args)
|
145 |
+
|
146 |
+
for epoch_idx in range(epoch_start, args.epochs):
|
147 |
+
if isinstance(train_dataloader.sampler, DistributedSampler):
|
148 |
+
train_dataloader.sampler.set_epoch(epoch_idx)
|
149 |
+
|
150 |
+
loss = train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args)
|
151 |
+
if dist.is_initialized():
|
152 |
+
loss = torch.tensor(loss, dtype=torch.float, device=device)
|
153 |
+
torch.distributed.all_reduce(loss)
|
154 |
+
loss = (loss / world_size).item()
|
155 |
+
|
156 |
+
logging.info(f'Train loss: {loss}')
|
157 |
+
logger.log_metrics({'train loss': loss}, epoch_idx)
|
158 |
+
|
159 |
+
for callback in callbacks:
|
160 |
+
callback.on_epoch_end()
|
161 |
+
|
162 |
+
if not args.benchmark and args.save_ckpt_path is not None and args.ckpt_interval > 0 \
|
163 |
+
and (epoch_idx + 1) % args.ckpt_interval == 0:
|
164 |
+
save_state(model, optimizer, epoch_idx, args.save_ckpt_path, callbacks)
|
165 |
+
|
166 |
+
if not args.benchmark and args.eval_interval > 0 and (epoch_idx + 1) % args.eval_interval == 0:
|
167 |
+
evaluate(model, val_dataloader, callbacks, args)
|
168 |
+
model.train()
|
169 |
+
|
170 |
+
for callback in callbacks:
|
171 |
+
callback.on_validation_end(epoch_idx)
|
172 |
+
|
173 |
+
if args.save_ckpt_path is not None and not args.benchmark:
|
174 |
+
save_state(model, optimizer, args.epochs, args.save_ckpt_path, callbacks)
|
175 |
+
|
176 |
+
for callback in callbacks:
|
177 |
+
callback.on_fit_end()
|
178 |
+
|
179 |
+
|
180 |
+
def print_parameters_count(model):
|
181 |
+
num_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
182 |
+
logging.info(f'Number of trainable parameters: {num_params_trainable}')
|
183 |
+
|
184 |
+
|
185 |
+
if __name__ == '__main__':
|
186 |
+
is_distributed = init_distributed()
|
187 |
+
local_rank = get_local_rank()
|
188 |
+
args = PARSER.parse_args()
|
189 |
+
|
190 |
+
logging.getLogger().setLevel(logging.CRITICAL if local_rank != 0 or args.silent else logging.INFO)
|
191 |
+
|
192 |
+
logging.info('====== SE(3)-Transformer ======')
|
193 |
+
logging.info('| Training procedure |')
|
194 |
+
logging.info('===============================')
|
195 |
+
|
196 |
+
if args.seed is not None:
|
197 |
+
logging.info(f'Using seed {args.seed}')
|
198 |
+
seed_everything(args.seed)
|
199 |
+
|
200 |
+
logger = LoggerCollection([
|
201 |
+
DLLogger(save_dir=args.log_dir, filename=args.dllogger_name),
|
202 |
+
WandbLogger(name=f'QM9({args.task})', save_dir=args.log_dir, project='se3-transformer')
|
203 |
+
])
|
204 |
+
|
205 |
+
datamodule = QM9DataModule(**vars(args))
|
206 |
+
model = SE3TransformerPooled(
|
207 |
+
fiber_in=Fiber({0: datamodule.NODE_FEATURE_DIM}),
|
208 |
+
fiber_out=Fiber({0: args.num_degrees * args.num_channels}),
|
209 |
+
fiber_edge=Fiber({0: datamodule.EDGE_FEATURE_DIM}),
|
210 |
+
output_dim=1,
|
211 |
+
tensor_cores=using_tensor_cores(args.amp), # use Tensor Cores more effectively
|
212 |
+
**vars(args)
|
213 |
+
)
|
214 |
+
loss_fn = nn.L1Loss()
|
215 |
+
|
216 |
+
if args.benchmark:
|
217 |
+
logging.info('Running benchmark mode')
|
218 |
+
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
219 |
+
callbacks = [PerformanceCallback(logger, args.batch_size * world_size)]
|
220 |
+
else:
|
221 |
+
callbacks = [QM9MetricCallback(logger, targets_std=datamodule.targets_std, prefix='validation'),
|
222 |
+
QM9LRSchedulerCallback(logger, epochs=args.epochs)]
|
223 |
+
|
224 |
+
if is_distributed:
|
225 |
+
gpu_affinity.set_affinity(gpu_id=get_local_rank(), nproc_per_node=torch.cuda.device_count())
|
226 |
+
|
227 |
+
print_parameters_count(model)
|
228 |
+
logger.log_hyperparams(vars(args))
|
229 |
+
increase_l2_fetch_granularity()
|
230 |
+
train(model,
|
231 |
+
loss_fn,
|
232 |
+
datamodule.train_dataloader(),
|
233 |
+
datamodule.val_dataloader(),
|
234 |
+
callbacks,
|
235 |
+
logger,
|
236 |
+
args)
|
237 |
+
|
238 |
+
logging.info('Training finished successfully')
|