GenFBDD / src /visualizer.py
libokj's picture
Initial commit GenFBDD
9439b9b
import torch
from rdkit import Chem
import os
import imageio
import matplotlib.pyplot as plt
import numpy as np
from openbabel import openbabel as ob
from sklearn.decomposition import PCA
from src import const
from src.molecule_builder import get_bond_order
ob.obErrorLog.SetOutputLevel(1)
def save_xyz_files(path, one_hot, positions, node_mask, names, is_geom, suffix=''):
idx2atom = const.GEOM_IDX2ATOM if is_geom else const.IDX2ATOM
for batch_i in range(one_hot.size(0)):
mask = node_mask[batch_i].squeeze()
n_atoms = mask.sum()
atom_idx = torch.where(mask)[0]
f = open(os.path.join(path, f'{names[batch_i]}_{suffix}.xyz'), "w")
f.write("%d\n\n" % n_atoms)
atoms = torch.argmax(one_hot[batch_i], dim=1)
for atom_i in atom_idx:
atom = atoms[atom_i].item()
atom = idx2atom[atom]
f.write("%s %.9f %.9f %.9f\n" % (
atom, positions[batch_i, atom_i, 0], positions[batch_i, atom_i, 1], positions[batch_i, atom_i, 2]
))
f.close()
# def coords_to_mol(one_hot, positions, node_mask, is_geom):
# # Select atom mapping based on whether geometry or generic atoms are used
# idx2atom = const.GEOM_IDX2ATOM if is_geom else const.IDX2ATOM
#
# # Identify valid atoms based on the mask
# mask = node_mask.squeeze()
# atom_indices = torch.where(mask)[0]
#
# # Create an editable molecule object
# mol = Chem.RWMol()
#
# # Add atoms to the molecule
# atoms = torch.argmax(one_hot, dim=1)
# for idx, atom_i in enumerate(atom_indices):
# atom = atoms[atom_i].item()
# atom_symbol = idx2atom[atom]
# mol.AddAtom(Chem.Atom(atom_symbol))
#
# # Add a conformer to the molecule to set atom positions
# conformer = Chem.Conformer(mol.GetNumAtoms())
# mol.AddConformer(conformer)
#
# # Set atom positions
# for idx, atom_i in enumerate(atom_indices):
# mol.GetConformer().SetAtomPosition(idx, (
# positions[atom_i, 0].item(),
# positions[atom_i, 1].item(),
# positions[atom_i, 2].item()
# ))
#
# # Generate and return SMILES
# smiles = Chem.MolToSmiles(mol)
# return smiles, mol
def save_sdf(path, one_hot, positions, node_mask, is_geom):
# Select atom mapping based on whether geometry or generic atoms are used
idx2atom = const.GEOM_IDX2ATOM if is_geom else const.IDX2ATOM
# Identify valid atoms based on the mask
mask = node_mask.squeeze()
atom_indices = torch.where(mask)[0]
obMol = ob.OBMol()
# Add atoms to OpenBabel molecule
atoms = torch.argmax(one_hot, dim=1)
for atom_i in atom_indices:
atom = atoms[atom_i].item()
atom_symbol = idx2atom[atom]
obAtom = obMol.NewAtom()
obAtom.SetAtomicNum(Chem.GetPeriodicTable().GetAtomicNumber(atom_symbol)) # Set atomic number
# Set atomic positions
pos = positions[atom_i]
obAtom.SetVector(pos[0].item(), pos[1].item(), pos[2].item())
# Infer bonds with OpenBabel
obMol.ConnectTheDots()
obMol.PerceiveBondOrders()
# Convert OpenBabel molecule to SDF
obConversion = ob.OBConversion()
obConversion.SetOutFormat("sdf")
sdf_string = obConversion.WriteString(obMol)
# Save SDF file
with open(path, "w") as f:
f.write(sdf_string)
# Generate SMILES
rdkit_mol = Chem.MolFromMolBlock(sdf_string)
if rdkit_mol is not None:
smiles = Chem.MolToSmiles(rdkit_mol)
else:
# Use OpenBabel to generate SMILES if RDKit fails
obConversion.SetOutFormat("can")
smiles = obConversion.WriteString(obMol).strip()
return smiles
def load_xyz_files(path, suffix=''):
files = []
for fname in os.listdir(path):
if fname.endswith(f'_{suffix}.xyz'):
files.append(fname)
files = sorted(files, key=lambda f: -int(f.replace(f'_{suffix}.xyz', '').split('_')[-1]))
return [os.path.join(path, fname) for fname in files]
def load_molecule_xyz(file, is_geom):
atom2idx = const.GEOM_ATOM2IDX if is_geom else const.ATOM2IDX
idx2atom = const.GEOM_IDX2ATOM if is_geom else const.IDX2ATOM
with open(file, encoding='utf8') as f:
n_atoms = int(f.readline())
one_hot = torch.zeros(n_atoms, len(idx2atom))
charges = torch.zeros(n_atoms, 1)
positions = torch.zeros(n_atoms, 3)
f.readline()
atoms = f.readlines()
for i in range(n_atoms):
atom = atoms[i].split(' ')
atom_type = atom[0]
one_hot[i, atom2idx[atom_type]] = 1
position = torch.Tensor([float(e) for e in atom[1:]])
positions[i, :] = position
return positions, one_hot, charges
def draw_sphere(ax, x, y, z, size, color, alpha):
u = np.linspace(0, 2 * np.pi, 100)
v = np.linspace(0, np.pi, 100)
xs = size * np.outer(np.cos(u), np.sin(v))
ys = size * np.outer(np.sin(u), np.sin(v)) #* 0.8
zs = size * np.outer(np.ones(np.size(u)), np.cos(v))
ax.plot_surface(x + xs, y + ys, z + zs, rstride=2, cstride=2, color=color, alpha=alpha)
def plot_molecule(ax, positions, atom_type, alpha, spheres_3d, hex_bg_color, is_geom, fragment_mask=None):
x = positions[:, 0]
y = positions[:, 1]
z = positions[:, 2]
# Hydrogen, Carbon, Nitrogen, Oxygen, Flourine
idx2atom = const.GEOM_IDX2ATOM if is_geom else const.IDX2ATOM
colors_dic = np.array(const.COLORS)
radius_dic = np.array(const.RADII)
area_dic = 1500 * radius_dic ** 2
areas = area_dic[atom_type]
radii = radius_dic[atom_type]
colors = colors_dic[atom_type]
if fragment_mask is None:
fragment_mask = torch.ones(len(x))
for i in range(len(x)):
for j in range(i + 1, len(x)):
p1 = np.array([x[i], y[i], z[i]])
p2 = np.array([x[j], y[j], z[j]])
dist = np.sqrt(np.sum((p1 - p2) ** 2))
atom1, atom2 = idx2atom[atom_type[i]], idx2atom[atom_type[j]]
draw_edge_int = get_bond_order(atom1, atom2, dist)
line_width = (3 - 2) * 2 * 2
draw_edge = draw_edge_int > 0
if draw_edge:
if draw_edge_int == 4:
linewidth_factor = 1.5
else:
linewidth_factor = 1
linewidth_factor *= 0.5
ax.plot(
[x[i], x[j]], [y[i], y[j]], [z[i], z[j]],
linewidth=line_width * linewidth_factor * 2,
c=hex_bg_color,
alpha=alpha
)
# from pdb import set_trace
# set_trace()
if spheres_3d:
# idx = torch.where(fragment_mask[:len(x)] == 0)[0]
# ax.scatter(
# x[idx],
# y[idx],
# z[idx],
# alpha=0.9 * alpha,
# edgecolors='#FCBA03',
# facecolors='none',
# linewidths=2,
# s=900
# )
for i, j, k, s, c, f in zip(x, y, z, radii, colors, fragment_mask):
if f == 1:
alpha = 1.0
draw_sphere(ax, i.item(), j.item(), k.item(), 0.5 * s, c, alpha)
else:
ax.scatter(x, y, z, s=areas, alpha=0.9 * alpha, c=colors)
def plot_data3d(positions, atom_type, is_geom, camera_elev=0, camera_azim=0, save_path=None, spheres_3d=False,
bg='black', alpha=1., fragment_mask=None):
black = (0, 0, 0)
white = (1, 1, 1)
hex_bg_color = '#FFFFFF' if bg == 'black' else '#000000' #'#666666'
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(projection='3d')
ax.set_aspect('auto')
ax.view_init(elev=camera_elev, azim=camera_azim)
if bg == 'black':
ax.set_facecolor(black)
else:
ax.set_facecolor(white)
ax.xaxis.pane.set_alpha(0)
ax.yaxis.pane.set_alpha(0)
ax.zaxis.pane.set_alpha(0)
ax._axis3don = False
if bg == 'black':
ax.w_xaxis.line.set_color("black")
else:
ax.w_xaxis.line.set_color("white")
plot_molecule(
ax, positions, atom_type, alpha, spheres_3d, hex_bg_color, is_geom=is_geom, fragment_mask=fragment_mask
)
max_value = positions.abs().max().item()
axis_lim = min(40, max(max_value / 1.5 + 0.3, 3.2))
ax.set_xlim(-axis_lim, axis_lim)
ax.set_ylim(-axis_lim, axis_lim)
ax.set_zlim(-axis_lim, axis_lim)
dpi = 120 if spheres_3d else 50
if save_path is not None:
plt.savefig(save_path, bbox_inches='tight', pad_inches=0.0, dpi=dpi)
# plt.savefig(save_path, bbox_inches='tight', pad_inches=0.0, dpi=dpi, transparent=True)
if spheres_3d:
img = imageio.imread(save_path)
img_brighter = np.clip(img * 1.4, 0, 255).astype('uint8')
imageio.imsave(save_path, img_brighter)
else:
plt.show()
plt.close()
def visualize_chain(
path, spheres_3d=False, bg="black", alpha=1.0, wandb=None, mode="chain", is_geom=False, fragment_mask=None
):
files = load_xyz_files(path)
save_paths = []
# Fit PCA to the final molecule – to obtain the best orientation for visualization
positions, one_hot, charges = load_molecule_xyz(files[-1], is_geom=is_geom)
pca = PCA(n_components=3)
pca.fit(positions)
for i in range(len(files)):
file = files[i]
positions, one_hot, charges = load_molecule_xyz(file, is_geom=is_geom)
atom_type = torch.argmax(one_hot, dim=1).numpy()
# Transform positions of each frame according to the best orientation of the last frame
positions = pca.transform(positions)
positions = torch.tensor(positions)
fn = file[:-4] + '.png'
plot_data3d(
positions, atom_type,
save_path=fn,
spheres_3d=spheres_3d,
alpha=alpha,
bg=bg,
camera_elev=90,
camera_azim=90,
is_geom=is_geom,
fragment_mask=fragment_mask,
)
save_paths.append(fn)
imgs = [imageio.imread(fn) for fn in save_paths]
dirname = os.path.dirname(save_paths[0])
gif_path = dirname + '/output.gif'
imageio.mimsave(gif_path, imgs, subrectangles=True)
if wandb is not None:
wandb.log({mode: [wandb.Video(gif_path, caption=gif_path)]})