import os import cv2 import time import random import re import string import numpy as np import pandas as pd import torch import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence import albumentations as A from albumentations.pytorch import ToTensorV2 from .indigo import Indigo from .indigo.renderer import IndigoRenderer from .augment import SafeRotate, CropWhite, PadWhite, SaltAndPepperNoise from .utils import FORMAT_INFO from .tokenizer import PAD_ID from .chemistry import get_num_atoms, normalize_nodes from .constants import RGROUP_SYMBOLS, SUBSTITUTIONS, ELEMENTS, COLORS cv2.setNumThreads(1) INDIGO_HYGROGEN_PROB = 0.2 INDIGO_FUNCTIONAL_GROUP_PROB = 0.8 INDIGO_CONDENSED_PROB = 0.5 INDIGO_RGROUP_PROB = 0.5 INDIGO_COMMENT_PROB = 0.3 INDIGO_DEARMOTIZE_PROB = 0.8 INDIGO_COLOR_PROB = 0.2 def get_transforms(input_size, augment=True, rotate=True, debug=False): trans_list = [] if augment and rotate: trans_list.append(SafeRotate(limit=90, border_mode=cv2.BORDER_CONSTANT, value=(255, 255, 255))) trans_list.append(CropWhite(pad=5)) if augment: trans_list += [ # NormalizedGridDistortion(num_steps=10, distort_limit=0.3), A.CropAndPad(percent=[-0.01, 0.00], keep_size=False, p=0.5), PadWhite(pad_ratio=0.4, p=0.2), A.Downscale(scale_min=0.2, scale_max=0.5, interpolation=3), A.Blur(), A.GaussNoise(), SaltAndPepperNoise(num_dots=20, p=0.5) ] trans_list.append(A.Resize(input_size, input_size)) if not debug: mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] trans_list += [ A.ToGray(p=1), A.Normalize(mean=mean, std=std), ToTensorV2(), ] return A.Compose(trans_list, keypoint_params=A.KeypointParams(format='xy', remove_invisible=False)) def add_functional_group(indigo, mol, debug=False): if random.random() > INDIGO_FUNCTIONAL_GROUP_PROB: return mol # Delete functional group and add a pseudo atom with its abbrv substitutions = [sub for sub in SUBSTITUTIONS] random.shuffle(substitutions) for sub in substitutions: query = indigo.loadSmarts(sub.smarts) matcher = indigo.substructureMatcher(mol) matched_atoms_ids = set() for match in matcher.iterateMatches(query): if random.random() < sub.probability or debug: atoms = [] atoms_ids = set() for item in query.iterateAtoms(): atom = match.mapAtom(item) atoms.append(atom) atoms_ids.add(atom.index()) if len(matched_atoms_ids.intersection(atoms_ids)) > 0: continue abbrv = random.choice(sub.abbrvs) superatom = mol.addAtom(abbrv) for atom in atoms: for nei in atom.iterateNeighbors(): if nei.index() not in atoms_ids: if nei.symbol() == 'H': # indigo won't match explicit hydrogen, so remove them explicitly atoms_ids.add(nei.index()) else: superatom.addBond(nei, nei.bond().bondOrder()) for id in atoms_ids: mol.getAtom(id).remove() matched_atoms_ids = matched_atoms_ids.union(atoms_ids) return mol def add_explicit_hydrogen(indigo, mol): atoms = [] for atom in mol.iterateAtoms(): try: hs = atom.countImplicitHydrogens() if hs > 0: atoms.append((atom, hs)) except: continue if len(atoms) > 0 and random.random() < INDIGO_HYGROGEN_PROB: atom, hs = random.choice(atoms) for i in range(hs): h = mol.addAtom('H') h.addBond(atom, 1) return mol def add_rgroup(indigo, mol, smiles): atoms = [] for atom in mol.iterateAtoms(): try: hs = atom.countImplicitHydrogens() if hs > 0: atoms.append(atom) except: continue if len(atoms) > 0 and '*' not in smiles: if random.random() < INDIGO_RGROUP_PROB: atom_idx = random.choice(range(len(atoms))) atom = atoms[atom_idx] atoms.pop(atom_idx) symbol = random.choice(RGROUP_SYMBOLS) r = mol.addAtom(symbol) r.addBond(atom, 1) return mol def get_rand_symb(): symb = random.choice(ELEMENTS) if random.random() < 0.1: symb += random.choice(string.ascii_lowercase) if random.random() < 0.1: symb += random.choice(string.ascii_uppercase) if random.random() < 0.1: symb = f'({gen_rand_condensed()})' return symb def get_rand_num(): if random.random() < 0.9: if random.random() < 0.8: return '' else: return str(random.randint(2, 9)) else: return '1' + str(random.randint(2, 9)) def gen_rand_condensed(): tokens = [] for i in range(5): if i >= 1 and random.random() < 0.8: break tokens.append(get_rand_symb()) tokens.append(get_rand_num()) return ''.join(tokens) def add_rand_condensed(indigo, mol): atoms = [] for atom in mol.iterateAtoms(): try: hs = atom.countImplicitHydrogens() if hs > 0: atoms.append(atom) except: continue if len(atoms) > 0 and random.random() < INDIGO_CONDENSED_PROB: atom = random.choice(atoms) symbol = gen_rand_condensed() r = mol.addAtom(symbol) r.addBond(atom, 1) return mol def generate_output_smiles(indigo, mol): # TODO: if using mol.canonicalSmiles(), explicit H will be removed smiles = mol.smiles() mol = indigo.loadMolecule(smiles) if '*' in smiles: part_a, part_b = smiles.split(' ', maxsplit=1) part_b = re.search(r'\$.*\$', part_b).group(0)[1:-1] symbols = [t for t in part_b.split(';') if len(t) > 0] output = '' cnt = 0 for i, c in enumerate(part_a): if c != '*': output += c else: output += f'[{symbols[cnt]}]' cnt += 1 return mol, output else: if ' ' in smiles: # special cases with extension smiles = smiles.split(' ')[0] return mol, smiles def add_comment(indigo): if random.random() < INDIGO_COMMENT_PROB: indigo.setOption('render-comment', str(random.randint(1, 20)) + random.choice(string.ascii_letters)) indigo.setOption('render-comment-font-size', random.randint(40, 60)) indigo.setOption('render-comment-alignment', random.choice([0, 0.5, 1])) indigo.setOption('render-comment-position', random.choice(['top', 'bottom'])) indigo.setOption('render-comment-offset', random.randint(2, 30)) def add_color(indigo, mol): if random.random() < INDIGO_COLOR_PROB: indigo.setOption('render-coloring', True) if random.random() < INDIGO_COLOR_PROB: indigo.setOption('render-base-color', random.choice(list(COLORS.values()))) if random.random() < INDIGO_COLOR_PROB: if random.random() < 0.5: indigo.setOption('render-highlight-color-enabled', True) indigo.setOption('render-highlight-color', random.choice(list(COLORS.values()))) if random.random() < 0.5: indigo.setOption('render-highlight-thickness-enabled', True) for atom in mol.iterateAtoms(): if random.random() < 0.1: atom.highlight() return mol def get_graph(mol, image, shuffle_nodes=False, pseudo_coords=False): mol.layout() coords, symbols = [], [] index_map = {} atoms = [atom for atom in mol.iterateAtoms()] if shuffle_nodes: random.shuffle(atoms) for i, atom in enumerate(atoms): if pseudo_coords: x, y, z = atom.xyz() else: x, y = atom.coords() coords.append([x, y]) symbols.append(atom.symbol()) index_map[atom.index()] = i if pseudo_coords: coords = normalize_nodes(np.array(coords)) h, w, _ = image.shape coords[:, 0] = coords[:, 0] * w coords[:, 1] = coords[:, 1] * h n = len(symbols) edges = np.zeros((n, n), dtype=int) for bond in mol.iterateBonds(): s = index_map[bond.source().index()] t = index_map[bond.destination().index()] # 1/2/3/4 : single/double/triple/aromatic edges[s, t] = bond.bondOrder() edges[t, s] = bond.bondOrder() if bond.bondStereo() in [5, 6]: edges[s, t] = bond.bondStereo() edges[t, s] = 11 - bond.bondStereo() graph = { 'coords': coords, 'symbols': symbols, 'edges': edges, 'num_atoms': len(symbols) } return graph def generate_indigo_image(smiles, mol_augment=True, default_option=False, shuffle_nodes=False, pseudo_coords=False, include_condensed=True, debug=False): indigo = Indigo() renderer = IndigoRenderer(indigo) indigo.setOption('render-output-format', 'png') indigo.setOption('render-background-color', '1,1,1') indigo.setOption('render-stereo-style', 'none') indigo.setOption('render-label-mode', 'hetero') indigo.setOption('render-font-family', 'Arial') if not default_option: thickness = random.uniform(0.5, 2) # limit the sum of the following two parameters to be smaller than 4 indigo.setOption('render-relative-thickness', thickness) indigo.setOption('render-bond-line-width', random.uniform(1, 4 - thickness)) if random.random() < 0.5: indigo.setOption('render-font-family', random.choice(['Arial', 'Times', 'Courier', 'Helvetica'])) indigo.setOption('render-label-mode', random.choice(['hetero', 'terminal-hetero'])) indigo.setOption('render-implicit-hydrogens-visible', random.choice([True, False])) if random.random() < 0.1: indigo.setOption('render-stereo-style', 'old') if random.random() < 0.2: indigo.setOption('render-atom-ids-visible', True) try: mol = indigo.loadMolecule(smiles) if mol_augment: if random.random() < INDIGO_DEARMOTIZE_PROB: mol.dearomatize() else: mol.aromatize() smiles = mol.canonicalSmiles() add_comment(indigo) mol = add_explicit_hydrogen(indigo, mol) mol = add_rgroup(indigo, mol, smiles) if include_condensed: mol = add_rand_condensed(indigo, mol) mol = add_functional_group(indigo, mol, debug) mol = add_color(indigo, mol) mol, smiles = generate_output_smiles(indigo, mol) buf = renderer.renderToBuffer(mol) img = cv2.imdecode(np.asarray(bytearray(buf), dtype=np.uint8), 1) # decode buffer to image # img = np.repeat(np.expand_dims(img, 2), 3, axis=2) # expand to RGB graph = get_graph(mol, img, shuffle_nodes, pseudo_coords) success = True except Exception: if debug: raise Exception img = np.array([[[255., 255., 255.]] * 10] * 10).astype(np.float32) graph = {} success = False return img, smiles, graph, success class TrainDataset(Dataset): def __init__(self, args, df, tokenizer, split='train', dynamic_indigo=False): super().__init__() self.df = df self.args = args self.tokenizer = tokenizer if 'file_path' in df.columns: self.file_paths = df['file_path'].values if not self.file_paths[0].startswith(args.data_path): self.file_paths = [os.path.join(args.data_path, path) for path in df['file_path']] self.smiles = df['SMILES'].values if 'SMILES' in df.columns else None self.formats = args.formats self.labelled = (split == 'train') if self.labelled: self.labels = {} for format_ in self.formats: if format_ in ['atomtok', 'inchi']: field = FORMAT_INFO[format_]['name'] if field in df.columns: self.labels[format_] = df[field].values self.transform = get_transforms(args.input_size, augment=(self.labelled and args.augment)) # self.fix_transform = A.Compose([A.Transpose(p=1), A.VerticalFlip(p=1)]) self.dynamic_indigo = (dynamic_indigo and split == 'train') if self.labelled and not dynamic_indigo and args.coords_file is not None: if args.coords_file == 'aux_file': self.coords_df = df self.pseudo_coords = True else: self.coords_df = pd.read_csv(args.coords_file) self.pseudo_coords = False else: self.coords_df = None self.pseudo_coords = args.pseudo_coords def __len__(self): return len(self.df) def image_transform(self, image, coords=[], renormalize=False): image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # .astype(np.float32) augmented = self.transform(image=image, keypoints=coords) image = augmented['image'] if len(coords) > 0: coords = np.array(augmented['keypoints']) if renormalize: coords = normalize_nodes(coords, flip_y=False) else: _, height, width = image.shape coords[:, 0] = coords[:, 0] / width coords[:, 1] = coords[:, 1] / height coords = np.array(coords).clip(0, 1) return image, coords return image def __getitem__(self, idx): try: return self.getitem(idx) except Exception as e: with open(os.path.join(self.args.save_path, f'error_dataset_{int(time.time())}.log'), 'w') as f: f.write(str(e)) raise e def getitem(self, idx): ref = {} if self.dynamic_indigo: begin = time.time() image, smiles, graph, success = generate_indigo_image( self.smiles[idx], mol_augment=self.args.mol_augment, default_option=self.args.default_option, shuffle_nodes=self.args.shuffle_nodes, pseudo_coords=self.pseudo_coords, include_condensed=self.args.include_condensed) # raw_image = image end = time.time() if idx < 30 and self.args.save_image: path = os.path.join(self.args.save_path, 'images') os.makedirs(path, exist_ok=True) cv2.imwrite(os.path.join(path, f'{idx}.png'), image) if not success: return idx, None, {} image, coords = self.image_transform(image, graph['coords'], renormalize=self.pseudo_coords) graph['coords'] = coords ref['time'] = end - begin if 'atomtok' in self.formats: max_len = FORMAT_INFO['atomtok']['max_len'] label = self.tokenizer['atomtok'].text_to_sequence(smiles, tokenized=False) ref['atomtok'] = torch.LongTensor(label[:max_len]) if 'edges' in self.formats and 'atomtok_coords' not in self.formats and 'chartok_coords' not in self.formats: ref['edges'] = torch.tensor(graph['edges']) if 'atomtok_coords' in self.formats: self._process_atomtok_coords(idx, ref, smiles, graph['coords'], graph['edges'], mask_ratio=self.args.mask_ratio) if 'chartok_coords' in self.formats: self._process_chartok_coords(idx, ref, smiles, graph['coords'], graph['edges'], mask_ratio=self.args.mask_ratio) return idx, image, ref else: file_path = self.file_paths[idx] image = cv2.imread(file_path) if image is None: image = np.array([[[255., 255., 255.]] * 10] * 10).astype(np.float32) print(file_path, 'not found!') if self.coords_df is not None: h, w, _ = image.shape coords = np.array(eval(self.coords_df.loc[idx, 'node_coords'])) if self.pseudo_coords: coords = normalize_nodes(coords) coords[:, 0] = coords[:, 0] * w coords[:, 1] = coords[:, 1] * h image, coords = self.image_transform(image, coords, renormalize=self.pseudo_coords) else: image = self.image_transform(image) coords = None if self.labelled: smiles = self.smiles[idx] if 'atomtok' in self.formats: max_len = FORMAT_INFO['atomtok']['max_len'] label = self.tokenizer['atomtok'].text_to_sequence(smiles, False) ref['atomtok'] = torch.LongTensor(label[:max_len]) if 'atomtok_coords' in self.formats: if coords is not None: self._process_atomtok_coords(idx, ref, smiles, coords, mask_ratio=0) else: self._process_atomtok_coords(idx, ref, smiles, mask_ratio=1) if 'chartok_coords' in self.formats: if coords is not None: self._process_chartok_coords(idx, ref, smiles, coords, mask_ratio=0) else: self._process_chartok_coords(idx, ref, smiles, mask_ratio=1) if self.args.predict_coords and ('atomtok_coords' in self.formats or 'chartok_coords' in self.formats): smiles = self.smiles[idx] if 'atomtok_coords' in self.formats: self._process_atomtok_coords(idx, ref, smiles, mask_ratio=1) if 'chartok_coords' in self.formats: self._process_chartok_coords(idx, ref, smiles, mask_ratio=1) return idx, image, ref def _process_atomtok_coords(self, idx, ref, smiles, coords=None, edges=None, mask_ratio=0): max_len = FORMAT_INFO['atomtok_coords']['max_len'] tokenizer = self.tokenizer['atomtok_coords'] if smiles is None or type(smiles) is not str: smiles = "" label, indices = tokenizer.smiles_to_sequence(smiles, coords, mask_ratio=mask_ratio) ref['atomtok_coords'] = torch.LongTensor(label[:max_len]) indices = [i for i in indices if i < max_len] ref['atom_indices'] = torch.LongTensor(indices) if tokenizer.continuous_coords: if coords is not None: ref['coords'] = torch.tensor(coords) else: ref['coords'] = torch.ones(len(indices), 2) * -1. if edges is not None: ref['edges'] = torch.tensor(edges)[:len(indices), :len(indices)] else: if 'edges' in self.df.columns: edge_list = eval(self.df.loc[idx, 'edges']) n = len(indices) edges = torch.zeros((n, n), dtype=torch.long) for u, v, t in edge_list: if u < n and v < n: if t <= 4: edges[u, v] = t edges[v, u] = t else: edges[u, v] = t edges[v, u] = 11 - t ref['edges'] = edges else: ref['edges'] = torch.ones(len(indices), len(indices), dtype=torch.long) * (-100) def _process_chartok_coords(self, idx, ref, smiles, coords=None, edges=None, mask_ratio=0): max_len = FORMAT_INFO['chartok_coords']['max_len'] tokenizer = self.tokenizer['chartok_coords'] if smiles is None or type(smiles) is not str: smiles = "" label, indices = tokenizer.smiles_to_sequence(smiles, coords, mask_ratio=mask_ratio) ref['chartok_coords'] = torch.LongTensor(label[:max_len]) indices = [i for i in indices if i < max_len] ref['atom_indices'] = torch.LongTensor(indices) if tokenizer.continuous_coords: if coords is not None: ref['coords'] = torch.tensor(coords) else: ref['coords'] = torch.ones(len(indices), 2) * -1. if edges is not None: ref['edges'] = torch.tensor(edges)[:len(indices), :len(indices)] else: if 'edges' in self.df.columns: edge_list = eval(self.df.loc[idx, 'edges']) n = len(indices) edges = torch.zeros((n, n), dtype=torch.long) for u, v, t in edge_list: if u < n and v < n: if t <= 4: edges[u, v] = t edges[v, u] = t else: edges[u, v] = t edges[v, u] = 11 - t ref['edges'] = edges else: ref['edges'] = torch.ones(len(indices), len(indices), dtype=torch.long) * (-100) class AuxTrainDataset(Dataset): def __init__(self, args, train_df, aux_df, tokenizer): super().__init__() self.train_dataset = TrainDataset(args, train_df, tokenizer, dynamic_indigo=args.dynamic_indigo) self.aux_dataset = TrainDataset(args, aux_df, tokenizer, dynamic_indigo=False) def __len__(self): return len(self.train_dataset) + len(self.aux_dataset) def __getitem__(self, idx): if idx < len(self.train_dataset): return self.train_dataset[idx] else: return self.aux_dataset[idx - len(self.train_dataset)] def pad_images(imgs): # B, C, H, W max_shape = [0, 0] for img in imgs: for i in range(len(max_shape)): max_shape[i] = max(max_shape[i], img.shape[-1 - i]) stack = [] for img in imgs: pad = [] for i in range(len(max_shape)): pad = pad + [0, max_shape[i] - img.shape[-1 - i]] stack.append(F.pad(img, pad, value=0)) return torch.stack(stack) def bms_collate(batch): ids = [] imgs = [] batch = [ex for ex in batch if ex[1] is not None] formats = list(batch[0][2].keys()) seq_formats = [k for k in formats if k in ['atomtok', 'inchi', 'nodes', 'atomtok_coords', 'chartok_coords', 'atom_indices']] refs = {key: [[], []] for key in seq_formats} for ex in batch: ids.append(ex[0]) imgs.append(ex[1]) ref = ex[2] for key in seq_formats: refs[key][0].append(ref[key]) refs[key][1].append(torch.LongTensor([len(ref[key])])) # Sequence for key in seq_formats: # this padding should work for atomtok_with_coords too, each of which has shape (length, 4) refs[key][0] = pad_sequence(refs[key][0], batch_first=True, padding_value=PAD_ID) refs[key][1] = torch.stack(refs[key][1]).reshape(-1, 1) # Time # if 'time' in formats: # refs['time'] = [ex[2]['time'] for ex in batch] # Coords if 'coords' in formats: refs['coords'] = pad_sequence([ex[2]['coords'] for ex in batch], batch_first=True, padding_value=-1.) # Edges if 'edges' in formats: edges_list = [ex[2]['edges'] for ex in batch] max_len = max([len(edges) for edges in edges_list]) refs['edges'] = torch.stack( [F.pad(edges, (0, max_len - len(edges), 0, max_len - len(edges)), value=-100) for edges in edges_list], dim=0) return ids, pad_images(imgs), refs