File size: 16,460 Bytes
9439b9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
import logging
import os
from typing import Tuple, Union, List

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from e3nn import o3
from e3nn.nn import BatchNorm
from e3nn.o3 import TensorProduct, Linear
from torch_scatter import scatter, scatter_mean

from models.layers import FCBlock


def get_irrep_seq(ns, nv, use_second_order_repr, reduce_pseudoscalars):
    if use_second_order_repr:
        irrep_seq = [
            f'{ns}x0e',
            f'{ns}x0e + {nv}x1o + {nv}x2e',
            f'{ns}x0e + {nv}x1o + {nv}x2e + {nv}x1e + {nv}x2o',
            f'{ns}x0e + {nv}x1o + {nv}x2e + {nv}x1e + {nv}x2o + {nv if reduce_pseudoscalars else ns}x0o'
        ]
    else:
        irrep_seq = [
            f'{ns}x0e',
            f'{ns}x0e + {nv}x1o',
            f'{ns}x0e + {nv}x1o + {nv}x1e',
            f'{ns}x0e + {nv}x1o + {nv}x1e + {nv if reduce_pseudoscalars else ns}x0o'
        ]
    return irrep_seq


def irrep_to_size(irrep):
    irreps = irrep.split(' + ')
    size = 0
    for ir in irreps:
        m, (l, p) = ir.split('x')
        size += int(m) * (2 * int(l) + 1)
    return size


class FasterTensorProduct(torch.nn.Module):
    # Implemented by Bowen Jing
    def __init__(self, in_irreps, sh_irreps, out_irreps, **kwargs):
        super().__init__()
        #for ir in in_irreps:
        #    m, (l, p) = ir
        #    assert l in [0, 1], "Higher order in irreps are not supported"
        #for ir in out_irreps:
        #    m, (l, p) = ir
        #    assert l in [0, 1], "Higher order out irreps are not supported"
        assert o3.Irreps(sh_irreps) == o3.Irreps('1x0e+1x1o'), "sh_irreps don't look like 1st order spherical harmonics"
        self.in_irreps = o3.Irreps(in_irreps)
        self.out_irreps = o3.Irreps(out_irreps)

        in_muls = {'0e': 0, '1o': 0, '1e': 0, '0o': 0}
        out_muls = {'0e': 0, '1o': 0, '1e': 0, '0o': 0}
        for (m, ir) in self.in_irreps: in_muls[str(ir)] = m
        for (m, ir) in self.out_irreps: out_muls[str(ir)] = m

        self.weight_shapes = {
            '0e': (in_muls['0e'] + in_muls['1o'], out_muls['0e']),
            '1o': (in_muls['0e'] + in_muls['1o'] + in_muls['1e'], out_muls['1o']),
            '1e': (in_muls['1o'] + in_muls['1e'] + in_muls['0o'], out_muls['1e']),
            '0o': (in_muls['1e'] + in_muls['0o'], out_muls['0o'])
        }
        self.weight_numel = sum(a * b for (a, b) in self.weight_shapes.values())

    def forward(self, in_, sh, weight):
        in_dict, out_dict = {}, {'0e': [], '1o': [], '1e': [], '0o': []}
        for (m, ir), sl in zip(self.in_irreps, self.in_irreps.slices()):
            in_dict[str(ir)] = in_[..., sl]
            if ir[0] == 1: in_dict[str(ir)] = in_dict[str(ir)].reshape(list(in_dict[str(ir)].shape)[:-1] + [-1, 3])
        sh_0e, sh_1o = sh[..., 0], sh[..., 1:]
        if '0e' in in_dict:
            out_dict['0e'].append(in_dict['0e'] * sh_0e.unsqueeze(-1))
            out_dict['1o'].append(in_dict['0e'].unsqueeze(-1) * sh_1o.unsqueeze(-2))
        if '1o' in in_dict:
            out_dict['0e'].append((in_dict['1o'] * sh_1o.unsqueeze(-2)).sum(-1) / np.sqrt(3))
            out_dict['1o'].append(in_dict['1o'] * sh_0e.unsqueeze(-1).unsqueeze(-1))
            out_dict['1e'].append(torch.linalg.cross(in_dict['1o'], sh_1o.unsqueeze(-2), dim=-1) / np.sqrt(2))
        if '1e' in in_dict:
            out_dict['1o'].append(torch.linalg.cross(in_dict['1e'], sh_1o.unsqueeze(-2), dim=-1) / np.sqrt(2))
            out_dict['1e'].append(in_dict['1e'] * sh_0e.unsqueeze(-1).unsqueeze(-1))
            out_dict['0o'].append((in_dict['1e'] * sh_1o.unsqueeze(-2)).sum(-1) / np.sqrt(3))
        if '0o' in in_dict:
            out_dict['1e'].append(in_dict['0o'].unsqueeze(-1) * sh_1o.unsqueeze(-2))
            out_dict['0o'].append(in_dict['0o'] * sh_0e.unsqueeze(-1))

        weight_dict = {}
        start = 0
        for key in self.weight_shapes:
            in_, out = self.weight_shapes[key]
            weight_dict[key] = weight[..., start:start + in_ * out].reshape(
                list(weight.shape)[:-1] + [in_, out]) / np.sqrt(in_)
            start += in_ * out

        if out_dict['0e']:
            out_dict['0e'] = torch.cat(out_dict['0e'], dim=-1)
            out_dict['0e'] = torch.matmul(out_dict['0e'].unsqueeze(-2), weight_dict['0e']).squeeze(-2)

        if out_dict['1o']:
            out_dict['1o'] = torch.cat(out_dict['1o'], dim=-2)
            out_dict['1o'] = (out_dict['1o'].unsqueeze(-2) * weight_dict['1o'].unsqueeze(-1)).sum(-3)
            out_dict['1o'] = out_dict['1o'].reshape(list(out_dict['1o'].shape)[:-2] + [-1])

        if out_dict['1e']:
            out_dict['1e'] = torch.cat(out_dict['1e'], dim=-2)
            out_dict['1e'] = (out_dict['1e'].unsqueeze(-2) * weight_dict['1e'].unsqueeze(-1)).sum(-3)
            out_dict['1e'] = out_dict['1e'].reshape(list(out_dict['1e'].shape)[:-2] + [-1])

        if out_dict['0o']:
            out_dict['0o'] = torch.cat(out_dict['0o'], dim=-1)
            # out_dict['0o'] = (out_dict['0o'].unsqueeze(-1) * weight_dict['0o']).sum(-2)
            out_dict['0o'] = torch.matmul(out_dict['0o'].unsqueeze(-2), weight_dict['0o']).squeeze(-2)

        out = []
        for _, ir in self.out_irreps:
            out.append(out_dict[str(ir)])
        return torch.cat(out, dim=-1)


def tp_scatter_simple(tp, fc_layer, node_attr, edge_index, edge_attr, edge_sh,
                      out_nodes=None, reduce='mean', edge_weight=1.0):
    """
    Perform TensorProduct + scatter operation, aka graph convolution.

    This function is only for edge_groups == 1. For multiple edge groups, and for larger graphs,
    use tp_scatter_multigroup instead.
    """

    assert isinstance(edge_attr, torch.Tensor), \
        "This function is only for a single edge group, so edge_attr must be a tensor and not a list."

    _device = node_attr.device
    _dtype = node_attr.dtype
    edge_src, edge_dst = edge_index
    out_irreps = fc_layer(edge_attr).to(_device).to(_dtype)
    out_irreps.mul_(edge_weight)
    tp = tp(node_attr[edge_dst], edge_sh, out_irreps)
    out_nodes = out_nodes or node_attr.shape[0]
    out = scatter(tp, edge_src, dim=0, dim_size=out_nodes, reduce=reduce)
    return out


def tp_scatter_multigroup(tp: o3.TensorProduct, fc_layer: Union[nn.Module, nn.ModuleList],
                          node_attr: torch.Tensor, edge_index: torch.Tensor,
                          edge_attr_groups: List[torch.Tensor], edge_sh: torch.Tensor,
                          out_nodes=None, reduce='mean', edge_weight=1.0):
    """
    Perform TensorProduct + scatter operation, aka graph convolution.

    To keep the peak memory usage reasonably low, this function does not concatenate the edge_attr_groups.
    Rather, we sum the output of the tensor product for each edge group, and then divide by the number of edges

    Parameters
    ----------
    tp: o3.TensorProduct
    fc_layer: nn.Module, or nn.ModuleList
        If a list, must be the same length as edge_attr_groups
    node_attr: torch.Tensor
    edge_index: torch.Tensor of shape (2, num_edges)
        Indicates the source and destination nodes of each edge
    edge_attr_groups: List[torch.Tensor]
        List of tensors, with shape (X_i, num_edge_attributes). Each tensor is a different group of edge attributes
        X may be different for each tensor, although sum(X_i) must be equal to edge_index.shape[1]
    edge_sh: torch.Tensor
        Spherical harmonics for the edges (see o3.spherical_harmonics)
    out_nodes:
        Number of output nodes
    reduce: str
        'mean' or 'sum'. Reduce function for scatter.
    edge_weight : float or torch.Tensor
        Edge weights. If a tensor, must be the same shape as `edge_index`

    Returns
    -------
    torch.Tensor
        Result of the graph convolution
    """

    assert isinstance(edge_attr_groups, list), "This function is only for a list of edge groups"
    assert reduce in {"mean", "sum"}, "Only 'mean' and 'sum' are supported for reduce"
    # It would be possible to support mul/min/max but that would require more work and more code,
    # so only going to do it if it's needed.

    _device = node_attr.device
    _dtype = node_attr.dtype
    edge_src, edge_dst = edge_index
    edge_attr_lengths = [_edge_attr.shape[0] for _edge_attr in edge_attr_groups]
    total_rows = sum(edge_attr_lengths)
    assert total_rows == edge_index.shape[1], "Sum of edge_attr_groups must be equal to edge_index.shape[1]"
    num_edge_groups = len(edge_attr_groups)
    edge_weight_is_indexable = hasattr(edge_weight, '__getitem__')

    out_nodes = out_nodes or node_attr.shape[0]
    total_output_dim = sum([x.dim for x in tp.irreps_out])
    final_out = torch.zeros((out_nodes, total_output_dim), device=_device, dtype=_dtype)
    div_factors = torch.zeros(out_nodes, device=_device, dtype=_dtype)

    cur_start = 0
    for ii in range(num_edge_groups):
        cur_length = edge_attr_lengths[ii]
        cur_end = cur_start + cur_length
        cur_edge_range = slice(cur_start, cur_end)
        cur_edge_src, cur_edge_dst = edge_src[cur_edge_range], edge_dst[cur_edge_range]

        cur_fc = fc_layer[ii] if isinstance(fc_layer, nn.ModuleList) else fc_layer
        cur_out_irreps = cur_fc(edge_attr_groups[ii])
        if edge_weight_is_indexable:
            cur_out_irreps.mul_(edge_weight[cur_edge_range])
        else:
            cur_out_irreps.mul_(edge_weight)

        summand = tp(node_attr[cur_edge_dst, :], edge_sh[cur_edge_range, :], cur_out_irreps)
        # We take a simple sum, and then add up the count of edges which contribute,
        # so that we can take the mean later.
        final_out += scatter(summand, cur_edge_src, dim=0, dim_size=out_nodes, reduce="sum")
        div_factors += torch.bincount(cur_edge_src, minlength=out_nodes)

        cur_start = cur_end

        del cur_out_irreps, summand

    if reduce == 'mean':
        div_factors = torch.clamp(div_factors, torch.finfo(_dtype).eps)
        final_out = final_out / div_factors[:, None]

    return final_out


class TensorProductConvLayer(torch.nn.Module):
    def __init__(self, in_irreps, sh_irreps, out_irreps, n_edge_features, residual=True, batch_norm=True, dropout=0.0,
                 hidden_features=None, faster=False, edge_groups=1, tp_weights_layers=2, activation='relu', depthwise=False):
        super(TensorProductConvLayer, self).__init__()
        self.in_irreps = in_irreps
        self.out_irreps = out_irreps
        self.sh_irreps = sh_irreps
        self.residual = residual
        self.edge_groups = edge_groups
        self.out_size = irrep_to_size(out_irreps)
        self.depthwise = depthwise
        if hidden_features is None:
            hidden_features = n_edge_features

        if depthwise:
            in_irreps = o3.Irreps(in_irreps)
            sh_irreps = o3.Irreps(sh_irreps)
            out_irreps = o3.Irreps(out_irreps)

            irreps_mid = []
            instructions = []
            for i, (mul, ir_in) in enumerate(in_irreps):
                for j, (_, ir_edge) in enumerate(sh_irreps):
                    for ir_out in ir_in * ir_edge:
                        if ir_out in out_irreps:
                            k = len(irreps_mid)
                            irreps_mid.append((mul, ir_out))
                            instructions.append((i, j, k, "uvu", True))

            # We sort the output irreps of the tensor product so that we can simplify them
            # when they are provided to the second o3.Linear
            irreps_mid = o3.Irreps(irreps_mid)
            irreps_mid, p, _ = irreps_mid.sort()

            # Permute the output indexes of the instructions to match the sorted irreps:
            instructions = [
                (i_in1, i_in2, p[i_out], mode, train)
                for i_in1, i_in2, i_out, mode, train in instructions
            ]

            self.tp = TensorProduct(
                in_irreps,
                sh_irreps,
                irreps_mid,
                instructions,
                shared_weights=False,
                internal_weights=False,
            )

            self.linear_2 = Linear(
                # irreps_mid has uncoallesed irreps because of the uvu instructions,
                # but there's no reason to treat them seperately for the Linear
                # Note that normalization of o3.Linear changes if irreps are coallesed
                # (likely for the better)
                irreps_in=irreps_mid.simplify(),
                irreps_out=out_irreps,
                internal_weights=True,
                shared_weights=True,
            )

        else:
            if faster:
                print("Faster Tensor Product")
                self.tp = FasterTensorProduct(in_irreps, sh_irreps, out_irreps)
            else:
                self.tp = o3.FullyConnectedTensorProduct(in_irreps, sh_irreps, out_irreps, shared_weights=False)

        if edge_groups == 1:
            self.fc = FCBlock(n_edge_features, hidden_features, self.tp.weight_numel, tp_weights_layers, dropout, activation)
        else:
            self.fc = [FCBlock(n_edge_features, hidden_features, self.tp.weight_numel, tp_weights_layers, dropout, activation) for _ in range(edge_groups)]
            self.fc = nn.ModuleList(self.fc)

        self.batch_norm = BatchNorm(out_irreps) if batch_norm else None

    def forward(self, node_attr, edge_index, edge_attr, edge_sh, out_nodes=None, reduce='mean', edge_weight=1.0):
        if edge_index.shape[1] == 0 and node_attr.shape[0] == 0:
            raise ValueError("No edges and no nodes")

        _dtype = node_attr.dtype
        if edge_index.shape[1] == 0:
            out = torch.zeros((node_attr.shape[0], self.out_size), dtype=_dtype, device=node_attr.device)
        else:
            if self.edge_groups == 1:
                out = tp_scatter_simple(self.tp, self.fc, node_attr, edge_index, edge_attr, edge_sh,
                                        out_nodes, reduce, edge_weight)
            else:
                out = tp_scatter_multigroup(self.tp, self.fc, node_attr, edge_index, edge_attr, edge_sh,
                                            out_nodes, reduce, edge_weight)

            if self.depthwise:
                out = self.linear_2(out)

            if self.batch_norm:
                out = self.batch_norm(out)

        if self.residual:
            padded = F.pad(node_attr, (0, out.shape[-1] - node_attr.shape[-1]))
            out = out + padded

        out = out.to(_dtype)
        return out


class OldTensorProductConvLayer(torch.nn.Module):
    def __init__(self, in_irreps, sh_irreps, out_irreps, n_edge_features, residual=True, batch_norm=True, dropout=0.0,
                 hidden_features=None):
        super(OldTensorProductConvLayer, self).__init__()
        self.in_irreps = in_irreps
        self.out_irreps = out_irreps
        self.sh_irreps = sh_irreps
        self.residual = residual
        if hidden_features is None:
            hidden_features = n_edge_features

        self.tp = tp = o3.FullyConnectedTensorProduct(in_irreps, sh_irreps, out_irreps, shared_weights=False)

        self.fc = nn.Sequential(
            nn.Linear(n_edge_features, hidden_features),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_features, tp.weight_numel)
        )
        self.batch_norm = BatchNorm(out_irreps) if batch_norm else None

    def forward(self, node_attr, edge_index, edge_attr, edge_sh, out_nodes=None, reduce='mean', edge_weight=1.0):

        # Break up the edge_attr into chunks to limit the maximum memory usage
        edge_chunk_size = 100_000
        num_edges = edge_attr.shape[0]
        num_chunks = (num_edges // edge_chunk_size) if num_edges % edge_chunk_size == 0 \
            else (num_edges // edge_chunk_size) + 1
        edge_ranges = np.array_split(np.arange(num_edges), num_chunks)
        edge_attr_groups = [edge_attr[cur_range] for cur_range in edge_ranges]

        out = tp_scatter_multigroup(self.tp, self.fc, node_attr, edge_index, edge_attr_groups, edge_sh,
                                    out_nodes, reduce, edge_weight)

        if self.residual:
            padded = F.pad(node_attr, (0, out.shape[-1] - node_attr.shape[-1]))
            out = out + padded

        if self.batch_norm:
            out = self.batch_norm(out)

        out = out.to(node_attr.dtype)
        return out