import torch from rdkit import Chem import os import imageio import matplotlib.pyplot as plt import numpy as np from openbabel import openbabel as ob from sklearn.decomposition import PCA from src import const from src.molecule_builder import get_bond_order ob.obErrorLog.SetOutputLevel(1) def save_xyz_files(path, one_hot, positions, node_mask, names, is_geom, suffix=''): idx2atom = const.GEOM_IDX2ATOM if is_geom else const.IDX2ATOM for batch_i in range(one_hot.size(0)): mask = node_mask[batch_i].squeeze() n_atoms = mask.sum() atom_idx = torch.where(mask)[0] f = open(os.path.join(path, f'{names[batch_i]}_{suffix}.xyz'), "w") f.write("%d\n\n" % n_atoms) atoms = torch.argmax(one_hot[batch_i], dim=1) for atom_i in atom_idx: atom = atoms[atom_i].item() atom = idx2atom[atom] f.write("%s %.9f %.9f %.9f\n" % ( atom, positions[batch_i, atom_i, 0], positions[batch_i, atom_i, 1], positions[batch_i, atom_i, 2] )) f.close() # def coords_to_mol(one_hot, positions, node_mask, is_geom): # # Select atom mapping based on whether geometry or generic atoms are used # idx2atom = const.GEOM_IDX2ATOM if is_geom else const.IDX2ATOM # # # Identify valid atoms based on the mask # mask = node_mask.squeeze() # atom_indices = torch.where(mask)[0] # # # Create an editable molecule object # mol = Chem.RWMol() # # # Add atoms to the molecule # atoms = torch.argmax(one_hot, dim=1) # for idx, atom_i in enumerate(atom_indices): # atom = atoms[atom_i].item() # atom_symbol = idx2atom[atom] # mol.AddAtom(Chem.Atom(atom_symbol)) # # # Add a conformer to the molecule to set atom positions # conformer = Chem.Conformer(mol.GetNumAtoms()) # mol.AddConformer(conformer) # # # Set atom positions # for idx, atom_i in enumerate(atom_indices): # mol.GetConformer().SetAtomPosition(idx, ( # positions[atom_i, 0].item(), # positions[atom_i, 1].item(), # positions[atom_i, 2].item() # )) # # # Generate and return SMILES # smiles = Chem.MolToSmiles(mol) # return smiles, mol def save_sdf(path, one_hot, positions, node_mask, is_geom): # Select atom mapping based on whether geometry or generic atoms are used idx2atom = const.GEOM_IDX2ATOM if is_geom else const.IDX2ATOM # Identify valid atoms based on the mask mask = node_mask.squeeze() atom_indices = torch.where(mask)[0] obMol = ob.OBMol() # Add atoms to OpenBabel molecule atoms = torch.argmax(one_hot, dim=1) for atom_i in atom_indices: atom = atoms[atom_i].item() atom_symbol = idx2atom[atom] obAtom = obMol.NewAtom() obAtom.SetAtomicNum(Chem.GetPeriodicTable().GetAtomicNumber(atom_symbol)) # Set atomic number # Set atomic positions pos = positions[atom_i] obAtom.SetVector(pos[0].item(), pos[1].item(), pos[2].item()) # Infer bonds with OpenBabel obMol.ConnectTheDots() obMol.PerceiveBondOrders() # Convert OpenBabel molecule to SDF obConversion = ob.OBConversion() obConversion.SetOutFormat("sdf") sdf_string = obConversion.WriteString(obMol) # Save SDF file with open(path, "w") as f: f.write(sdf_string) # Generate SMILES rdkit_mol = Chem.MolFromMolBlock(sdf_string) if rdkit_mol is not None: smiles = Chem.MolToSmiles(rdkit_mol) else: # Use OpenBabel to generate SMILES if RDKit fails obConversion.SetOutFormat("can") smiles = obConversion.WriteString(obMol).strip() return smiles def load_xyz_files(path, suffix=''): files = [] for fname in os.listdir(path): if fname.endswith(f'_{suffix}.xyz'): files.append(fname) files = sorted(files, key=lambda f: -int(f.replace(f'_{suffix}.xyz', '').split('_')[-1])) return [os.path.join(path, fname) for fname in files] def load_molecule_xyz(file, is_geom): atom2idx = const.GEOM_ATOM2IDX if is_geom else const.ATOM2IDX idx2atom = const.GEOM_IDX2ATOM if is_geom else const.IDX2ATOM with open(file, encoding='utf8') as f: n_atoms = int(f.readline()) one_hot = torch.zeros(n_atoms, len(idx2atom)) charges = torch.zeros(n_atoms, 1) positions = torch.zeros(n_atoms, 3) f.readline() atoms = f.readlines() for i in range(n_atoms): atom = atoms[i].split(' ') atom_type = atom[0] one_hot[i, atom2idx[atom_type]] = 1 position = torch.Tensor([float(e) for e in atom[1:]]) positions[i, :] = position return positions, one_hot, charges def draw_sphere(ax, x, y, z, size, color, alpha): u = np.linspace(0, 2 * np.pi, 100) v = np.linspace(0, np.pi, 100) xs = size * np.outer(np.cos(u), np.sin(v)) ys = size * np.outer(np.sin(u), np.sin(v)) #* 0.8 zs = size * np.outer(np.ones(np.size(u)), np.cos(v)) ax.plot_surface(x + xs, y + ys, z + zs, rstride=2, cstride=2, color=color, alpha=alpha) def plot_molecule(ax, positions, atom_type, alpha, spheres_3d, hex_bg_color, is_geom, fragment_mask=None): x = positions[:, 0] y = positions[:, 1] z = positions[:, 2] # Hydrogen, Carbon, Nitrogen, Oxygen, Flourine idx2atom = const.GEOM_IDX2ATOM if is_geom else const.IDX2ATOM colors_dic = np.array(const.COLORS) radius_dic = np.array(const.RADII) area_dic = 1500 * radius_dic ** 2 areas = area_dic[atom_type] radii = radius_dic[atom_type] colors = colors_dic[atom_type] if fragment_mask is None: fragment_mask = torch.ones(len(x)) for i in range(len(x)): for j in range(i + 1, len(x)): p1 = np.array([x[i], y[i], z[i]]) p2 = np.array([x[j], y[j], z[j]]) dist = np.sqrt(np.sum((p1 - p2) ** 2)) atom1, atom2 = idx2atom[atom_type[i]], idx2atom[atom_type[j]] draw_edge_int = get_bond_order(atom1, atom2, dist) line_width = (3 - 2) * 2 * 2 draw_edge = draw_edge_int > 0 if draw_edge: if draw_edge_int == 4: linewidth_factor = 1.5 else: linewidth_factor = 1 linewidth_factor *= 0.5 ax.plot( [x[i], x[j]], [y[i], y[j]], [z[i], z[j]], linewidth=line_width * linewidth_factor * 2, c=hex_bg_color, alpha=alpha ) # from pdb import set_trace # set_trace() if spheres_3d: # idx = torch.where(fragment_mask[:len(x)] == 0)[0] # ax.scatter( # x[idx], # y[idx], # z[idx], # alpha=0.9 * alpha, # edgecolors='#FCBA03', # facecolors='none', # linewidths=2, # s=900 # ) for i, j, k, s, c, f in zip(x, y, z, radii, colors, fragment_mask): if f == 1: alpha = 1.0 draw_sphere(ax, i.item(), j.item(), k.item(), 0.5 * s, c, alpha) else: ax.scatter(x, y, z, s=areas, alpha=0.9 * alpha, c=colors) def plot_data3d(positions, atom_type, is_geom, camera_elev=0, camera_azim=0, save_path=None, spheres_3d=False, bg='black', alpha=1., fragment_mask=None): black = (0, 0, 0) white = (1, 1, 1) hex_bg_color = '#FFFFFF' if bg == 'black' else '#000000' #'#666666' fig = plt.figure(figsize=(10, 10)) ax = fig.add_subplot(projection='3d') ax.set_aspect('auto') ax.view_init(elev=camera_elev, azim=camera_azim) if bg == 'black': ax.set_facecolor(black) else: ax.set_facecolor(white) ax.xaxis.pane.set_alpha(0) ax.yaxis.pane.set_alpha(0) ax.zaxis.pane.set_alpha(0) ax._axis3don = False if bg == 'black': ax.w_xaxis.line.set_color("black") else: ax.w_xaxis.line.set_color("white") plot_molecule( ax, positions, atom_type, alpha, spheres_3d, hex_bg_color, is_geom=is_geom, fragment_mask=fragment_mask ) max_value = positions.abs().max().item() axis_lim = min(40, max(max_value / 1.5 + 0.3, 3.2)) ax.set_xlim(-axis_lim, axis_lim) ax.set_ylim(-axis_lim, axis_lim) ax.set_zlim(-axis_lim, axis_lim) dpi = 120 if spheres_3d else 50 if save_path is not None: plt.savefig(save_path, bbox_inches='tight', pad_inches=0.0, dpi=dpi) # plt.savefig(save_path, bbox_inches='tight', pad_inches=0.0, dpi=dpi, transparent=True) if spheres_3d: img = imageio.imread(save_path) img_brighter = np.clip(img * 1.4, 0, 255).astype('uint8') imageio.imsave(save_path, img_brighter) else: plt.close() def visualize_chain( path, spheres_3d=False, bg="black", alpha=1.0, wandb=None, mode="chain", is_geom=False, fragment_mask=None ): files = load_xyz_files(path) save_paths = [] # Fit PCA to the final molecule – to obtain the best orientation for visualization positions, one_hot, charges = load_molecule_xyz(files[-1], is_geom=is_geom) pca = PCA(n_components=3) for i in range(len(files)): file = files[i] positions, one_hot, charges = load_molecule_xyz(file, is_geom=is_geom) atom_type = torch.argmax(one_hot, dim=1).numpy() # Transform positions of each frame according to the best orientation of the last frame positions = pca.transform(positions) positions = torch.tensor(positions) fn = file[:-4] + '.png' plot_data3d( positions, atom_type, save_path=fn, spheres_3d=spheres_3d, alpha=alpha, bg=bg, camera_elev=90, camera_azim=90, is_geom=is_geom, fragment_mask=fragment_mask, ) save_paths.append(fn) imgs = [imageio.imread(fn) for fn in save_paths] dirname = os.path.dirname(save_paths[0]) gif_path = dirname + '/output.gif' imageio.mimsave(gif_path, imgs, subrectangles=True) if wandb is not None: wandb.log({mode: [wandb.Video(gif_path, caption=gif_path)]})