GlandVergil commited on
Commit
a507bdb
1 Parent(s): a93e585

Upload 23 files

Browse files
se3_transformer/__init__.py ADDED
File without changes
se3_transformer/data_loading/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .qm9 import QM9DataModule
se3_transformer/data_loading/data_module.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a
4
+ # copy of this software and associated documentation files (the "Software"),
5
+ # to deal in the Software without restriction, including without limitation
6
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense,
7
+ # and/or sell copies of the Software, and to permit persons to whom the
8
+ # Software is furnished to do so, subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in
11
+ # all copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
16
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18
+ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19
+ # DEALINGS IN THE SOFTWARE.
20
+ #
21
+ # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
22
+ # SPDX-License-Identifier: MIT
23
+
24
+ import torch.distributed as dist
25
+ from abc import ABC
26
+ from torch.utils.data import DataLoader, DistributedSampler, Dataset
27
+
28
+ from se3_transformer.runtime.utils import get_local_rank
29
+
30
+
31
+ def _get_dataloader(dataset: Dataset, shuffle: bool, **kwargs) -> DataLoader:
32
+ # Classic or distributed dataloader depending on the context
33
+ sampler = DistributedSampler(dataset, shuffle=shuffle) if dist.is_initialized() else None
34
+ return DataLoader(dataset, shuffle=(shuffle and sampler is None), sampler=sampler, **kwargs)
35
+
36
+
37
+ class DataModule(ABC):
38
+ """ Abstract DataModule. Children must define self.ds_{train | val | test}. """
39
+
40
+ def __init__(self, **dataloader_kwargs):
41
+ super().__init__()
42
+ if get_local_rank() == 0:
43
+ self.prepare_data()
44
+
45
+ # Wait until rank zero has prepared the data (download, preprocessing, ...)
46
+ if dist.is_initialized():
47
+ dist.barrier(device_ids=[get_local_rank()])
48
+
49
+ self.dataloader_kwargs = {'pin_memory': True, 'persistent_workers': True, **dataloader_kwargs}
50
+ self.ds_train, self.ds_val, self.ds_test = None, None, None
51
+
52
+ def prepare_data(self):
53
+ """ Method called only once per node. Put here any downloading or preprocessing """
54
+ pass
55
+
56
+ def train_dataloader(self) -> DataLoader:
57
+ return _get_dataloader(self.ds_train, shuffle=True, **self.dataloader_kwargs)
58
+
59
+ def val_dataloader(self) -> DataLoader:
60
+ return _get_dataloader(self.ds_val, shuffle=False, **self.dataloader_kwargs)
61
+
62
+ def test_dataloader(self) -> DataLoader:
63
+ return _get_dataloader(self.ds_test, shuffle=False, **self.dataloader_kwargs)
se3_transformer/data_loading/qm9.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a
4
+ # copy of this software and associated documentation files (the "Software"),
5
+ # to deal in the Software without restriction, including without limitation
6
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense,
7
+ # and/or sell copies of the Software, and to permit persons to whom the
8
+ # Software is furnished to do so, subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in
11
+ # all copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
16
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18
+ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19
+ # DEALINGS IN THE SOFTWARE.
20
+ #
21
+ # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
22
+ # SPDX-License-Identifier: MIT
23
+ from typing import Tuple
24
+
25
+ import dgl
26
+ import pathlib
27
+ import torch
28
+ from dgl.data import QM9EdgeDataset
29
+ from dgl import DGLGraph
30
+ from torch import Tensor
31
+ from torch.utils.data import random_split, DataLoader, Dataset
32
+ from tqdm import tqdm
33
+
34
+ from se3_transformer.data_loading.data_module import DataModule
35
+ from se3_transformer.model.basis import get_basis
36
+ from se3_transformer.runtime.utils import get_local_rank, str2bool, using_tensor_cores
37
+
38
+
39
+ def _get_relative_pos(qm9_graph: DGLGraph) -> Tensor:
40
+ x = qm9_graph.ndata['pos']
41
+ src, dst = qm9_graph.edges()
42
+ rel_pos = x[dst] - x[src]
43
+ return rel_pos
44
+
45
+
46
+ def _get_split_sizes(full_dataset: Dataset) -> Tuple[int, int, int]:
47
+ len_full = len(full_dataset)
48
+ len_train = 100_000
49
+ len_test = int(0.1 * len_full)
50
+ len_val = len_full - len_train - len_test
51
+ return len_train, len_val, len_test
52
+
53
+
54
+ class QM9DataModule(DataModule):
55
+ """
56
+ Datamodule wrapping https://docs.dgl.ai/en/latest/api/python/dgl.data.html#qm9edge-dataset
57
+ Training set is 100k molecules. Test set is 10% of the dataset. Validation set is the rest.
58
+ This includes all the molecules from QM9 except the ones that are uncharacterized.
59
+ """
60
+
61
+ NODE_FEATURE_DIM = 6
62
+ EDGE_FEATURE_DIM = 4
63
+
64
+ def __init__(self,
65
+ data_dir: pathlib.Path,
66
+ task: str = 'homo',
67
+ batch_size: int = 240,
68
+ num_workers: int = 8,
69
+ num_degrees: int = 4,
70
+ amp: bool = False,
71
+ precompute_bases: bool = False,
72
+ **kwargs):
73
+ self.data_dir = data_dir # This needs to be before __init__ so that prepare_data has access to it
74
+ super().__init__(batch_size=batch_size, num_workers=num_workers, collate_fn=self._collate)
75
+ self.amp = amp
76
+ self.task = task
77
+ self.batch_size = batch_size
78
+ self.num_degrees = num_degrees
79
+
80
+ qm9_kwargs = dict(label_keys=[self.task], verbose=False, raw_dir=str(data_dir))
81
+ if precompute_bases:
82
+ bases_kwargs = dict(max_degree=num_degrees - 1, use_pad_trick=using_tensor_cores(amp), amp=amp)
83
+ full_dataset = CachedBasesQM9EdgeDataset(bases_kwargs=bases_kwargs, batch_size=batch_size,
84
+ num_workers=num_workers, **qm9_kwargs)
85
+ else:
86
+ full_dataset = QM9EdgeDataset(**qm9_kwargs)
87
+
88
+ self.ds_train, self.ds_val, self.ds_test = random_split(full_dataset, _get_split_sizes(full_dataset),
89
+ generator=torch.Generator().manual_seed(0))
90
+
91
+ train_targets = full_dataset.targets[self.ds_train.indices, full_dataset.label_keys[0]]
92
+ self.targets_mean = train_targets.mean()
93
+ self.targets_std = train_targets.std()
94
+
95
+ def prepare_data(self):
96
+ # Download the QM9 preprocessed data
97
+ QM9EdgeDataset(verbose=True, raw_dir=str(self.data_dir))
98
+
99
+ def _collate(self, samples):
100
+ graphs, y, *bases = map(list, zip(*samples))
101
+ batched_graph = dgl.batch(graphs)
102
+ edge_feats = {'0': batched_graph.edata['edge_attr'][..., None]}
103
+ batched_graph.edata['rel_pos'] = _get_relative_pos(batched_graph)
104
+ # get node features
105
+ node_feats = {'0': batched_graph.ndata['attr'][:, :6, None]}
106
+ targets = (torch.cat(y) - self.targets_mean) / self.targets_std
107
+
108
+ if bases:
109
+ # collate bases
110
+ all_bases = {
111
+ key: torch.cat([b[key] for b in bases[0]], dim=0)
112
+ for key in bases[0][0].keys()
113
+ }
114
+
115
+ return batched_graph, node_feats, edge_feats, all_bases, targets
116
+ else:
117
+ return batched_graph, node_feats, edge_feats, targets
118
+
119
+ @staticmethod
120
+ def add_argparse_args(parent_parser):
121
+ parser = parent_parser.add_argument_group("QM9 dataset")
122
+ parser.add_argument('--task', type=str, default='homo', const='homo', nargs='?',
123
+ choices=['mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv',
124
+ 'U0_atom', 'U_atom', 'H_atom', 'G_atom', 'A', 'B', 'C'],
125
+ help='Regression task to train on')
126
+ parser.add_argument('--precompute_bases', type=str2bool, nargs='?', const=True, default=False,
127
+ help='Precompute bases at the beginning of the script during dataset initialization,'
128
+ ' instead of computing them at the beginning of each forward pass.')
129
+ return parent_parser
130
+
131
+ def __repr__(self):
132
+ return f'QM9({self.task})'
133
+
134
+
135
+ class CachedBasesQM9EdgeDataset(QM9EdgeDataset):
136
+ """ Dataset extending the QM9 dataset from DGL with precomputed (cached in RAM) pairwise bases """
137
+
138
+ def __init__(self, bases_kwargs: dict, batch_size: int, num_workers: int, *args, **kwargs):
139
+ """
140
+ :param bases_kwargs: Arguments to feed the bases computation function
141
+ :param batch_size: Batch size to use when iterating over the dataset for computing bases
142
+ """
143
+ self.bases_kwargs = bases_kwargs
144
+ self.batch_size = batch_size
145
+ self.bases = None
146
+ self.num_workers = num_workers
147
+ super().__init__(*args, **kwargs)
148
+
149
+ def load(self):
150
+ super().load()
151
+ # Iterate through the dataset and compute bases (pairwise only)
152
+ # Potential improvement: use multi-GPU and gather
153
+ dataloader = DataLoader(self, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers,
154
+ collate_fn=lambda samples: dgl.batch([sample[0] for sample in samples]))
155
+ bases = []
156
+ for i, graph in tqdm(enumerate(dataloader), total=len(dataloader), desc='Precomputing QM9 bases',
157
+ disable=get_local_rank() != 0):
158
+ rel_pos = _get_relative_pos(graph)
159
+ # Compute the bases with the GPU but convert the result to CPU to store in RAM
160
+ bases.append({k: v.cpu() for k, v in get_basis(rel_pos.cuda(), **self.bases_kwargs).items()})
161
+ self.bases = bases # Assign at the end so that __getitem__ isn't confused
162
+
163
+ def __getitem__(self, idx: int):
164
+ graph, label = super().__getitem__(idx)
165
+
166
+ if self.bases:
167
+ bases_idx = idx // self.batch_size
168
+ bases_cumsum_idx = self.ne_cumsum[idx] - self.ne_cumsum[bases_idx * self.batch_size]
169
+ bases_cumsum_next_idx = self.ne_cumsum[idx + 1] - self.ne_cumsum[bases_idx * self.batch_size]
170
+ return graph, label, {key: basis[bases_cumsum_idx:bases_cumsum_next_idx] for key, basis in
171
+ self.bases[bases_idx].items()}
172
+ else:
173
+ return graph, label
se3_transformer/model/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .transformer import SE3Transformer, SE3TransformerPooled
2
+ from .fiber import Fiber
se3_transformer/model/basis.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a
4
+ # copy of this software and associated documentation files (the "Software"),
5
+ # to deal in the Software without restriction, including without limitation
6
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense,
7
+ # and/or sell copies of the Software, and to permit persons to whom the
8
+ # Software is furnished to do so, subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in
11
+ # all copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
16
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18
+ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19
+ # DEALINGS IN THE SOFTWARE.
20
+ #
21
+ # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
22
+ # SPDX-License-Identifier: MIT
23
+
24
+
25
+ from functools import lru_cache
26
+ from typing import Dict, List
27
+
28
+ import e3nn.o3 as o3
29
+ import torch
30
+ import torch.nn.functional as F
31
+ from torch import Tensor
32
+ from torch.cuda.nvtx import range as nvtx_range
33
+
34
+ from se3_transformer.runtime.utils import degree_to_dim
35
+
36
+
37
+ @lru_cache(maxsize=None)
38
+ def get_clebsch_gordon(J: int, d_in: int, d_out: int, device) -> Tensor:
39
+ """ Get the (cached) Q^{d_out,d_in}_J matrices from equation (8) """
40
+ return o3.wigner_3j(J, d_in, d_out, dtype=torch.float64, device=device).permute(2, 1, 0)
41
+
42
+
43
+ @lru_cache(maxsize=None)
44
+ def get_all_clebsch_gordon(max_degree: int, device) -> List[List[Tensor]]:
45
+ all_cb = []
46
+ for d_in in range(max_degree + 1):
47
+ for d_out in range(max_degree + 1):
48
+ K_Js = []
49
+ for J in range(abs(d_in - d_out), d_in + d_out + 1):
50
+ K_Js.append(get_clebsch_gordon(J, d_in, d_out, device))
51
+ all_cb.append(K_Js)
52
+ return all_cb
53
+
54
+
55
+ def get_spherical_harmonics(relative_pos: Tensor, max_degree: int) -> List[Tensor]:
56
+ all_degrees = list(range(2 * max_degree + 1))
57
+ with nvtx_range('spherical harmonics'):
58
+ sh = o3.spherical_harmonics(all_degrees, relative_pos, normalize=True)
59
+ return torch.split(sh, [degree_to_dim(d) for d in all_degrees], dim=1)
60
+
61
+
62
+ @torch.jit.script
63
+ def get_basis_script(max_degree: int,
64
+ use_pad_trick: bool,
65
+ spherical_harmonics: List[Tensor],
66
+ clebsch_gordon: List[List[Tensor]],
67
+ amp: bool) -> Dict[str, Tensor]:
68
+ """
69
+ Compute pairwise bases matrices for degrees up to max_degree
70
+ :param max_degree: Maximum input or output degree
71
+ :param use_pad_trick: Pad some of the odd dimensions for a better use of Tensor Cores
72
+ :param spherical_harmonics: List of computed spherical harmonics
73
+ :param clebsch_gordon: List of computed CB-coefficients
74
+ :param amp: When true, return bases in FP16 precision
75
+ """
76
+ basis = {}
77
+ idx = 0
78
+ # Double for loop instead of product() because of JIT script
79
+ for d_in in range(max_degree + 1):
80
+ for d_out in range(max_degree + 1):
81
+ key = f'{d_in},{d_out}'
82
+ K_Js = []
83
+ for freq_idx, J in enumerate(range(abs(d_in - d_out), d_in + d_out + 1)):
84
+ Q_J = clebsch_gordon[idx][freq_idx]
85
+ K_Js.append(torch.einsum('n f, k l f -> n l k', spherical_harmonics[J].float(), Q_J.float()))
86
+
87
+ basis[key] = torch.stack(K_Js, 2) # Stack on second dim so order is n l f k
88
+ if amp:
89
+ basis[key] = basis[key].half()
90
+ if use_pad_trick:
91
+ basis[key] = F.pad(basis[key], (0, 1)) # Pad the k dimension, that can be sliced later
92
+
93
+ idx += 1
94
+
95
+ return basis
96
+
97
+
98
+ @torch.jit.script
99
+ def update_basis_with_fused(basis: Dict[str, Tensor],
100
+ max_degree: int,
101
+ use_pad_trick: bool,
102
+ fully_fused: bool) -> Dict[str, Tensor]:
103
+ """ Update the basis dict with partially and optionally fully fused bases """
104
+ num_edges = basis['0,0'].shape[0]
105
+ device = basis['0,0'].device
106
+ dtype = basis['0,0'].dtype
107
+ sum_dim = sum([degree_to_dim(d) for d in range(max_degree + 1)])
108
+
109
+ # Fused per output degree
110
+ for d_out in range(max_degree + 1):
111
+ sum_freq = sum([degree_to_dim(min(d, d_out)) for d in range(max_degree + 1)])
112
+ basis_fused = torch.zeros(num_edges, sum_dim, sum_freq, degree_to_dim(d_out) + int(use_pad_trick),
113
+ device=device, dtype=dtype)
114
+ acc_d, acc_f = 0, 0
115
+ for d_in in range(max_degree + 1):
116
+ basis_fused[:, acc_d:acc_d + degree_to_dim(d_in), acc_f:acc_f + degree_to_dim(min(d_out, d_in)),
117
+ :degree_to_dim(d_out)] = basis[f'{d_in},{d_out}'][:, :, :, :degree_to_dim(d_out)]
118
+
119
+ acc_d += degree_to_dim(d_in)
120
+ acc_f += degree_to_dim(min(d_out, d_in))
121
+
122
+ basis[f'out{d_out}_fused'] = basis_fused
123
+
124
+ # Fused per input degree
125
+ for d_in in range(max_degree + 1):
126
+ sum_freq = sum([degree_to_dim(min(d, d_in)) for d in range(max_degree + 1)])
127
+ basis_fused = torch.zeros(num_edges, degree_to_dim(d_in), sum_freq, sum_dim,
128
+ device=device, dtype=dtype)
129
+ acc_d, acc_f = 0, 0
130
+ for d_out in range(max_degree + 1):
131
+ basis_fused[:, :, acc_f:acc_f + degree_to_dim(min(d_out, d_in)), acc_d:acc_d + degree_to_dim(d_out)] \
132
+ = basis[f'{d_in},{d_out}'][:, :, :, :degree_to_dim(d_out)]
133
+
134
+ acc_d += degree_to_dim(d_out)
135
+ acc_f += degree_to_dim(min(d_out, d_in))
136
+
137
+ basis[f'in{d_in}_fused'] = basis_fused
138
+
139
+ if fully_fused:
140
+ # Fully fused
141
+ # Double sum this way because of JIT script
142
+ sum_freq = sum([
143
+ sum([degree_to_dim(min(d_in, d_out)) for d_in in range(max_degree + 1)]) for d_out in range(max_degree + 1)
144
+ ])
145
+ basis_fused = torch.zeros(num_edges, sum_dim, sum_freq, sum_dim, device=device, dtype=dtype)
146
+
147
+ acc_d, acc_f = 0, 0
148
+ for d_out in range(max_degree + 1):
149
+ b = basis[f'out{d_out}_fused']
150
+ basis_fused[:, :, acc_f:acc_f + b.shape[2], acc_d:acc_d + degree_to_dim(d_out)] = b[:, :, :,
151
+ :degree_to_dim(d_out)]
152
+ acc_f += b.shape[2]
153
+ acc_d += degree_to_dim(d_out)
154
+
155
+ basis['fully_fused'] = basis_fused
156
+
157
+ del basis['0,0'] # We know that the basis for l = k = 0 is filled with a constant
158
+ return basis
159
+
160
+
161
+ def get_basis(relative_pos: Tensor,
162
+ max_degree: int = 4,
163
+ compute_gradients: bool = False,
164
+ use_pad_trick: bool = False,
165
+ amp: bool = False) -> Dict[str, Tensor]:
166
+ with nvtx_range('spherical harmonics'):
167
+ spherical_harmonics = get_spherical_harmonics(relative_pos, max_degree)
168
+ with nvtx_range('CB coefficients'):
169
+ clebsch_gordon = get_all_clebsch_gordon(max_degree, relative_pos.device)
170
+
171
+ with torch.autograd.set_grad_enabled(compute_gradients):
172
+ with nvtx_range('bases'):
173
+ basis = get_basis_script(max_degree=max_degree,
174
+ use_pad_trick=use_pad_trick,
175
+ spherical_harmonics=spherical_harmonics,
176
+ clebsch_gordon=clebsch_gordon,
177
+ amp=amp)
178
+ return basis
se3_transformer/model/fiber.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a
4
+ # copy of this software and associated documentation files (the "Software"),
5
+ # to deal in the Software without restriction, including without limitation
6
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense,
7
+ # and/or sell copies of the Software, and to permit persons to whom the
8
+ # Software is furnished to do so, subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in
11
+ # all copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
16
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18
+ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19
+ # DEALINGS IN THE SOFTWARE.
20
+ #
21
+ # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
22
+ # SPDX-License-Identifier: MIT
23
+
24
+
25
+ from collections import namedtuple
26
+ from itertools import product
27
+ from typing import Dict
28
+
29
+ import torch
30
+ from torch import Tensor
31
+
32
+ from se3_transformer.runtime.utils import degree_to_dim
33
+
34
+ FiberEl = namedtuple('FiberEl', ['degree', 'channels'])
35
+
36
+
37
+ class Fiber(dict):
38
+ """
39
+ Describes the structure of some set of features.
40
+ Features are split into types (0, 1, 2, 3, ...). A feature of type k has a dimension of 2k+1.
41
+ Type-0 features: invariant scalars
42
+ Type-1 features: equivariant 3D vectors
43
+ Type-2 features: equivariant symmetric traceless matrices
44
+ ...
45
+
46
+ As inputs to a SE3 layer, there can be many features of the same types, and many features of different types.
47
+ The 'multiplicity' or 'number of channels' is the number of features of a given type.
48
+ This class puts together all the degrees and their multiplicities in order to describe
49
+ the inputs, outputs or hidden features of SE3 layers.
50
+ """
51
+
52
+ def __init__(self, structure):
53
+ if isinstance(structure, dict):
54
+ structure = [FiberEl(int(d), int(m)) for d, m in sorted(structure.items(), key=lambda x: x[1])]
55
+ elif not isinstance(structure[0], FiberEl):
56
+ structure = list(map(lambda t: FiberEl(*t), sorted(structure, key=lambda x: x[1])))
57
+ self.structure = structure
58
+ super().__init__({d: m for d, m in self.structure})
59
+
60
+ @property
61
+ def degrees(self):
62
+ return sorted([t.degree for t in self.structure])
63
+
64
+ @property
65
+ def channels(self):
66
+ return [self[d] for d in self.degrees]
67
+
68
+ @property
69
+ def num_features(self):
70
+ """ Size of the resulting tensor if all features were concatenated together """
71
+ return sum(t.channels * degree_to_dim(t.degree) for t in self.structure)
72
+
73
+ @staticmethod
74
+ def create(num_degrees: int, num_channels: int):
75
+ """ Create a Fiber with degrees 0..num_degrees-1, all with the same multiplicity """
76
+ return Fiber([(degree, num_channels) for degree in range(num_degrees)])
77
+
78
+ @staticmethod
79
+ def from_features(feats: Dict[str, Tensor]):
80
+ """ Infer the Fiber structure from a feature dict """
81
+ structure = {}
82
+ for k, v in feats.items():
83
+ degree = int(k)
84
+ assert len(v.shape) == 3, 'Feature shape should be (N, C, 2D+1)'
85
+ assert v.shape[-1] == degree_to_dim(degree)
86
+ structure[degree] = v.shape[-2]
87
+ return Fiber(structure)
88
+
89
+ def __getitem__(self, degree: int):
90
+ """ fiber[degree] returns the multiplicity for this degree """
91
+ return dict(self.structure).get(degree, 0)
92
+
93
+ def __iter__(self):
94
+ """ Iterate over namedtuples (degree, channels) """
95
+ return iter(self.structure)
96
+
97
+ def __mul__(self, other):
98
+ """
99
+ If other in an int, multiplies all the multiplicities by other.
100
+ If other is a fiber, returns the cartesian product.
101
+ """
102
+ if isinstance(other, Fiber):
103
+ return product(self.structure, other.structure)
104
+ elif isinstance(other, int):
105
+ return Fiber({t.degree: t.channels * other for t in self.structure})
106
+
107
+ def __add__(self, other):
108
+ """
109
+ If other in an int, add other to all the multiplicities.
110
+ If other is a fiber, add the multiplicities of the fibers together.
111
+ """
112
+ if isinstance(other, Fiber):
113
+ return Fiber({t.degree: t.channels + other[t.degree] for t in self.structure})
114
+ elif isinstance(other, int):
115
+ return Fiber({t.degree: t.channels + other for t in self.structure})
116
+
117
+ def __repr__(self):
118
+ return str(self.structure)
119
+
120
+ @staticmethod
121
+ def combine_max(f1, f2):
122
+ """ Combine two fiber by taking the maximum multiplicity for each degree in both fibers """
123
+ new_dict = dict(f1.structure)
124
+ for k, m in f2.structure:
125
+ new_dict[k] = max(new_dict.get(k, 0), m)
126
+
127
+ return Fiber(list(new_dict.items()))
128
+
129
+ @staticmethod
130
+ def combine_selectively(f1, f2):
131
+ """ Combine two fiber by taking the sum of multiplicities for each degree in the first fiber """
132
+ # only use orders which occur in fiber f1
133
+ new_dict = dict(f1.structure)
134
+ for k in f1.degrees:
135
+ if k in f2.degrees:
136
+ new_dict[k] += f2[k]
137
+ return Fiber(list(new_dict.items()))
138
+
139
+ def to_attention_heads(self, tensors: Dict[str, Tensor], num_heads: int):
140
+ # dict(N, num_channels, 2d+1) -> (N, num_heads, -1)
141
+ fibers = [tensors[str(degree)].reshape(*tensors[str(degree)].shape[:-2], num_heads, -1) for degree in
142
+ self.degrees]
143
+ fibers = torch.cat(fibers, -1)
144
+ return fibers
se3_transformer/model/layers/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .linear import LinearSE3
2
+ from .norm import NormSE3
3
+ from .pooling import GPooling
4
+ from .convolution import ConvSE3
5
+ from .attention import AttentionBlockSE3
se3_transformer/model/layers/attention.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a
4
+ # copy of this software and associated documentation files (the "Software"),
5
+ # to deal in the Software without restriction, including without limitation
6
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense,
7
+ # and/or sell copies of the Software, and to permit persons to whom the
8
+ # Software is furnished to do so, subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in
11
+ # all copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
16
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18
+ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19
+ # DEALINGS IN THE SOFTWARE.
20
+ #
21
+ # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
22
+ # SPDX-License-Identifier: MIT
23
+
24
+ import dgl
25
+ import numpy as np
26
+ import torch
27
+ import torch.nn as nn
28
+ from dgl import DGLGraph
29
+ from dgl.ops import edge_softmax
30
+ from torch import Tensor
31
+ from typing import Dict, Optional, Union
32
+
33
+ from se3_transformer.model.fiber import Fiber
34
+ from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel
35
+ from se3_transformer.model.layers.linear import LinearSE3
36
+ from se3_transformer.runtime.utils import degree_to_dim, aggregate_residual, unfuse_features
37
+ from torch.cuda.nvtx import range as nvtx_range
38
+
39
+
40
+ class AttentionSE3(nn.Module):
41
+ """ Multi-headed sparse graph self-attention (SE(3)-equivariant) """
42
+
43
+ def __init__(
44
+ self,
45
+ num_heads: int,
46
+ key_fiber: Fiber,
47
+ value_fiber: Fiber
48
+ ):
49
+ """
50
+ :param num_heads: Number of attention heads
51
+ :param key_fiber: Fiber for the keys (and also for the queries)
52
+ :param value_fiber: Fiber for the values
53
+ """
54
+ super().__init__()
55
+ self.num_heads = num_heads
56
+ self.key_fiber = key_fiber
57
+ self.value_fiber = value_fiber
58
+
59
+ def forward(
60
+ self,
61
+ value: Union[Tensor, Dict[str, Tensor]], # edge features (may be fused)
62
+ key: Union[Tensor, Dict[str, Tensor]], # edge features (may be fused)
63
+ query: Dict[str, Tensor], # node features
64
+ graph: DGLGraph
65
+ ):
66
+ with nvtx_range('AttentionSE3'):
67
+ with nvtx_range('reshape keys and queries'):
68
+ if isinstance(key, Tensor):
69
+ # case where features of all types are fused
70
+ key = key.reshape(key.shape[0], self.num_heads, -1)
71
+ # need to reshape queries that way to keep the same layout as keys
72
+ out = torch.cat([query[str(d)] for d in self.key_fiber.degrees], dim=-1)
73
+ query = out.reshape(list(query.values())[0].shape[0], self.num_heads, -1)
74
+ else:
75
+ # features are not fused, need to fuse and reshape them
76
+ key = self.key_fiber.to_attention_heads(key, self.num_heads)
77
+ query = self.key_fiber.to_attention_heads(query, self.num_heads)
78
+
79
+ with nvtx_range('attention dot product + softmax'):
80
+ # Compute attention weights (softmax of inner product between key and query)
81
+ edge_weights = dgl.ops.e_dot_v(graph, key, query).squeeze(-1)
82
+ edge_weights /= np.sqrt(self.key_fiber.num_features)
83
+ edge_weights = edge_softmax(graph, edge_weights)
84
+ edge_weights = edge_weights[..., None, None]
85
+
86
+ with nvtx_range('weighted sum'):
87
+ if isinstance(value, Tensor):
88
+ # features of all types are fused
89
+ v = value.view(value.shape[0], self.num_heads, -1, value.shape[-1])
90
+ weights = edge_weights * v
91
+ feat_out = dgl.ops.copy_e_sum(graph, weights)
92
+ feat_out = feat_out.view(feat_out.shape[0], -1, feat_out.shape[-1]) # merge heads
93
+ out = unfuse_features(feat_out, self.value_fiber.degrees)
94
+ else:
95
+ out = {}
96
+ for degree, channels in self.value_fiber:
97
+ v = value[str(degree)].view(-1, self.num_heads, channels // self.num_heads,
98
+ degree_to_dim(degree))
99
+ weights = edge_weights * v
100
+ res = dgl.ops.copy_e_sum(graph, weights)
101
+ out[str(degree)] = res.view(-1, channels, degree_to_dim(degree)) # merge heads
102
+
103
+ return out
104
+
105
+
106
+ class AttentionBlockSE3(nn.Module):
107
+ """ Multi-headed sparse graph self-attention block with skip connection, linear projection (SE(3)-equivariant) """
108
+
109
+ def __init__(
110
+ self,
111
+ fiber_in: Fiber,
112
+ fiber_out: Fiber,
113
+ fiber_edge: Optional[Fiber] = None,
114
+ num_heads: int = 4,
115
+ channels_div: int = 2,
116
+ use_layer_norm: bool = False,
117
+ max_degree: bool = 4,
118
+ fuse_level: ConvSE3FuseLevel = ConvSE3FuseLevel.FULL,
119
+ **kwargs
120
+ ):
121
+ """
122
+ :param fiber_in: Fiber describing the input features
123
+ :param fiber_out: Fiber describing the output features
124
+ :param fiber_edge: Fiber describing the edge features (node distances excluded)
125
+ :param num_heads: Number of attention heads
126
+ :param channels_div: Divide the channels by this integer for computing values
127
+ :param use_layer_norm: Apply layer normalization between MLP layers
128
+ :param max_degree: Maximum degree used in the bases computation
129
+ :param fuse_level: Maximum fuse level to use in TFN convolutions
130
+ """
131
+ super().__init__()
132
+ if fiber_edge is None:
133
+ fiber_edge = Fiber({})
134
+ self.fiber_in = fiber_in
135
+ # value_fiber has same structure as fiber_out but #channels divided by 'channels_div'
136
+ value_fiber = Fiber([(degree, channels // channels_div) for degree, channels in fiber_out])
137
+ # key_query_fiber has the same structure as fiber_out, but only degrees which are in in_fiber
138
+ # (queries are merely projected, hence degrees have to match input)
139
+ key_query_fiber = Fiber([(fe.degree, fe.channels) for fe in value_fiber if fe.degree in fiber_in.degrees])
140
+
141
+ self.to_key_value = ConvSE3(fiber_in, value_fiber + key_query_fiber, pool=False, fiber_edge=fiber_edge,
142
+ use_layer_norm=use_layer_norm, max_degree=max_degree, fuse_level=fuse_level,
143
+ allow_fused_output=True)
144
+ self.to_query = LinearSE3(fiber_in, key_query_fiber)
145
+ self.attention = AttentionSE3(num_heads, key_query_fiber, value_fiber)
146
+ self.project = LinearSE3(value_fiber + fiber_in, fiber_out)
147
+
148
+ def forward(
149
+ self,
150
+ node_features: Dict[str, Tensor],
151
+ edge_features: Dict[str, Tensor],
152
+ graph: DGLGraph,
153
+ basis: Dict[str, Tensor]
154
+ ):
155
+ with nvtx_range('AttentionBlockSE3'):
156
+ with nvtx_range('keys / values'):
157
+ fused_key_value = self.to_key_value(node_features, edge_features, graph, basis)
158
+ key, value = self._get_key_value_from_fused(fused_key_value)
159
+
160
+ with nvtx_range('queries'):
161
+ query = self.to_query(node_features)
162
+
163
+ z = self.attention(value, key, query, graph)
164
+ z_concat = aggregate_residual(node_features, z, 'cat')
165
+ return self.project(z_concat)
166
+
167
+ def _get_key_value_from_fused(self, fused_key_value):
168
+ # Extract keys and queries features from fused features
169
+ if isinstance(fused_key_value, Tensor):
170
+ # Previous layer was a fully fused convolution
171
+ value, key = torch.chunk(fused_key_value, chunks=2, dim=-2)
172
+ else:
173
+ key, value = {}, {}
174
+ for degree, feat in fused_key_value.items():
175
+ if int(degree) in self.fiber_in.degrees:
176
+ value[degree], key[degree] = torch.chunk(feat, chunks=2, dim=-2)
177
+ else:
178
+ value[degree] = feat
179
+
180
+ return key, value
se3_transformer/model/layers/convolution.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a
4
+ # copy of this software and associated documentation files (the "Software"),
5
+ # to deal in the Software without restriction, including without limitation
6
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense,
7
+ # and/or sell copies of the Software, and to permit persons to whom the
8
+ # Software is furnished to do so, subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in
11
+ # all copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
16
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18
+ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19
+ # DEALINGS IN THE SOFTWARE.
20
+ #
21
+ # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
22
+ # SPDX-License-Identifier: MIT
23
+
24
+ from enum import Enum
25
+ from itertools import product
26
+ from typing import Dict
27
+
28
+ import dgl
29
+ import numpy as np
30
+ import torch
31
+ import torch.nn as nn
32
+ from dgl import DGLGraph
33
+ from torch import Tensor
34
+ from torch.cuda.nvtx import range as nvtx_range
35
+
36
+ from se3_transformer.model.fiber import Fiber
37
+ from se3_transformer.runtime.utils import degree_to_dim, unfuse_features
38
+
39
+
40
+ class ConvSE3FuseLevel(Enum):
41
+ """
42
+ Enum to select a maximum level of fusing optimizations that will be applied when certain conditions are met.
43
+ If a desired level L is picked and the level L cannot be applied to a level, other fused ops < L are considered.
44
+ A higher level means faster training, but also more memory usage.
45
+ If you are tight on memory and want to feed large inputs to the network, choose a low value.
46
+ If you want to train fast, choose a high value.
47
+ Recommended value is FULL with AMP.
48
+
49
+ Fully fused TFN convolutions requirements:
50
+ - all input channels are the same
51
+ - all output channels are the same
52
+ - input degrees span the range [0, ..., max_degree]
53
+ - output degrees span the range [0, ..., max_degree]
54
+
55
+ Partially fused TFN convolutions requirements:
56
+ * For fusing by output degree:
57
+ - all input channels are the same
58
+ - input degrees span the range [0, ..., max_degree]
59
+ * For fusing by input degree:
60
+ - all output channels are the same
61
+ - output degrees span the range [0, ..., max_degree]
62
+
63
+ Original TFN pairwise convolutions: no requirements
64
+ """
65
+
66
+ FULL = 2
67
+ PARTIAL = 1
68
+ NONE = 0
69
+
70
+
71
+ class RadialProfile(nn.Module):
72
+ """
73
+ Radial profile function.
74
+ Outputs weights used to weigh basis matrices in order to get convolution kernels.
75
+ In TFN notation: $R^{l,k}$
76
+ In SE(3)-Transformer notation: $\phi^{l,k}$
77
+
78
+ Note:
79
+ In the original papers, this function only depends on relative node distances ||x||.
80
+ Here, we allow this function to also take as input additional invariant edge features.
81
+ This does not break equivariance and adds expressive power to the model.
82
+
83
+ Diagram:
84
+ invariant edge features (node distances included) ───> MLP layer (shared across edges) ───> radial weights
85
+ """
86
+
87
+ def __init__(
88
+ self,
89
+ num_freq: int,
90
+ channels_in: int,
91
+ channels_out: int,
92
+ edge_dim: int = 1,
93
+ mid_dim: int = 32,
94
+ use_layer_norm: bool = False
95
+ ):
96
+ """
97
+ :param num_freq: Number of frequencies
98
+ :param channels_in: Number of input channels
99
+ :param channels_out: Number of output channels
100
+ :param edge_dim: Number of invariant edge features (input to the radial function)
101
+ :param mid_dim: Size of the hidden MLP layers
102
+ :param use_layer_norm: Apply layer normalization between MLP layers
103
+ """
104
+ super().__init__()
105
+ modules = [
106
+ nn.Linear(edge_dim, mid_dim),
107
+ nn.LayerNorm(mid_dim) if use_layer_norm else None,
108
+ nn.ReLU(),
109
+ nn.Linear(mid_dim, mid_dim),
110
+ nn.LayerNorm(mid_dim) if use_layer_norm else None,
111
+ nn.ReLU(),
112
+ nn.Linear(mid_dim, num_freq * channels_in * channels_out, bias=False)
113
+ ]
114
+
115
+ self.net = nn.Sequential(*[m for m in modules if m is not None])
116
+
117
+ def forward(self, features: Tensor) -> Tensor:
118
+ return self.net(features)
119
+
120
+
121
+ class VersatileConvSE3(nn.Module):
122
+ """
123
+ Building block for TFN convolutions.
124
+ This single module can be used for fully fused convolutions, partially fused convolutions, or pairwise convolutions.
125
+ """
126
+
127
+ def __init__(self,
128
+ freq_sum: int,
129
+ channels_in: int,
130
+ channels_out: int,
131
+ edge_dim: int,
132
+ use_layer_norm: bool,
133
+ fuse_level: ConvSE3FuseLevel):
134
+ super().__init__()
135
+ self.freq_sum = freq_sum
136
+ self.channels_out = channels_out
137
+ self.channels_in = channels_in
138
+ self.fuse_level = fuse_level
139
+ self.radial_func = RadialProfile(num_freq=freq_sum,
140
+ channels_in=channels_in,
141
+ channels_out=channels_out,
142
+ edge_dim=edge_dim,
143
+ use_layer_norm=use_layer_norm)
144
+
145
+ def forward(self, features: Tensor, invariant_edge_feats: Tensor, basis: Tensor):
146
+ with nvtx_range(f'VersatileConvSE3'):
147
+ num_edges = features.shape[0]
148
+ in_dim = features.shape[2]
149
+ with nvtx_range(f'RadialProfile'):
150
+ radial_weights = self.radial_func(invariant_edge_feats) \
151
+ .view(-1, self.channels_out, self.channels_in * self.freq_sum)
152
+
153
+ if basis is not None:
154
+ # This block performs the einsum n i l, n o i f, n l f k -> n o k
155
+ out_dim = basis.shape[-1]
156
+ if self.fuse_level != ConvSE3FuseLevel.FULL:
157
+ out_dim += out_dim % 2 - 1 # Account for padded basis
158
+ basis_view = basis.view(num_edges, in_dim, -1)
159
+ tmp = (features @ basis_view).view(num_edges, -1, basis.shape[-1])
160
+ return (radial_weights @ tmp)[:, :, :out_dim]
161
+ else:
162
+ # k = l = 0 non-fused case
163
+ return radial_weights @ features
164
+
165
+
166
+ class ConvSE3(nn.Module):
167
+ """
168
+ SE(3)-equivariant graph convolution (Tensor Field Network convolution).
169
+ This convolution can map an arbitrary input Fiber to an arbitrary output Fiber, while preserving equivariance.
170
+ Features of different degrees interact together to produce output features.
171
+
172
+ Note 1:
173
+ The option is given to not pool the output. This means that the convolution sum over neighbors will not be
174
+ done, and the returned features will be edge features instead of node features.
175
+
176
+ Note 2:
177
+ Unlike the original paper and implementation, this convolution can handle edge feature of degree greater than 0.
178
+ Input edge features are concatenated with input source node features before the kernel is applied.
179
+ """
180
+
181
+ def __init__(
182
+ self,
183
+ fiber_in: Fiber,
184
+ fiber_out: Fiber,
185
+ fiber_edge: Fiber,
186
+ pool: bool = True,
187
+ use_layer_norm: bool = False,
188
+ self_interaction: bool = False,
189
+ max_degree: int = 4,
190
+ fuse_level: ConvSE3FuseLevel = ConvSE3FuseLevel.FULL,
191
+ allow_fused_output: bool = False
192
+ ):
193
+ """
194
+ :param fiber_in: Fiber describing the input features
195
+ :param fiber_out: Fiber describing the output features
196
+ :param fiber_edge: Fiber describing the edge features (node distances excluded)
197
+ :param pool: If True, compute final node features by averaging incoming edge features
198
+ :param use_layer_norm: Apply layer normalization between MLP layers
199
+ :param self_interaction: Apply self-interaction of nodes
200
+ :param max_degree: Maximum degree used in the bases computation
201
+ :param fuse_level: Maximum fuse level to use in TFN convolutions
202
+ :param allow_fused_output: Allow the module to output a fused representation of features
203
+ """
204
+ super().__init__()
205
+ self.pool = pool
206
+ self.fiber_in = fiber_in
207
+ self.fiber_out = fiber_out
208
+ self.self_interaction = self_interaction
209
+ self.max_degree = max_degree
210
+ self.allow_fused_output = allow_fused_output
211
+
212
+ # channels_in: account for the concatenation of edge features
213
+ channels_in_set = set([f.channels + fiber_edge[f.degree] * (f.degree > 0) for f in self.fiber_in])
214
+ channels_out_set = set([f.channels for f in self.fiber_out])
215
+ unique_channels_in = (len(channels_in_set) == 1)
216
+ unique_channels_out = (len(channels_out_set) == 1)
217
+ degrees_up_to_max = list(range(max_degree + 1))
218
+ common_args = dict(edge_dim=fiber_edge[0] + 1, use_layer_norm=use_layer_norm)
219
+
220
+ if fuse_level.value >= ConvSE3FuseLevel.FULL.value and \
221
+ unique_channels_in and fiber_in.degrees == degrees_up_to_max and \
222
+ unique_channels_out and fiber_out.degrees == degrees_up_to_max:
223
+ # Single fused convolution
224
+ self.used_fuse_level = ConvSE3FuseLevel.FULL
225
+
226
+ sum_freq = sum([
227
+ degree_to_dim(min(d_in, d_out))
228
+ for d_in, d_out in product(degrees_up_to_max, degrees_up_to_max)
229
+ ])
230
+
231
+ self.conv = VersatileConvSE3(sum_freq, list(channels_in_set)[0], list(channels_out_set)[0],
232
+ fuse_level=self.used_fuse_level, **common_args)
233
+
234
+ elif fuse_level.value >= ConvSE3FuseLevel.PARTIAL.value and \
235
+ unique_channels_in and fiber_in.degrees == degrees_up_to_max:
236
+ # Convolutions fused per output degree
237
+ self.used_fuse_level = ConvSE3FuseLevel.PARTIAL
238
+ self.conv_out = nn.ModuleDict()
239
+ for d_out, c_out in fiber_out:
240
+ sum_freq = sum([degree_to_dim(min(d_out, d)) for d in fiber_in.degrees])
241
+ self.conv_out[str(d_out)] = VersatileConvSE3(sum_freq, list(channels_in_set)[0], c_out,
242
+ fuse_level=self.used_fuse_level, **common_args)
243
+
244
+ elif fuse_level.value >= ConvSE3FuseLevel.PARTIAL.value and \
245
+ unique_channels_out and fiber_out.degrees == degrees_up_to_max:
246
+ # Convolutions fused per input degree
247
+ self.used_fuse_level = ConvSE3FuseLevel.PARTIAL
248
+ self.conv_in = nn.ModuleDict()
249
+ for d_in, c_in in fiber_in:
250
+ sum_freq = sum([degree_to_dim(min(d_in, d)) for d in fiber_out.degrees])
251
+ self.conv_in[str(d_in)] = VersatileConvSE3(sum_freq, c_in, list(channels_out_set)[0],
252
+ fuse_level=ConvSE3FuseLevel.FULL, **common_args)
253
+ #fuse_level=self.used_fuse_level, **common_args)
254
+ else:
255
+ # Use pairwise TFN convolutions
256
+ self.used_fuse_level = ConvSE3FuseLevel.NONE
257
+ self.conv = nn.ModuleDict()
258
+ for (degree_in, channels_in), (degree_out, channels_out) in (self.fiber_in * self.fiber_out):
259
+ dict_key = f'{degree_in},{degree_out}'
260
+ channels_in_new = channels_in + fiber_edge[degree_in] * (degree_in > 0)
261
+ sum_freq = degree_to_dim(min(degree_in, degree_out))
262
+ self.conv[dict_key] = VersatileConvSE3(sum_freq, channels_in_new, channels_out,
263
+ fuse_level=self.used_fuse_level, **common_args)
264
+
265
+ if self_interaction:
266
+ self.to_kernel_self = nn.ParameterDict()
267
+ for degree_out, channels_out in fiber_out:
268
+ if fiber_in[degree_out]:
269
+ self.to_kernel_self[str(degree_out)] = nn.Parameter(
270
+ torch.randn(channels_out, fiber_in[degree_out]) / np.sqrt(fiber_in[degree_out]))
271
+
272
+ def forward(
273
+ self,
274
+ node_feats: Dict[str, Tensor],
275
+ edge_feats: Dict[str, Tensor],
276
+ graph: DGLGraph,
277
+ basis: Dict[str, Tensor]
278
+ ):
279
+ with nvtx_range(f'ConvSE3'):
280
+ invariant_edge_feats = edge_feats['0'].squeeze(-1)
281
+ src, dst = graph.edges()
282
+ out = {}
283
+ in_features = []
284
+
285
+ # Fetch all input features from edge and node features
286
+ for degree_in in self.fiber_in.degrees:
287
+ src_node_features = node_feats[str(degree_in)][src]
288
+ if degree_in > 0 and str(degree_in) in edge_feats:
289
+ # Handle edge features of any type by concatenating them to node features
290
+ src_node_features = torch.cat([src_node_features, edge_feats[str(degree_in)]], dim=1)
291
+ in_features.append(src_node_features)
292
+
293
+ if self.used_fuse_level == ConvSE3FuseLevel.FULL:
294
+ in_features_fused = torch.cat(in_features, dim=-1)
295
+ out = self.conv(in_features_fused, invariant_edge_feats, basis['fully_fused'])
296
+
297
+ if not self.allow_fused_output or self.self_interaction or self.pool:
298
+ out = unfuse_features(out, self.fiber_out.degrees)
299
+
300
+ elif self.used_fuse_level == ConvSE3FuseLevel.PARTIAL and hasattr(self, 'conv_out'):
301
+ in_features_fused = torch.cat(in_features, dim=-1)
302
+ for degree_out in self.fiber_out.degrees:
303
+ out[str(degree_out)] = self.conv_out[str(degree_out)](in_features_fused, invariant_edge_feats,
304
+ basis[f'out{degree_out}_fused'])
305
+
306
+ elif self.used_fuse_level == ConvSE3FuseLevel.PARTIAL and hasattr(self, 'conv_in'):
307
+ out = 0
308
+ for degree_in, feature in zip(self.fiber_in.degrees, in_features):
309
+ out += self.conv_in[str(degree_in)](feature, invariant_edge_feats,
310
+ basis[f'in{degree_in}_fused'])
311
+ if not self.allow_fused_output or self.self_interaction or self.pool:
312
+ out = unfuse_features(out, self.fiber_out.degrees)
313
+ else:
314
+ # Fallback to pairwise TFN convolutions
315
+ for degree_out in self.fiber_out.degrees:
316
+ out_feature = 0
317
+ for degree_in, feature in zip(self.fiber_in.degrees, in_features):
318
+ dict_key = f'{degree_in},{degree_out}'
319
+ out_feature = out_feature + self.conv[dict_key](feature, invariant_edge_feats,
320
+ basis.get(dict_key, None))
321
+ out[str(degree_out)] = out_feature
322
+
323
+ for degree_out in self.fiber_out.degrees:
324
+ if self.self_interaction and str(degree_out) in self.to_kernel_self:
325
+ with nvtx_range(f'self interaction'):
326
+ dst_features = node_feats[str(degree_out)][dst]
327
+ kernel_self = self.to_kernel_self[str(degree_out)]
328
+ out[str(degree_out)] += kernel_self @ dst_features
329
+
330
+ if self.pool:
331
+ with nvtx_range(f'pooling'):
332
+ if isinstance(out, dict):
333
+ out[str(degree_out)] = dgl.ops.copy_e_sum(graph, out[str(degree_out)])
334
+ else:
335
+ out = dgl.ops.copy_e_sum(graph, out)
336
+ return out
se3_transformer/model/layers/linear.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a
4
+ # copy of this software and associated documentation files (the "Software"),
5
+ # to deal in the Software without restriction, including without limitation
6
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense,
7
+ # and/or sell copies of the Software, and to permit persons to whom the
8
+ # Software is furnished to do so, subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in
11
+ # all copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
16
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18
+ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19
+ # DEALINGS IN THE SOFTWARE.
20
+ #
21
+ # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
22
+ # SPDX-License-Identifier: MIT
23
+
24
+
25
+ from typing import Dict
26
+
27
+ import numpy as np
28
+ import torch
29
+ import torch.nn as nn
30
+ from torch import Tensor
31
+
32
+ from se3_transformer.model.fiber import Fiber
33
+
34
+
35
+ class LinearSE3(nn.Module):
36
+ """
37
+ Graph Linear SE(3)-equivariant layer, equivalent to a 1x1 convolution.
38
+ Maps a fiber to a fiber with the same degrees (channels may be different).
39
+ No interaction between degrees, but interaction between channels.
40
+
41
+ type-0 features (C_0 channels) ────> Linear(bias=False) ────> type-0 features (C'_0 channels)
42
+ type-1 features (C_1 channels) ────> Linear(bias=False) ────> type-1 features (C'_1 channels)
43
+ :
44
+ type-k features (C_k channels) ────> Linear(bias=False) ────> type-k features (C'_k channels)
45
+ """
46
+
47
+ def __init__(self, fiber_in: Fiber, fiber_out: Fiber):
48
+ super().__init__()
49
+ self.weights = nn.ParameterDict({
50
+ str(degree_out): nn.Parameter(
51
+ torch.randn(channels_out, fiber_in[degree_out]) / np.sqrt(fiber_in[degree_out]))
52
+ for degree_out, channels_out in fiber_out
53
+ })
54
+
55
+ def forward(self, features: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Tensor]:
56
+ return {
57
+ degree: self.weights[degree] @ features[degree]
58
+ for degree, weight in self.weights.items()
59
+ }
se3_transformer/model/layers/norm.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a
4
+ # copy of this software and associated documentation files (the "Software"),
5
+ # to deal in the Software without restriction, including without limitation
6
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense,
7
+ # and/or sell copies of the Software, and to permit persons to whom the
8
+ # Software is furnished to do so, subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in
11
+ # all copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
16
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18
+ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19
+ # DEALINGS IN THE SOFTWARE.
20
+ #
21
+ # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
22
+ # SPDX-License-Identifier: MIT
23
+
24
+
25
+ from typing import Dict
26
+
27
+ import torch
28
+ import torch.nn as nn
29
+ from torch import Tensor
30
+ from torch.cuda.nvtx import range as nvtx_range
31
+
32
+ from se3_transformer.model.fiber import Fiber
33
+
34
+
35
+ class NormSE3(nn.Module):
36
+ """
37
+ Norm-based SE(3)-equivariant nonlinearity.
38
+
39
+ ┌──> feature_norm ──> LayerNorm() ──> ReLU() ──┐
40
+ feature_in ──┤ * ──> feature_out
41
+ └──> feature_phase ────────────────────────────┘
42
+ """
43
+
44
+ NORM_CLAMP = 2 ** -24 # Minimum positive subnormal for FP16
45
+
46
+ def __init__(self, fiber: Fiber, nonlinearity: nn.Module = nn.ReLU()):
47
+ super().__init__()
48
+ self.fiber = fiber
49
+ self.nonlinearity = nonlinearity
50
+
51
+ if len(set(fiber.channels)) == 1:
52
+ # Fuse all the layer normalizations into a group normalization
53
+ self.group_norm = nn.GroupNorm(num_groups=len(fiber.degrees), num_channels=sum(fiber.channels))
54
+ else:
55
+ # Use multiple layer normalizations
56
+ self.layer_norms = nn.ModuleDict({
57
+ str(degree): nn.LayerNorm(channels)
58
+ for degree, channels in fiber
59
+ })
60
+
61
+ def forward(self, features: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Tensor]:
62
+ with nvtx_range('NormSE3'):
63
+ output = {}
64
+ if hasattr(self, 'group_norm'):
65
+ # Compute per-degree norms of features
66
+ norms = [features[str(d)].norm(dim=-1, keepdim=True).clamp(min=self.NORM_CLAMP)
67
+ for d in self.fiber.degrees]
68
+ fused_norms = torch.cat(norms, dim=-2)
69
+
70
+ # Transform the norms only
71
+ new_norms = self.nonlinearity(self.group_norm(fused_norms.squeeze(-1))).unsqueeze(-1)
72
+ new_norms = torch.chunk(new_norms, chunks=len(self.fiber.degrees), dim=-2)
73
+
74
+ # Scale features to the new norms
75
+ for norm, new_norm, d in zip(norms, new_norms, self.fiber.degrees):
76
+ output[str(d)] = features[str(d)] / norm * new_norm
77
+ else:
78
+ for degree, feat in features.items():
79
+ norm = feat.norm(dim=-1, keepdim=True).clamp(min=self.NORM_CLAMP)
80
+ new_norm = self.nonlinearity(self.layer_norms[degree](norm.squeeze(-1)).unsqueeze(-1))
81
+ output[degree] = new_norm * feat / norm
82
+
83
+ return output
se3_transformer/model/layers/pooling.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a
4
+ # copy of this software and associated documentation files (the "Software"),
5
+ # to deal in the Software without restriction, including without limitation
6
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense,
7
+ # and/or sell copies of the Software, and to permit persons to whom the
8
+ # Software is furnished to do so, subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in
11
+ # all copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
16
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18
+ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19
+ # DEALINGS IN THE SOFTWARE.
20
+ #
21
+ # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
22
+ # SPDX-License-Identifier: MIT
23
+
24
+ from typing import Dict, Literal
25
+
26
+ import torch.nn as nn
27
+ from dgl import DGLGraph
28
+ from dgl.nn.pytorch import AvgPooling, MaxPooling
29
+ from torch import Tensor
30
+
31
+
32
+ class GPooling(nn.Module):
33
+ """
34
+ Graph max/average pooling on a given feature type.
35
+ The average can be taken for any feature type, and equivariance will be maintained.
36
+ The maximum can only be taken for invariant features (type 0).
37
+ If you want max-pooling for type > 0 features, look into Vector Neurons.
38
+ """
39
+
40
+ def __init__(self, feat_type: int = 0, pool: Literal['max', 'avg'] = 'max'):
41
+ """
42
+ :param feat_type: Feature type to pool
43
+ :param pool: Type of pooling: max or avg
44
+ """
45
+ super().__init__()
46
+ assert pool in ['max', 'avg'], f'Unknown pooling: {pool}'
47
+ assert feat_type == 0 or pool == 'avg', 'Max pooling on type > 0 features will break equivariance'
48
+ self.feat_type = feat_type
49
+ self.pool = MaxPooling() if pool == 'max' else AvgPooling()
50
+
51
+ def forward(self, features: Dict[str, Tensor], graph: DGLGraph, **kwargs) -> Tensor:
52
+ pooled = self.pool(graph, features[str(self.feat_type)])
53
+ return pooled.squeeze(dim=-1)
se3_transformer/model/transformer.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a
4
+ # copy of this software and associated documentation files (the "Software"),
5
+ # to deal in the Software without restriction, including without limitation
6
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense,
7
+ # and/or sell copies of the Software, and to permit persons to whom the
8
+ # Software is furnished to do so, subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in
11
+ # all copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
16
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18
+ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19
+ # DEALINGS IN THE SOFTWARE.
20
+ #
21
+ # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
22
+ # SPDX-License-Identifier: MIT
23
+
24
+ import logging
25
+ from typing import Optional, Literal, Dict
26
+
27
+ import torch
28
+ import torch.nn as nn
29
+ from dgl import DGLGraph
30
+ from torch import Tensor
31
+
32
+ from se3_transformer.model.basis import get_basis, update_basis_with_fused
33
+ from se3_transformer.model.layers.attention import AttentionBlockSE3
34
+ from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel
35
+ from se3_transformer.model.layers.norm import NormSE3
36
+ from se3_transformer.model.layers.pooling import GPooling
37
+ from se3_transformer.runtime.utils import str2bool
38
+ from se3_transformer.model.fiber import Fiber
39
+
40
+
41
+ class Sequential(nn.Sequential):
42
+ """ Sequential module with arbitrary forward args and kwargs. Used to pass graph, basis and edge features. """
43
+
44
+ def forward(self, input, *args, **kwargs):
45
+ for module in self:
46
+ input = module(input, *args, **kwargs)
47
+ return input
48
+
49
+
50
+ def get_populated_edge_features(relative_pos: Tensor, edge_features: Optional[Dict[str, Tensor]] = None):
51
+ """ Add relative positions to existing edge features """
52
+ edge_features = edge_features.copy() if edge_features else {}
53
+ r = relative_pos.norm(dim=-1, keepdim=True)
54
+ if '0' in edge_features:
55
+ edge_features['0'] = torch.cat([edge_features['0'], r[..., None]], dim=1)
56
+ else:
57
+ edge_features['0'] = r[..., None]
58
+
59
+ return edge_features
60
+
61
+
62
+ class SE3Transformer(nn.Module):
63
+ def __init__(self,
64
+ num_layers: int,
65
+ fiber_in: Fiber,
66
+ fiber_hidden: Fiber,
67
+ fiber_out: Fiber,
68
+ num_heads: int,
69
+ channels_div: int,
70
+ fiber_edge: Fiber = Fiber({}),
71
+ return_type: Optional[int] = None,
72
+ pooling: Optional[Literal['avg', 'max']] = None,
73
+ norm: bool = True,
74
+ use_layer_norm: bool = True,
75
+ tensor_cores: bool = False,
76
+ low_memory: bool = False,
77
+ **kwargs):
78
+ """
79
+ :param num_layers: Number of attention layers
80
+ :param fiber_in: Input fiber description
81
+ :param fiber_hidden: Hidden fiber description
82
+ :param fiber_out: Output fiber description
83
+ :param fiber_edge: Input edge fiber description
84
+ :param num_heads: Number of attention heads
85
+ :param channels_div: Channels division before feeding to attention layer
86
+ :param return_type: Return only features of this type
87
+ :param pooling: 'avg' or 'max' graph pooling before MLP layers
88
+ :param norm: Apply a normalization layer after each attention block
89
+ :param use_layer_norm: Apply layer normalization between MLP layers
90
+ :param tensor_cores: True if using Tensor Cores (affects the use of fully fused convs, and padded bases)
91
+ :param low_memory: If True, will use slower ops that use less memory
92
+ """
93
+ super().__init__()
94
+ self.num_layers = num_layers
95
+ self.fiber_edge = fiber_edge
96
+ self.num_heads = num_heads
97
+ self.channels_div = channels_div
98
+ self.return_type = return_type
99
+ self.pooling = pooling
100
+ self.max_degree = max(*fiber_in.degrees, *fiber_hidden.degrees, *fiber_out.degrees)
101
+ self.tensor_cores = tensor_cores
102
+ self.low_memory = low_memory
103
+
104
+ if low_memory and not tensor_cores:
105
+ logging.warning('Low memory mode will have no effect with no Tensor Cores')
106
+
107
+ # Fully fused convolutions when using Tensor Cores (and not low memory mode)
108
+ fuse_level = ConvSE3FuseLevel.FULL if tensor_cores and not low_memory else ConvSE3FuseLevel.PARTIAL
109
+
110
+ graph_modules = []
111
+ for i in range(num_layers):
112
+ graph_modules.append(AttentionBlockSE3(fiber_in=fiber_in,
113
+ fiber_out=fiber_hidden,
114
+ fiber_edge=fiber_edge,
115
+ num_heads=num_heads,
116
+ channels_div=channels_div,
117
+ use_layer_norm=use_layer_norm,
118
+ max_degree=self.max_degree,
119
+ fuse_level=fuse_level))
120
+ if norm:
121
+ graph_modules.append(NormSE3(fiber_hidden))
122
+ fiber_in = fiber_hidden
123
+
124
+ graph_modules.append(ConvSE3(fiber_in=fiber_in,
125
+ fiber_out=fiber_out,
126
+ fiber_edge=fiber_edge,
127
+ self_interaction=True,
128
+ use_layer_norm=use_layer_norm,
129
+ max_degree=self.max_degree))
130
+ self.graph_modules = Sequential(*graph_modules)
131
+
132
+ if pooling is not None:
133
+ assert return_type is not None, 'return_type must be specified when pooling'
134
+ self.pooling_module = GPooling(pool=pooling, feat_type=return_type)
135
+
136
+ def forward(self, graph: DGLGraph, node_feats: Dict[str, Tensor],
137
+ edge_feats: Optional[Dict[str, Tensor]] = None,
138
+ basis: Optional[Dict[str, Tensor]] = None):
139
+ # Compute bases in case they weren't precomputed as part of the data loading
140
+ basis = basis or get_basis(graph.edata['rel_pos'], max_degree=self.max_degree, compute_gradients=False,
141
+ use_pad_trick=self.tensor_cores and not self.low_memory,
142
+ amp=torch.is_autocast_enabled())
143
+
144
+ # Add fused bases (per output degree, per input degree, and fully fused) to the dict
145
+ basis = update_basis_with_fused(basis, self.max_degree, use_pad_trick=self.tensor_cores and not self.low_memory,
146
+ fully_fused=self.tensor_cores and not self.low_memory)
147
+
148
+ edge_feats = get_populated_edge_features(graph.edata['rel_pos'], edge_feats)
149
+
150
+ node_feats = self.graph_modules(node_feats, edge_feats, graph=graph, basis=basis)
151
+
152
+ if self.pooling is not None:
153
+ return self.pooling_module(node_feats, graph=graph)
154
+
155
+ if self.return_type is not None:
156
+ return node_feats[str(self.return_type)]
157
+
158
+ return node_feats
159
+
160
+ @staticmethod
161
+ def add_argparse_args(parser):
162
+ parser.add_argument('--num_layers', type=int, default=7,
163
+ help='Number of stacked Transformer layers')
164
+ parser.add_argument('--num_heads', type=int, default=8,
165
+ help='Number of heads in self-attention')
166
+ parser.add_argument('--channels_div', type=int, default=2,
167
+ help='Channels division before feeding to attention layer')
168
+ parser.add_argument('--pooling', type=str, default=None, const=None, nargs='?', choices=['max', 'avg'],
169
+ help='Type of graph pooling')
170
+ parser.add_argument('--norm', type=str2bool, nargs='?', const=True, default=False,
171
+ help='Apply a normalization layer after each attention block')
172
+ parser.add_argument('--use_layer_norm', type=str2bool, nargs='?', const=True, default=False,
173
+ help='Apply layer normalization between MLP layers')
174
+ parser.add_argument('--low_memory', type=str2bool, nargs='?', const=True, default=False,
175
+ help='If true, will use fused ops that are slower but that use less memory '
176
+ '(expect 25 percent less memory). '
177
+ 'Only has an effect if AMP is enabled on Volta GPUs, or if running on Ampere GPUs')
178
+
179
+ return parser
180
+
181
+
182
+ class SE3TransformerPooled(nn.Module):
183
+ def __init__(self,
184
+ fiber_in: Fiber,
185
+ fiber_out: Fiber,
186
+ fiber_edge: Fiber,
187
+ num_degrees: int,
188
+ num_channels: int,
189
+ output_dim: int,
190
+ **kwargs):
191
+ super().__init__()
192
+ kwargs['pooling'] = kwargs['pooling'] or 'max'
193
+ self.transformer = SE3Transformer(
194
+ fiber_in=fiber_in,
195
+ fiber_hidden=Fiber.create(num_degrees, num_channels),
196
+ fiber_out=fiber_out,
197
+ fiber_edge=fiber_edge,
198
+ return_type=0,
199
+ **kwargs
200
+ )
201
+
202
+ n_out_features = fiber_out.num_features
203
+ self.mlp = nn.Sequential(
204
+ nn.Linear(n_out_features, n_out_features),
205
+ nn.ReLU(),
206
+ nn.Linear(n_out_features, output_dim)
207
+ )
208
+
209
+ def forward(self, graph, node_feats, edge_feats, basis=None):
210
+ feats = self.transformer(graph, node_feats, edge_feats, basis).squeeze(-1)
211
+ y = self.mlp(feats).squeeze(-1)
212
+ return y
213
+
214
+ @staticmethod
215
+ def add_argparse_args(parent_parser):
216
+ parser = parent_parser.add_argument_group("Model architecture")
217
+ SE3Transformer.add_argparse_args(parser)
218
+ parser.add_argument('--num_degrees',
219
+ help='Number of degrees to use. Hidden features will have types [0, ..., num_degrees - 1]',
220
+ type=int, default=4)
221
+ parser.add_argument('--num_channels', help='Number of channels for the hidden features', type=int, default=32)
222
+ return parent_parser
se3_transformer/runtime/__init__.py ADDED
File without changes
se3_transformer/runtime/arguments.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a
4
+ # copy of this software and associated documentation files (the "Software"),
5
+ # to deal in the Software without restriction, including without limitation
6
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense,
7
+ # and/or sell copies of the Software, and to permit persons to whom the
8
+ # Software is furnished to do so, subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in
11
+ # all copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
16
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18
+ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19
+ # DEALINGS IN THE SOFTWARE.
20
+ #
21
+ # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
22
+ # SPDX-License-Identifier: MIT
23
+
24
+ import argparse
25
+ import pathlib
26
+
27
+ from se3_transformer.data_loading import QM9DataModule
28
+ from se3_transformer.model import SE3TransformerPooled
29
+ from se3_transformer.runtime.utils import str2bool
30
+
31
+ PARSER = argparse.ArgumentParser(description='SE(3)-Transformer')
32
+
33
+ paths = PARSER.add_argument_group('Paths')
34
+ paths.add_argument('--data_dir', type=pathlib.Path, default=pathlib.Path('./data'),
35
+ help='Directory where the data is located or should be downloaded')
36
+ paths.add_argument('--log_dir', type=pathlib.Path, default=pathlib.Path('/results'),
37
+ help='Directory where the results logs should be saved')
38
+ paths.add_argument('--dllogger_name', type=str, default='dllogger_results.json',
39
+ help='Name for the resulting DLLogger JSON file')
40
+ paths.add_argument('--save_ckpt_path', type=pathlib.Path, default=None,
41
+ help='File where the checkpoint should be saved')
42
+ paths.add_argument('--load_ckpt_path', type=pathlib.Path, default=None,
43
+ help='File of the checkpoint to be loaded')
44
+
45
+ optimizer = PARSER.add_argument_group('Optimizer')
46
+ optimizer.add_argument('--optimizer', choices=['adam', 'sgd', 'lamb'], default='adam')
47
+ optimizer.add_argument('--learning_rate', '--lr', dest='learning_rate', type=float, default=0.002)
48
+ optimizer.add_argument('--min_learning_rate', '--min_lr', dest='min_learning_rate', type=float, default=None)
49
+ optimizer.add_argument('--momentum', type=float, default=0.9)
50
+ optimizer.add_argument('--weight_decay', type=float, default=0.1)
51
+
52
+ PARSER.add_argument('--epochs', type=int, default=100, help='Number of training epochs')
53
+ PARSER.add_argument('--batch_size', type=int, default=240, help='Batch size')
54
+ PARSER.add_argument('--seed', type=int, default=None, help='Set a seed globally')
55
+ PARSER.add_argument('--num_workers', type=int, default=8, help='Number of dataloading workers')
56
+
57
+ PARSER.add_argument('--amp', type=str2bool, nargs='?', const=True, default=False, help='Use Automatic Mixed Precision')
58
+ PARSER.add_argument('--gradient_clip', type=float, default=None, help='Clipping of the gradient norms')
59
+ PARSER.add_argument('--accumulate_grad_batches', type=int, default=1, help='Gradient accumulation')
60
+ PARSER.add_argument('--ckpt_interval', type=int, default=-1, help='Save a checkpoint every N epochs')
61
+ PARSER.add_argument('--eval_interval', dest='eval_interval', type=int, default=1,
62
+ help='Do an evaluation round every N epochs')
63
+ PARSER.add_argument('--silent', type=str2bool, nargs='?', const=True, default=False,
64
+ help='Minimize stdout output')
65
+
66
+ PARSER.add_argument('--benchmark', type=str2bool, nargs='?', const=True, default=False,
67
+ help='Benchmark mode')
68
+
69
+ QM9DataModule.add_argparse_args(PARSER)
70
+ SE3TransformerPooled.add_argparse_args(PARSER)
se3_transformer/runtime/callbacks.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a
4
+ # copy of this software and associated documentation files (the "Software"),
5
+ # to deal in the Software without restriction, including without limitation
6
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense,
7
+ # and/or sell copies of the Software, and to permit persons to whom the
8
+ # Software is furnished to do so, subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in
11
+ # all copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
16
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18
+ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19
+ # DEALINGS IN THE SOFTWARE.
20
+ #
21
+ # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
22
+ # SPDX-License-Identifier: MIT
23
+
24
+ import logging
25
+ import time
26
+ from abc import ABC, abstractmethod
27
+ from typing import Optional
28
+
29
+ import numpy as np
30
+ import torch
31
+
32
+ from se3_transformer.runtime.loggers import Logger
33
+ from se3_transformer.runtime.metrics import MeanAbsoluteError
34
+
35
+
36
+ class BaseCallback(ABC):
37
+ def on_fit_start(self, optimizer, args):
38
+ pass
39
+
40
+ def on_fit_end(self):
41
+ pass
42
+
43
+ def on_epoch_end(self):
44
+ pass
45
+
46
+ def on_batch_start(self):
47
+ pass
48
+
49
+ def on_validation_step(self, input, target, pred):
50
+ pass
51
+
52
+ def on_validation_end(self, epoch=None):
53
+ pass
54
+
55
+ def on_checkpoint_load(self, checkpoint):
56
+ pass
57
+
58
+ def on_checkpoint_save(self, checkpoint):
59
+ pass
60
+
61
+
62
+ class LRSchedulerCallback(BaseCallback):
63
+ def __init__(self, logger: Optional[Logger] = None):
64
+ self.logger = logger
65
+ self.scheduler = None
66
+
67
+ @abstractmethod
68
+ def get_scheduler(self, optimizer, args):
69
+ pass
70
+
71
+ def on_fit_start(self, optimizer, args):
72
+ self.scheduler = self.get_scheduler(optimizer, args)
73
+
74
+ def on_checkpoint_load(self, checkpoint):
75
+ self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
76
+
77
+ def on_checkpoint_save(self, checkpoint):
78
+ checkpoint['scheduler_state_dict'] = self.scheduler.state_dict()
79
+
80
+ def on_epoch_end(self):
81
+ if self.logger is not None:
82
+ self.logger.log_metrics({'learning rate': self.scheduler.get_last_lr()[0]}, step=self.scheduler.last_epoch)
83
+ self.scheduler.step()
84
+
85
+
86
+ class QM9MetricCallback(BaseCallback):
87
+ """ Logs the rescaled mean absolute error for QM9 regression tasks """
88
+
89
+ def __init__(self, logger, targets_std, prefix=''):
90
+ self.mae = MeanAbsoluteError()
91
+ self.logger = logger
92
+ self.targets_std = targets_std
93
+ self.prefix = prefix
94
+ self.best_mae = float('inf')
95
+
96
+ def on_validation_step(self, input, target, pred):
97
+ self.mae(pred.detach(), target.detach())
98
+
99
+ def on_validation_end(self, epoch=None):
100
+ mae = self.mae.compute() * self.targets_std
101
+ logging.info(f'{self.prefix} MAE: {mae}')
102
+ self.logger.log_metrics({f'{self.prefix} MAE': mae}, epoch)
103
+ self.best_mae = min(self.best_mae, mae)
104
+
105
+ def on_fit_end(self):
106
+ if self.best_mae != float('inf'):
107
+ self.logger.log_metrics({f'{self.prefix} best MAE': self.best_mae})
108
+
109
+
110
+ class QM9LRSchedulerCallback(LRSchedulerCallback):
111
+ def __init__(self, logger, epochs):
112
+ super().__init__(logger)
113
+ self.epochs = epochs
114
+
115
+ def get_scheduler(self, optimizer, args):
116
+ min_lr = args.min_learning_rate if args.min_learning_rate else args.learning_rate / 10.0
117
+ return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, self.epochs, eta_min=min_lr)
118
+
119
+
120
+ class PerformanceCallback(BaseCallback):
121
+ def __init__(self, logger, batch_size: int, warmup_epochs: int = 1, mode: str = 'train'):
122
+ self.batch_size = batch_size
123
+ self.warmup_epochs = warmup_epochs
124
+ self.epoch = 0
125
+ self.timestamps = []
126
+ self.mode = mode
127
+ self.logger = logger
128
+
129
+ def on_batch_start(self):
130
+ if self.epoch >= self.warmup_epochs:
131
+ self.timestamps.append(time.time() * 1000.0)
132
+
133
+ def _log_perf(self):
134
+ stats = self.process_performance_stats()
135
+ for k, v in stats.items():
136
+ logging.info(f'performance {k}: {v}')
137
+
138
+ self.logger.log_metrics(stats)
139
+
140
+ def on_epoch_end(self):
141
+ self.epoch += 1
142
+
143
+ def on_fit_end(self):
144
+ if self.epoch > self.warmup_epochs:
145
+ self._log_perf()
146
+ self.timestamps = []
147
+
148
+ def process_performance_stats(self):
149
+ timestamps = np.asarray(self.timestamps)
150
+ deltas = np.diff(timestamps)
151
+ throughput = (self.batch_size / deltas).mean()
152
+ stats = {
153
+ f"throughput_{self.mode}": throughput,
154
+ f"latency_{self.mode}_mean": deltas.mean(),
155
+ f"total_time_{self.mode}": timestamps[-1] - timestamps[0],
156
+ }
157
+ for level in [90, 95, 99]:
158
+ stats.update({f"latency_{self.mode}_{level}": np.percentile(deltas, level)})
159
+
160
+ return stats
se3_transformer/runtime/gpu_affinity.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a
4
+ # copy of this software and associated documentation files (the "Software"),
5
+ # to deal in the Software without restriction, including without limitation
6
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense,
7
+ # and/or sell copies of the Software, and to permit persons to whom the
8
+ # Software is furnished to do so, subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in
11
+ # all copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
16
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18
+ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19
+ # DEALINGS IN THE SOFTWARE.
20
+ #
21
+ # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
22
+ # SPDX-License-Identifier: MIT
23
+
24
+ import collections
25
+ import itertools
26
+ import math
27
+ import os
28
+ import pathlib
29
+ import re
30
+
31
+ import pynvml
32
+
33
+
34
+ class Device:
35
+ # assumes nvml returns list of 64 bit ints
36
+ _nvml_affinity_elements = math.ceil(os.cpu_count() / 64)
37
+
38
+ def __init__(self, device_idx):
39
+ super().__init__()
40
+ self.handle = pynvml.nvmlDeviceGetHandleByIndex(device_idx)
41
+
42
+ def get_name(self):
43
+ return pynvml.nvmlDeviceGetName(self.handle)
44
+
45
+ def get_uuid(self):
46
+ return pynvml.nvmlDeviceGetUUID(self.handle)
47
+
48
+ def get_cpu_affinity(self):
49
+ affinity_string = ""
50
+ for j in pynvml.nvmlDeviceGetCpuAffinity(self.handle, Device._nvml_affinity_elements):
51
+ # assume nvml returns list of 64 bit ints
52
+ affinity_string = "{:064b}".format(j) + affinity_string
53
+
54
+ affinity_list = [int(x) for x in affinity_string]
55
+ affinity_list.reverse() # so core 0 is in 0th element of list
56
+
57
+ ret = [i for i, e in enumerate(affinity_list) if e != 0]
58
+ return ret
59
+
60
+
61
+ def get_thread_siblings_list():
62
+ """
63
+ Returns a list of 2-element integer tuples representing pairs of
64
+ hyperthreading cores.
65
+ """
66
+ path = "/sys/devices/system/cpu/cpu*/topology/thread_siblings_list"
67
+ thread_siblings_list = []
68
+ pattern = re.compile(r"(\d+)\D(\d+)")
69
+ for fname in pathlib.Path(path[0]).glob(path[1:]):
70
+ with open(fname) as f:
71
+ content = f.read().strip()
72
+ res = pattern.findall(content)
73
+ if res:
74
+ pair = tuple(map(int, res[0]))
75
+ thread_siblings_list.append(pair)
76
+ return thread_siblings_list
77
+
78
+
79
+ def check_socket_affinities(socket_affinities):
80
+ # sets of cores should be either identical or disjoint
81
+ for i, j in itertools.product(socket_affinities, socket_affinities):
82
+ if not set(i) == set(j) and not set(i).isdisjoint(set(j)):
83
+ raise RuntimeError(f"Sets of cores should be either identical or disjoint, " f"but got {i} and {j}.")
84
+
85
+
86
+ def get_socket_affinities(nproc_per_node, exclude_unavailable_cores=True):
87
+ devices = [Device(i) for i in range(nproc_per_node)]
88
+ socket_affinities = [dev.get_cpu_affinity() for dev in devices]
89
+
90
+ if exclude_unavailable_cores:
91
+ available_cores = os.sched_getaffinity(0)
92
+ socket_affinities = [list(set(affinity) & available_cores) for affinity in socket_affinities]
93
+
94
+ check_socket_affinities(socket_affinities)
95
+
96
+ return socket_affinities
97
+
98
+
99
+ def set_socket_affinity(gpu_id):
100
+ """
101
+ The process is assigned with all available logical CPU cores from the CPU
102
+ socket connected to the GPU with a given id.
103
+
104
+ Args:
105
+ gpu_id: index of a GPU
106
+ """
107
+ dev = Device(gpu_id)
108
+ affinity = dev.get_cpu_affinity()
109
+ os.sched_setaffinity(0, affinity)
110
+
111
+
112
+ def set_single_affinity(gpu_id):
113
+ """
114
+ The process is assigned with the first available logical CPU core from the
115
+ list of all CPU cores from the CPU socket connected to the GPU with a given
116
+ id.
117
+
118
+ Args:
119
+ gpu_id: index of a GPU
120
+ """
121
+ dev = Device(gpu_id)
122
+ affinity = dev.get_cpu_affinity()
123
+
124
+ # exclude unavailable cores
125
+ available_cores = os.sched_getaffinity(0)
126
+ affinity = list(set(affinity) & available_cores)
127
+ os.sched_setaffinity(0, affinity[:1])
128
+
129
+
130
+ def set_single_unique_affinity(gpu_id, nproc_per_node):
131
+ """
132
+ The process is assigned with a single unique available physical CPU core
133
+ from the list of all CPU cores from the CPU socket connected to the GPU with
134
+ a given id.
135
+
136
+ Args:
137
+ gpu_id: index of a GPU
138
+ """
139
+ socket_affinities = get_socket_affinities(nproc_per_node)
140
+
141
+ siblings_list = get_thread_siblings_list()
142
+ siblings_dict = dict(siblings_list)
143
+
144
+ # remove siblings
145
+ for idx, socket_affinity in enumerate(socket_affinities):
146
+ socket_affinities[idx] = list(set(socket_affinity) - set(siblings_dict.values()))
147
+
148
+ affinities = []
149
+ assigned = []
150
+
151
+ for socket_affinity in socket_affinities:
152
+ for core in socket_affinity:
153
+ if core not in assigned:
154
+ affinities.append([core])
155
+ assigned.append(core)
156
+ break
157
+ os.sched_setaffinity(0, affinities[gpu_id])
158
+
159
+
160
+ def set_socket_unique_affinity(gpu_id, nproc_per_node, mode, balanced=True):
161
+ """
162
+ The process is assigned with an unique subset of available physical CPU
163
+ cores from the CPU socket connected to a GPU with a given id.
164
+ Assignment automatically includes hyperthreading siblings (if siblings are
165
+ available).
166
+
167
+ Args:
168
+ gpu_id: index of a GPU
169
+ nproc_per_node: total number of processes per node
170
+ mode: mode
171
+ balanced: assign an equal number of physical cores to each process
172
+ """
173
+ socket_affinities = get_socket_affinities(nproc_per_node)
174
+
175
+ siblings_list = get_thread_siblings_list()
176
+ siblings_dict = dict(siblings_list)
177
+
178
+ # remove hyperthreading siblings
179
+ for idx, socket_affinity in enumerate(socket_affinities):
180
+ socket_affinities[idx] = list(set(socket_affinity) - set(siblings_dict.values()))
181
+
182
+ socket_affinities_to_device_ids = collections.defaultdict(list)
183
+
184
+ for idx, socket_affinity in enumerate(socket_affinities):
185
+ socket_affinities_to_device_ids[tuple(socket_affinity)].append(idx)
186
+
187
+ # compute minimal number of physical cores per GPU across all GPUs and
188
+ # sockets, code assigns this number of cores per GPU if balanced == True
189
+ min_physical_cores_per_gpu = min(
190
+ [len(cores) // len(gpus) for cores, gpus in socket_affinities_to_device_ids.items()]
191
+ )
192
+
193
+ for socket_affinity, device_ids in socket_affinities_to_device_ids.items():
194
+ devices_per_group = len(device_ids)
195
+ if balanced:
196
+ cores_per_device = min_physical_cores_per_gpu
197
+ socket_affinity = socket_affinity[: devices_per_group * min_physical_cores_per_gpu]
198
+ else:
199
+ cores_per_device = len(socket_affinity) // devices_per_group
200
+
201
+ for group_id, device_id in enumerate(device_ids):
202
+ if device_id == gpu_id:
203
+
204
+ # In theory there should be no difference in performance between
205
+ # 'interleaved' and 'continuous' pattern on Intel-based DGX-1,
206
+ # but 'continuous' should be better for DGX A100 because on AMD
207
+ # Rome 4 consecutive cores are sharing L3 cache.
208
+ # TODO: code doesn't attempt to automatically detect layout of
209
+ # L3 cache, also external environment may already exclude some
210
+ # cores, this code makes no attempt to detect it and to align
211
+ # mapping to multiples of 4.
212
+
213
+ if mode == "interleaved":
214
+ affinity = list(socket_affinity[group_id::devices_per_group])
215
+ elif mode == "continuous":
216
+ affinity = list(socket_affinity[group_id * cores_per_device: (group_id + 1) * cores_per_device])
217
+ else:
218
+ raise RuntimeError("Unknown set_socket_unique_affinity mode")
219
+
220
+ # unconditionally reintroduce hyperthreading siblings, this step
221
+ # may result in a different numbers of logical cores assigned to
222
+ # each GPU even if balanced == True (if hyperthreading siblings
223
+ # aren't available for a subset of cores due to some external
224
+ # constraints, siblings are re-added unconditionally, in the
225
+ # worst case unavailable logical core will be ignored by
226
+ # os.sched_setaffinity().
227
+ affinity += [siblings_dict[aff] for aff in affinity if aff in siblings_dict]
228
+ os.sched_setaffinity(0, affinity)
229
+
230
+
231
+ def set_affinity(gpu_id, nproc_per_node, mode="socket_unique_continuous", balanced=True):
232
+ """
233
+ The process is assigned with a proper CPU affinity which matches hardware
234
+ architecture on a given platform. Usually it improves and stabilizes
235
+ performance of deep learning training workloads.
236
+
237
+ This function assumes that the workload is running in multi-process
238
+ single-device mode (there are multiple training processes and each process
239
+ is running on a single GPU), which is typical for multi-GPU training
240
+ workloads using `torch.nn.parallel.DistributedDataParallel`.
241
+
242
+ Available affinity modes:
243
+ * 'socket' - the process is assigned with all available logical CPU cores
244
+ from the CPU socket connected to the GPU with a given id.
245
+ * 'single' - the process is assigned with the first available logical CPU
246
+ core from the list of all CPU cores from the CPU socket connected to the GPU
247
+ with a given id (multiple GPUs could be assigned with the same CPU core).
248
+ * 'single_unique' - the process is assigned with a single unique available
249
+ physical CPU core from the list of all CPU cores from the CPU socket
250
+ connected to the GPU with a given id.
251
+ * 'socket_unique_interleaved' - the process is assigned with an unique
252
+ subset of available physical CPU cores from the CPU socket connected to a
253
+ GPU with a given id, hyperthreading siblings are included automatically,
254
+ cores are assigned with interleaved indexing pattern
255
+ * 'socket_unique_continuous' - (the default) the process is assigned with an
256
+ unique subset of available physical CPU cores from the CPU socket connected
257
+ to a GPU with a given id, hyperthreading siblings are included
258
+ automatically, cores are assigned with continuous indexing pattern
259
+
260
+ 'socket_unique_continuous' is the recommended mode for deep learning
261
+ training workloads on NVIDIA DGX machines.
262
+
263
+ Args:
264
+ gpu_id: integer index of a GPU
265
+ nproc_per_node: number of processes per node
266
+ mode: affinity mode
267
+ balanced: assign an equal number of physical cores to each process,
268
+ affects only 'socket_unique_interleaved' and
269
+ 'socket_unique_continuous' affinity modes
270
+
271
+ Returns a set of logical CPU cores on which the process is eligible to run.
272
+
273
+ Example:
274
+
275
+ import argparse
276
+ import os
277
+
278
+ import gpu_affinity
279
+ import torch
280
+
281
+
282
+ def main():
283
+ parser = argparse.ArgumentParser()
284
+ parser.add_argument(
285
+ '--local_rank',
286
+ type=int,
287
+ default=os.getenv('LOCAL_RANK', 0),
288
+ )
289
+ args = parser.parse_args()
290
+
291
+ nproc_per_node = torch.cuda.device_count()
292
+
293
+ affinity = gpu_affinity.set_affinity(args.local_rank, nproc_per_node)
294
+ print(f'{args.local_rank}: core affinity: {affinity}')
295
+
296
+
297
+ if __name__ == "__main__":
298
+ main()
299
+
300
+ Launch the example with:
301
+ python -m torch.distributed.launch --nproc_per_node <#GPUs> example.py
302
+
303
+
304
+ WARNING: On DGX A100 only a half of CPU cores have direct access to GPUs.
305
+ This function restricts execution only to the CPU cores directly connected
306
+ to GPUs, so on DGX A100 it will limit the code to half of CPU cores and half
307
+ of CPU memory bandwidth (which may be fine for many DL models).
308
+ """
309
+ pynvml.nvmlInit()
310
+
311
+ if mode == "socket":
312
+ set_socket_affinity(gpu_id)
313
+ elif mode == "single":
314
+ set_single_affinity(gpu_id)
315
+ elif mode == "single_unique":
316
+ set_single_unique_affinity(gpu_id, nproc_per_node)
317
+ elif mode == "socket_unique_interleaved":
318
+ set_socket_unique_affinity(gpu_id, nproc_per_node, "interleaved", balanced)
319
+ elif mode == "socket_unique_continuous":
320
+ set_socket_unique_affinity(gpu_id, nproc_per_node, "continuous", balanced)
321
+ else:
322
+ raise RuntimeError("Unknown affinity mode")
323
+
324
+ affinity = os.sched_getaffinity(0)
325
+ return affinity
se3_transformer/runtime/inference.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a
4
+ # copy of this software and associated documentation files (the "Software"),
5
+ # to deal in the Software without restriction, including without limitation
6
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense,
7
+ # and/or sell copies of the Software, and to permit persons to whom the
8
+ # Software is furnished to do so, subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in
11
+ # all copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
16
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18
+ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19
+ # DEALINGS IN THE SOFTWARE.
20
+ #
21
+ # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
22
+ # SPDX-License-Identifier: MIT
23
+
24
+ from typing import List
25
+
26
+ import torch
27
+ import torch.nn as nn
28
+ from torch.nn.parallel import DistributedDataParallel
29
+ from torch.utils.data import DataLoader
30
+ from tqdm import tqdm
31
+
32
+ from se3_transformer.runtime import gpu_affinity
33
+ from se3_transformer.runtime.arguments import PARSER
34
+ from se3_transformer.runtime.callbacks import BaseCallback
35
+ from se3_transformer.runtime.loggers import DLLogger
36
+ from se3_transformer.runtime.utils import to_cuda, get_local_rank
37
+
38
+
39
+ @torch.inference_mode()
40
+ def evaluate(model: nn.Module,
41
+ dataloader: DataLoader,
42
+ callbacks: List[BaseCallback],
43
+ args):
44
+ model.eval()
45
+ for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), unit='batch', desc=f'Evaluation',
46
+ leave=False, disable=(args.silent or get_local_rank() != 0)):
47
+ *input, target = to_cuda(batch)
48
+
49
+ for callback in callbacks:
50
+ callback.on_batch_start()
51
+
52
+ with torch.cuda.amp.autocast(enabled=args.amp):
53
+ pred = model(*input)
54
+
55
+ for callback in callbacks:
56
+ callback.on_validation_step(input, target, pred)
57
+
58
+
59
+ if __name__ == '__main__':
60
+ from se3_transformer.runtime.callbacks import QM9MetricCallback, PerformanceCallback
61
+ from se3_transformer.runtime.utils import init_distributed, seed_everything
62
+ from se3_transformer.model import SE3TransformerPooled, Fiber
63
+ from se3_transformer.data_loading import QM9DataModule
64
+ import torch.distributed as dist
65
+ import logging
66
+ import sys
67
+
68
+ is_distributed = init_distributed()
69
+ local_rank = get_local_rank()
70
+ args = PARSER.parse_args()
71
+
72
+ logging.getLogger().setLevel(logging.CRITICAL if local_rank != 0 or args.silent else logging.INFO)
73
+
74
+ logging.info('====== SE(3)-Transformer ======')
75
+ logging.info('| Inference on the test set |')
76
+ logging.info('===============================')
77
+
78
+ if not args.benchmark and args.load_ckpt_path is None:
79
+ logging.error('No load_ckpt_path provided, you need to provide a saved model to evaluate')
80
+ sys.exit(1)
81
+
82
+ if args.benchmark:
83
+ logging.info('Running benchmark mode with one warmup pass')
84
+
85
+ if args.seed is not None:
86
+ seed_everything(args.seed)
87
+
88
+ major_cc, minor_cc = torch.cuda.get_device_capability()
89
+
90
+ logger = DLLogger(args.log_dir, filename=args.dllogger_name)
91
+ datamodule = QM9DataModule(**vars(args))
92
+ model = SE3TransformerPooled(
93
+ fiber_in=Fiber({0: datamodule.NODE_FEATURE_DIM}),
94
+ fiber_out=Fiber({0: args.num_degrees * args.num_channels}),
95
+ fiber_edge=Fiber({0: datamodule.EDGE_FEATURE_DIM}),
96
+ output_dim=1,
97
+ tensor_cores=(args.amp and major_cc >= 7) or major_cc >= 8, # use Tensor Cores more effectively
98
+ **vars(args)
99
+ )
100
+ callbacks = [QM9MetricCallback(logger, targets_std=datamodule.targets_std, prefix='test')]
101
+
102
+ model.to(device=torch.cuda.current_device())
103
+ if args.load_ckpt_path is not None:
104
+ checkpoint = torch.load(str(args.load_ckpt_path), map_location={'cuda:0': f'cuda:{local_rank}'})
105
+ model.load_state_dict(checkpoint['state_dict'])
106
+
107
+ if is_distributed:
108
+ nproc_per_node = torch.cuda.device_count()
109
+ affinity = gpu_affinity.set_affinity(local_rank, nproc_per_node)
110
+ model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)
111
+
112
+ test_dataloader = datamodule.test_dataloader() if not args.benchmark else datamodule.train_dataloader()
113
+ evaluate(model,
114
+ test_dataloader,
115
+ callbacks,
116
+ args)
117
+
118
+ for callback in callbacks:
119
+ callback.on_validation_end()
120
+
121
+ if args.benchmark:
122
+ world_size = dist.get_world_size() if dist.is_initialized() else 1
123
+ callbacks = [PerformanceCallback(logger, args.batch_size * world_size, warmup_epochs=1, mode='inference')]
124
+ for _ in range(6):
125
+ evaluate(model,
126
+ test_dataloader,
127
+ callbacks,
128
+ args)
129
+ callbacks[0].on_epoch_end()
130
+
131
+ callbacks[0].on_fit_end()
se3_transformer/runtime/loggers.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a
4
+ # copy of this software and associated documentation files (the "Software"),
5
+ # to deal in the Software without restriction, including without limitation
6
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense,
7
+ # and/or sell copies of the Software, and to permit persons to whom the
8
+ # Software is furnished to do so, subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in
11
+ # all copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
16
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18
+ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19
+ # DEALINGS IN THE SOFTWARE.
20
+ #
21
+ # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
22
+ # SPDX-License-Identifier: MIT
23
+
24
+ import pathlib
25
+ from abc import ABC, abstractmethod
26
+ from enum import Enum
27
+ from typing import Dict, Any, Callable, Optional
28
+
29
+ import dllogger
30
+ import torch.distributed as dist
31
+ import wandb
32
+ from dllogger import Verbosity
33
+
34
+ from se3_transformer.runtime.utils import rank_zero_only
35
+
36
+
37
+ class Logger(ABC):
38
+ @rank_zero_only
39
+ @abstractmethod
40
+ def log_hyperparams(self, params):
41
+ pass
42
+
43
+ @rank_zero_only
44
+ @abstractmethod
45
+ def log_metrics(self, metrics, step=None):
46
+ pass
47
+
48
+ @staticmethod
49
+ def _sanitize_params(params):
50
+ def _sanitize(val):
51
+ if isinstance(val, Callable):
52
+ try:
53
+ _val = val()
54
+ if isinstance(_val, Callable):
55
+ return val.__name__
56
+ return _val
57
+ except Exception:
58
+ return getattr(val, "__name__", None)
59
+ elif isinstance(val, pathlib.Path) or isinstance(val, Enum):
60
+ return str(val)
61
+ return val
62
+
63
+ return {key: _sanitize(val) for key, val in params.items()}
64
+
65
+
66
+ class LoggerCollection(Logger):
67
+ def __init__(self, loggers):
68
+ super().__init__()
69
+ self.loggers = loggers
70
+
71
+ def __getitem__(self, index):
72
+ return [logger for logger in self.loggers][index]
73
+
74
+ @rank_zero_only
75
+ def log_metrics(self, metrics, step=None):
76
+ for logger in self.loggers:
77
+ logger.log_metrics(metrics, step)
78
+
79
+ @rank_zero_only
80
+ def log_hyperparams(self, params):
81
+ for logger in self.loggers:
82
+ logger.log_hyperparams(params)
83
+
84
+
85
+ class DLLogger(Logger):
86
+ def __init__(self, save_dir: pathlib.Path, filename: str):
87
+ super().__init__()
88
+ if not dist.is_initialized() or dist.get_rank() == 0:
89
+ save_dir.mkdir(parents=True, exist_ok=True)
90
+ dllogger.init(
91
+ backends=[dllogger.JSONStreamBackend(Verbosity.DEFAULT, str(save_dir / filename))])
92
+
93
+ @rank_zero_only
94
+ def log_hyperparams(self, params):
95
+ params = self._sanitize_params(params)
96
+ dllogger.log(step="PARAMETER", data=params)
97
+
98
+ @rank_zero_only
99
+ def log_metrics(self, metrics, step=None):
100
+ if step is None:
101
+ step = tuple()
102
+
103
+ dllogger.log(step=step, data=metrics)
104
+
105
+
106
+ class WandbLogger(Logger):
107
+ def __init__(
108
+ self,
109
+ name: str,
110
+ save_dir: pathlib.Path,
111
+ id: Optional[str] = None,
112
+ project: Optional[str] = None
113
+ ):
114
+ super().__init__()
115
+ if not dist.is_initialized() or dist.get_rank() == 0:
116
+ save_dir.mkdir(parents=True, exist_ok=True)
117
+ self.experiment = wandb.init(name=name,
118
+ project=project,
119
+ id=id,
120
+ dir=str(save_dir),
121
+ resume='allow',
122
+ anonymous='must')
123
+
124
+ @rank_zero_only
125
+ def log_hyperparams(self, params: Dict[str, Any]) -> None:
126
+ params = self._sanitize_params(params)
127
+ self.experiment.config.update(params, allow_val_change=True)
128
+
129
+ @rank_zero_only
130
+ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
131
+ if step is not None:
132
+ self.experiment.log({**metrics, 'epoch': step})
133
+ else:
134
+ self.experiment.log(metrics)
se3_transformer/runtime/metrics.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a
4
+ # copy of this software and associated documentation files (the "Software"),
5
+ # to deal in the Software without restriction, including without limitation
6
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense,
7
+ # and/or sell copies of the Software, and to permit persons to whom the
8
+ # Software is furnished to do so, subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in
11
+ # all copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
16
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18
+ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19
+ # DEALINGS IN THE SOFTWARE.
20
+ #
21
+ # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
22
+ # SPDX-License-Identifier: MIT
23
+
24
+ from abc import ABC, abstractmethod
25
+
26
+ import torch
27
+ import torch.distributed as dist
28
+ from torch import Tensor
29
+
30
+
31
+ class Metric(ABC):
32
+ """ Metric class with synchronization capabilities similar to TorchMetrics """
33
+
34
+ def __init__(self):
35
+ self.states = {}
36
+
37
+ def add_state(self, name: str, default: Tensor):
38
+ assert name not in self.states
39
+ self.states[name] = default.clone()
40
+ setattr(self, name, default)
41
+
42
+ def synchronize(self):
43
+ if dist.is_initialized():
44
+ for state in self.states:
45
+ dist.all_reduce(getattr(self, state), op=dist.ReduceOp.SUM, group=dist.group.WORLD)
46
+
47
+ def __call__(self, *args, **kwargs):
48
+ self.update(*args, **kwargs)
49
+
50
+ def reset(self):
51
+ for name, default in self.states.items():
52
+ setattr(self, name, default.clone())
53
+
54
+ def compute(self):
55
+ self.synchronize()
56
+ value = self._compute().item()
57
+ self.reset()
58
+ return value
59
+
60
+ @abstractmethod
61
+ def _compute(self):
62
+ pass
63
+
64
+ @abstractmethod
65
+ def update(self, preds: Tensor, targets: Tensor):
66
+ pass
67
+
68
+
69
+ class MeanAbsoluteError(Metric):
70
+ def __init__(self):
71
+ super().__init__()
72
+ self.add_state('error', torch.tensor(0, dtype=torch.float32, device='cuda'))
73
+ self.add_state('total', torch.tensor(0, dtype=torch.int32, device='cuda'))
74
+
75
+ def update(self, preds: Tensor, targets: Tensor):
76
+ preds = preds.detach()
77
+ n = preds.shape[0]
78
+ error = torch.abs(preds.view(n, -1) - targets.view(n, -1)).sum()
79
+ self.total += n
80
+ self.error += error
81
+
82
+ def _compute(self):
83
+ return self.error / self.total
se3_transformer/runtime/training.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a
4
+ # copy of this software and associated documentation files (the "Software"),
5
+ # to deal in the Software without restriction, including without limitation
6
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense,
7
+ # and/or sell copies of the Software, and to permit persons to whom the
8
+ # Software is furnished to do so, subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in
11
+ # all copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
16
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18
+ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19
+ # DEALINGS IN THE SOFTWARE.
20
+ #
21
+ # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
22
+ # SPDX-License-Identifier: MIT
23
+
24
+ import logging
25
+ import pathlib
26
+ from typing import List
27
+
28
+ import numpy as np
29
+ import torch
30
+ import torch.distributed as dist
31
+ import torch.nn as nn
32
+ from apex.optimizers import FusedAdam, FusedLAMB
33
+ from torch.nn.modules.loss import _Loss
34
+ from torch.nn.parallel import DistributedDataParallel
35
+ from torch.optim import Optimizer
36
+ from torch.utils.data import DataLoader, DistributedSampler
37
+ from tqdm import tqdm
38
+
39
+ from se3_transformer.data_loading import QM9DataModule
40
+ from se3_transformer.model import SE3TransformerPooled
41
+ from se3_transformer.model.fiber import Fiber
42
+ from se3_transformer.runtime import gpu_affinity
43
+ from se3_transformer.runtime.arguments import PARSER
44
+ from se3_transformer.runtime.callbacks import QM9MetricCallback, QM9LRSchedulerCallback, BaseCallback, \
45
+ PerformanceCallback
46
+ from se3_transformer.runtime.inference import evaluate
47
+ from se3_transformer.runtime.loggers import LoggerCollection, DLLogger, WandbLogger, Logger
48
+ from se3_transformer.runtime.utils import to_cuda, get_local_rank, init_distributed, seed_everything, \
49
+ using_tensor_cores, increase_l2_fetch_granularity
50
+
51
+
52
+ def save_state(model: nn.Module, optimizer: Optimizer, epoch: int, path: pathlib.Path, callbacks: List[BaseCallback]):
53
+ """ Saves model, optimizer and epoch states to path (only once per node) """
54
+ if get_local_rank() == 0:
55
+ state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict()
56
+ checkpoint = {
57
+ 'state_dict': state_dict,
58
+ 'optimizer_state_dict': optimizer.state_dict(),
59
+ 'epoch': epoch
60
+ }
61
+ for callback in callbacks:
62
+ callback.on_checkpoint_save(checkpoint)
63
+
64
+ torch.save(checkpoint, str(path))
65
+ logging.info(f'Saved checkpoint to {str(path)}')
66
+
67
+
68
+ def load_state(model: nn.Module, optimizer: Optimizer, path: pathlib.Path, callbacks: List[BaseCallback]):
69
+ """ Loads model, optimizer and epoch states from path """
70
+ checkpoint = torch.load(str(path), map_location={'cuda:0': f'cuda:{get_local_rank()}'})
71
+ if isinstance(model, DistributedDataParallel):
72
+ model.module.load_state_dict(checkpoint['state_dict'])
73
+ else:
74
+ model.load_state_dict(checkpoint['state_dict'])
75
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
76
+
77
+ for callback in callbacks:
78
+ callback.on_checkpoint_load(checkpoint)
79
+
80
+ logging.info(f'Loaded checkpoint from {str(path)}')
81
+ return checkpoint['epoch']
82
+
83
+
84
+ def train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args):
85
+ losses = []
86
+ for i, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader), unit='batch',
87
+ desc=f'Epoch {epoch_idx}', disable=(args.silent or local_rank != 0)):
88
+ *inputs, target = to_cuda(batch)
89
+
90
+ for callback in callbacks:
91
+ callback.on_batch_start()
92
+
93
+ with torch.cuda.amp.autocast(enabled=args.amp):
94
+ pred = model(*inputs)
95
+ loss = loss_fn(pred, target) / args.accumulate_grad_batches
96
+
97
+ grad_scaler.scale(loss).backward()
98
+
99
+ # gradient accumulation
100
+ if (i + 1) % args.accumulate_grad_batches == 0 or (i + 1) == len(train_dataloader):
101
+ if args.gradient_clip:
102
+ grad_scaler.unscale_(optimizer)
103
+ torch.nn.utils.clip_grad_norm_(model.parameters(), args.gradient_clip)
104
+
105
+ grad_scaler.step(optimizer)
106
+ grad_scaler.update()
107
+ optimizer.zero_grad()
108
+
109
+ losses.append(loss.item())
110
+
111
+ return np.mean(losses)
112
+
113
+
114
+ def train(model: nn.Module,
115
+ loss_fn: _Loss,
116
+ train_dataloader: DataLoader,
117
+ val_dataloader: DataLoader,
118
+ callbacks: List[BaseCallback],
119
+ logger: Logger,
120
+ args):
121
+ device = torch.cuda.current_device()
122
+ model.to(device=device)
123
+ local_rank = get_local_rank()
124
+ world_size = dist.get_world_size() if dist.is_initialized() else 1
125
+
126
+ if dist.is_initialized():
127
+ model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)
128
+
129
+ model.train()
130
+ grad_scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
131
+ if args.optimizer == 'adam':
132
+ optimizer = FusedAdam(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999),
133
+ weight_decay=args.weight_decay)
134
+ elif args.optimizer == 'lamb':
135
+ optimizer = FusedLAMB(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999),
136
+ weight_decay=args.weight_decay)
137
+ else:
138
+ optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum,
139
+ weight_decay=args.weight_decay)
140
+
141
+ epoch_start = load_state(model, optimizer, args.load_ckpt_path, callbacks) if args.load_ckpt_path else 0
142
+
143
+ for callback in callbacks:
144
+ callback.on_fit_start(optimizer, args)
145
+
146
+ for epoch_idx in range(epoch_start, args.epochs):
147
+ if isinstance(train_dataloader.sampler, DistributedSampler):
148
+ train_dataloader.sampler.set_epoch(epoch_idx)
149
+
150
+ loss = train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args)
151
+ if dist.is_initialized():
152
+ loss = torch.tensor(loss, dtype=torch.float, device=device)
153
+ torch.distributed.all_reduce(loss)
154
+ loss = (loss / world_size).item()
155
+
156
+ logging.info(f'Train loss: {loss}')
157
+ logger.log_metrics({'train loss': loss}, epoch_idx)
158
+
159
+ for callback in callbacks:
160
+ callback.on_epoch_end()
161
+
162
+ if not args.benchmark and args.save_ckpt_path is not None and args.ckpt_interval > 0 \
163
+ and (epoch_idx + 1) % args.ckpt_interval == 0:
164
+ save_state(model, optimizer, epoch_idx, args.save_ckpt_path, callbacks)
165
+
166
+ if not args.benchmark and args.eval_interval > 0 and (epoch_idx + 1) % args.eval_interval == 0:
167
+ evaluate(model, val_dataloader, callbacks, args)
168
+ model.train()
169
+
170
+ for callback in callbacks:
171
+ callback.on_validation_end(epoch_idx)
172
+
173
+ if args.save_ckpt_path is not None and not args.benchmark:
174
+ save_state(model, optimizer, args.epochs, args.save_ckpt_path, callbacks)
175
+
176
+ for callback in callbacks:
177
+ callback.on_fit_end()
178
+
179
+
180
+ def print_parameters_count(model):
181
+ num_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
182
+ logging.info(f'Number of trainable parameters: {num_params_trainable}')
183
+
184
+
185
+ if __name__ == '__main__':
186
+ is_distributed = init_distributed()
187
+ local_rank = get_local_rank()
188
+ args = PARSER.parse_args()
189
+
190
+ logging.getLogger().setLevel(logging.CRITICAL if local_rank != 0 or args.silent else logging.INFO)
191
+
192
+ logging.info('====== SE(3)-Transformer ======')
193
+ logging.info('| Training procedure |')
194
+ logging.info('===============================')
195
+
196
+ if args.seed is not None:
197
+ logging.info(f'Using seed {args.seed}')
198
+ seed_everything(args.seed)
199
+
200
+ logger = LoggerCollection([
201
+ DLLogger(save_dir=args.log_dir, filename=args.dllogger_name),
202
+ WandbLogger(name=f'QM9({args.task})', save_dir=args.log_dir, project='se3-transformer')
203
+ ])
204
+
205
+ datamodule = QM9DataModule(**vars(args))
206
+ model = SE3TransformerPooled(
207
+ fiber_in=Fiber({0: datamodule.NODE_FEATURE_DIM}),
208
+ fiber_out=Fiber({0: args.num_degrees * args.num_channels}),
209
+ fiber_edge=Fiber({0: datamodule.EDGE_FEATURE_DIM}),
210
+ output_dim=1,
211
+ tensor_cores=using_tensor_cores(args.amp), # use Tensor Cores more effectively
212
+ **vars(args)
213
+ )
214
+ loss_fn = nn.L1Loss()
215
+
216
+ if args.benchmark:
217
+ logging.info('Running benchmark mode')
218
+ world_size = dist.get_world_size() if dist.is_initialized() else 1
219
+ callbacks = [PerformanceCallback(logger, args.batch_size * world_size)]
220
+ else:
221
+ callbacks = [QM9MetricCallback(logger, targets_std=datamodule.targets_std, prefix='validation'),
222
+ QM9LRSchedulerCallback(logger, epochs=args.epochs)]
223
+
224
+ if is_distributed:
225
+ gpu_affinity.set_affinity(gpu_id=get_local_rank(), nproc_per_node=torch.cuda.device_count())
226
+
227
+ print_parameters_count(model)
228
+ logger.log_hyperparams(vars(args))
229
+ increase_l2_fetch_granularity()
230
+ train(model,
231
+ loss_fn,
232
+ datamodule.train_dataloader(),
233
+ datamodule.val_dataloader(),
234
+ callbacks,
235
+ logger,
236
+ args)
237
+
238
+ logging.info('Training finished successfully')
se3_transformer/runtime/utils.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a
4
+ # copy of this software and associated documentation files (the "Software"),
5
+ # to deal in the Software without restriction, including without limitation
6
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense,
7
+ # and/or sell copies of the Software, and to permit persons to whom the
8
+ # Software is furnished to do so, subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in
11
+ # all copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
16
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18
+ # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19
+ # DEALINGS IN THE SOFTWARE.
20
+ #
21
+ # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
22
+ # SPDX-License-Identifier: MIT
23
+
24
+ import argparse
25
+ import ctypes
26
+ import logging
27
+ import os
28
+ import random
29
+ from functools import wraps
30
+ from typing import Union, List, Dict
31
+
32
+ import numpy as np
33
+ import torch
34
+ import torch.distributed as dist
35
+ from torch import Tensor
36
+
37
+
38
+ def aggregate_residual(feats1, feats2, method: str):
39
+ """ Add or concatenate two fiber features together. If degrees don't match, will use the ones of feats2. """
40
+ if method in ['add', 'sum']:
41
+ return {k: (v + feats1[k]) if k in feats1 else v for k, v in feats2.items()}
42
+ elif method in ['cat', 'concat']:
43
+ return {k: torch.cat([v, feats1[k]], dim=1) if k in feats1 else v for k, v in feats2.items()}
44
+ else:
45
+ raise ValueError('Method must be add/sum or cat/concat')
46
+
47
+
48
+ def degree_to_dim(degree: int) -> int:
49
+ return 2 * degree + 1
50
+
51
+
52
+ def unfuse_features(features: Tensor, degrees: List[int]) -> Dict[str, Tensor]:
53
+ return dict(zip(map(str, degrees), features.split([degree_to_dim(deg) for deg in degrees], dim=-1)))
54
+
55
+
56
+ def str2bool(v: Union[bool, str]) -> bool:
57
+ if isinstance(v, bool):
58
+ return v
59
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
60
+ return True
61
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
62
+ return False
63
+ else:
64
+ raise argparse.ArgumentTypeError('Boolean value expected.')
65
+
66
+
67
+ def to_cuda(x):
68
+ """ Try to convert a Tensor, a collection of Tensors or a DGLGraph to CUDA """
69
+ if isinstance(x, Tensor):
70
+ return x.cuda(non_blocking=True)
71
+ elif isinstance(x, tuple):
72
+ return (to_cuda(v) for v in x)
73
+ elif isinstance(x, list):
74
+ return [to_cuda(v) for v in x]
75
+ elif isinstance(x, dict):
76
+ return {k: to_cuda(v) for k, v in x.items()}
77
+ else:
78
+ # DGLGraph or other objects
79
+ return x.to(device=torch.cuda.current_device())
80
+
81
+
82
+ def get_local_rank() -> int:
83
+ return int(os.environ.get('LOCAL_RANK', 0))
84
+
85
+
86
+ def init_distributed() -> bool:
87
+ world_size = int(os.environ.get('WORLD_SIZE', 1))
88
+ distributed = world_size > 1
89
+ if distributed:
90
+ backend = 'nccl' if torch.cuda.is_available() else 'gloo'
91
+ dist.init_process_group(backend=backend, init_method='env://')
92
+ if backend == 'nccl':
93
+ torch.cuda.set_device(get_local_rank())
94
+ else:
95
+ logging.warning('Running on CPU only!')
96
+ assert torch.distributed.is_initialized()
97
+ return distributed
98
+
99
+
100
+ def increase_l2_fetch_granularity():
101
+ # maximum fetch granularity of L2: 128 bytes
102
+ _libcudart = ctypes.CDLL('libcudart.so')
103
+ # set device limit on the current device
104
+ # cudaLimitMaxL2FetchGranularity = 0x05
105
+ pValue = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int))
106
+ _libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128))
107
+ _libcudart.cudaDeviceGetLimit(pValue, ctypes.c_int(0x05))
108
+ assert pValue.contents.value == 128
109
+
110
+
111
+ def seed_everything(seed):
112
+ seed = int(seed)
113
+ random.seed(seed)
114
+ np.random.seed(seed)
115
+ torch.manual_seed(seed)
116
+ torch.cuda.manual_seed_all(seed)
117
+
118
+
119
+ def rank_zero_only(fn):
120
+ @wraps(fn)
121
+ def wrapped_fn(*args, **kwargs):
122
+ if not dist.is_initialized() or dist.get_rank() == 0:
123
+ return fn(*args, **kwargs)
124
+
125
+ return wrapped_fn
126
+
127
+
128
+ def using_tensor_cores(amp: bool) -> bool:
129
+ major_cc, minor_cc = torch.cuda.get_device_capability()
130
+ return (amp and major_cc >= 7) or major_cc >= 8