Spaces:
Runtime error
Runtime error
GlandVergil
commited on
Commit
•
a507bdb
1
Parent(s):
a93e585
Upload 23 files
Browse files- se3_transformer/__init__.py +0 -0
- se3_transformer/data_loading/__init__.py +1 -0
- se3_transformer/data_loading/data_module.py +63 -0
- se3_transformer/data_loading/qm9.py +173 -0
- se3_transformer/model/__init__.py +2 -0
- se3_transformer/model/basis.py +178 -0
- se3_transformer/model/fiber.py +144 -0
- se3_transformer/model/layers/__init__.py +5 -0
- se3_transformer/model/layers/attention.py +180 -0
- se3_transformer/model/layers/convolution.py +336 -0
- se3_transformer/model/layers/linear.py +59 -0
- se3_transformer/model/layers/norm.py +83 -0
- se3_transformer/model/layers/pooling.py +53 -0
- se3_transformer/model/transformer.py +222 -0
- se3_transformer/runtime/__init__.py +0 -0
- se3_transformer/runtime/arguments.py +70 -0
- se3_transformer/runtime/callbacks.py +160 -0
- se3_transformer/runtime/gpu_affinity.py +325 -0
- se3_transformer/runtime/inference.py +131 -0
- se3_transformer/runtime/loggers.py +134 -0
- se3_transformer/runtime/metrics.py +83 -0
- se3_transformer/runtime/training.py +238 -0
- se3_transformer/runtime/utils.py +130 -0
se3_transformer/__init__.py
ADDED
File without changes
|
se3_transformer/data_loading/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .qm9 import QM9DataModule
|
se3_transformer/data_loading/data_module.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
4 |
+
# copy of this software and associated documentation files (the "Software"),
|
5 |
+
# to deal in the Software without restriction, including without limitation
|
6 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
8 |
+
# Software is furnished to do so, subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in
|
11 |
+
# all copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
16 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
18 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
19 |
+
# DEALINGS IN THE SOFTWARE.
|
20 |
+
#
|
21 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
22 |
+
# SPDX-License-Identifier: MIT
|
23 |
+
|
24 |
+
import torch.distributed as dist
|
25 |
+
from abc import ABC
|
26 |
+
from torch.utils.data import DataLoader, DistributedSampler, Dataset
|
27 |
+
|
28 |
+
from se3_transformer.runtime.utils import get_local_rank
|
29 |
+
|
30 |
+
|
31 |
+
def _get_dataloader(dataset: Dataset, shuffle: bool, **kwargs) -> DataLoader:
|
32 |
+
# Classic or distributed dataloader depending on the context
|
33 |
+
sampler = DistributedSampler(dataset, shuffle=shuffle) if dist.is_initialized() else None
|
34 |
+
return DataLoader(dataset, shuffle=(shuffle and sampler is None), sampler=sampler, **kwargs)
|
35 |
+
|
36 |
+
|
37 |
+
class DataModule(ABC):
|
38 |
+
""" Abstract DataModule. Children must define self.ds_{train | val | test}. """
|
39 |
+
|
40 |
+
def __init__(self, **dataloader_kwargs):
|
41 |
+
super().__init__()
|
42 |
+
if get_local_rank() == 0:
|
43 |
+
self.prepare_data()
|
44 |
+
|
45 |
+
# Wait until rank zero has prepared the data (download, preprocessing, ...)
|
46 |
+
if dist.is_initialized():
|
47 |
+
dist.barrier(device_ids=[get_local_rank()])
|
48 |
+
|
49 |
+
self.dataloader_kwargs = {'pin_memory': True, 'persistent_workers': True, **dataloader_kwargs}
|
50 |
+
self.ds_train, self.ds_val, self.ds_test = None, None, None
|
51 |
+
|
52 |
+
def prepare_data(self):
|
53 |
+
""" Method called only once per node. Put here any downloading or preprocessing """
|
54 |
+
pass
|
55 |
+
|
56 |
+
def train_dataloader(self) -> DataLoader:
|
57 |
+
return _get_dataloader(self.ds_train, shuffle=True, **self.dataloader_kwargs)
|
58 |
+
|
59 |
+
def val_dataloader(self) -> DataLoader:
|
60 |
+
return _get_dataloader(self.ds_val, shuffle=False, **self.dataloader_kwargs)
|
61 |
+
|
62 |
+
def test_dataloader(self) -> DataLoader:
|
63 |
+
return _get_dataloader(self.ds_test, shuffle=False, **self.dataloader_kwargs)
|
se3_transformer/data_loading/qm9.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
4 |
+
# copy of this software and associated documentation files (the "Software"),
|
5 |
+
# to deal in the Software without restriction, including without limitation
|
6 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
8 |
+
# Software is furnished to do so, subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in
|
11 |
+
# all copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
16 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
18 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
19 |
+
# DEALINGS IN THE SOFTWARE.
|
20 |
+
#
|
21 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
22 |
+
# SPDX-License-Identifier: MIT
|
23 |
+
from typing import Tuple
|
24 |
+
|
25 |
+
import dgl
|
26 |
+
import pathlib
|
27 |
+
import torch
|
28 |
+
from dgl.data import QM9EdgeDataset
|
29 |
+
from dgl import DGLGraph
|
30 |
+
from torch import Tensor
|
31 |
+
from torch.utils.data import random_split, DataLoader, Dataset
|
32 |
+
from tqdm import tqdm
|
33 |
+
|
34 |
+
from se3_transformer.data_loading.data_module import DataModule
|
35 |
+
from se3_transformer.model.basis import get_basis
|
36 |
+
from se3_transformer.runtime.utils import get_local_rank, str2bool, using_tensor_cores
|
37 |
+
|
38 |
+
|
39 |
+
def _get_relative_pos(qm9_graph: DGLGraph) -> Tensor:
|
40 |
+
x = qm9_graph.ndata['pos']
|
41 |
+
src, dst = qm9_graph.edges()
|
42 |
+
rel_pos = x[dst] - x[src]
|
43 |
+
return rel_pos
|
44 |
+
|
45 |
+
|
46 |
+
def _get_split_sizes(full_dataset: Dataset) -> Tuple[int, int, int]:
|
47 |
+
len_full = len(full_dataset)
|
48 |
+
len_train = 100_000
|
49 |
+
len_test = int(0.1 * len_full)
|
50 |
+
len_val = len_full - len_train - len_test
|
51 |
+
return len_train, len_val, len_test
|
52 |
+
|
53 |
+
|
54 |
+
class QM9DataModule(DataModule):
|
55 |
+
"""
|
56 |
+
Datamodule wrapping https://docs.dgl.ai/en/latest/api/python/dgl.data.html#qm9edge-dataset
|
57 |
+
Training set is 100k molecules. Test set is 10% of the dataset. Validation set is the rest.
|
58 |
+
This includes all the molecules from QM9 except the ones that are uncharacterized.
|
59 |
+
"""
|
60 |
+
|
61 |
+
NODE_FEATURE_DIM = 6
|
62 |
+
EDGE_FEATURE_DIM = 4
|
63 |
+
|
64 |
+
def __init__(self,
|
65 |
+
data_dir: pathlib.Path,
|
66 |
+
task: str = 'homo',
|
67 |
+
batch_size: int = 240,
|
68 |
+
num_workers: int = 8,
|
69 |
+
num_degrees: int = 4,
|
70 |
+
amp: bool = False,
|
71 |
+
precompute_bases: bool = False,
|
72 |
+
**kwargs):
|
73 |
+
self.data_dir = data_dir # This needs to be before __init__ so that prepare_data has access to it
|
74 |
+
super().__init__(batch_size=batch_size, num_workers=num_workers, collate_fn=self._collate)
|
75 |
+
self.amp = amp
|
76 |
+
self.task = task
|
77 |
+
self.batch_size = batch_size
|
78 |
+
self.num_degrees = num_degrees
|
79 |
+
|
80 |
+
qm9_kwargs = dict(label_keys=[self.task], verbose=False, raw_dir=str(data_dir))
|
81 |
+
if precompute_bases:
|
82 |
+
bases_kwargs = dict(max_degree=num_degrees - 1, use_pad_trick=using_tensor_cores(amp), amp=amp)
|
83 |
+
full_dataset = CachedBasesQM9EdgeDataset(bases_kwargs=bases_kwargs, batch_size=batch_size,
|
84 |
+
num_workers=num_workers, **qm9_kwargs)
|
85 |
+
else:
|
86 |
+
full_dataset = QM9EdgeDataset(**qm9_kwargs)
|
87 |
+
|
88 |
+
self.ds_train, self.ds_val, self.ds_test = random_split(full_dataset, _get_split_sizes(full_dataset),
|
89 |
+
generator=torch.Generator().manual_seed(0))
|
90 |
+
|
91 |
+
train_targets = full_dataset.targets[self.ds_train.indices, full_dataset.label_keys[0]]
|
92 |
+
self.targets_mean = train_targets.mean()
|
93 |
+
self.targets_std = train_targets.std()
|
94 |
+
|
95 |
+
def prepare_data(self):
|
96 |
+
# Download the QM9 preprocessed data
|
97 |
+
QM9EdgeDataset(verbose=True, raw_dir=str(self.data_dir))
|
98 |
+
|
99 |
+
def _collate(self, samples):
|
100 |
+
graphs, y, *bases = map(list, zip(*samples))
|
101 |
+
batched_graph = dgl.batch(graphs)
|
102 |
+
edge_feats = {'0': batched_graph.edata['edge_attr'][..., None]}
|
103 |
+
batched_graph.edata['rel_pos'] = _get_relative_pos(batched_graph)
|
104 |
+
# get node features
|
105 |
+
node_feats = {'0': batched_graph.ndata['attr'][:, :6, None]}
|
106 |
+
targets = (torch.cat(y) - self.targets_mean) / self.targets_std
|
107 |
+
|
108 |
+
if bases:
|
109 |
+
# collate bases
|
110 |
+
all_bases = {
|
111 |
+
key: torch.cat([b[key] for b in bases[0]], dim=0)
|
112 |
+
for key in bases[0][0].keys()
|
113 |
+
}
|
114 |
+
|
115 |
+
return batched_graph, node_feats, edge_feats, all_bases, targets
|
116 |
+
else:
|
117 |
+
return batched_graph, node_feats, edge_feats, targets
|
118 |
+
|
119 |
+
@staticmethod
|
120 |
+
def add_argparse_args(parent_parser):
|
121 |
+
parser = parent_parser.add_argument_group("QM9 dataset")
|
122 |
+
parser.add_argument('--task', type=str, default='homo', const='homo', nargs='?',
|
123 |
+
choices=['mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv',
|
124 |
+
'U0_atom', 'U_atom', 'H_atom', 'G_atom', 'A', 'B', 'C'],
|
125 |
+
help='Regression task to train on')
|
126 |
+
parser.add_argument('--precompute_bases', type=str2bool, nargs='?', const=True, default=False,
|
127 |
+
help='Precompute bases at the beginning of the script during dataset initialization,'
|
128 |
+
' instead of computing them at the beginning of each forward pass.')
|
129 |
+
return parent_parser
|
130 |
+
|
131 |
+
def __repr__(self):
|
132 |
+
return f'QM9({self.task})'
|
133 |
+
|
134 |
+
|
135 |
+
class CachedBasesQM9EdgeDataset(QM9EdgeDataset):
|
136 |
+
""" Dataset extending the QM9 dataset from DGL with precomputed (cached in RAM) pairwise bases """
|
137 |
+
|
138 |
+
def __init__(self, bases_kwargs: dict, batch_size: int, num_workers: int, *args, **kwargs):
|
139 |
+
"""
|
140 |
+
:param bases_kwargs: Arguments to feed the bases computation function
|
141 |
+
:param batch_size: Batch size to use when iterating over the dataset for computing bases
|
142 |
+
"""
|
143 |
+
self.bases_kwargs = bases_kwargs
|
144 |
+
self.batch_size = batch_size
|
145 |
+
self.bases = None
|
146 |
+
self.num_workers = num_workers
|
147 |
+
super().__init__(*args, **kwargs)
|
148 |
+
|
149 |
+
def load(self):
|
150 |
+
super().load()
|
151 |
+
# Iterate through the dataset and compute bases (pairwise only)
|
152 |
+
# Potential improvement: use multi-GPU and gather
|
153 |
+
dataloader = DataLoader(self, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers,
|
154 |
+
collate_fn=lambda samples: dgl.batch([sample[0] for sample in samples]))
|
155 |
+
bases = []
|
156 |
+
for i, graph in tqdm(enumerate(dataloader), total=len(dataloader), desc='Precomputing QM9 bases',
|
157 |
+
disable=get_local_rank() != 0):
|
158 |
+
rel_pos = _get_relative_pos(graph)
|
159 |
+
# Compute the bases with the GPU but convert the result to CPU to store in RAM
|
160 |
+
bases.append({k: v.cpu() for k, v in get_basis(rel_pos.cuda(), **self.bases_kwargs).items()})
|
161 |
+
self.bases = bases # Assign at the end so that __getitem__ isn't confused
|
162 |
+
|
163 |
+
def __getitem__(self, idx: int):
|
164 |
+
graph, label = super().__getitem__(idx)
|
165 |
+
|
166 |
+
if self.bases:
|
167 |
+
bases_idx = idx // self.batch_size
|
168 |
+
bases_cumsum_idx = self.ne_cumsum[idx] - self.ne_cumsum[bases_idx * self.batch_size]
|
169 |
+
bases_cumsum_next_idx = self.ne_cumsum[idx + 1] - self.ne_cumsum[bases_idx * self.batch_size]
|
170 |
+
return graph, label, {key: basis[bases_cumsum_idx:bases_cumsum_next_idx] for key, basis in
|
171 |
+
self.bases[bases_idx].items()}
|
172 |
+
else:
|
173 |
+
return graph, label
|
se3_transformer/model/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .transformer import SE3Transformer, SE3TransformerPooled
|
2 |
+
from .fiber import Fiber
|
se3_transformer/model/basis.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
4 |
+
# copy of this software and associated documentation files (the "Software"),
|
5 |
+
# to deal in the Software without restriction, including without limitation
|
6 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
8 |
+
# Software is furnished to do so, subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in
|
11 |
+
# all copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
16 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
18 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
19 |
+
# DEALINGS IN THE SOFTWARE.
|
20 |
+
#
|
21 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
22 |
+
# SPDX-License-Identifier: MIT
|
23 |
+
|
24 |
+
|
25 |
+
from functools import lru_cache
|
26 |
+
from typing import Dict, List
|
27 |
+
|
28 |
+
import e3nn.o3 as o3
|
29 |
+
import torch
|
30 |
+
import torch.nn.functional as F
|
31 |
+
from torch import Tensor
|
32 |
+
from torch.cuda.nvtx import range as nvtx_range
|
33 |
+
|
34 |
+
from se3_transformer.runtime.utils import degree_to_dim
|
35 |
+
|
36 |
+
|
37 |
+
@lru_cache(maxsize=None)
|
38 |
+
def get_clebsch_gordon(J: int, d_in: int, d_out: int, device) -> Tensor:
|
39 |
+
""" Get the (cached) Q^{d_out,d_in}_J matrices from equation (8) """
|
40 |
+
return o3.wigner_3j(J, d_in, d_out, dtype=torch.float64, device=device).permute(2, 1, 0)
|
41 |
+
|
42 |
+
|
43 |
+
@lru_cache(maxsize=None)
|
44 |
+
def get_all_clebsch_gordon(max_degree: int, device) -> List[List[Tensor]]:
|
45 |
+
all_cb = []
|
46 |
+
for d_in in range(max_degree + 1):
|
47 |
+
for d_out in range(max_degree + 1):
|
48 |
+
K_Js = []
|
49 |
+
for J in range(abs(d_in - d_out), d_in + d_out + 1):
|
50 |
+
K_Js.append(get_clebsch_gordon(J, d_in, d_out, device))
|
51 |
+
all_cb.append(K_Js)
|
52 |
+
return all_cb
|
53 |
+
|
54 |
+
|
55 |
+
def get_spherical_harmonics(relative_pos: Tensor, max_degree: int) -> List[Tensor]:
|
56 |
+
all_degrees = list(range(2 * max_degree + 1))
|
57 |
+
with nvtx_range('spherical harmonics'):
|
58 |
+
sh = o3.spherical_harmonics(all_degrees, relative_pos, normalize=True)
|
59 |
+
return torch.split(sh, [degree_to_dim(d) for d in all_degrees], dim=1)
|
60 |
+
|
61 |
+
|
62 |
+
@torch.jit.script
|
63 |
+
def get_basis_script(max_degree: int,
|
64 |
+
use_pad_trick: bool,
|
65 |
+
spherical_harmonics: List[Tensor],
|
66 |
+
clebsch_gordon: List[List[Tensor]],
|
67 |
+
amp: bool) -> Dict[str, Tensor]:
|
68 |
+
"""
|
69 |
+
Compute pairwise bases matrices for degrees up to max_degree
|
70 |
+
:param max_degree: Maximum input or output degree
|
71 |
+
:param use_pad_trick: Pad some of the odd dimensions for a better use of Tensor Cores
|
72 |
+
:param spherical_harmonics: List of computed spherical harmonics
|
73 |
+
:param clebsch_gordon: List of computed CB-coefficients
|
74 |
+
:param amp: When true, return bases in FP16 precision
|
75 |
+
"""
|
76 |
+
basis = {}
|
77 |
+
idx = 0
|
78 |
+
# Double for loop instead of product() because of JIT script
|
79 |
+
for d_in in range(max_degree + 1):
|
80 |
+
for d_out in range(max_degree + 1):
|
81 |
+
key = f'{d_in},{d_out}'
|
82 |
+
K_Js = []
|
83 |
+
for freq_idx, J in enumerate(range(abs(d_in - d_out), d_in + d_out + 1)):
|
84 |
+
Q_J = clebsch_gordon[idx][freq_idx]
|
85 |
+
K_Js.append(torch.einsum('n f, k l f -> n l k', spherical_harmonics[J].float(), Q_J.float()))
|
86 |
+
|
87 |
+
basis[key] = torch.stack(K_Js, 2) # Stack on second dim so order is n l f k
|
88 |
+
if amp:
|
89 |
+
basis[key] = basis[key].half()
|
90 |
+
if use_pad_trick:
|
91 |
+
basis[key] = F.pad(basis[key], (0, 1)) # Pad the k dimension, that can be sliced later
|
92 |
+
|
93 |
+
idx += 1
|
94 |
+
|
95 |
+
return basis
|
96 |
+
|
97 |
+
|
98 |
+
@torch.jit.script
|
99 |
+
def update_basis_with_fused(basis: Dict[str, Tensor],
|
100 |
+
max_degree: int,
|
101 |
+
use_pad_trick: bool,
|
102 |
+
fully_fused: bool) -> Dict[str, Tensor]:
|
103 |
+
""" Update the basis dict with partially and optionally fully fused bases """
|
104 |
+
num_edges = basis['0,0'].shape[0]
|
105 |
+
device = basis['0,0'].device
|
106 |
+
dtype = basis['0,0'].dtype
|
107 |
+
sum_dim = sum([degree_to_dim(d) for d in range(max_degree + 1)])
|
108 |
+
|
109 |
+
# Fused per output degree
|
110 |
+
for d_out in range(max_degree + 1):
|
111 |
+
sum_freq = sum([degree_to_dim(min(d, d_out)) for d in range(max_degree + 1)])
|
112 |
+
basis_fused = torch.zeros(num_edges, sum_dim, sum_freq, degree_to_dim(d_out) + int(use_pad_trick),
|
113 |
+
device=device, dtype=dtype)
|
114 |
+
acc_d, acc_f = 0, 0
|
115 |
+
for d_in in range(max_degree + 1):
|
116 |
+
basis_fused[:, acc_d:acc_d + degree_to_dim(d_in), acc_f:acc_f + degree_to_dim(min(d_out, d_in)),
|
117 |
+
:degree_to_dim(d_out)] = basis[f'{d_in},{d_out}'][:, :, :, :degree_to_dim(d_out)]
|
118 |
+
|
119 |
+
acc_d += degree_to_dim(d_in)
|
120 |
+
acc_f += degree_to_dim(min(d_out, d_in))
|
121 |
+
|
122 |
+
basis[f'out{d_out}_fused'] = basis_fused
|
123 |
+
|
124 |
+
# Fused per input degree
|
125 |
+
for d_in in range(max_degree + 1):
|
126 |
+
sum_freq = sum([degree_to_dim(min(d, d_in)) for d in range(max_degree + 1)])
|
127 |
+
basis_fused = torch.zeros(num_edges, degree_to_dim(d_in), sum_freq, sum_dim,
|
128 |
+
device=device, dtype=dtype)
|
129 |
+
acc_d, acc_f = 0, 0
|
130 |
+
for d_out in range(max_degree + 1):
|
131 |
+
basis_fused[:, :, acc_f:acc_f + degree_to_dim(min(d_out, d_in)), acc_d:acc_d + degree_to_dim(d_out)] \
|
132 |
+
= basis[f'{d_in},{d_out}'][:, :, :, :degree_to_dim(d_out)]
|
133 |
+
|
134 |
+
acc_d += degree_to_dim(d_out)
|
135 |
+
acc_f += degree_to_dim(min(d_out, d_in))
|
136 |
+
|
137 |
+
basis[f'in{d_in}_fused'] = basis_fused
|
138 |
+
|
139 |
+
if fully_fused:
|
140 |
+
# Fully fused
|
141 |
+
# Double sum this way because of JIT script
|
142 |
+
sum_freq = sum([
|
143 |
+
sum([degree_to_dim(min(d_in, d_out)) for d_in in range(max_degree + 1)]) for d_out in range(max_degree + 1)
|
144 |
+
])
|
145 |
+
basis_fused = torch.zeros(num_edges, sum_dim, sum_freq, sum_dim, device=device, dtype=dtype)
|
146 |
+
|
147 |
+
acc_d, acc_f = 0, 0
|
148 |
+
for d_out in range(max_degree + 1):
|
149 |
+
b = basis[f'out{d_out}_fused']
|
150 |
+
basis_fused[:, :, acc_f:acc_f + b.shape[2], acc_d:acc_d + degree_to_dim(d_out)] = b[:, :, :,
|
151 |
+
:degree_to_dim(d_out)]
|
152 |
+
acc_f += b.shape[2]
|
153 |
+
acc_d += degree_to_dim(d_out)
|
154 |
+
|
155 |
+
basis['fully_fused'] = basis_fused
|
156 |
+
|
157 |
+
del basis['0,0'] # We know that the basis for l = k = 0 is filled with a constant
|
158 |
+
return basis
|
159 |
+
|
160 |
+
|
161 |
+
def get_basis(relative_pos: Tensor,
|
162 |
+
max_degree: int = 4,
|
163 |
+
compute_gradients: bool = False,
|
164 |
+
use_pad_trick: bool = False,
|
165 |
+
amp: bool = False) -> Dict[str, Tensor]:
|
166 |
+
with nvtx_range('spherical harmonics'):
|
167 |
+
spherical_harmonics = get_spherical_harmonics(relative_pos, max_degree)
|
168 |
+
with nvtx_range('CB coefficients'):
|
169 |
+
clebsch_gordon = get_all_clebsch_gordon(max_degree, relative_pos.device)
|
170 |
+
|
171 |
+
with torch.autograd.set_grad_enabled(compute_gradients):
|
172 |
+
with nvtx_range('bases'):
|
173 |
+
basis = get_basis_script(max_degree=max_degree,
|
174 |
+
use_pad_trick=use_pad_trick,
|
175 |
+
spherical_harmonics=spherical_harmonics,
|
176 |
+
clebsch_gordon=clebsch_gordon,
|
177 |
+
amp=amp)
|
178 |
+
return basis
|
se3_transformer/model/fiber.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
4 |
+
# copy of this software and associated documentation files (the "Software"),
|
5 |
+
# to deal in the Software without restriction, including without limitation
|
6 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
8 |
+
# Software is furnished to do so, subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in
|
11 |
+
# all copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
16 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
18 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
19 |
+
# DEALINGS IN THE SOFTWARE.
|
20 |
+
#
|
21 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
22 |
+
# SPDX-License-Identifier: MIT
|
23 |
+
|
24 |
+
|
25 |
+
from collections import namedtuple
|
26 |
+
from itertools import product
|
27 |
+
from typing import Dict
|
28 |
+
|
29 |
+
import torch
|
30 |
+
from torch import Tensor
|
31 |
+
|
32 |
+
from se3_transformer.runtime.utils import degree_to_dim
|
33 |
+
|
34 |
+
FiberEl = namedtuple('FiberEl', ['degree', 'channels'])
|
35 |
+
|
36 |
+
|
37 |
+
class Fiber(dict):
|
38 |
+
"""
|
39 |
+
Describes the structure of some set of features.
|
40 |
+
Features are split into types (0, 1, 2, 3, ...). A feature of type k has a dimension of 2k+1.
|
41 |
+
Type-0 features: invariant scalars
|
42 |
+
Type-1 features: equivariant 3D vectors
|
43 |
+
Type-2 features: equivariant symmetric traceless matrices
|
44 |
+
...
|
45 |
+
|
46 |
+
As inputs to a SE3 layer, there can be many features of the same types, and many features of different types.
|
47 |
+
The 'multiplicity' or 'number of channels' is the number of features of a given type.
|
48 |
+
This class puts together all the degrees and their multiplicities in order to describe
|
49 |
+
the inputs, outputs or hidden features of SE3 layers.
|
50 |
+
"""
|
51 |
+
|
52 |
+
def __init__(self, structure):
|
53 |
+
if isinstance(structure, dict):
|
54 |
+
structure = [FiberEl(int(d), int(m)) for d, m in sorted(structure.items(), key=lambda x: x[1])]
|
55 |
+
elif not isinstance(structure[0], FiberEl):
|
56 |
+
structure = list(map(lambda t: FiberEl(*t), sorted(structure, key=lambda x: x[1])))
|
57 |
+
self.structure = structure
|
58 |
+
super().__init__({d: m for d, m in self.structure})
|
59 |
+
|
60 |
+
@property
|
61 |
+
def degrees(self):
|
62 |
+
return sorted([t.degree for t in self.structure])
|
63 |
+
|
64 |
+
@property
|
65 |
+
def channels(self):
|
66 |
+
return [self[d] for d in self.degrees]
|
67 |
+
|
68 |
+
@property
|
69 |
+
def num_features(self):
|
70 |
+
""" Size of the resulting tensor if all features were concatenated together """
|
71 |
+
return sum(t.channels * degree_to_dim(t.degree) for t in self.structure)
|
72 |
+
|
73 |
+
@staticmethod
|
74 |
+
def create(num_degrees: int, num_channels: int):
|
75 |
+
""" Create a Fiber with degrees 0..num_degrees-1, all with the same multiplicity """
|
76 |
+
return Fiber([(degree, num_channels) for degree in range(num_degrees)])
|
77 |
+
|
78 |
+
@staticmethod
|
79 |
+
def from_features(feats: Dict[str, Tensor]):
|
80 |
+
""" Infer the Fiber structure from a feature dict """
|
81 |
+
structure = {}
|
82 |
+
for k, v in feats.items():
|
83 |
+
degree = int(k)
|
84 |
+
assert len(v.shape) == 3, 'Feature shape should be (N, C, 2D+1)'
|
85 |
+
assert v.shape[-1] == degree_to_dim(degree)
|
86 |
+
structure[degree] = v.shape[-2]
|
87 |
+
return Fiber(structure)
|
88 |
+
|
89 |
+
def __getitem__(self, degree: int):
|
90 |
+
""" fiber[degree] returns the multiplicity for this degree """
|
91 |
+
return dict(self.structure).get(degree, 0)
|
92 |
+
|
93 |
+
def __iter__(self):
|
94 |
+
""" Iterate over namedtuples (degree, channels) """
|
95 |
+
return iter(self.structure)
|
96 |
+
|
97 |
+
def __mul__(self, other):
|
98 |
+
"""
|
99 |
+
If other in an int, multiplies all the multiplicities by other.
|
100 |
+
If other is a fiber, returns the cartesian product.
|
101 |
+
"""
|
102 |
+
if isinstance(other, Fiber):
|
103 |
+
return product(self.structure, other.structure)
|
104 |
+
elif isinstance(other, int):
|
105 |
+
return Fiber({t.degree: t.channels * other for t in self.structure})
|
106 |
+
|
107 |
+
def __add__(self, other):
|
108 |
+
"""
|
109 |
+
If other in an int, add other to all the multiplicities.
|
110 |
+
If other is a fiber, add the multiplicities of the fibers together.
|
111 |
+
"""
|
112 |
+
if isinstance(other, Fiber):
|
113 |
+
return Fiber({t.degree: t.channels + other[t.degree] for t in self.structure})
|
114 |
+
elif isinstance(other, int):
|
115 |
+
return Fiber({t.degree: t.channels + other for t in self.structure})
|
116 |
+
|
117 |
+
def __repr__(self):
|
118 |
+
return str(self.structure)
|
119 |
+
|
120 |
+
@staticmethod
|
121 |
+
def combine_max(f1, f2):
|
122 |
+
""" Combine two fiber by taking the maximum multiplicity for each degree in both fibers """
|
123 |
+
new_dict = dict(f1.structure)
|
124 |
+
for k, m in f2.structure:
|
125 |
+
new_dict[k] = max(new_dict.get(k, 0), m)
|
126 |
+
|
127 |
+
return Fiber(list(new_dict.items()))
|
128 |
+
|
129 |
+
@staticmethod
|
130 |
+
def combine_selectively(f1, f2):
|
131 |
+
""" Combine two fiber by taking the sum of multiplicities for each degree in the first fiber """
|
132 |
+
# only use orders which occur in fiber f1
|
133 |
+
new_dict = dict(f1.structure)
|
134 |
+
for k in f1.degrees:
|
135 |
+
if k in f2.degrees:
|
136 |
+
new_dict[k] += f2[k]
|
137 |
+
return Fiber(list(new_dict.items()))
|
138 |
+
|
139 |
+
def to_attention_heads(self, tensors: Dict[str, Tensor], num_heads: int):
|
140 |
+
# dict(N, num_channels, 2d+1) -> (N, num_heads, -1)
|
141 |
+
fibers = [tensors[str(degree)].reshape(*tensors[str(degree)].shape[:-2], num_heads, -1) for degree in
|
142 |
+
self.degrees]
|
143 |
+
fibers = torch.cat(fibers, -1)
|
144 |
+
return fibers
|
se3_transformer/model/layers/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .linear import LinearSE3
|
2 |
+
from .norm import NormSE3
|
3 |
+
from .pooling import GPooling
|
4 |
+
from .convolution import ConvSE3
|
5 |
+
from .attention import AttentionBlockSE3
|
se3_transformer/model/layers/attention.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
4 |
+
# copy of this software and associated documentation files (the "Software"),
|
5 |
+
# to deal in the Software without restriction, including without limitation
|
6 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
8 |
+
# Software is furnished to do so, subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in
|
11 |
+
# all copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
16 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
18 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
19 |
+
# DEALINGS IN THE SOFTWARE.
|
20 |
+
#
|
21 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
22 |
+
# SPDX-License-Identifier: MIT
|
23 |
+
|
24 |
+
import dgl
|
25 |
+
import numpy as np
|
26 |
+
import torch
|
27 |
+
import torch.nn as nn
|
28 |
+
from dgl import DGLGraph
|
29 |
+
from dgl.ops import edge_softmax
|
30 |
+
from torch import Tensor
|
31 |
+
from typing import Dict, Optional, Union
|
32 |
+
|
33 |
+
from se3_transformer.model.fiber import Fiber
|
34 |
+
from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel
|
35 |
+
from se3_transformer.model.layers.linear import LinearSE3
|
36 |
+
from se3_transformer.runtime.utils import degree_to_dim, aggregate_residual, unfuse_features
|
37 |
+
from torch.cuda.nvtx import range as nvtx_range
|
38 |
+
|
39 |
+
|
40 |
+
class AttentionSE3(nn.Module):
|
41 |
+
""" Multi-headed sparse graph self-attention (SE(3)-equivariant) """
|
42 |
+
|
43 |
+
def __init__(
|
44 |
+
self,
|
45 |
+
num_heads: int,
|
46 |
+
key_fiber: Fiber,
|
47 |
+
value_fiber: Fiber
|
48 |
+
):
|
49 |
+
"""
|
50 |
+
:param num_heads: Number of attention heads
|
51 |
+
:param key_fiber: Fiber for the keys (and also for the queries)
|
52 |
+
:param value_fiber: Fiber for the values
|
53 |
+
"""
|
54 |
+
super().__init__()
|
55 |
+
self.num_heads = num_heads
|
56 |
+
self.key_fiber = key_fiber
|
57 |
+
self.value_fiber = value_fiber
|
58 |
+
|
59 |
+
def forward(
|
60 |
+
self,
|
61 |
+
value: Union[Tensor, Dict[str, Tensor]], # edge features (may be fused)
|
62 |
+
key: Union[Tensor, Dict[str, Tensor]], # edge features (may be fused)
|
63 |
+
query: Dict[str, Tensor], # node features
|
64 |
+
graph: DGLGraph
|
65 |
+
):
|
66 |
+
with nvtx_range('AttentionSE3'):
|
67 |
+
with nvtx_range('reshape keys and queries'):
|
68 |
+
if isinstance(key, Tensor):
|
69 |
+
# case where features of all types are fused
|
70 |
+
key = key.reshape(key.shape[0], self.num_heads, -1)
|
71 |
+
# need to reshape queries that way to keep the same layout as keys
|
72 |
+
out = torch.cat([query[str(d)] for d in self.key_fiber.degrees], dim=-1)
|
73 |
+
query = out.reshape(list(query.values())[0].shape[0], self.num_heads, -1)
|
74 |
+
else:
|
75 |
+
# features are not fused, need to fuse and reshape them
|
76 |
+
key = self.key_fiber.to_attention_heads(key, self.num_heads)
|
77 |
+
query = self.key_fiber.to_attention_heads(query, self.num_heads)
|
78 |
+
|
79 |
+
with nvtx_range('attention dot product + softmax'):
|
80 |
+
# Compute attention weights (softmax of inner product between key and query)
|
81 |
+
edge_weights = dgl.ops.e_dot_v(graph, key, query).squeeze(-1)
|
82 |
+
edge_weights /= np.sqrt(self.key_fiber.num_features)
|
83 |
+
edge_weights = edge_softmax(graph, edge_weights)
|
84 |
+
edge_weights = edge_weights[..., None, None]
|
85 |
+
|
86 |
+
with nvtx_range('weighted sum'):
|
87 |
+
if isinstance(value, Tensor):
|
88 |
+
# features of all types are fused
|
89 |
+
v = value.view(value.shape[0], self.num_heads, -1, value.shape[-1])
|
90 |
+
weights = edge_weights * v
|
91 |
+
feat_out = dgl.ops.copy_e_sum(graph, weights)
|
92 |
+
feat_out = feat_out.view(feat_out.shape[0], -1, feat_out.shape[-1]) # merge heads
|
93 |
+
out = unfuse_features(feat_out, self.value_fiber.degrees)
|
94 |
+
else:
|
95 |
+
out = {}
|
96 |
+
for degree, channels in self.value_fiber:
|
97 |
+
v = value[str(degree)].view(-1, self.num_heads, channels // self.num_heads,
|
98 |
+
degree_to_dim(degree))
|
99 |
+
weights = edge_weights * v
|
100 |
+
res = dgl.ops.copy_e_sum(graph, weights)
|
101 |
+
out[str(degree)] = res.view(-1, channels, degree_to_dim(degree)) # merge heads
|
102 |
+
|
103 |
+
return out
|
104 |
+
|
105 |
+
|
106 |
+
class AttentionBlockSE3(nn.Module):
|
107 |
+
""" Multi-headed sparse graph self-attention block with skip connection, linear projection (SE(3)-equivariant) """
|
108 |
+
|
109 |
+
def __init__(
|
110 |
+
self,
|
111 |
+
fiber_in: Fiber,
|
112 |
+
fiber_out: Fiber,
|
113 |
+
fiber_edge: Optional[Fiber] = None,
|
114 |
+
num_heads: int = 4,
|
115 |
+
channels_div: int = 2,
|
116 |
+
use_layer_norm: bool = False,
|
117 |
+
max_degree: bool = 4,
|
118 |
+
fuse_level: ConvSE3FuseLevel = ConvSE3FuseLevel.FULL,
|
119 |
+
**kwargs
|
120 |
+
):
|
121 |
+
"""
|
122 |
+
:param fiber_in: Fiber describing the input features
|
123 |
+
:param fiber_out: Fiber describing the output features
|
124 |
+
:param fiber_edge: Fiber describing the edge features (node distances excluded)
|
125 |
+
:param num_heads: Number of attention heads
|
126 |
+
:param channels_div: Divide the channels by this integer for computing values
|
127 |
+
:param use_layer_norm: Apply layer normalization between MLP layers
|
128 |
+
:param max_degree: Maximum degree used in the bases computation
|
129 |
+
:param fuse_level: Maximum fuse level to use in TFN convolutions
|
130 |
+
"""
|
131 |
+
super().__init__()
|
132 |
+
if fiber_edge is None:
|
133 |
+
fiber_edge = Fiber({})
|
134 |
+
self.fiber_in = fiber_in
|
135 |
+
# value_fiber has same structure as fiber_out but #channels divided by 'channels_div'
|
136 |
+
value_fiber = Fiber([(degree, channels // channels_div) for degree, channels in fiber_out])
|
137 |
+
# key_query_fiber has the same structure as fiber_out, but only degrees which are in in_fiber
|
138 |
+
# (queries are merely projected, hence degrees have to match input)
|
139 |
+
key_query_fiber = Fiber([(fe.degree, fe.channels) for fe in value_fiber if fe.degree in fiber_in.degrees])
|
140 |
+
|
141 |
+
self.to_key_value = ConvSE3(fiber_in, value_fiber + key_query_fiber, pool=False, fiber_edge=fiber_edge,
|
142 |
+
use_layer_norm=use_layer_norm, max_degree=max_degree, fuse_level=fuse_level,
|
143 |
+
allow_fused_output=True)
|
144 |
+
self.to_query = LinearSE3(fiber_in, key_query_fiber)
|
145 |
+
self.attention = AttentionSE3(num_heads, key_query_fiber, value_fiber)
|
146 |
+
self.project = LinearSE3(value_fiber + fiber_in, fiber_out)
|
147 |
+
|
148 |
+
def forward(
|
149 |
+
self,
|
150 |
+
node_features: Dict[str, Tensor],
|
151 |
+
edge_features: Dict[str, Tensor],
|
152 |
+
graph: DGLGraph,
|
153 |
+
basis: Dict[str, Tensor]
|
154 |
+
):
|
155 |
+
with nvtx_range('AttentionBlockSE3'):
|
156 |
+
with nvtx_range('keys / values'):
|
157 |
+
fused_key_value = self.to_key_value(node_features, edge_features, graph, basis)
|
158 |
+
key, value = self._get_key_value_from_fused(fused_key_value)
|
159 |
+
|
160 |
+
with nvtx_range('queries'):
|
161 |
+
query = self.to_query(node_features)
|
162 |
+
|
163 |
+
z = self.attention(value, key, query, graph)
|
164 |
+
z_concat = aggregate_residual(node_features, z, 'cat')
|
165 |
+
return self.project(z_concat)
|
166 |
+
|
167 |
+
def _get_key_value_from_fused(self, fused_key_value):
|
168 |
+
# Extract keys and queries features from fused features
|
169 |
+
if isinstance(fused_key_value, Tensor):
|
170 |
+
# Previous layer was a fully fused convolution
|
171 |
+
value, key = torch.chunk(fused_key_value, chunks=2, dim=-2)
|
172 |
+
else:
|
173 |
+
key, value = {}, {}
|
174 |
+
for degree, feat in fused_key_value.items():
|
175 |
+
if int(degree) in self.fiber_in.degrees:
|
176 |
+
value[degree], key[degree] = torch.chunk(feat, chunks=2, dim=-2)
|
177 |
+
else:
|
178 |
+
value[degree] = feat
|
179 |
+
|
180 |
+
return key, value
|
se3_transformer/model/layers/convolution.py
ADDED
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
4 |
+
# copy of this software and associated documentation files (the "Software"),
|
5 |
+
# to deal in the Software without restriction, including without limitation
|
6 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
8 |
+
# Software is furnished to do so, subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in
|
11 |
+
# all copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
16 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
18 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
19 |
+
# DEALINGS IN THE SOFTWARE.
|
20 |
+
#
|
21 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
22 |
+
# SPDX-License-Identifier: MIT
|
23 |
+
|
24 |
+
from enum import Enum
|
25 |
+
from itertools import product
|
26 |
+
from typing import Dict
|
27 |
+
|
28 |
+
import dgl
|
29 |
+
import numpy as np
|
30 |
+
import torch
|
31 |
+
import torch.nn as nn
|
32 |
+
from dgl import DGLGraph
|
33 |
+
from torch import Tensor
|
34 |
+
from torch.cuda.nvtx import range as nvtx_range
|
35 |
+
|
36 |
+
from se3_transformer.model.fiber import Fiber
|
37 |
+
from se3_transformer.runtime.utils import degree_to_dim, unfuse_features
|
38 |
+
|
39 |
+
|
40 |
+
class ConvSE3FuseLevel(Enum):
|
41 |
+
"""
|
42 |
+
Enum to select a maximum level of fusing optimizations that will be applied when certain conditions are met.
|
43 |
+
If a desired level L is picked and the level L cannot be applied to a level, other fused ops < L are considered.
|
44 |
+
A higher level means faster training, but also more memory usage.
|
45 |
+
If you are tight on memory and want to feed large inputs to the network, choose a low value.
|
46 |
+
If you want to train fast, choose a high value.
|
47 |
+
Recommended value is FULL with AMP.
|
48 |
+
|
49 |
+
Fully fused TFN convolutions requirements:
|
50 |
+
- all input channels are the same
|
51 |
+
- all output channels are the same
|
52 |
+
- input degrees span the range [0, ..., max_degree]
|
53 |
+
- output degrees span the range [0, ..., max_degree]
|
54 |
+
|
55 |
+
Partially fused TFN convolutions requirements:
|
56 |
+
* For fusing by output degree:
|
57 |
+
- all input channels are the same
|
58 |
+
- input degrees span the range [0, ..., max_degree]
|
59 |
+
* For fusing by input degree:
|
60 |
+
- all output channels are the same
|
61 |
+
- output degrees span the range [0, ..., max_degree]
|
62 |
+
|
63 |
+
Original TFN pairwise convolutions: no requirements
|
64 |
+
"""
|
65 |
+
|
66 |
+
FULL = 2
|
67 |
+
PARTIAL = 1
|
68 |
+
NONE = 0
|
69 |
+
|
70 |
+
|
71 |
+
class RadialProfile(nn.Module):
|
72 |
+
"""
|
73 |
+
Radial profile function.
|
74 |
+
Outputs weights used to weigh basis matrices in order to get convolution kernels.
|
75 |
+
In TFN notation: $R^{l,k}$
|
76 |
+
In SE(3)-Transformer notation: $\phi^{l,k}$
|
77 |
+
|
78 |
+
Note:
|
79 |
+
In the original papers, this function only depends on relative node distances ||x||.
|
80 |
+
Here, we allow this function to also take as input additional invariant edge features.
|
81 |
+
This does not break equivariance and adds expressive power to the model.
|
82 |
+
|
83 |
+
Diagram:
|
84 |
+
invariant edge features (node distances included) ───> MLP layer (shared across edges) ───> radial weights
|
85 |
+
"""
|
86 |
+
|
87 |
+
def __init__(
|
88 |
+
self,
|
89 |
+
num_freq: int,
|
90 |
+
channels_in: int,
|
91 |
+
channels_out: int,
|
92 |
+
edge_dim: int = 1,
|
93 |
+
mid_dim: int = 32,
|
94 |
+
use_layer_norm: bool = False
|
95 |
+
):
|
96 |
+
"""
|
97 |
+
:param num_freq: Number of frequencies
|
98 |
+
:param channels_in: Number of input channels
|
99 |
+
:param channels_out: Number of output channels
|
100 |
+
:param edge_dim: Number of invariant edge features (input to the radial function)
|
101 |
+
:param mid_dim: Size of the hidden MLP layers
|
102 |
+
:param use_layer_norm: Apply layer normalization between MLP layers
|
103 |
+
"""
|
104 |
+
super().__init__()
|
105 |
+
modules = [
|
106 |
+
nn.Linear(edge_dim, mid_dim),
|
107 |
+
nn.LayerNorm(mid_dim) if use_layer_norm else None,
|
108 |
+
nn.ReLU(),
|
109 |
+
nn.Linear(mid_dim, mid_dim),
|
110 |
+
nn.LayerNorm(mid_dim) if use_layer_norm else None,
|
111 |
+
nn.ReLU(),
|
112 |
+
nn.Linear(mid_dim, num_freq * channels_in * channels_out, bias=False)
|
113 |
+
]
|
114 |
+
|
115 |
+
self.net = nn.Sequential(*[m for m in modules if m is not None])
|
116 |
+
|
117 |
+
def forward(self, features: Tensor) -> Tensor:
|
118 |
+
return self.net(features)
|
119 |
+
|
120 |
+
|
121 |
+
class VersatileConvSE3(nn.Module):
|
122 |
+
"""
|
123 |
+
Building block for TFN convolutions.
|
124 |
+
This single module can be used for fully fused convolutions, partially fused convolutions, or pairwise convolutions.
|
125 |
+
"""
|
126 |
+
|
127 |
+
def __init__(self,
|
128 |
+
freq_sum: int,
|
129 |
+
channels_in: int,
|
130 |
+
channels_out: int,
|
131 |
+
edge_dim: int,
|
132 |
+
use_layer_norm: bool,
|
133 |
+
fuse_level: ConvSE3FuseLevel):
|
134 |
+
super().__init__()
|
135 |
+
self.freq_sum = freq_sum
|
136 |
+
self.channels_out = channels_out
|
137 |
+
self.channels_in = channels_in
|
138 |
+
self.fuse_level = fuse_level
|
139 |
+
self.radial_func = RadialProfile(num_freq=freq_sum,
|
140 |
+
channels_in=channels_in,
|
141 |
+
channels_out=channels_out,
|
142 |
+
edge_dim=edge_dim,
|
143 |
+
use_layer_norm=use_layer_norm)
|
144 |
+
|
145 |
+
def forward(self, features: Tensor, invariant_edge_feats: Tensor, basis: Tensor):
|
146 |
+
with nvtx_range(f'VersatileConvSE3'):
|
147 |
+
num_edges = features.shape[0]
|
148 |
+
in_dim = features.shape[2]
|
149 |
+
with nvtx_range(f'RadialProfile'):
|
150 |
+
radial_weights = self.radial_func(invariant_edge_feats) \
|
151 |
+
.view(-1, self.channels_out, self.channels_in * self.freq_sum)
|
152 |
+
|
153 |
+
if basis is not None:
|
154 |
+
# This block performs the einsum n i l, n o i f, n l f k -> n o k
|
155 |
+
out_dim = basis.shape[-1]
|
156 |
+
if self.fuse_level != ConvSE3FuseLevel.FULL:
|
157 |
+
out_dim += out_dim % 2 - 1 # Account for padded basis
|
158 |
+
basis_view = basis.view(num_edges, in_dim, -1)
|
159 |
+
tmp = (features @ basis_view).view(num_edges, -1, basis.shape[-1])
|
160 |
+
return (radial_weights @ tmp)[:, :, :out_dim]
|
161 |
+
else:
|
162 |
+
# k = l = 0 non-fused case
|
163 |
+
return radial_weights @ features
|
164 |
+
|
165 |
+
|
166 |
+
class ConvSE3(nn.Module):
|
167 |
+
"""
|
168 |
+
SE(3)-equivariant graph convolution (Tensor Field Network convolution).
|
169 |
+
This convolution can map an arbitrary input Fiber to an arbitrary output Fiber, while preserving equivariance.
|
170 |
+
Features of different degrees interact together to produce output features.
|
171 |
+
|
172 |
+
Note 1:
|
173 |
+
The option is given to not pool the output. This means that the convolution sum over neighbors will not be
|
174 |
+
done, and the returned features will be edge features instead of node features.
|
175 |
+
|
176 |
+
Note 2:
|
177 |
+
Unlike the original paper and implementation, this convolution can handle edge feature of degree greater than 0.
|
178 |
+
Input edge features are concatenated with input source node features before the kernel is applied.
|
179 |
+
"""
|
180 |
+
|
181 |
+
def __init__(
|
182 |
+
self,
|
183 |
+
fiber_in: Fiber,
|
184 |
+
fiber_out: Fiber,
|
185 |
+
fiber_edge: Fiber,
|
186 |
+
pool: bool = True,
|
187 |
+
use_layer_norm: bool = False,
|
188 |
+
self_interaction: bool = False,
|
189 |
+
max_degree: int = 4,
|
190 |
+
fuse_level: ConvSE3FuseLevel = ConvSE3FuseLevel.FULL,
|
191 |
+
allow_fused_output: bool = False
|
192 |
+
):
|
193 |
+
"""
|
194 |
+
:param fiber_in: Fiber describing the input features
|
195 |
+
:param fiber_out: Fiber describing the output features
|
196 |
+
:param fiber_edge: Fiber describing the edge features (node distances excluded)
|
197 |
+
:param pool: If True, compute final node features by averaging incoming edge features
|
198 |
+
:param use_layer_norm: Apply layer normalization between MLP layers
|
199 |
+
:param self_interaction: Apply self-interaction of nodes
|
200 |
+
:param max_degree: Maximum degree used in the bases computation
|
201 |
+
:param fuse_level: Maximum fuse level to use in TFN convolutions
|
202 |
+
:param allow_fused_output: Allow the module to output a fused representation of features
|
203 |
+
"""
|
204 |
+
super().__init__()
|
205 |
+
self.pool = pool
|
206 |
+
self.fiber_in = fiber_in
|
207 |
+
self.fiber_out = fiber_out
|
208 |
+
self.self_interaction = self_interaction
|
209 |
+
self.max_degree = max_degree
|
210 |
+
self.allow_fused_output = allow_fused_output
|
211 |
+
|
212 |
+
# channels_in: account for the concatenation of edge features
|
213 |
+
channels_in_set = set([f.channels + fiber_edge[f.degree] * (f.degree > 0) for f in self.fiber_in])
|
214 |
+
channels_out_set = set([f.channels for f in self.fiber_out])
|
215 |
+
unique_channels_in = (len(channels_in_set) == 1)
|
216 |
+
unique_channels_out = (len(channels_out_set) == 1)
|
217 |
+
degrees_up_to_max = list(range(max_degree + 1))
|
218 |
+
common_args = dict(edge_dim=fiber_edge[0] + 1, use_layer_norm=use_layer_norm)
|
219 |
+
|
220 |
+
if fuse_level.value >= ConvSE3FuseLevel.FULL.value and \
|
221 |
+
unique_channels_in and fiber_in.degrees == degrees_up_to_max and \
|
222 |
+
unique_channels_out and fiber_out.degrees == degrees_up_to_max:
|
223 |
+
# Single fused convolution
|
224 |
+
self.used_fuse_level = ConvSE3FuseLevel.FULL
|
225 |
+
|
226 |
+
sum_freq = sum([
|
227 |
+
degree_to_dim(min(d_in, d_out))
|
228 |
+
for d_in, d_out in product(degrees_up_to_max, degrees_up_to_max)
|
229 |
+
])
|
230 |
+
|
231 |
+
self.conv = VersatileConvSE3(sum_freq, list(channels_in_set)[0], list(channels_out_set)[0],
|
232 |
+
fuse_level=self.used_fuse_level, **common_args)
|
233 |
+
|
234 |
+
elif fuse_level.value >= ConvSE3FuseLevel.PARTIAL.value and \
|
235 |
+
unique_channels_in and fiber_in.degrees == degrees_up_to_max:
|
236 |
+
# Convolutions fused per output degree
|
237 |
+
self.used_fuse_level = ConvSE3FuseLevel.PARTIAL
|
238 |
+
self.conv_out = nn.ModuleDict()
|
239 |
+
for d_out, c_out in fiber_out:
|
240 |
+
sum_freq = sum([degree_to_dim(min(d_out, d)) for d in fiber_in.degrees])
|
241 |
+
self.conv_out[str(d_out)] = VersatileConvSE3(sum_freq, list(channels_in_set)[0], c_out,
|
242 |
+
fuse_level=self.used_fuse_level, **common_args)
|
243 |
+
|
244 |
+
elif fuse_level.value >= ConvSE3FuseLevel.PARTIAL.value and \
|
245 |
+
unique_channels_out and fiber_out.degrees == degrees_up_to_max:
|
246 |
+
# Convolutions fused per input degree
|
247 |
+
self.used_fuse_level = ConvSE3FuseLevel.PARTIAL
|
248 |
+
self.conv_in = nn.ModuleDict()
|
249 |
+
for d_in, c_in in fiber_in:
|
250 |
+
sum_freq = sum([degree_to_dim(min(d_in, d)) for d in fiber_out.degrees])
|
251 |
+
self.conv_in[str(d_in)] = VersatileConvSE3(sum_freq, c_in, list(channels_out_set)[0],
|
252 |
+
fuse_level=ConvSE3FuseLevel.FULL, **common_args)
|
253 |
+
#fuse_level=self.used_fuse_level, **common_args)
|
254 |
+
else:
|
255 |
+
# Use pairwise TFN convolutions
|
256 |
+
self.used_fuse_level = ConvSE3FuseLevel.NONE
|
257 |
+
self.conv = nn.ModuleDict()
|
258 |
+
for (degree_in, channels_in), (degree_out, channels_out) in (self.fiber_in * self.fiber_out):
|
259 |
+
dict_key = f'{degree_in},{degree_out}'
|
260 |
+
channels_in_new = channels_in + fiber_edge[degree_in] * (degree_in > 0)
|
261 |
+
sum_freq = degree_to_dim(min(degree_in, degree_out))
|
262 |
+
self.conv[dict_key] = VersatileConvSE3(sum_freq, channels_in_new, channels_out,
|
263 |
+
fuse_level=self.used_fuse_level, **common_args)
|
264 |
+
|
265 |
+
if self_interaction:
|
266 |
+
self.to_kernel_self = nn.ParameterDict()
|
267 |
+
for degree_out, channels_out in fiber_out:
|
268 |
+
if fiber_in[degree_out]:
|
269 |
+
self.to_kernel_self[str(degree_out)] = nn.Parameter(
|
270 |
+
torch.randn(channels_out, fiber_in[degree_out]) / np.sqrt(fiber_in[degree_out]))
|
271 |
+
|
272 |
+
def forward(
|
273 |
+
self,
|
274 |
+
node_feats: Dict[str, Tensor],
|
275 |
+
edge_feats: Dict[str, Tensor],
|
276 |
+
graph: DGLGraph,
|
277 |
+
basis: Dict[str, Tensor]
|
278 |
+
):
|
279 |
+
with nvtx_range(f'ConvSE3'):
|
280 |
+
invariant_edge_feats = edge_feats['0'].squeeze(-1)
|
281 |
+
src, dst = graph.edges()
|
282 |
+
out = {}
|
283 |
+
in_features = []
|
284 |
+
|
285 |
+
# Fetch all input features from edge and node features
|
286 |
+
for degree_in in self.fiber_in.degrees:
|
287 |
+
src_node_features = node_feats[str(degree_in)][src]
|
288 |
+
if degree_in > 0 and str(degree_in) in edge_feats:
|
289 |
+
# Handle edge features of any type by concatenating them to node features
|
290 |
+
src_node_features = torch.cat([src_node_features, edge_feats[str(degree_in)]], dim=1)
|
291 |
+
in_features.append(src_node_features)
|
292 |
+
|
293 |
+
if self.used_fuse_level == ConvSE3FuseLevel.FULL:
|
294 |
+
in_features_fused = torch.cat(in_features, dim=-1)
|
295 |
+
out = self.conv(in_features_fused, invariant_edge_feats, basis['fully_fused'])
|
296 |
+
|
297 |
+
if not self.allow_fused_output or self.self_interaction or self.pool:
|
298 |
+
out = unfuse_features(out, self.fiber_out.degrees)
|
299 |
+
|
300 |
+
elif self.used_fuse_level == ConvSE3FuseLevel.PARTIAL and hasattr(self, 'conv_out'):
|
301 |
+
in_features_fused = torch.cat(in_features, dim=-1)
|
302 |
+
for degree_out in self.fiber_out.degrees:
|
303 |
+
out[str(degree_out)] = self.conv_out[str(degree_out)](in_features_fused, invariant_edge_feats,
|
304 |
+
basis[f'out{degree_out}_fused'])
|
305 |
+
|
306 |
+
elif self.used_fuse_level == ConvSE3FuseLevel.PARTIAL and hasattr(self, 'conv_in'):
|
307 |
+
out = 0
|
308 |
+
for degree_in, feature in zip(self.fiber_in.degrees, in_features):
|
309 |
+
out += self.conv_in[str(degree_in)](feature, invariant_edge_feats,
|
310 |
+
basis[f'in{degree_in}_fused'])
|
311 |
+
if not self.allow_fused_output or self.self_interaction or self.pool:
|
312 |
+
out = unfuse_features(out, self.fiber_out.degrees)
|
313 |
+
else:
|
314 |
+
# Fallback to pairwise TFN convolutions
|
315 |
+
for degree_out in self.fiber_out.degrees:
|
316 |
+
out_feature = 0
|
317 |
+
for degree_in, feature in zip(self.fiber_in.degrees, in_features):
|
318 |
+
dict_key = f'{degree_in},{degree_out}'
|
319 |
+
out_feature = out_feature + self.conv[dict_key](feature, invariant_edge_feats,
|
320 |
+
basis.get(dict_key, None))
|
321 |
+
out[str(degree_out)] = out_feature
|
322 |
+
|
323 |
+
for degree_out in self.fiber_out.degrees:
|
324 |
+
if self.self_interaction and str(degree_out) in self.to_kernel_self:
|
325 |
+
with nvtx_range(f'self interaction'):
|
326 |
+
dst_features = node_feats[str(degree_out)][dst]
|
327 |
+
kernel_self = self.to_kernel_self[str(degree_out)]
|
328 |
+
out[str(degree_out)] += kernel_self @ dst_features
|
329 |
+
|
330 |
+
if self.pool:
|
331 |
+
with nvtx_range(f'pooling'):
|
332 |
+
if isinstance(out, dict):
|
333 |
+
out[str(degree_out)] = dgl.ops.copy_e_sum(graph, out[str(degree_out)])
|
334 |
+
else:
|
335 |
+
out = dgl.ops.copy_e_sum(graph, out)
|
336 |
+
return out
|
se3_transformer/model/layers/linear.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
4 |
+
# copy of this software and associated documentation files (the "Software"),
|
5 |
+
# to deal in the Software without restriction, including without limitation
|
6 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
8 |
+
# Software is furnished to do so, subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in
|
11 |
+
# all copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
16 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
18 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
19 |
+
# DEALINGS IN THE SOFTWARE.
|
20 |
+
#
|
21 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
22 |
+
# SPDX-License-Identifier: MIT
|
23 |
+
|
24 |
+
|
25 |
+
from typing import Dict
|
26 |
+
|
27 |
+
import numpy as np
|
28 |
+
import torch
|
29 |
+
import torch.nn as nn
|
30 |
+
from torch import Tensor
|
31 |
+
|
32 |
+
from se3_transformer.model.fiber import Fiber
|
33 |
+
|
34 |
+
|
35 |
+
class LinearSE3(nn.Module):
|
36 |
+
"""
|
37 |
+
Graph Linear SE(3)-equivariant layer, equivalent to a 1x1 convolution.
|
38 |
+
Maps a fiber to a fiber with the same degrees (channels may be different).
|
39 |
+
No interaction between degrees, but interaction between channels.
|
40 |
+
|
41 |
+
type-0 features (C_0 channels) ────> Linear(bias=False) ────> type-0 features (C'_0 channels)
|
42 |
+
type-1 features (C_1 channels) ────> Linear(bias=False) ────> type-1 features (C'_1 channels)
|
43 |
+
:
|
44 |
+
type-k features (C_k channels) ────> Linear(bias=False) ────> type-k features (C'_k channels)
|
45 |
+
"""
|
46 |
+
|
47 |
+
def __init__(self, fiber_in: Fiber, fiber_out: Fiber):
|
48 |
+
super().__init__()
|
49 |
+
self.weights = nn.ParameterDict({
|
50 |
+
str(degree_out): nn.Parameter(
|
51 |
+
torch.randn(channels_out, fiber_in[degree_out]) / np.sqrt(fiber_in[degree_out]))
|
52 |
+
for degree_out, channels_out in fiber_out
|
53 |
+
})
|
54 |
+
|
55 |
+
def forward(self, features: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Tensor]:
|
56 |
+
return {
|
57 |
+
degree: self.weights[degree] @ features[degree]
|
58 |
+
for degree, weight in self.weights.items()
|
59 |
+
}
|
se3_transformer/model/layers/norm.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
4 |
+
# copy of this software and associated documentation files (the "Software"),
|
5 |
+
# to deal in the Software without restriction, including without limitation
|
6 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
8 |
+
# Software is furnished to do so, subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in
|
11 |
+
# all copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
16 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
18 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
19 |
+
# DEALINGS IN THE SOFTWARE.
|
20 |
+
#
|
21 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
22 |
+
# SPDX-License-Identifier: MIT
|
23 |
+
|
24 |
+
|
25 |
+
from typing import Dict
|
26 |
+
|
27 |
+
import torch
|
28 |
+
import torch.nn as nn
|
29 |
+
from torch import Tensor
|
30 |
+
from torch.cuda.nvtx import range as nvtx_range
|
31 |
+
|
32 |
+
from se3_transformer.model.fiber import Fiber
|
33 |
+
|
34 |
+
|
35 |
+
class NormSE3(nn.Module):
|
36 |
+
"""
|
37 |
+
Norm-based SE(3)-equivariant nonlinearity.
|
38 |
+
|
39 |
+
┌──> feature_norm ──> LayerNorm() ──> ReLU() ──┐
|
40 |
+
feature_in ──┤ * ──> feature_out
|
41 |
+
└──> feature_phase ────────────────────────────┘
|
42 |
+
"""
|
43 |
+
|
44 |
+
NORM_CLAMP = 2 ** -24 # Minimum positive subnormal for FP16
|
45 |
+
|
46 |
+
def __init__(self, fiber: Fiber, nonlinearity: nn.Module = nn.ReLU()):
|
47 |
+
super().__init__()
|
48 |
+
self.fiber = fiber
|
49 |
+
self.nonlinearity = nonlinearity
|
50 |
+
|
51 |
+
if len(set(fiber.channels)) == 1:
|
52 |
+
# Fuse all the layer normalizations into a group normalization
|
53 |
+
self.group_norm = nn.GroupNorm(num_groups=len(fiber.degrees), num_channels=sum(fiber.channels))
|
54 |
+
else:
|
55 |
+
# Use multiple layer normalizations
|
56 |
+
self.layer_norms = nn.ModuleDict({
|
57 |
+
str(degree): nn.LayerNorm(channels)
|
58 |
+
for degree, channels in fiber
|
59 |
+
})
|
60 |
+
|
61 |
+
def forward(self, features: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Tensor]:
|
62 |
+
with nvtx_range('NormSE3'):
|
63 |
+
output = {}
|
64 |
+
if hasattr(self, 'group_norm'):
|
65 |
+
# Compute per-degree norms of features
|
66 |
+
norms = [features[str(d)].norm(dim=-1, keepdim=True).clamp(min=self.NORM_CLAMP)
|
67 |
+
for d in self.fiber.degrees]
|
68 |
+
fused_norms = torch.cat(norms, dim=-2)
|
69 |
+
|
70 |
+
# Transform the norms only
|
71 |
+
new_norms = self.nonlinearity(self.group_norm(fused_norms.squeeze(-1))).unsqueeze(-1)
|
72 |
+
new_norms = torch.chunk(new_norms, chunks=len(self.fiber.degrees), dim=-2)
|
73 |
+
|
74 |
+
# Scale features to the new norms
|
75 |
+
for norm, new_norm, d in zip(norms, new_norms, self.fiber.degrees):
|
76 |
+
output[str(d)] = features[str(d)] / norm * new_norm
|
77 |
+
else:
|
78 |
+
for degree, feat in features.items():
|
79 |
+
norm = feat.norm(dim=-1, keepdim=True).clamp(min=self.NORM_CLAMP)
|
80 |
+
new_norm = self.nonlinearity(self.layer_norms[degree](norm.squeeze(-1)).unsqueeze(-1))
|
81 |
+
output[degree] = new_norm * feat / norm
|
82 |
+
|
83 |
+
return output
|
se3_transformer/model/layers/pooling.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
4 |
+
# copy of this software and associated documentation files (the "Software"),
|
5 |
+
# to deal in the Software without restriction, including without limitation
|
6 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
8 |
+
# Software is furnished to do so, subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in
|
11 |
+
# all copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
16 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
18 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
19 |
+
# DEALINGS IN THE SOFTWARE.
|
20 |
+
#
|
21 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
22 |
+
# SPDX-License-Identifier: MIT
|
23 |
+
|
24 |
+
from typing import Dict, Literal
|
25 |
+
|
26 |
+
import torch.nn as nn
|
27 |
+
from dgl import DGLGraph
|
28 |
+
from dgl.nn.pytorch import AvgPooling, MaxPooling
|
29 |
+
from torch import Tensor
|
30 |
+
|
31 |
+
|
32 |
+
class GPooling(nn.Module):
|
33 |
+
"""
|
34 |
+
Graph max/average pooling on a given feature type.
|
35 |
+
The average can be taken for any feature type, and equivariance will be maintained.
|
36 |
+
The maximum can only be taken for invariant features (type 0).
|
37 |
+
If you want max-pooling for type > 0 features, look into Vector Neurons.
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(self, feat_type: int = 0, pool: Literal['max', 'avg'] = 'max'):
|
41 |
+
"""
|
42 |
+
:param feat_type: Feature type to pool
|
43 |
+
:param pool: Type of pooling: max or avg
|
44 |
+
"""
|
45 |
+
super().__init__()
|
46 |
+
assert pool in ['max', 'avg'], f'Unknown pooling: {pool}'
|
47 |
+
assert feat_type == 0 or pool == 'avg', 'Max pooling on type > 0 features will break equivariance'
|
48 |
+
self.feat_type = feat_type
|
49 |
+
self.pool = MaxPooling() if pool == 'max' else AvgPooling()
|
50 |
+
|
51 |
+
def forward(self, features: Dict[str, Tensor], graph: DGLGraph, **kwargs) -> Tensor:
|
52 |
+
pooled = self.pool(graph, features[str(self.feat_type)])
|
53 |
+
return pooled.squeeze(dim=-1)
|
se3_transformer/model/transformer.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
4 |
+
# copy of this software and associated documentation files (the "Software"),
|
5 |
+
# to deal in the Software without restriction, including without limitation
|
6 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
8 |
+
# Software is furnished to do so, subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in
|
11 |
+
# all copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
16 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
18 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
19 |
+
# DEALINGS IN THE SOFTWARE.
|
20 |
+
#
|
21 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
22 |
+
# SPDX-License-Identifier: MIT
|
23 |
+
|
24 |
+
import logging
|
25 |
+
from typing import Optional, Literal, Dict
|
26 |
+
|
27 |
+
import torch
|
28 |
+
import torch.nn as nn
|
29 |
+
from dgl import DGLGraph
|
30 |
+
from torch import Tensor
|
31 |
+
|
32 |
+
from se3_transformer.model.basis import get_basis, update_basis_with_fused
|
33 |
+
from se3_transformer.model.layers.attention import AttentionBlockSE3
|
34 |
+
from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel
|
35 |
+
from se3_transformer.model.layers.norm import NormSE3
|
36 |
+
from se3_transformer.model.layers.pooling import GPooling
|
37 |
+
from se3_transformer.runtime.utils import str2bool
|
38 |
+
from se3_transformer.model.fiber import Fiber
|
39 |
+
|
40 |
+
|
41 |
+
class Sequential(nn.Sequential):
|
42 |
+
""" Sequential module with arbitrary forward args and kwargs. Used to pass graph, basis and edge features. """
|
43 |
+
|
44 |
+
def forward(self, input, *args, **kwargs):
|
45 |
+
for module in self:
|
46 |
+
input = module(input, *args, **kwargs)
|
47 |
+
return input
|
48 |
+
|
49 |
+
|
50 |
+
def get_populated_edge_features(relative_pos: Tensor, edge_features: Optional[Dict[str, Tensor]] = None):
|
51 |
+
""" Add relative positions to existing edge features """
|
52 |
+
edge_features = edge_features.copy() if edge_features else {}
|
53 |
+
r = relative_pos.norm(dim=-1, keepdim=True)
|
54 |
+
if '0' in edge_features:
|
55 |
+
edge_features['0'] = torch.cat([edge_features['0'], r[..., None]], dim=1)
|
56 |
+
else:
|
57 |
+
edge_features['0'] = r[..., None]
|
58 |
+
|
59 |
+
return edge_features
|
60 |
+
|
61 |
+
|
62 |
+
class SE3Transformer(nn.Module):
|
63 |
+
def __init__(self,
|
64 |
+
num_layers: int,
|
65 |
+
fiber_in: Fiber,
|
66 |
+
fiber_hidden: Fiber,
|
67 |
+
fiber_out: Fiber,
|
68 |
+
num_heads: int,
|
69 |
+
channels_div: int,
|
70 |
+
fiber_edge: Fiber = Fiber({}),
|
71 |
+
return_type: Optional[int] = None,
|
72 |
+
pooling: Optional[Literal['avg', 'max']] = None,
|
73 |
+
norm: bool = True,
|
74 |
+
use_layer_norm: bool = True,
|
75 |
+
tensor_cores: bool = False,
|
76 |
+
low_memory: bool = False,
|
77 |
+
**kwargs):
|
78 |
+
"""
|
79 |
+
:param num_layers: Number of attention layers
|
80 |
+
:param fiber_in: Input fiber description
|
81 |
+
:param fiber_hidden: Hidden fiber description
|
82 |
+
:param fiber_out: Output fiber description
|
83 |
+
:param fiber_edge: Input edge fiber description
|
84 |
+
:param num_heads: Number of attention heads
|
85 |
+
:param channels_div: Channels division before feeding to attention layer
|
86 |
+
:param return_type: Return only features of this type
|
87 |
+
:param pooling: 'avg' or 'max' graph pooling before MLP layers
|
88 |
+
:param norm: Apply a normalization layer after each attention block
|
89 |
+
:param use_layer_norm: Apply layer normalization between MLP layers
|
90 |
+
:param tensor_cores: True if using Tensor Cores (affects the use of fully fused convs, and padded bases)
|
91 |
+
:param low_memory: If True, will use slower ops that use less memory
|
92 |
+
"""
|
93 |
+
super().__init__()
|
94 |
+
self.num_layers = num_layers
|
95 |
+
self.fiber_edge = fiber_edge
|
96 |
+
self.num_heads = num_heads
|
97 |
+
self.channels_div = channels_div
|
98 |
+
self.return_type = return_type
|
99 |
+
self.pooling = pooling
|
100 |
+
self.max_degree = max(*fiber_in.degrees, *fiber_hidden.degrees, *fiber_out.degrees)
|
101 |
+
self.tensor_cores = tensor_cores
|
102 |
+
self.low_memory = low_memory
|
103 |
+
|
104 |
+
if low_memory and not tensor_cores:
|
105 |
+
logging.warning('Low memory mode will have no effect with no Tensor Cores')
|
106 |
+
|
107 |
+
# Fully fused convolutions when using Tensor Cores (and not low memory mode)
|
108 |
+
fuse_level = ConvSE3FuseLevel.FULL if tensor_cores and not low_memory else ConvSE3FuseLevel.PARTIAL
|
109 |
+
|
110 |
+
graph_modules = []
|
111 |
+
for i in range(num_layers):
|
112 |
+
graph_modules.append(AttentionBlockSE3(fiber_in=fiber_in,
|
113 |
+
fiber_out=fiber_hidden,
|
114 |
+
fiber_edge=fiber_edge,
|
115 |
+
num_heads=num_heads,
|
116 |
+
channels_div=channels_div,
|
117 |
+
use_layer_norm=use_layer_norm,
|
118 |
+
max_degree=self.max_degree,
|
119 |
+
fuse_level=fuse_level))
|
120 |
+
if norm:
|
121 |
+
graph_modules.append(NormSE3(fiber_hidden))
|
122 |
+
fiber_in = fiber_hidden
|
123 |
+
|
124 |
+
graph_modules.append(ConvSE3(fiber_in=fiber_in,
|
125 |
+
fiber_out=fiber_out,
|
126 |
+
fiber_edge=fiber_edge,
|
127 |
+
self_interaction=True,
|
128 |
+
use_layer_norm=use_layer_norm,
|
129 |
+
max_degree=self.max_degree))
|
130 |
+
self.graph_modules = Sequential(*graph_modules)
|
131 |
+
|
132 |
+
if pooling is not None:
|
133 |
+
assert return_type is not None, 'return_type must be specified when pooling'
|
134 |
+
self.pooling_module = GPooling(pool=pooling, feat_type=return_type)
|
135 |
+
|
136 |
+
def forward(self, graph: DGLGraph, node_feats: Dict[str, Tensor],
|
137 |
+
edge_feats: Optional[Dict[str, Tensor]] = None,
|
138 |
+
basis: Optional[Dict[str, Tensor]] = None):
|
139 |
+
# Compute bases in case they weren't precomputed as part of the data loading
|
140 |
+
basis = basis or get_basis(graph.edata['rel_pos'], max_degree=self.max_degree, compute_gradients=False,
|
141 |
+
use_pad_trick=self.tensor_cores and not self.low_memory,
|
142 |
+
amp=torch.is_autocast_enabled())
|
143 |
+
|
144 |
+
# Add fused bases (per output degree, per input degree, and fully fused) to the dict
|
145 |
+
basis = update_basis_with_fused(basis, self.max_degree, use_pad_trick=self.tensor_cores and not self.low_memory,
|
146 |
+
fully_fused=self.tensor_cores and not self.low_memory)
|
147 |
+
|
148 |
+
edge_feats = get_populated_edge_features(graph.edata['rel_pos'], edge_feats)
|
149 |
+
|
150 |
+
node_feats = self.graph_modules(node_feats, edge_feats, graph=graph, basis=basis)
|
151 |
+
|
152 |
+
if self.pooling is not None:
|
153 |
+
return self.pooling_module(node_feats, graph=graph)
|
154 |
+
|
155 |
+
if self.return_type is not None:
|
156 |
+
return node_feats[str(self.return_type)]
|
157 |
+
|
158 |
+
return node_feats
|
159 |
+
|
160 |
+
@staticmethod
|
161 |
+
def add_argparse_args(parser):
|
162 |
+
parser.add_argument('--num_layers', type=int, default=7,
|
163 |
+
help='Number of stacked Transformer layers')
|
164 |
+
parser.add_argument('--num_heads', type=int, default=8,
|
165 |
+
help='Number of heads in self-attention')
|
166 |
+
parser.add_argument('--channels_div', type=int, default=2,
|
167 |
+
help='Channels division before feeding to attention layer')
|
168 |
+
parser.add_argument('--pooling', type=str, default=None, const=None, nargs='?', choices=['max', 'avg'],
|
169 |
+
help='Type of graph pooling')
|
170 |
+
parser.add_argument('--norm', type=str2bool, nargs='?', const=True, default=False,
|
171 |
+
help='Apply a normalization layer after each attention block')
|
172 |
+
parser.add_argument('--use_layer_norm', type=str2bool, nargs='?', const=True, default=False,
|
173 |
+
help='Apply layer normalization between MLP layers')
|
174 |
+
parser.add_argument('--low_memory', type=str2bool, nargs='?', const=True, default=False,
|
175 |
+
help='If true, will use fused ops that are slower but that use less memory '
|
176 |
+
'(expect 25 percent less memory). '
|
177 |
+
'Only has an effect if AMP is enabled on Volta GPUs, or if running on Ampere GPUs')
|
178 |
+
|
179 |
+
return parser
|
180 |
+
|
181 |
+
|
182 |
+
class SE3TransformerPooled(nn.Module):
|
183 |
+
def __init__(self,
|
184 |
+
fiber_in: Fiber,
|
185 |
+
fiber_out: Fiber,
|
186 |
+
fiber_edge: Fiber,
|
187 |
+
num_degrees: int,
|
188 |
+
num_channels: int,
|
189 |
+
output_dim: int,
|
190 |
+
**kwargs):
|
191 |
+
super().__init__()
|
192 |
+
kwargs['pooling'] = kwargs['pooling'] or 'max'
|
193 |
+
self.transformer = SE3Transformer(
|
194 |
+
fiber_in=fiber_in,
|
195 |
+
fiber_hidden=Fiber.create(num_degrees, num_channels),
|
196 |
+
fiber_out=fiber_out,
|
197 |
+
fiber_edge=fiber_edge,
|
198 |
+
return_type=0,
|
199 |
+
**kwargs
|
200 |
+
)
|
201 |
+
|
202 |
+
n_out_features = fiber_out.num_features
|
203 |
+
self.mlp = nn.Sequential(
|
204 |
+
nn.Linear(n_out_features, n_out_features),
|
205 |
+
nn.ReLU(),
|
206 |
+
nn.Linear(n_out_features, output_dim)
|
207 |
+
)
|
208 |
+
|
209 |
+
def forward(self, graph, node_feats, edge_feats, basis=None):
|
210 |
+
feats = self.transformer(graph, node_feats, edge_feats, basis).squeeze(-1)
|
211 |
+
y = self.mlp(feats).squeeze(-1)
|
212 |
+
return y
|
213 |
+
|
214 |
+
@staticmethod
|
215 |
+
def add_argparse_args(parent_parser):
|
216 |
+
parser = parent_parser.add_argument_group("Model architecture")
|
217 |
+
SE3Transformer.add_argparse_args(parser)
|
218 |
+
parser.add_argument('--num_degrees',
|
219 |
+
help='Number of degrees to use. Hidden features will have types [0, ..., num_degrees - 1]',
|
220 |
+
type=int, default=4)
|
221 |
+
parser.add_argument('--num_channels', help='Number of channels for the hidden features', type=int, default=32)
|
222 |
+
return parent_parser
|
se3_transformer/runtime/__init__.py
ADDED
File without changes
|
se3_transformer/runtime/arguments.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
4 |
+
# copy of this software and associated documentation files (the "Software"),
|
5 |
+
# to deal in the Software without restriction, including without limitation
|
6 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
8 |
+
# Software is furnished to do so, subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in
|
11 |
+
# all copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
16 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
18 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
19 |
+
# DEALINGS IN THE SOFTWARE.
|
20 |
+
#
|
21 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
22 |
+
# SPDX-License-Identifier: MIT
|
23 |
+
|
24 |
+
import argparse
|
25 |
+
import pathlib
|
26 |
+
|
27 |
+
from se3_transformer.data_loading import QM9DataModule
|
28 |
+
from se3_transformer.model import SE3TransformerPooled
|
29 |
+
from se3_transformer.runtime.utils import str2bool
|
30 |
+
|
31 |
+
PARSER = argparse.ArgumentParser(description='SE(3)-Transformer')
|
32 |
+
|
33 |
+
paths = PARSER.add_argument_group('Paths')
|
34 |
+
paths.add_argument('--data_dir', type=pathlib.Path, default=pathlib.Path('./data'),
|
35 |
+
help='Directory where the data is located or should be downloaded')
|
36 |
+
paths.add_argument('--log_dir', type=pathlib.Path, default=pathlib.Path('/results'),
|
37 |
+
help='Directory where the results logs should be saved')
|
38 |
+
paths.add_argument('--dllogger_name', type=str, default='dllogger_results.json',
|
39 |
+
help='Name for the resulting DLLogger JSON file')
|
40 |
+
paths.add_argument('--save_ckpt_path', type=pathlib.Path, default=None,
|
41 |
+
help='File where the checkpoint should be saved')
|
42 |
+
paths.add_argument('--load_ckpt_path', type=pathlib.Path, default=None,
|
43 |
+
help='File of the checkpoint to be loaded')
|
44 |
+
|
45 |
+
optimizer = PARSER.add_argument_group('Optimizer')
|
46 |
+
optimizer.add_argument('--optimizer', choices=['adam', 'sgd', 'lamb'], default='adam')
|
47 |
+
optimizer.add_argument('--learning_rate', '--lr', dest='learning_rate', type=float, default=0.002)
|
48 |
+
optimizer.add_argument('--min_learning_rate', '--min_lr', dest='min_learning_rate', type=float, default=None)
|
49 |
+
optimizer.add_argument('--momentum', type=float, default=0.9)
|
50 |
+
optimizer.add_argument('--weight_decay', type=float, default=0.1)
|
51 |
+
|
52 |
+
PARSER.add_argument('--epochs', type=int, default=100, help='Number of training epochs')
|
53 |
+
PARSER.add_argument('--batch_size', type=int, default=240, help='Batch size')
|
54 |
+
PARSER.add_argument('--seed', type=int, default=None, help='Set a seed globally')
|
55 |
+
PARSER.add_argument('--num_workers', type=int, default=8, help='Number of dataloading workers')
|
56 |
+
|
57 |
+
PARSER.add_argument('--amp', type=str2bool, nargs='?', const=True, default=False, help='Use Automatic Mixed Precision')
|
58 |
+
PARSER.add_argument('--gradient_clip', type=float, default=None, help='Clipping of the gradient norms')
|
59 |
+
PARSER.add_argument('--accumulate_grad_batches', type=int, default=1, help='Gradient accumulation')
|
60 |
+
PARSER.add_argument('--ckpt_interval', type=int, default=-1, help='Save a checkpoint every N epochs')
|
61 |
+
PARSER.add_argument('--eval_interval', dest='eval_interval', type=int, default=1,
|
62 |
+
help='Do an evaluation round every N epochs')
|
63 |
+
PARSER.add_argument('--silent', type=str2bool, nargs='?', const=True, default=False,
|
64 |
+
help='Minimize stdout output')
|
65 |
+
|
66 |
+
PARSER.add_argument('--benchmark', type=str2bool, nargs='?', const=True, default=False,
|
67 |
+
help='Benchmark mode')
|
68 |
+
|
69 |
+
QM9DataModule.add_argparse_args(PARSER)
|
70 |
+
SE3TransformerPooled.add_argparse_args(PARSER)
|
se3_transformer/runtime/callbacks.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
4 |
+
# copy of this software and associated documentation files (the "Software"),
|
5 |
+
# to deal in the Software without restriction, including without limitation
|
6 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
8 |
+
# Software is furnished to do so, subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in
|
11 |
+
# all copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
16 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
18 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
19 |
+
# DEALINGS IN THE SOFTWARE.
|
20 |
+
#
|
21 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
22 |
+
# SPDX-License-Identifier: MIT
|
23 |
+
|
24 |
+
import logging
|
25 |
+
import time
|
26 |
+
from abc import ABC, abstractmethod
|
27 |
+
from typing import Optional
|
28 |
+
|
29 |
+
import numpy as np
|
30 |
+
import torch
|
31 |
+
|
32 |
+
from se3_transformer.runtime.loggers import Logger
|
33 |
+
from se3_transformer.runtime.metrics import MeanAbsoluteError
|
34 |
+
|
35 |
+
|
36 |
+
class BaseCallback(ABC):
|
37 |
+
def on_fit_start(self, optimizer, args):
|
38 |
+
pass
|
39 |
+
|
40 |
+
def on_fit_end(self):
|
41 |
+
pass
|
42 |
+
|
43 |
+
def on_epoch_end(self):
|
44 |
+
pass
|
45 |
+
|
46 |
+
def on_batch_start(self):
|
47 |
+
pass
|
48 |
+
|
49 |
+
def on_validation_step(self, input, target, pred):
|
50 |
+
pass
|
51 |
+
|
52 |
+
def on_validation_end(self, epoch=None):
|
53 |
+
pass
|
54 |
+
|
55 |
+
def on_checkpoint_load(self, checkpoint):
|
56 |
+
pass
|
57 |
+
|
58 |
+
def on_checkpoint_save(self, checkpoint):
|
59 |
+
pass
|
60 |
+
|
61 |
+
|
62 |
+
class LRSchedulerCallback(BaseCallback):
|
63 |
+
def __init__(self, logger: Optional[Logger] = None):
|
64 |
+
self.logger = logger
|
65 |
+
self.scheduler = None
|
66 |
+
|
67 |
+
@abstractmethod
|
68 |
+
def get_scheduler(self, optimizer, args):
|
69 |
+
pass
|
70 |
+
|
71 |
+
def on_fit_start(self, optimizer, args):
|
72 |
+
self.scheduler = self.get_scheduler(optimizer, args)
|
73 |
+
|
74 |
+
def on_checkpoint_load(self, checkpoint):
|
75 |
+
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
76 |
+
|
77 |
+
def on_checkpoint_save(self, checkpoint):
|
78 |
+
checkpoint['scheduler_state_dict'] = self.scheduler.state_dict()
|
79 |
+
|
80 |
+
def on_epoch_end(self):
|
81 |
+
if self.logger is not None:
|
82 |
+
self.logger.log_metrics({'learning rate': self.scheduler.get_last_lr()[0]}, step=self.scheduler.last_epoch)
|
83 |
+
self.scheduler.step()
|
84 |
+
|
85 |
+
|
86 |
+
class QM9MetricCallback(BaseCallback):
|
87 |
+
""" Logs the rescaled mean absolute error for QM9 regression tasks """
|
88 |
+
|
89 |
+
def __init__(self, logger, targets_std, prefix=''):
|
90 |
+
self.mae = MeanAbsoluteError()
|
91 |
+
self.logger = logger
|
92 |
+
self.targets_std = targets_std
|
93 |
+
self.prefix = prefix
|
94 |
+
self.best_mae = float('inf')
|
95 |
+
|
96 |
+
def on_validation_step(self, input, target, pred):
|
97 |
+
self.mae(pred.detach(), target.detach())
|
98 |
+
|
99 |
+
def on_validation_end(self, epoch=None):
|
100 |
+
mae = self.mae.compute() * self.targets_std
|
101 |
+
logging.info(f'{self.prefix} MAE: {mae}')
|
102 |
+
self.logger.log_metrics({f'{self.prefix} MAE': mae}, epoch)
|
103 |
+
self.best_mae = min(self.best_mae, mae)
|
104 |
+
|
105 |
+
def on_fit_end(self):
|
106 |
+
if self.best_mae != float('inf'):
|
107 |
+
self.logger.log_metrics({f'{self.prefix} best MAE': self.best_mae})
|
108 |
+
|
109 |
+
|
110 |
+
class QM9LRSchedulerCallback(LRSchedulerCallback):
|
111 |
+
def __init__(self, logger, epochs):
|
112 |
+
super().__init__(logger)
|
113 |
+
self.epochs = epochs
|
114 |
+
|
115 |
+
def get_scheduler(self, optimizer, args):
|
116 |
+
min_lr = args.min_learning_rate if args.min_learning_rate else args.learning_rate / 10.0
|
117 |
+
return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, self.epochs, eta_min=min_lr)
|
118 |
+
|
119 |
+
|
120 |
+
class PerformanceCallback(BaseCallback):
|
121 |
+
def __init__(self, logger, batch_size: int, warmup_epochs: int = 1, mode: str = 'train'):
|
122 |
+
self.batch_size = batch_size
|
123 |
+
self.warmup_epochs = warmup_epochs
|
124 |
+
self.epoch = 0
|
125 |
+
self.timestamps = []
|
126 |
+
self.mode = mode
|
127 |
+
self.logger = logger
|
128 |
+
|
129 |
+
def on_batch_start(self):
|
130 |
+
if self.epoch >= self.warmup_epochs:
|
131 |
+
self.timestamps.append(time.time() * 1000.0)
|
132 |
+
|
133 |
+
def _log_perf(self):
|
134 |
+
stats = self.process_performance_stats()
|
135 |
+
for k, v in stats.items():
|
136 |
+
logging.info(f'performance {k}: {v}')
|
137 |
+
|
138 |
+
self.logger.log_metrics(stats)
|
139 |
+
|
140 |
+
def on_epoch_end(self):
|
141 |
+
self.epoch += 1
|
142 |
+
|
143 |
+
def on_fit_end(self):
|
144 |
+
if self.epoch > self.warmup_epochs:
|
145 |
+
self._log_perf()
|
146 |
+
self.timestamps = []
|
147 |
+
|
148 |
+
def process_performance_stats(self):
|
149 |
+
timestamps = np.asarray(self.timestamps)
|
150 |
+
deltas = np.diff(timestamps)
|
151 |
+
throughput = (self.batch_size / deltas).mean()
|
152 |
+
stats = {
|
153 |
+
f"throughput_{self.mode}": throughput,
|
154 |
+
f"latency_{self.mode}_mean": deltas.mean(),
|
155 |
+
f"total_time_{self.mode}": timestamps[-1] - timestamps[0],
|
156 |
+
}
|
157 |
+
for level in [90, 95, 99]:
|
158 |
+
stats.update({f"latency_{self.mode}_{level}": np.percentile(deltas, level)})
|
159 |
+
|
160 |
+
return stats
|
se3_transformer/runtime/gpu_affinity.py
ADDED
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
4 |
+
# copy of this software and associated documentation files (the "Software"),
|
5 |
+
# to deal in the Software without restriction, including without limitation
|
6 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
8 |
+
# Software is furnished to do so, subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in
|
11 |
+
# all copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
16 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
18 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
19 |
+
# DEALINGS IN THE SOFTWARE.
|
20 |
+
#
|
21 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
22 |
+
# SPDX-License-Identifier: MIT
|
23 |
+
|
24 |
+
import collections
|
25 |
+
import itertools
|
26 |
+
import math
|
27 |
+
import os
|
28 |
+
import pathlib
|
29 |
+
import re
|
30 |
+
|
31 |
+
import pynvml
|
32 |
+
|
33 |
+
|
34 |
+
class Device:
|
35 |
+
# assumes nvml returns list of 64 bit ints
|
36 |
+
_nvml_affinity_elements = math.ceil(os.cpu_count() / 64)
|
37 |
+
|
38 |
+
def __init__(self, device_idx):
|
39 |
+
super().__init__()
|
40 |
+
self.handle = pynvml.nvmlDeviceGetHandleByIndex(device_idx)
|
41 |
+
|
42 |
+
def get_name(self):
|
43 |
+
return pynvml.nvmlDeviceGetName(self.handle)
|
44 |
+
|
45 |
+
def get_uuid(self):
|
46 |
+
return pynvml.nvmlDeviceGetUUID(self.handle)
|
47 |
+
|
48 |
+
def get_cpu_affinity(self):
|
49 |
+
affinity_string = ""
|
50 |
+
for j in pynvml.nvmlDeviceGetCpuAffinity(self.handle, Device._nvml_affinity_elements):
|
51 |
+
# assume nvml returns list of 64 bit ints
|
52 |
+
affinity_string = "{:064b}".format(j) + affinity_string
|
53 |
+
|
54 |
+
affinity_list = [int(x) for x in affinity_string]
|
55 |
+
affinity_list.reverse() # so core 0 is in 0th element of list
|
56 |
+
|
57 |
+
ret = [i for i, e in enumerate(affinity_list) if e != 0]
|
58 |
+
return ret
|
59 |
+
|
60 |
+
|
61 |
+
def get_thread_siblings_list():
|
62 |
+
"""
|
63 |
+
Returns a list of 2-element integer tuples representing pairs of
|
64 |
+
hyperthreading cores.
|
65 |
+
"""
|
66 |
+
path = "/sys/devices/system/cpu/cpu*/topology/thread_siblings_list"
|
67 |
+
thread_siblings_list = []
|
68 |
+
pattern = re.compile(r"(\d+)\D(\d+)")
|
69 |
+
for fname in pathlib.Path(path[0]).glob(path[1:]):
|
70 |
+
with open(fname) as f:
|
71 |
+
content = f.read().strip()
|
72 |
+
res = pattern.findall(content)
|
73 |
+
if res:
|
74 |
+
pair = tuple(map(int, res[0]))
|
75 |
+
thread_siblings_list.append(pair)
|
76 |
+
return thread_siblings_list
|
77 |
+
|
78 |
+
|
79 |
+
def check_socket_affinities(socket_affinities):
|
80 |
+
# sets of cores should be either identical or disjoint
|
81 |
+
for i, j in itertools.product(socket_affinities, socket_affinities):
|
82 |
+
if not set(i) == set(j) and not set(i).isdisjoint(set(j)):
|
83 |
+
raise RuntimeError(f"Sets of cores should be either identical or disjoint, " f"but got {i} and {j}.")
|
84 |
+
|
85 |
+
|
86 |
+
def get_socket_affinities(nproc_per_node, exclude_unavailable_cores=True):
|
87 |
+
devices = [Device(i) for i in range(nproc_per_node)]
|
88 |
+
socket_affinities = [dev.get_cpu_affinity() for dev in devices]
|
89 |
+
|
90 |
+
if exclude_unavailable_cores:
|
91 |
+
available_cores = os.sched_getaffinity(0)
|
92 |
+
socket_affinities = [list(set(affinity) & available_cores) for affinity in socket_affinities]
|
93 |
+
|
94 |
+
check_socket_affinities(socket_affinities)
|
95 |
+
|
96 |
+
return socket_affinities
|
97 |
+
|
98 |
+
|
99 |
+
def set_socket_affinity(gpu_id):
|
100 |
+
"""
|
101 |
+
The process is assigned with all available logical CPU cores from the CPU
|
102 |
+
socket connected to the GPU with a given id.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
gpu_id: index of a GPU
|
106 |
+
"""
|
107 |
+
dev = Device(gpu_id)
|
108 |
+
affinity = dev.get_cpu_affinity()
|
109 |
+
os.sched_setaffinity(0, affinity)
|
110 |
+
|
111 |
+
|
112 |
+
def set_single_affinity(gpu_id):
|
113 |
+
"""
|
114 |
+
The process is assigned with the first available logical CPU core from the
|
115 |
+
list of all CPU cores from the CPU socket connected to the GPU with a given
|
116 |
+
id.
|
117 |
+
|
118 |
+
Args:
|
119 |
+
gpu_id: index of a GPU
|
120 |
+
"""
|
121 |
+
dev = Device(gpu_id)
|
122 |
+
affinity = dev.get_cpu_affinity()
|
123 |
+
|
124 |
+
# exclude unavailable cores
|
125 |
+
available_cores = os.sched_getaffinity(0)
|
126 |
+
affinity = list(set(affinity) & available_cores)
|
127 |
+
os.sched_setaffinity(0, affinity[:1])
|
128 |
+
|
129 |
+
|
130 |
+
def set_single_unique_affinity(gpu_id, nproc_per_node):
|
131 |
+
"""
|
132 |
+
The process is assigned with a single unique available physical CPU core
|
133 |
+
from the list of all CPU cores from the CPU socket connected to the GPU with
|
134 |
+
a given id.
|
135 |
+
|
136 |
+
Args:
|
137 |
+
gpu_id: index of a GPU
|
138 |
+
"""
|
139 |
+
socket_affinities = get_socket_affinities(nproc_per_node)
|
140 |
+
|
141 |
+
siblings_list = get_thread_siblings_list()
|
142 |
+
siblings_dict = dict(siblings_list)
|
143 |
+
|
144 |
+
# remove siblings
|
145 |
+
for idx, socket_affinity in enumerate(socket_affinities):
|
146 |
+
socket_affinities[idx] = list(set(socket_affinity) - set(siblings_dict.values()))
|
147 |
+
|
148 |
+
affinities = []
|
149 |
+
assigned = []
|
150 |
+
|
151 |
+
for socket_affinity in socket_affinities:
|
152 |
+
for core in socket_affinity:
|
153 |
+
if core not in assigned:
|
154 |
+
affinities.append([core])
|
155 |
+
assigned.append(core)
|
156 |
+
break
|
157 |
+
os.sched_setaffinity(0, affinities[gpu_id])
|
158 |
+
|
159 |
+
|
160 |
+
def set_socket_unique_affinity(gpu_id, nproc_per_node, mode, balanced=True):
|
161 |
+
"""
|
162 |
+
The process is assigned with an unique subset of available physical CPU
|
163 |
+
cores from the CPU socket connected to a GPU with a given id.
|
164 |
+
Assignment automatically includes hyperthreading siblings (if siblings are
|
165 |
+
available).
|
166 |
+
|
167 |
+
Args:
|
168 |
+
gpu_id: index of a GPU
|
169 |
+
nproc_per_node: total number of processes per node
|
170 |
+
mode: mode
|
171 |
+
balanced: assign an equal number of physical cores to each process
|
172 |
+
"""
|
173 |
+
socket_affinities = get_socket_affinities(nproc_per_node)
|
174 |
+
|
175 |
+
siblings_list = get_thread_siblings_list()
|
176 |
+
siblings_dict = dict(siblings_list)
|
177 |
+
|
178 |
+
# remove hyperthreading siblings
|
179 |
+
for idx, socket_affinity in enumerate(socket_affinities):
|
180 |
+
socket_affinities[idx] = list(set(socket_affinity) - set(siblings_dict.values()))
|
181 |
+
|
182 |
+
socket_affinities_to_device_ids = collections.defaultdict(list)
|
183 |
+
|
184 |
+
for idx, socket_affinity in enumerate(socket_affinities):
|
185 |
+
socket_affinities_to_device_ids[tuple(socket_affinity)].append(idx)
|
186 |
+
|
187 |
+
# compute minimal number of physical cores per GPU across all GPUs and
|
188 |
+
# sockets, code assigns this number of cores per GPU if balanced == True
|
189 |
+
min_physical_cores_per_gpu = min(
|
190 |
+
[len(cores) // len(gpus) for cores, gpus in socket_affinities_to_device_ids.items()]
|
191 |
+
)
|
192 |
+
|
193 |
+
for socket_affinity, device_ids in socket_affinities_to_device_ids.items():
|
194 |
+
devices_per_group = len(device_ids)
|
195 |
+
if balanced:
|
196 |
+
cores_per_device = min_physical_cores_per_gpu
|
197 |
+
socket_affinity = socket_affinity[: devices_per_group * min_physical_cores_per_gpu]
|
198 |
+
else:
|
199 |
+
cores_per_device = len(socket_affinity) // devices_per_group
|
200 |
+
|
201 |
+
for group_id, device_id in enumerate(device_ids):
|
202 |
+
if device_id == gpu_id:
|
203 |
+
|
204 |
+
# In theory there should be no difference in performance between
|
205 |
+
# 'interleaved' and 'continuous' pattern on Intel-based DGX-1,
|
206 |
+
# but 'continuous' should be better for DGX A100 because on AMD
|
207 |
+
# Rome 4 consecutive cores are sharing L3 cache.
|
208 |
+
# TODO: code doesn't attempt to automatically detect layout of
|
209 |
+
# L3 cache, also external environment may already exclude some
|
210 |
+
# cores, this code makes no attempt to detect it and to align
|
211 |
+
# mapping to multiples of 4.
|
212 |
+
|
213 |
+
if mode == "interleaved":
|
214 |
+
affinity = list(socket_affinity[group_id::devices_per_group])
|
215 |
+
elif mode == "continuous":
|
216 |
+
affinity = list(socket_affinity[group_id * cores_per_device: (group_id + 1) * cores_per_device])
|
217 |
+
else:
|
218 |
+
raise RuntimeError("Unknown set_socket_unique_affinity mode")
|
219 |
+
|
220 |
+
# unconditionally reintroduce hyperthreading siblings, this step
|
221 |
+
# may result in a different numbers of logical cores assigned to
|
222 |
+
# each GPU even if balanced == True (if hyperthreading siblings
|
223 |
+
# aren't available for a subset of cores due to some external
|
224 |
+
# constraints, siblings are re-added unconditionally, in the
|
225 |
+
# worst case unavailable logical core will be ignored by
|
226 |
+
# os.sched_setaffinity().
|
227 |
+
affinity += [siblings_dict[aff] for aff in affinity if aff in siblings_dict]
|
228 |
+
os.sched_setaffinity(0, affinity)
|
229 |
+
|
230 |
+
|
231 |
+
def set_affinity(gpu_id, nproc_per_node, mode="socket_unique_continuous", balanced=True):
|
232 |
+
"""
|
233 |
+
The process is assigned with a proper CPU affinity which matches hardware
|
234 |
+
architecture on a given platform. Usually it improves and stabilizes
|
235 |
+
performance of deep learning training workloads.
|
236 |
+
|
237 |
+
This function assumes that the workload is running in multi-process
|
238 |
+
single-device mode (there are multiple training processes and each process
|
239 |
+
is running on a single GPU), which is typical for multi-GPU training
|
240 |
+
workloads using `torch.nn.parallel.DistributedDataParallel`.
|
241 |
+
|
242 |
+
Available affinity modes:
|
243 |
+
* 'socket' - the process is assigned with all available logical CPU cores
|
244 |
+
from the CPU socket connected to the GPU with a given id.
|
245 |
+
* 'single' - the process is assigned with the first available logical CPU
|
246 |
+
core from the list of all CPU cores from the CPU socket connected to the GPU
|
247 |
+
with a given id (multiple GPUs could be assigned with the same CPU core).
|
248 |
+
* 'single_unique' - the process is assigned with a single unique available
|
249 |
+
physical CPU core from the list of all CPU cores from the CPU socket
|
250 |
+
connected to the GPU with a given id.
|
251 |
+
* 'socket_unique_interleaved' - the process is assigned with an unique
|
252 |
+
subset of available physical CPU cores from the CPU socket connected to a
|
253 |
+
GPU with a given id, hyperthreading siblings are included automatically,
|
254 |
+
cores are assigned with interleaved indexing pattern
|
255 |
+
* 'socket_unique_continuous' - (the default) the process is assigned with an
|
256 |
+
unique subset of available physical CPU cores from the CPU socket connected
|
257 |
+
to a GPU with a given id, hyperthreading siblings are included
|
258 |
+
automatically, cores are assigned with continuous indexing pattern
|
259 |
+
|
260 |
+
'socket_unique_continuous' is the recommended mode for deep learning
|
261 |
+
training workloads on NVIDIA DGX machines.
|
262 |
+
|
263 |
+
Args:
|
264 |
+
gpu_id: integer index of a GPU
|
265 |
+
nproc_per_node: number of processes per node
|
266 |
+
mode: affinity mode
|
267 |
+
balanced: assign an equal number of physical cores to each process,
|
268 |
+
affects only 'socket_unique_interleaved' and
|
269 |
+
'socket_unique_continuous' affinity modes
|
270 |
+
|
271 |
+
Returns a set of logical CPU cores on which the process is eligible to run.
|
272 |
+
|
273 |
+
Example:
|
274 |
+
|
275 |
+
import argparse
|
276 |
+
import os
|
277 |
+
|
278 |
+
import gpu_affinity
|
279 |
+
import torch
|
280 |
+
|
281 |
+
|
282 |
+
def main():
|
283 |
+
parser = argparse.ArgumentParser()
|
284 |
+
parser.add_argument(
|
285 |
+
'--local_rank',
|
286 |
+
type=int,
|
287 |
+
default=os.getenv('LOCAL_RANK', 0),
|
288 |
+
)
|
289 |
+
args = parser.parse_args()
|
290 |
+
|
291 |
+
nproc_per_node = torch.cuda.device_count()
|
292 |
+
|
293 |
+
affinity = gpu_affinity.set_affinity(args.local_rank, nproc_per_node)
|
294 |
+
print(f'{args.local_rank}: core affinity: {affinity}')
|
295 |
+
|
296 |
+
|
297 |
+
if __name__ == "__main__":
|
298 |
+
main()
|
299 |
+
|
300 |
+
Launch the example with:
|
301 |
+
python -m torch.distributed.launch --nproc_per_node <#GPUs> example.py
|
302 |
+
|
303 |
+
|
304 |
+
WARNING: On DGX A100 only a half of CPU cores have direct access to GPUs.
|
305 |
+
This function restricts execution only to the CPU cores directly connected
|
306 |
+
to GPUs, so on DGX A100 it will limit the code to half of CPU cores and half
|
307 |
+
of CPU memory bandwidth (which may be fine for many DL models).
|
308 |
+
"""
|
309 |
+
pynvml.nvmlInit()
|
310 |
+
|
311 |
+
if mode == "socket":
|
312 |
+
set_socket_affinity(gpu_id)
|
313 |
+
elif mode == "single":
|
314 |
+
set_single_affinity(gpu_id)
|
315 |
+
elif mode == "single_unique":
|
316 |
+
set_single_unique_affinity(gpu_id, nproc_per_node)
|
317 |
+
elif mode == "socket_unique_interleaved":
|
318 |
+
set_socket_unique_affinity(gpu_id, nproc_per_node, "interleaved", balanced)
|
319 |
+
elif mode == "socket_unique_continuous":
|
320 |
+
set_socket_unique_affinity(gpu_id, nproc_per_node, "continuous", balanced)
|
321 |
+
else:
|
322 |
+
raise RuntimeError("Unknown affinity mode")
|
323 |
+
|
324 |
+
affinity = os.sched_getaffinity(0)
|
325 |
+
return affinity
|
se3_transformer/runtime/inference.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
4 |
+
# copy of this software and associated documentation files (the "Software"),
|
5 |
+
# to deal in the Software without restriction, including without limitation
|
6 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
8 |
+
# Software is furnished to do so, subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in
|
11 |
+
# all copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
16 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
18 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
19 |
+
# DEALINGS IN THE SOFTWARE.
|
20 |
+
#
|
21 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
22 |
+
# SPDX-License-Identifier: MIT
|
23 |
+
|
24 |
+
from typing import List
|
25 |
+
|
26 |
+
import torch
|
27 |
+
import torch.nn as nn
|
28 |
+
from torch.nn.parallel import DistributedDataParallel
|
29 |
+
from torch.utils.data import DataLoader
|
30 |
+
from tqdm import tqdm
|
31 |
+
|
32 |
+
from se3_transformer.runtime import gpu_affinity
|
33 |
+
from se3_transformer.runtime.arguments import PARSER
|
34 |
+
from se3_transformer.runtime.callbacks import BaseCallback
|
35 |
+
from se3_transformer.runtime.loggers import DLLogger
|
36 |
+
from se3_transformer.runtime.utils import to_cuda, get_local_rank
|
37 |
+
|
38 |
+
|
39 |
+
@torch.inference_mode()
|
40 |
+
def evaluate(model: nn.Module,
|
41 |
+
dataloader: DataLoader,
|
42 |
+
callbacks: List[BaseCallback],
|
43 |
+
args):
|
44 |
+
model.eval()
|
45 |
+
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), unit='batch', desc=f'Evaluation',
|
46 |
+
leave=False, disable=(args.silent or get_local_rank() != 0)):
|
47 |
+
*input, target = to_cuda(batch)
|
48 |
+
|
49 |
+
for callback in callbacks:
|
50 |
+
callback.on_batch_start()
|
51 |
+
|
52 |
+
with torch.cuda.amp.autocast(enabled=args.amp):
|
53 |
+
pred = model(*input)
|
54 |
+
|
55 |
+
for callback in callbacks:
|
56 |
+
callback.on_validation_step(input, target, pred)
|
57 |
+
|
58 |
+
|
59 |
+
if __name__ == '__main__':
|
60 |
+
from se3_transformer.runtime.callbacks import QM9MetricCallback, PerformanceCallback
|
61 |
+
from se3_transformer.runtime.utils import init_distributed, seed_everything
|
62 |
+
from se3_transformer.model import SE3TransformerPooled, Fiber
|
63 |
+
from se3_transformer.data_loading import QM9DataModule
|
64 |
+
import torch.distributed as dist
|
65 |
+
import logging
|
66 |
+
import sys
|
67 |
+
|
68 |
+
is_distributed = init_distributed()
|
69 |
+
local_rank = get_local_rank()
|
70 |
+
args = PARSER.parse_args()
|
71 |
+
|
72 |
+
logging.getLogger().setLevel(logging.CRITICAL if local_rank != 0 or args.silent else logging.INFO)
|
73 |
+
|
74 |
+
logging.info('====== SE(3)-Transformer ======')
|
75 |
+
logging.info('| Inference on the test set |')
|
76 |
+
logging.info('===============================')
|
77 |
+
|
78 |
+
if not args.benchmark and args.load_ckpt_path is None:
|
79 |
+
logging.error('No load_ckpt_path provided, you need to provide a saved model to evaluate')
|
80 |
+
sys.exit(1)
|
81 |
+
|
82 |
+
if args.benchmark:
|
83 |
+
logging.info('Running benchmark mode with one warmup pass')
|
84 |
+
|
85 |
+
if args.seed is not None:
|
86 |
+
seed_everything(args.seed)
|
87 |
+
|
88 |
+
major_cc, minor_cc = torch.cuda.get_device_capability()
|
89 |
+
|
90 |
+
logger = DLLogger(args.log_dir, filename=args.dllogger_name)
|
91 |
+
datamodule = QM9DataModule(**vars(args))
|
92 |
+
model = SE3TransformerPooled(
|
93 |
+
fiber_in=Fiber({0: datamodule.NODE_FEATURE_DIM}),
|
94 |
+
fiber_out=Fiber({0: args.num_degrees * args.num_channels}),
|
95 |
+
fiber_edge=Fiber({0: datamodule.EDGE_FEATURE_DIM}),
|
96 |
+
output_dim=1,
|
97 |
+
tensor_cores=(args.amp and major_cc >= 7) or major_cc >= 8, # use Tensor Cores more effectively
|
98 |
+
**vars(args)
|
99 |
+
)
|
100 |
+
callbacks = [QM9MetricCallback(logger, targets_std=datamodule.targets_std, prefix='test')]
|
101 |
+
|
102 |
+
model.to(device=torch.cuda.current_device())
|
103 |
+
if args.load_ckpt_path is not None:
|
104 |
+
checkpoint = torch.load(str(args.load_ckpt_path), map_location={'cuda:0': f'cuda:{local_rank}'})
|
105 |
+
model.load_state_dict(checkpoint['state_dict'])
|
106 |
+
|
107 |
+
if is_distributed:
|
108 |
+
nproc_per_node = torch.cuda.device_count()
|
109 |
+
affinity = gpu_affinity.set_affinity(local_rank, nproc_per_node)
|
110 |
+
model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)
|
111 |
+
|
112 |
+
test_dataloader = datamodule.test_dataloader() if not args.benchmark else datamodule.train_dataloader()
|
113 |
+
evaluate(model,
|
114 |
+
test_dataloader,
|
115 |
+
callbacks,
|
116 |
+
args)
|
117 |
+
|
118 |
+
for callback in callbacks:
|
119 |
+
callback.on_validation_end()
|
120 |
+
|
121 |
+
if args.benchmark:
|
122 |
+
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
123 |
+
callbacks = [PerformanceCallback(logger, args.batch_size * world_size, warmup_epochs=1, mode='inference')]
|
124 |
+
for _ in range(6):
|
125 |
+
evaluate(model,
|
126 |
+
test_dataloader,
|
127 |
+
callbacks,
|
128 |
+
args)
|
129 |
+
callbacks[0].on_epoch_end()
|
130 |
+
|
131 |
+
callbacks[0].on_fit_end()
|
se3_transformer/runtime/loggers.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
4 |
+
# copy of this software and associated documentation files (the "Software"),
|
5 |
+
# to deal in the Software without restriction, including without limitation
|
6 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
8 |
+
# Software is furnished to do so, subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in
|
11 |
+
# all copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
16 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
18 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
19 |
+
# DEALINGS IN THE SOFTWARE.
|
20 |
+
#
|
21 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
22 |
+
# SPDX-License-Identifier: MIT
|
23 |
+
|
24 |
+
import pathlib
|
25 |
+
from abc import ABC, abstractmethod
|
26 |
+
from enum import Enum
|
27 |
+
from typing import Dict, Any, Callable, Optional
|
28 |
+
|
29 |
+
import dllogger
|
30 |
+
import torch.distributed as dist
|
31 |
+
import wandb
|
32 |
+
from dllogger import Verbosity
|
33 |
+
|
34 |
+
from se3_transformer.runtime.utils import rank_zero_only
|
35 |
+
|
36 |
+
|
37 |
+
class Logger(ABC):
|
38 |
+
@rank_zero_only
|
39 |
+
@abstractmethod
|
40 |
+
def log_hyperparams(self, params):
|
41 |
+
pass
|
42 |
+
|
43 |
+
@rank_zero_only
|
44 |
+
@abstractmethod
|
45 |
+
def log_metrics(self, metrics, step=None):
|
46 |
+
pass
|
47 |
+
|
48 |
+
@staticmethod
|
49 |
+
def _sanitize_params(params):
|
50 |
+
def _sanitize(val):
|
51 |
+
if isinstance(val, Callable):
|
52 |
+
try:
|
53 |
+
_val = val()
|
54 |
+
if isinstance(_val, Callable):
|
55 |
+
return val.__name__
|
56 |
+
return _val
|
57 |
+
except Exception:
|
58 |
+
return getattr(val, "__name__", None)
|
59 |
+
elif isinstance(val, pathlib.Path) or isinstance(val, Enum):
|
60 |
+
return str(val)
|
61 |
+
return val
|
62 |
+
|
63 |
+
return {key: _sanitize(val) for key, val in params.items()}
|
64 |
+
|
65 |
+
|
66 |
+
class LoggerCollection(Logger):
|
67 |
+
def __init__(self, loggers):
|
68 |
+
super().__init__()
|
69 |
+
self.loggers = loggers
|
70 |
+
|
71 |
+
def __getitem__(self, index):
|
72 |
+
return [logger for logger in self.loggers][index]
|
73 |
+
|
74 |
+
@rank_zero_only
|
75 |
+
def log_metrics(self, metrics, step=None):
|
76 |
+
for logger in self.loggers:
|
77 |
+
logger.log_metrics(metrics, step)
|
78 |
+
|
79 |
+
@rank_zero_only
|
80 |
+
def log_hyperparams(self, params):
|
81 |
+
for logger in self.loggers:
|
82 |
+
logger.log_hyperparams(params)
|
83 |
+
|
84 |
+
|
85 |
+
class DLLogger(Logger):
|
86 |
+
def __init__(self, save_dir: pathlib.Path, filename: str):
|
87 |
+
super().__init__()
|
88 |
+
if not dist.is_initialized() or dist.get_rank() == 0:
|
89 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
90 |
+
dllogger.init(
|
91 |
+
backends=[dllogger.JSONStreamBackend(Verbosity.DEFAULT, str(save_dir / filename))])
|
92 |
+
|
93 |
+
@rank_zero_only
|
94 |
+
def log_hyperparams(self, params):
|
95 |
+
params = self._sanitize_params(params)
|
96 |
+
dllogger.log(step="PARAMETER", data=params)
|
97 |
+
|
98 |
+
@rank_zero_only
|
99 |
+
def log_metrics(self, metrics, step=None):
|
100 |
+
if step is None:
|
101 |
+
step = tuple()
|
102 |
+
|
103 |
+
dllogger.log(step=step, data=metrics)
|
104 |
+
|
105 |
+
|
106 |
+
class WandbLogger(Logger):
|
107 |
+
def __init__(
|
108 |
+
self,
|
109 |
+
name: str,
|
110 |
+
save_dir: pathlib.Path,
|
111 |
+
id: Optional[str] = None,
|
112 |
+
project: Optional[str] = None
|
113 |
+
):
|
114 |
+
super().__init__()
|
115 |
+
if not dist.is_initialized() or dist.get_rank() == 0:
|
116 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
117 |
+
self.experiment = wandb.init(name=name,
|
118 |
+
project=project,
|
119 |
+
id=id,
|
120 |
+
dir=str(save_dir),
|
121 |
+
resume='allow',
|
122 |
+
anonymous='must')
|
123 |
+
|
124 |
+
@rank_zero_only
|
125 |
+
def log_hyperparams(self, params: Dict[str, Any]) -> None:
|
126 |
+
params = self._sanitize_params(params)
|
127 |
+
self.experiment.config.update(params, allow_val_change=True)
|
128 |
+
|
129 |
+
@rank_zero_only
|
130 |
+
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
|
131 |
+
if step is not None:
|
132 |
+
self.experiment.log({**metrics, 'epoch': step})
|
133 |
+
else:
|
134 |
+
self.experiment.log(metrics)
|
se3_transformer/runtime/metrics.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
4 |
+
# copy of this software and associated documentation files (the "Software"),
|
5 |
+
# to deal in the Software without restriction, including without limitation
|
6 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
8 |
+
# Software is furnished to do so, subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in
|
11 |
+
# all copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
16 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
18 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
19 |
+
# DEALINGS IN THE SOFTWARE.
|
20 |
+
#
|
21 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
22 |
+
# SPDX-License-Identifier: MIT
|
23 |
+
|
24 |
+
from abc import ABC, abstractmethod
|
25 |
+
|
26 |
+
import torch
|
27 |
+
import torch.distributed as dist
|
28 |
+
from torch import Tensor
|
29 |
+
|
30 |
+
|
31 |
+
class Metric(ABC):
|
32 |
+
""" Metric class with synchronization capabilities similar to TorchMetrics """
|
33 |
+
|
34 |
+
def __init__(self):
|
35 |
+
self.states = {}
|
36 |
+
|
37 |
+
def add_state(self, name: str, default: Tensor):
|
38 |
+
assert name not in self.states
|
39 |
+
self.states[name] = default.clone()
|
40 |
+
setattr(self, name, default)
|
41 |
+
|
42 |
+
def synchronize(self):
|
43 |
+
if dist.is_initialized():
|
44 |
+
for state in self.states:
|
45 |
+
dist.all_reduce(getattr(self, state), op=dist.ReduceOp.SUM, group=dist.group.WORLD)
|
46 |
+
|
47 |
+
def __call__(self, *args, **kwargs):
|
48 |
+
self.update(*args, **kwargs)
|
49 |
+
|
50 |
+
def reset(self):
|
51 |
+
for name, default in self.states.items():
|
52 |
+
setattr(self, name, default.clone())
|
53 |
+
|
54 |
+
def compute(self):
|
55 |
+
self.synchronize()
|
56 |
+
value = self._compute().item()
|
57 |
+
self.reset()
|
58 |
+
return value
|
59 |
+
|
60 |
+
@abstractmethod
|
61 |
+
def _compute(self):
|
62 |
+
pass
|
63 |
+
|
64 |
+
@abstractmethod
|
65 |
+
def update(self, preds: Tensor, targets: Tensor):
|
66 |
+
pass
|
67 |
+
|
68 |
+
|
69 |
+
class MeanAbsoluteError(Metric):
|
70 |
+
def __init__(self):
|
71 |
+
super().__init__()
|
72 |
+
self.add_state('error', torch.tensor(0, dtype=torch.float32, device='cuda'))
|
73 |
+
self.add_state('total', torch.tensor(0, dtype=torch.int32, device='cuda'))
|
74 |
+
|
75 |
+
def update(self, preds: Tensor, targets: Tensor):
|
76 |
+
preds = preds.detach()
|
77 |
+
n = preds.shape[0]
|
78 |
+
error = torch.abs(preds.view(n, -1) - targets.view(n, -1)).sum()
|
79 |
+
self.total += n
|
80 |
+
self.error += error
|
81 |
+
|
82 |
+
def _compute(self):
|
83 |
+
return self.error / self.total
|
se3_transformer/runtime/training.py
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
4 |
+
# copy of this software and associated documentation files (the "Software"),
|
5 |
+
# to deal in the Software without restriction, including without limitation
|
6 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
8 |
+
# Software is furnished to do so, subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in
|
11 |
+
# all copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
16 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
18 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
19 |
+
# DEALINGS IN THE SOFTWARE.
|
20 |
+
#
|
21 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
22 |
+
# SPDX-License-Identifier: MIT
|
23 |
+
|
24 |
+
import logging
|
25 |
+
import pathlib
|
26 |
+
from typing import List
|
27 |
+
|
28 |
+
import numpy as np
|
29 |
+
import torch
|
30 |
+
import torch.distributed as dist
|
31 |
+
import torch.nn as nn
|
32 |
+
from apex.optimizers import FusedAdam, FusedLAMB
|
33 |
+
from torch.nn.modules.loss import _Loss
|
34 |
+
from torch.nn.parallel import DistributedDataParallel
|
35 |
+
from torch.optim import Optimizer
|
36 |
+
from torch.utils.data import DataLoader, DistributedSampler
|
37 |
+
from tqdm import tqdm
|
38 |
+
|
39 |
+
from se3_transformer.data_loading import QM9DataModule
|
40 |
+
from se3_transformer.model import SE3TransformerPooled
|
41 |
+
from se3_transformer.model.fiber import Fiber
|
42 |
+
from se3_transformer.runtime import gpu_affinity
|
43 |
+
from se3_transformer.runtime.arguments import PARSER
|
44 |
+
from se3_transformer.runtime.callbacks import QM9MetricCallback, QM9LRSchedulerCallback, BaseCallback, \
|
45 |
+
PerformanceCallback
|
46 |
+
from se3_transformer.runtime.inference import evaluate
|
47 |
+
from se3_transformer.runtime.loggers import LoggerCollection, DLLogger, WandbLogger, Logger
|
48 |
+
from se3_transformer.runtime.utils import to_cuda, get_local_rank, init_distributed, seed_everything, \
|
49 |
+
using_tensor_cores, increase_l2_fetch_granularity
|
50 |
+
|
51 |
+
|
52 |
+
def save_state(model: nn.Module, optimizer: Optimizer, epoch: int, path: pathlib.Path, callbacks: List[BaseCallback]):
|
53 |
+
""" Saves model, optimizer and epoch states to path (only once per node) """
|
54 |
+
if get_local_rank() == 0:
|
55 |
+
state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict()
|
56 |
+
checkpoint = {
|
57 |
+
'state_dict': state_dict,
|
58 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
59 |
+
'epoch': epoch
|
60 |
+
}
|
61 |
+
for callback in callbacks:
|
62 |
+
callback.on_checkpoint_save(checkpoint)
|
63 |
+
|
64 |
+
torch.save(checkpoint, str(path))
|
65 |
+
logging.info(f'Saved checkpoint to {str(path)}')
|
66 |
+
|
67 |
+
|
68 |
+
def load_state(model: nn.Module, optimizer: Optimizer, path: pathlib.Path, callbacks: List[BaseCallback]):
|
69 |
+
""" Loads model, optimizer and epoch states from path """
|
70 |
+
checkpoint = torch.load(str(path), map_location={'cuda:0': f'cuda:{get_local_rank()}'})
|
71 |
+
if isinstance(model, DistributedDataParallel):
|
72 |
+
model.module.load_state_dict(checkpoint['state_dict'])
|
73 |
+
else:
|
74 |
+
model.load_state_dict(checkpoint['state_dict'])
|
75 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
76 |
+
|
77 |
+
for callback in callbacks:
|
78 |
+
callback.on_checkpoint_load(checkpoint)
|
79 |
+
|
80 |
+
logging.info(f'Loaded checkpoint from {str(path)}')
|
81 |
+
return checkpoint['epoch']
|
82 |
+
|
83 |
+
|
84 |
+
def train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args):
|
85 |
+
losses = []
|
86 |
+
for i, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader), unit='batch',
|
87 |
+
desc=f'Epoch {epoch_idx}', disable=(args.silent or local_rank != 0)):
|
88 |
+
*inputs, target = to_cuda(batch)
|
89 |
+
|
90 |
+
for callback in callbacks:
|
91 |
+
callback.on_batch_start()
|
92 |
+
|
93 |
+
with torch.cuda.amp.autocast(enabled=args.amp):
|
94 |
+
pred = model(*inputs)
|
95 |
+
loss = loss_fn(pred, target) / args.accumulate_grad_batches
|
96 |
+
|
97 |
+
grad_scaler.scale(loss).backward()
|
98 |
+
|
99 |
+
# gradient accumulation
|
100 |
+
if (i + 1) % args.accumulate_grad_batches == 0 or (i + 1) == len(train_dataloader):
|
101 |
+
if args.gradient_clip:
|
102 |
+
grad_scaler.unscale_(optimizer)
|
103 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), args.gradient_clip)
|
104 |
+
|
105 |
+
grad_scaler.step(optimizer)
|
106 |
+
grad_scaler.update()
|
107 |
+
optimizer.zero_grad()
|
108 |
+
|
109 |
+
losses.append(loss.item())
|
110 |
+
|
111 |
+
return np.mean(losses)
|
112 |
+
|
113 |
+
|
114 |
+
def train(model: nn.Module,
|
115 |
+
loss_fn: _Loss,
|
116 |
+
train_dataloader: DataLoader,
|
117 |
+
val_dataloader: DataLoader,
|
118 |
+
callbacks: List[BaseCallback],
|
119 |
+
logger: Logger,
|
120 |
+
args):
|
121 |
+
device = torch.cuda.current_device()
|
122 |
+
model.to(device=device)
|
123 |
+
local_rank = get_local_rank()
|
124 |
+
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
125 |
+
|
126 |
+
if dist.is_initialized():
|
127 |
+
model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)
|
128 |
+
|
129 |
+
model.train()
|
130 |
+
grad_scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
|
131 |
+
if args.optimizer == 'adam':
|
132 |
+
optimizer = FusedAdam(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999),
|
133 |
+
weight_decay=args.weight_decay)
|
134 |
+
elif args.optimizer == 'lamb':
|
135 |
+
optimizer = FusedLAMB(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999),
|
136 |
+
weight_decay=args.weight_decay)
|
137 |
+
else:
|
138 |
+
optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum,
|
139 |
+
weight_decay=args.weight_decay)
|
140 |
+
|
141 |
+
epoch_start = load_state(model, optimizer, args.load_ckpt_path, callbacks) if args.load_ckpt_path else 0
|
142 |
+
|
143 |
+
for callback in callbacks:
|
144 |
+
callback.on_fit_start(optimizer, args)
|
145 |
+
|
146 |
+
for epoch_idx in range(epoch_start, args.epochs):
|
147 |
+
if isinstance(train_dataloader.sampler, DistributedSampler):
|
148 |
+
train_dataloader.sampler.set_epoch(epoch_idx)
|
149 |
+
|
150 |
+
loss = train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args)
|
151 |
+
if dist.is_initialized():
|
152 |
+
loss = torch.tensor(loss, dtype=torch.float, device=device)
|
153 |
+
torch.distributed.all_reduce(loss)
|
154 |
+
loss = (loss / world_size).item()
|
155 |
+
|
156 |
+
logging.info(f'Train loss: {loss}')
|
157 |
+
logger.log_metrics({'train loss': loss}, epoch_idx)
|
158 |
+
|
159 |
+
for callback in callbacks:
|
160 |
+
callback.on_epoch_end()
|
161 |
+
|
162 |
+
if not args.benchmark and args.save_ckpt_path is not None and args.ckpt_interval > 0 \
|
163 |
+
and (epoch_idx + 1) % args.ckpt_interval == 0:
|
164 |
+
save_state(model, optimizer, epoch_idx, args.save_ckpt_path, callbacks)
|
165 |
+
|
166 |
+
if not args.benchmark and args.eval_interval > 0 and (epoch_idx + 1) % args.eval_interval == 0:
|
167 |
+
evaluate(model, val_dataloader, callbacks, args)
|
168 |
+
model.train()
|
169 |
+
|
170 |
+
for callback in callbacks:
|
171 |
+
callback.on_validation_end(epoch_idx)
|
172 |
+
|
173 |
+
if args.save_ckpt_path is not None and not args.benchmark:
|
174 |
+
save_state(model, optimizer, args.epochs, args.save_ckpt_path, callbacks)
|
175 |
+
|
176 |
+
for callback in callbacks:
|
177 |
+
callback.on_fit_end()
|
178 |
+
|
179 |
+
|
180 |
+
def print_parameters_count(model):
|
181 |
+
num_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
182 |
+
logging.info(f'Number of trainable parameters: {num_params_trainable}')
|
183 |
+
|
184 |
+
|
185 |
+
if __name__ == '__main__':
|
186 |
+
is_distributed = init_distributed()
|
187 |
+
local_rank = get_local_rank()
|
188 |
+
args = PARSER.parse_args()
|
189 |
+
|
190 |
+
logging.getLogger().setLevel(logging.CRITICAL if local_rank != 0 or args.silent else logging.INFO)
|
191 |
+
|
192 |
+
logging.info('====== SE(3)-Transformer ======')
|
193 |
+
logging.info('| Training procedure |')
|
194 |
+
logging.info('===============================')
|
195 |
+
|
196 |
+
if args.seed is not None:
|
197 |
+
logging.info(f'Using seed {args.seed}')
|
198 |
+
seed_everything(args.seed)
|
199 |
+
|
200 |
+
logger = LoggerCollection([
|
201 |
+
DLLogger(save_dir=args.log_dir, filename=args.dllogger_name),
|
202 |
+
WandbLogger(name=f'QM9({args.task})', save_dir=args.log_dir, project='se3-transformer')
|
203 |
+
])
|
204 |
+
|
205 |
+
datamodule = QM9DataModule(**vars(args))
|
206 |
+
model = SE3TransformerPooled(
|
207 |
+
fiber_in=Fiber({0: datamodule.NODE_FEATURE_DIM}),
|
208 |
+
fiber_out=Fiber({0: args.num_degrees * args.num_channels}),
|
209 |
+
fiber_edge=Fiber({0: datamodule.EDGE_FEATURE_DIM}),
|
210 |
+
output_dim=1,
|
211 |
+
tensor_cores=using_tensor_cores(args.amp), # use Tensor Cores more effectively
|
212 |
+
**vars(args)
|
213 |
+
)
|
214 |
+
loss_fn = nn.L1Loss()
|
215 |
+
|
216 |
+
if args.benchmark:
|
217 |
+
logging.info('Running benchmark mode')
|
218 |
+
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
219 |
+
callbacks = [PerformanceCallback(logger, args.batch_size * world_size)]
|
220 |
+
else:
|
221 |
+
callbacks = [QM9MetricCallback(logger, targets_std=datamodule.targets_std, prefix='validation'),
|
222 |
+
QM9LRSchedulerCallback(logger, epochs=args.epochs)]
|
223 |
+
|
224 |
+
if is_distributed:
|
225 |
+
gpu_affinity.set_affinity(gpu_id=get_local_rank(), nproc_per_node=torch.cuda.device_count())
|
226 |
+
|
227 |
+
print_parameters_count(model)
|
228 |
+
logger.log_hyperparams(vars(args))
|
229 |
+
increase_l2_fetch_granularity()
|
230 |
+
train(model,
|
231 |
+
loss_fn,
|
232 |
+
datamodule.train_dataloader(),
|
233 |
+
datamodule.val_dataloader(),
|
234 |
+
callbacks,
|
235 |
+
logger,
|
236 |
+
args)
|
237 |
+
|
238 |
+
logging.info('Training finished successfully')
|
se3_transformer/runtime/utils.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# Permission is hereby granted, free of charge, to any person obtaining a
|
4 |
+
# copy of this software and associated documentation files (the "Software"),
|
5 |
+
# to deal in the Software without restriction, including without limitation
|
6 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
7 |
+
# and/or sell copies of the Software, and to permit persons to whom the
|
8 |
+
# Software is furnished to do so, subject to the following conditions:
|
9 |
+
#
|
10 |
+
# The above copyright notice and this permission notice shall be included in
|
11 |
+
# all copies or substantial portions of the Software.
|
12 |
+
#
|
13 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
16 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
18 |
+
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
19 |
+
# DEALINGS IN THE SOFTWARE.
|
20 |
+
#
|
21 |
+
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
22 |
+
# SPDX-License-Identifier: MIT
|
23 |
+
|
24 |
+
import argparse
|
25 |
+
import ctypes
|
26 |
+
import logging
|
27 |
+
import os
|
28 |
+
import random
|
29 |
+
from functools import wraps
|
30 |
+
from typing import Union, List, Dict
|
31 |
+
|
32 |
+
import numpy as np
|
33 |
+
import torch
|
34 |
+
import torch.distributed as dist
|
35 |
+
from torch import Tensor
|
36 |
+
|
37 |
+
|
38 |
+
def aggregate_residual(feats1, feats2, method: str):
|
39 |
+
""" Add or concatenate two fiber features together. If degrees don't match, will use the ones of feats2. """
|
40 |
+
if method in ['add', 'sum']:
|
41 |
+
return {k: (v + feats1[k]) if k in feats1 else v for k, v in feats2.items()}
|
42 |
+
elif method in ['cat', 'concat']:
|
43 |
+
return {k: torch.cat([v, feats1[k]], dim=1) if k in feats1 else v for k, v in feats2.items()}
|
44 |
+
else:
|
45 |
+
raise ValueError('Method must be add/sum or cat/concat')
|
46 |
+
|
47 |
+
|
48 |
+
def degree_to_dim(degree: int) -> int:
|
49 |
+
return 2 * degree + 1
|
50 |
+
|
51 |
+
|
52 |
+
def unfuse_features(features: Tensor, degrees: List[int]) -> Dict[str, Tensor]:
|
53 |
+
return dict(zip(map(str, degrees), features.split([degree_to_dim(deg) for deg in degrees], dim=-1)))
|
54 |
+
|
55 |
+
|
56 |
+
def str2bool(v: Union[bool, str]) -> bool:
|
57 |
+
if isinstance(v, bool):
|
58 |
+
return v
|
59 |
+
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
60 |
+
return True
|
61 |
+
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
62 |
+
return False
|
63 |
+
else:
|
64 |
+
raise argparse.ArgumentTypeError('Boolean value expected.')
|
65 |
+
|
66 |
+
|
67 |
+
def to_cuda(x):
|
68 |
+
""" Try to convert a Tensor, a collection of Tensors or a DGLGraph to CUDA """
|
69 |
+
if isinstance(x, Tensor):
|
70 |
+
return x.cuda(non_blocking=True)
|
71 |
+
elif isinstance(x, tuple):
|
72 |
+
return (to_cuda(v) for v in x)
|
73 |
+
elif isinstance(x, list):
|
74 |
+
return [to_cuda(v) for v in x]
|
75 |
+
elif isinstance(x, dict):
|
76 |
+
return {k: to_cuda(v) for k, v in x.items()}
|
77 |
+
else:
|
78 |
+
# DGLGraph or other objects
|
79 |
+
return x.to(device=torch.cuda.current_device())
|
80 |
+
|
81 |
+
|
82 |
+
def get_local_rank() -> int:
|
83 |
+
return int(os.environ.get('LOCAL_RANK', 0))
|
84 |
+
|
85 |
+
|
86 |
+
def init_distributed() -> bool:
|
87 |
+
world_size = int(os.environ.get('WORLD_SIZE', 1))
|
88 |
+
distributed = world_size > 1
|
89 |
+
if distributed:
|
90 |
+
backend = 'nccl' if torch.cuda.is_available() else 'gloo'
|
91 |
+
dist.init_process_group(backend=backend, init_method='env://')
|
92 |
+
if backend == 'nccl':
|
93 |
+
torch.cuda.set_device(get_local_rank())
|
94 |
+
else:
|
95 |
+
logging.warning('Running on CPU only!')
|
96 |
+
assert torch.distributed.is_initialized()
|
97 |
+
return distributed
|
98 |
+
|
99 |
+
|
100 |
+
def increase_l2_fetch_granularity():
|
101 |
+
# maximum fetch granularity of L2: 128 bytes
|
102 |
+
_libcudart = ctypes.CDLL('libcudart.so')
|
103 |
+
# set device limit on the current device
|
104 |
+
# cudaLimitMaxL2FetchGranularity = 0x05
|
105 |
+
pValue = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int))
|
106 |
+
_libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128))
|
107 |
+
_libcudart.cudaDeviceGetLimit(pValue, ctypes.c_int(0x05))
|
108 |
+
assert pValue.contents.value == 128
|
109 |
+
|
110 |
+
|
111 |
+
def seed_everything(seed):
|
112 |
+
seed = int(seed)
|
113 |
+
random.seed(seed)
|
114 |
+
np.random.seed(seed)
|
115 |
+
torch.manual_seed(seed)
|
116 |
+
torch.cuda.manual_seed_all(seed)
|
117 |
+
|
118 |
+
|
119 |
+
def rank_zero_only(fn):
|
120 |
+
@wraps(fn)
|
121 |
+
def wrapped_fn(*args, **kwargs):
|
122 |
+
if not dist.is_initialized() or dist.get_rank() == 0:
|
123 |
+
return fn(*args, **kwargs)
|
124 |
+
|
125 |
+
return wrapped_fn
|
126 |
+
|
127 |
+
|
128 |
+
def using_tensor_cores(amp: bool) -> bool:
|
129 |
+
major_cc, minor_cc = torch.cuda.get_device_capability()
|
130 |
+
return (amp and major_cc >= 7) or major_cc >= 8
|