import copy import traceback import numpy as np import multiprocessing import rdkit import rdkit.Chem as Chem rdkit.RDLogger.DisableLog('rdApp.*') from SmilesPE.pretokenizer import atomwise_tokenizer from .constants import RGROUP_SYMBOLS, ABBREVIATIONS, VALENCES, FORMULA_REGEX def is_valid_mol(s, format_='atomtok'): if format_ == 'atomtok': mol = Chem.MolFromSmiles(s) elif format_ == 'inchi': if not s.startswith('InChI=1S'): s = f"InChI=1S/{s}" mol = Chem.MolFromInchi(s) else: raise NotImplemented return mol is not None def _convert_smiles_to_inchi(smiles): try: mol = Chem.MolFromSmiles(smiles) inchi = Chem.MolToInchi(mol) except: inchi = None return inchi def convert_smiles_to_inchi(smiles_list, num_workers=16): with multiprocessing.Pool(num_workers) as p: inchi_list = p.map(_convert_smiles_to_inchi, smiles_list, chunksize=128) n_success = sum([x is not None for x in inchi_list]) r_success = n_success / len(inchi_list) inchi_list = [x if x else 'InChI=1S/H2O/h1H2' for x in inchi_list] return inchi_list, r_success def merge_inchi(inchi1, inchi2): replaced = 0 inchi1 = copy.deepcopy(inchi1) for i in range(len(inchi1)): if inchi1[i] == 'InChI=1S/H2O/h1H2': inchi1[i] = inchi2[i] replaced += 1 return inchi1, replaced def _get_num_atoms(smiles): try: return Chem.MolFromSmiles(smiles).GetNumAtoms() except: return 0 def get_num_atoms(smiles, num_workers=16): if type(smiles) is str: return _get_num_atoms(smiles) with multiprocessing.Pool(num_workers) as p: num_atoms = p.map(_get_num_atoms, smiles) return num_atoms def normalize_nodes(nodes, flip_y=True): x, y = nodes[:, 0], nodes[:, 1] minx, maxx = min(x), max(x) miny, maxy = min(y), max(y) x = (x - minx) / max(maxx - minx, 1e-6) if flip_y: y = (maxy - y) / max(maxy - miny, 1e-6) else: y = (y - miny) / max(maxy - miny, 1e-6) return np.stack([x, y], axis=1) def _verify_chirality(mol, coords, symbols, edges, debug=False): try: n = mol.GetNumAtoms() # Make a temp mol to find chiral centers mol_tmp = mol.GetMol() Chem.SanitizeMol(mol_tmp) chiral_centers = Chem.FindMolChiralCenters( mol_tmp, includeUnassigned=True, includeCIP=False, useLegacyImplementation=False) chiral_center_ids = [idx for idx, _ in chiral_centers] # List[Tuple[int, any]] -> List[int] # correction to clear pre-condition violation (for some corner cases) for bond in mol.GetBonds(): if bond.GetBondType() == Chem.BondType.SINGLE: bond.SetBondDir(Chem.BondDir.NONE) # Create conformer from 2D coordinate conf = Chem.Conformer(n) conf.Set3D(True) for i, (x, y) in enumerate(coords): conf.SetAtomPosition(i, (x, 1 - y, 0)) mol.AddConformer(conf) Chem.SanitizeMol(mol) Chem.AssignStereochemistryFrom3D(mol) # NOTE: seems that only AssignStereochemistryFrom3D can handle double bond E/Z # So we do this first, remove the conformer and add back the 2D conformer for chiral correction mol.RemoveAllConformers() conf = Chem.Conformer(n) conf.Set3D(False) for i, (x, y) in enumerate(coords): conf.SetAtomPosition(i, (x, 1 - y, 0)) mol.AddConformer(conf) # Magic, inferring chirality from coordinates and BondDir. DO NOT CHANGE. Chem.SanitizeMol(mol) Chem.AssignChiralTypesFromBondDirs(mol) Chem.AssignStereochemistry(mol, force=True) # Second loop to reset any wedge/dash bond to be starting from the chiral center) for i in chiral_center_ids: for j in range(n): if edges[i][j] == 5: # assert edges[j][i] == 6 mol.RemoveBond(i, j) mol.AddBond(i, j, Chem.BondType.SINGLE) mol.GetBondBetweenAtoms(i, j).SetBondDir(Chem.BondDir.BEGINWEDGE) elif edges[i][j] == 6: # assert edges[j][i] == 5 mol.RemoveBond(i, j) mol.AddBond(i, j, Chem.BondType.SINGLE) mol.GetBondBetweenAtoms(i, j).SetBondDir(Chem.BondDir.BEGINDASH) Chem.AssignChiralTypesFromBondDirs(mol) Chem.AssignStereochemistry(mol, force=True) # reset chiral tags for non-carbon atom for atom in mol.GetAtoms(): if atom.GetSymbol() != "C": atom.SetChiralTag(Chem.rdchem.ChiralType.CHI_UNSPECIFIED) mol = mol.GetMol() except Exception as e: if debug: raise e pass return mol def _parse_tokens(tokens: list): """ Parse tokens of condensed formula into list of pairs `(elt, num)` where `num` is the multiplicity of the atom (or nested condensed formula) `elt` Used by `_parse_formula`, which does the same thing but takes a formula in string form as input """ elements = [] i = 0 j = 0 while i < len(tokens): if tokens[i] == '(': while j < len(tokens) and tokens[j] != ')': j += 1 elt = _parse_tokens(tokens[i + 1:j]) else: elt = tokens[i] j += 1 if j < len(tokens) and tokens[j].isnumeric(): num = int(tokens[j]) j += 1 else: num = 1 elements.append((elt, num)) i = j return elements def _parse_formula(formula: str): """ Parse condensed formula into list of pairs `(elt, num)` where `num` is the subscript to the atom (or nested condensed formula) `elt` Example: "C2H4O" -> [('C', 2), ('H', 4), ('O', 1)] """ tokens = FORMULA_REGEX.findall(formula) # if ''.join(tokens) != formula: # tokens = FORMULA_REGEX_BACKUP.findall(formula) return _parse_tokens(tokens) def _expand_carbon(elements: list): """ Given list of pairs `(elt, num)`, output single list of all atoms in order, expanding carbon sequences (CaXb where a > 1 and X is halogen) if necessary Example: [('C', 2), ('H', 4), ('O', 1)] -> ['C', 'H', 'H', 'C', 'H', 'H', 'O']) """ expanded = [] i = 0 while i < len(elements): elt, num = elements[i] # expand carbon sequence if elt == 'C' and num > 1 and i + 1 < len(elements): next_elt, next_num = elements[i + 1] quotient, remainder = next_num // num, next_num % num for _ in range(num): expanded.append('C') for _ in range(quotient): expanded.append(next_elt) for _ in range(remainder): expanded.append(next_elt) i += 2 # recurse if `elt` itself is a list (nested formula) elif isinstance(elt, list): new_elt = _expand_carbon(elt) for _ in range(num): expanded.append(new_elt) i += 1 # simplest case: simply append `elt` `num` times else: for _ in range(num): expanded.append(elt) i += 1 return expanded def _expand_abbreviation(abbrev): """ Expand abbreviation into its SMILES; also converts [Rn] to [n*] Used in `_condensed_formula_list_to_smiles` when encountering abbrev. in condensed formula """ if abbrev in ABBREVIATIONS: return ABBREVIATIONS[abbrev].smiles if abbrev in RGROUP_SYMBOLS or (abbrev[0] == 'R' and abbrev[1:].isdigit()): if abbrev[1:].isdigit(): return f'[{abbrev[1:]}*]' return '*' return f'[{abbrev}]' def _get_bond_symb(bond_num): """ Get SMILES symbol for a bond given bond order Used in `_condensed_formula_list_to_smiles` while writing the SMILES string """ if bond_num == 0: return '.' if bond_num == 1: return '' if bond_num == 2: return '=' if bond_num == 3: return '#' return '' def _condensed_formula_list_to_smiles(formula_list, start_bond, end_bond=None, direction=None): """ Converts condensed formula (in the form of a list of symbols) to smiles Input: `formula_list`: e.g. ['C', 'H', 'H', 'N', ['C', 'H', 'H', 'H'], ['C', 'H', 'H', 'H']] for CH2N(CH3)2 `start_bond`: # bonds attached to beginning of formula `end_bond`: # bonds attached to end of formula (deduce automatically if None) `direction` (1, -1, or None): direction in which to process the list (1: left to right; -1: right to left; None: deduce automatically) Returns: `smiles`: smiles corresponding to input condensed formula `bonds_left`: bonds remaining at the end of the formula (for connecting back to main molecule); should equal `end_bond` if specified `num_trials`: number of trials `success` (bool): whether conversion was successful """ # `direction` not specified: try left to right; if fails, try right to left if direction is None: num_trials = 1 for dir_choice in [1, -1]: smiles, bonds_left, trials, success = _condensed_formula_list_to_smiles(formula_list, start_bond, end_bond, dir_choice) num_trials += trials if success: return smiles, bonds_left, num_trials, success return None, None, num_trials, False assert direction == 1 or direction == -1 def dfs(smiles, bonds_left, cur_idx, add_idx): """ `smiles`: SMILES string so far `cur_idx`: index (in list `formula`) of current atom (i.e. atom to which subsequent atoms are being attached) `cur_flat_idx`: index of current atom in list of atom tokens of SMILES so far `bonds_left`: bonds remaining on current atom for subsequent atoms to be attached to `add_idx`: index (in list `formula`) of atom to be attached to current atom `add_flat_idx`: index of atom to be added in list of atom tokens of SMILES so far Note: "atom" could refer to nested condensed formula (e.g. CH3 in CH2N(CH3)2) """ num_trials = 1 # end of formula: return result if (direction == 1 and add_idx == len(formula_list)) or (direction == -1 and add_idx == -1): if end_bond is not None and end_bond != bonds_left: return smiles, bonds_left, num_trials, False return smiles, bonds_left, num_trials, True # no more bonds but there are atoms remaining: conversion failed if bonds_left <= 0: return smiles, bonds_left, num_trials, False to_add = formula_list[add_idx] # atom to be added to current atom if isinstance(to_add, list): # "atom" added is a list (i.e. nested condensed formula): assume valence of 1 if bonds_left > 1: # "atom" added does not use up remaining bonds of current atom # get smiles of "atom" (which is itself a condensed formula) add_str, val, trials, success = _condensed_formula_list_to_smiles(to_add, 1, None, direction) if val > 0: add_str = _get_bond_symb(val + 1) + add_str num_trials += trials if not success: return smiles, bonds_left, num_trials, False # put smiles of "atom" in parentheses and append to smiles; go to next atom to add to current atom result = dfs(smiles + f'({add_str})', bonds_left - 1, cur_idx, add_idx + direction) else: # "atom" added uses up remaining bonds of current atom # get smiles of "atom" and bonds left on it add_str, bonds_left, trials, success = _condensed_formula_list_to_smiles(to_add, 1, None, direction) num_trials += trials if not success: return smiles, bonds_left, num_trials, False # append smiles of "atom" (without parentheses) to smiles; it becomes new current atom result = dfs(smiles + add_str, bonds_left, add_idx, add_idx + direction) smiles, bonds_left, trials, success = result num_trials += trials return smiles, bonds_left, num_trials, success # atom added is a single symbol (as opposed to nested condensed formula) for val in VALENCES.get(to_add, [1]): # try all possible valences of atom added add_str = _expand_abbreviation(to_add) # expand to smiles if symbol is abbreviation if bonds_left > val: # atom added does not use up remaining bonds of current atom; go to next atom to add to current atom if cur_idx >= 0: add_str = _get_bond_symb(val) + add_str result = dfs(smiles + f'({add_str})', bonds_left - val, cur_idx, add_idx + direction) else: # atom added uses up remaining bonds of current atom; it becomes new current atom if cur_idx >= 0: add_str = _get_bond_symb(bonds_left) + add_str result = dfs(smiles + add_str, val - bonds_left, add_idx, add_idx + direction) trials, success = result[2:] num_trials += trials if success: return result[0], result[1], num_trials, success if num_trials > 10000: break return smiles, bonds_left, num_trials, False cur_idx = -1 if direction == 1 else len(formula_list) add_idx = 0 if direction == 1 else len(formula_list) - 1 return dfs('', start_bond, cur_idx, add_idx) def get_smiles_from_symbol(symbol, mol, atom, bonds): """ Convert symbol (abbrev. or condensed formula) to smiles If condensed formula, determine parsing direction and num. bonds on each side using coordinates """ print(symbol) if symbol in ABBREVIATIONS: return ABBREVIATIONS[symbol].smiles if len(symbol) > 20: return None #mol_check = Chem.MolFromSmiles(symbol) #if mol_check: # print(symbol) # Print the symbol to debug # return symbol total_bonds = int(sum([bond.GetBondTypeAsDouble() for bond in bonds])) formula_list = _expand_carbon(_parse_formula(symbol)) smiles, bonds_left, num_trails, success = _condensed_formula_list_to_smiles(formula_list, total_bonds, None) if success: mol_check = Chem.MolFromSmiles(smiles) # Check if the SMILES is valid if mol_check: print(f"smiles:{smiles}") # Print the symbol to debug return smiles mol_check = Chem.MolFromSmiles(symbol) if mol_check: print(f"symbol:{symbol}") # Print the symbol to debug return symbol return None def _replace_functional_group(smiles): smiles = smiles.replace('', 'C') for i, r in enumerate(RGROUP_SYMBOLS): symbol = f'[{r}]' if symbol in smiles: if r[0] == 'R' and r[1:].isdigit(): smiles = smiles.replace(symbol, f'[{int(r[1:])}*]') else: smiles = smiles.replace(symbol, '*') # For unknown tokens (i.e. rdkit cannot parse), replace them with [{isotope}*], where isotope is an identifier. tokens = atomwise_tokenizer(smiles) new_tokens = [] mappings = {} # isotope : symbol isotope = 50 for token in tokens: if token[0] == '[': if token[1:-1] in ABBREVIATIONS or Chem.AtomFromSmiles(token) is None: while f'[{isotope}*]' in smiles or f'[{isotope}*]' in new_tokens: isotope += 1 placeholder = f'[{isotope}*]' mappings[isotope] = token[1:-1] new_tokens.append(placeholder) continue new_tokens.append(token) smiles = ''.join(new_tokens) return smiles, mappings def convert_smiles_to_mol(smiles): if smiles is None or smiles == '': return None try: mol = Chem.MolFromSmiles(smiles) except: return None return mol BOND_TYPES = {1: Chem.rdchem.BondType.SINGLE, 2: Chem.rdchem.BondType.DOUBLE, 3: Chem.rdchem.BondType.TRIPLE} def _expand_functional_group(mol, mappings, debug=False): def _need_expand(mol, mappings): return any([len(Chem.GetAtomAlias(atom)) > 0 for atom in mol.GetAtoms()]) or len(mappings) > 0 if _need_expand(mol, mappings): mol_w = Chem.RWMol(mol) num_atoms = mol_w.GetNumAtoms() for i, atom in enumerate(mol_w.GetAtoms()): # reset radical electrons atom.SetNumRadicalElectrons(0) atoms_to_remove = [] for i in range(num_atoms): atom = mol_w.GetAtomWithIdx(i) if atom.GetSymbol() == '*': symbol = Chem.GetAtomAlias(atom) isotope = atom.GetIsotope() if isotope > 0 and isotope in mappings: symbol = mappings[isotope] if not (isinstance(symbol, str) and len(symbol) > 0): continue # rgroups do not need to be expanded if symbol in RGROUP_SYMBOLS: continue bonds = atom.GetBonds() sub_smiles = get_smiles_from_symbol(symbol, mol_w, atom, bonds) # create mol object for abbreviation/condensed formula from its SMILES mol_r = convert_smiles_to_mol(sub_smiles) if mol_r is None: # atom.SetAtomicNum(6) atom.SetIsotope(0) continue # remove bonds connected to abbreviation/condensed formula adjacent_indices = [bond.GetOtherAtomIdx(i) for bond in bonds] for adjacent_idx in adjacent_indices: mol_w.RemoveBond(i, adjacent_idx) adjacent_atoms = [mol_w.GetAtomWithIdx(adjacent_idx) for adjacent_idx in adjacent_indices] for adjacent_atom, bond in zip(adjacent_atoms, bonds): adjacent_atom.SetNumRadicalElectrons(int(bond.GetBondTypeAsDouble())) # get indices of atoms of main body that connect to substituent bonding_atoms_w = adjacent_indices # assume indices are concated after combine mol_w and mol_r bonding_atoms_r = [mol_w.GetNumAtoms()] for atm in mol_r.GetAtoms(): if atm.GetNumRadicalElectrons() and atm.GetIdx() > 0: bonding_atoms_r.append(mol_w.GetNumAtoms() + atm.GetIdx()) # combine main body and substituent into a single molecule object combo = Chem.CombineMols(mol_w, mol_r) # connect substituent to main body with bonds mol_w = Chem.RWMol(combo) # if len(bonding_atoms_r) == 1: # substituent uses one atom to bond to main body for atm in bonding_atoms_w: bond_order = mol_w.GetAtomWithIdx(atm).GetNumRadicalElectrons() mol_w.AddBond(atm, bonding_atoms_r[0], order=BOND_TYPES[bond_order]) # reset radical electrons for atm in bonding_atoms_w: mol_w.GetAtomWithIdx(atm).SetNumRadicalElectrons(0) for atm in bonding_atoms_r: mol_w.GetAtomWithIdx(atm).SetNumRadicalElectrons(0) atoms_to_remove.append(i) # Remove atom in the end, otherwise the id will change # Reverse the order and remove atoms with larger id first atoms_to_remove.sort(reverse=True) for i in atoms_to_remove: mol_w.RemoveAtom(i) smiles = Chem.MolToSmiles(mol_w) mol = mol_w.GetMol() else: smiles = Chem.MolToSmiles(mol) return smiles, mol def _convert_graph_to_smiles(coords, symbols, edges, image=None, debug=False): mol = Chem.RWMol() n = len(symbols) ids = [] for i in range(n): symbol = symbols[i] if symbol[0] == '[': symbol = symbol[1:-1] if symbol in RGROUP_SYMBOLS: atom = Chem.Atom("*") if symbol[0] == 'R' and symbol[1:].isdigit(): atom.SetIsotope(int(symbol[1:])) Chem.SetAtomAlias(atom, symbol) elif symbol in ABBREVIATIONS: atom = Chem.Atom("*") Chem.SetAtomAlias(atom, symbol) else: try: # try to get SMILES of atom atom = Chem.AtomFromSmiles(symbols[i]) atom.SetChiralTag(Chem.rdchem.ChiralType.CHI_UNSPECIFIED) except: # otherwise, abbreviation or condensed formula atom = Chem.Atom("*") Chem.SetAtomAlias(atom, symbol) if atom.GetSymbol() == '*': atom.SetProp('molFileAlias', symbol) idx = mol.AddAtom(atom) assert idx == i ids.append(idx) for i in range(n): for j in range(i + 1, n): if edges[i][j] == 1: mol.AddBond(ids[i], ids[j], Chem.BondType.SINGLE) elif edges[i][j] == 2: mol.AddBond(ids[i], ids[j], Chem.BondType.DOUBLE) elif edges[i][j] == 3: mol.AddBond(ids[i], ids[j], Chem.BondType.TRIPLE) elif edges[i][j] == 4: mol.AddBond(ids[i], ids[j], Chem.BondType.AROMATIC) elif edges[i][j] == 5: mol.AddBond(ids[i], ids[j], Chem.BondType.SINGLE) mol.GetBondBetweenAtoms(ids[i], ids[j]).SetBondDir(Chem.BondDir.BEGINWEDGE) elif edges[i][j] == 6: mol.AddBond(ids[i], ids[j], Chem.BondType.SINGLE) mol.GetBondBetweenAtoms(ids[i], ids[j]).SetBondDir(Chem.BondDir.BEGINDASH) pred_smiles = '' try: # TODO: move to an util function if image is not None: height, width, _ = image.shape ratio = width / height coords = [[x * ratio * 10, y * 10] for x, y in coords] mol = _verify_chirality(mol, coords, symbols, edges, debug) # molblock is obtained before expanding func groups, otherwise the expanded group won't have coordinates. # TODO: make sure molblock has the abbreviation information pred_molblock = Chem.MolToMolBlock(mol) pred_smiles, mol = _expand_functional_group(mol, {}, debug) success = True except Exception as e: if debug: print(traceback.format_exc()) pred_molblock = '' success = False if debug: return pred_smiles, pred_molblock, mol, success return pred_smiles, pred_molblock, success def convert_graph_to_smiles(coords, symbols, edges, images=None, num_workers=16): with multiprocessing.Pool(num_workers) as p: if images is None: results = p.starmap(_convert_graph_to_smiles, zip(coords, symbols, edges), chunksize=128) else: results = p.starmap(_convert_graph_to_smiles, zip(coords, symbols, edges, images), chunksize=128) smiles_list, molblock_list, success = zip(*results) r_success = np.mean(success) return smiles_list, molblock_list, r_success def _postprocess_smiles(smiles, coords=None, symbols=None, edges=None, molblock=False, debug=False): if type(smiles) is not str or smiles == '': return '', False mol = None pred_molblock = '' try: pred_smiles = smiles pred_smiles, mappings = _replace_functional_group(pred_smiles) if coords is not None and symbols is not None and edges is not None: pred_smiles = pred_smiles.replace('@', '').replace('/', '').replace('\\', '') mol = Chem.RWMol(Chem.MolFromSmiles(pred_smiles, sanitize=False)) mol = _verify_chirality(mol, coords, symbols, edges, debug) else: mol = Chem.MolFromSmiles(pred_smiles, sanitize=False) # pred_smiles = Chem.MolToSmiles(mol, isomericSmiles=True, canonical=True) if molblock: pred_molblock = Chem.MolToMolBlock(mol) pred_smiles, mol = _expand_functional_group(mol, mappings) success = True except Exception as e: if debug: print(traceback.format_exc()) pred_smiles = smiles pred_molblock = '' success = False if debug: return pred_smiles, pred_molblock, mol, success return pred_smiles, pred_molblock, success def postprocess_smiles(smiles, coords=None, symbols=None, edges=None, molblock=False, num_workers=16): with multiprocessing.Pool(num_workers) as p: if coords is not None and symbols is not None and edges is not None: results = p.starmap(_postprocess_smiles, zip(smiles, coords, symbols, edges), chunksize=128) else: results = p.map(_postprocess_smiles, smiles, chunksize=128) smiles_list, molblock_list, success = zip(*results) r_success = np.mean(success) return smiles_list, molblock_list, r_success def _keep_main_molecule(smiles, debug=False): try: mol = Chem.MolFromSmiles(smiles) frags = Chem.GetMolFrags(mol, asMols=True) if len(frags) > 1: num_atoms = [m.GetNumAtoms() for m in frags] main_mol = frags[np.argmax(num_atoms)] smiles = Chem.MolToSmiles(main_mol) except Exception as e: if debug: print(traceback.format_exc()) return smiles def keep_main_molecule(smiles, num_workers=16): with multiprocessing.Pool(num_workers) as p: results = p.map(_keep_main_molecule, smiles, chunksize=128) return results