# Copyright 2021 AlQuraishi Laboratory # Copyright 2021 DeepMind Technologies Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from functools import partial import logging from typing import Tuple, List, Callable, Any, Dict, Sequence, Optional import torch import torch.nn as nn def add(m1, m2, inplace): # The first operation in a checkpoint can't be in-place, but it's # nice to have in-place addition during inference. Thus... if(not inplace): m1 = m1 + m2 else: m1 += m2 return m1 def permute_final_dims(tensor: torch.Tensor, inds: List[int]): zero_index = -1 * len(inds) first_inds = list(range(len(tensor.shape[:zero_index]))) return tensor.permute(first_inds + [zero_index + i for i in inds]) def flatten_final_dims(t: torch.Tensor, no_dims: int): return t.reshape(t.shape[:-no_dims] + (-1,)) def masked_mean(mask, value, dim, eps=1e-4): mask = mask.expand(*value.shape) return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim)) def pts_to_distogram(pts, min_bin=2.3125, max_bin=21.6875, no_bins=64): boundaries = torch.linspace( min_bin, max_bin, no_bins - 1, device=pts.device ) dists = torch.sqrt( torch.sum((pts.unsqueeze(-2) - pts.unsqueeze(-3)) ** 2, dim=-1) ) return torch.bucketize(dists, boundaries) def dict_multimap(fn, dicts): first = dicts[0] new_dict = {} for k, v in first.items(): all_v = [d[k] for d in dicts] if type(v) is dict: new_dict[k] = dict_multimap(fn, all_v) else: new_dict[k] = fn(all_v) return new_dict def one_hot(x, v_bins): reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),)) diffs = x[..., None] - reshaped_bins am = torch.argmin(torch.abs(diffs), dim=-1) return nn.functional.one_hot(am, num_classes=len(v_bins)).float() def batched_gather(data, inds, dim=0, no_batch_dims=0): ranges = [] for i, s in enumerate(data.shape[:no_batch_dims]): r = torch.arange(s) r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1)))) ranges.append(r) remaining_dims = [ slice(None) for _ in range(len(data.shape) - no_batch_dims) ] remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds ranges.extend(remaining_dims) return data[ranges] # With tree_map, a poor man's JAX tree_map def dict_map(fn, dic, leaf_type): new_dict = {} for k, v in dic.items(): # print("dictttt", k,type(v), v) if type(v) is dict: new_dict[k] = dict_map(fn, v, leaf_type) else: new_dict[k] = tree_map(fn, v, leaf_type) return new_dict def tree_map(fn, tree, leaf_type): if isinstance(tree, dict): return dict_map(fn, tree, leaf_type) elif isinstance(tree, list): return [tree_map(fn, x, leaf_type) for x in tree] elif isinstance(tree, tuple): return tuple([tree_map(fn, x, leaf_type) for x in tree]) elif isinstance(tree, leaf_type): return fn(tree) else: print(type(tree)) raise ValueError("Not supported") tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor)