Spaces:
Sleeping
Sleeping
import pandas as pd | |
import shutil, os | |
import os.path as osp | |
import numpy as np | |
from tqdm import tqdm | |
import torch | |
from torch_geometric.data import Data | |
from torch.autograd import Variable | |
from rdkit import Chem | |
from data.features import ( | |
allowable_features, | |
atom_to_feature_vector, | |
bond_to_feature_vector, | |
atom_feature_vector_to_dict, | |
bond_feature_vector_to_dict, | |
) | |
from utils.data_util import one_hot_vector_sm, one_hot_vector_am, get_atom_feature_dims | |
def load_dataset( | |
cross_val, binary_task, target, args, use_prot=False, advs=False, test=False | |
): | |
""" | |
Load data and return data in dataframes format for each split and the loader of each split. | |
Args: | |
cross_val (int): Data partition being used [1-4]. | |
binary_tast (boolean): Whether to perform binary classification or multiclass classification. | |
target (string): Name of the protein target for binary classification. | |
args (parser): Complete arguments (configuration) of the model. | |
use_prot (boolean): Whether to use the PM module. | |
advs (boolean): Whether to train the LM module with adversarial augmentations. | |
test (boolean): Whether the model is being tested or trained. | |
Return: | |
train (loader): Training loader | |
valid (loader): Validation loader | |
test (loader): Test loader | |
data_train (dataframe): Training data dataframe | |
data_valid (dataframe): Validation data dataframe | |
data_test (dataframe): Test data dataframe | |
""" | |
# TODO: NO QUEREMOS QUE ESTÉ LA PARTICIÓN DEL MULTICLASE? | |
# Read all data files | |
if not test: | |
# Verify cross validation partition is defined | |
assert cross_val in [1, 2, 3, 4], "{} data partition is not defined".format( | |
cross_val | |
) | |
print("Loading data...") | |
if binary_task: | |
path = "data/datasets/AD/" | |
A = pd.read_csv( | |
path + "Smiles_AD_1.csv", names=["Smiles", "Target", "Label"] | |
) | |
B = pd.read_csv( | |
path + "Smiles_AD_2.csv", names=["Smiles", "Target", "Label"] | |
) | |
C = pd.read_csv( | |
path + "Smiles_AD_3.csv", names=["Smiles", "Target", "Label"] | |
) | |
D = pd.read_csv( | |
path + "Smiles_AD_4.csv", names=["Smiles", "Target", "Label"] | |
) | |
data_test = pd.read_csv( | |
path + "AD_Test.csv", names=["Smiles", "Target", "Label"] | |
) | |
if use_prot: | |
data_target = pd.read_csv( | |
path + "Targets_Fasta.csv", names=["Fasta", "Target", "Label"] | |
) | |
else: | |
data_target = [] | |
# Generate train and validation splits according to cross validation number | |
if cross_val == 1: | |
data_train = pd.concat([A, B, C], ignore_index=True) | |
data_val = D | |
elif cross_val == 2: | |
data_train = pd.concat([A, C, D], ignore_index=True) | |
data_val = B | |
elif cross_val == 3: | |
data_train = pd.concat([A, B, D], ignore_index=True) | |
data_val = C | |
elif cross_val == 4: | |
data_train = pd.concat([B, C, D], ignore_index=True) | |
data_val = A | |
# If in binary classification select data for the specific target being train | |
if binary_task: | |
data_train = data_train[data_train.Target == target] | |
data_val = data_val[data_val.Target == target] | |
data_test = data_test[data_test.Target == target] | |
if use_prot: | |
data_target = data_target[data_target.Target == target] | |
# Get dataset for each split | |
train = get_dataset(data_train, use_prot, data_target, args, advs) | |
valid = get_dataset(data_val, use_prot, data_target, args) | |
test = get_dataset(data_test, use_prot, data_target, args) | |
else: | |
# Read test data file | |
if binary_task: | |
path = "data/datasets/AD/" | |
data_test = pd.read_csv( | |
path + "Smiles_AD_Test.csv", names=["Smiles", "Target", "Label"] | |
) | |
data_test = data_test[data_test.Target == target] | |
if use_prot: | |
data_target = pd.read_csv( | |
path + "Targets_Fasta.csv", names=["Fasta", "Target", "Label"] | |
) | |
data_target = data_target[data_target.Target == target] | |
else: | |
data_target = [] | |
test = get_dataset(data_test,target=data_target, use_prot=use_prot, args=args, advs=advs, saliency=args.saliency) | |
train = [] | |
valid = [] | |
data_train = [] | |
data_val = [] | |
print("Done.") | |
return train, valid, test, data_train, data_val, data_test | |
def reload_dataset(cross_val, binary_task, target, args, advs=False): | |
print("Reloading data") | |
args.edge_dict = {} | |
if binary_task: | |
path = "data/datasets/AD/" | |
A = pd.read_csv(path + "Smiles_AD_1.csv", names=["Smiles", "Target", "Label"]) | |
B = pd.read_csv(path + "Smiles_AD_2.csv", names=["Smiles", "Target", "Label"]) | |
C = pd.read_csv(path + "Smiles_AD_3.csv", names=["Smiles", "Target", "Label"]) | |
D = pd.read_csv(path + "Smiles_AD_4.csv", names=["Smiles", "Target", "Label"]) | |
data_test = pd.read_csv( | |
path + "AD_Test.csv", names=["Smiles", "Target", "Label"] | |
) | |
if cross_val == 1: | |
data_train = pd.concat([A, B, C], ignore_index=True) | |
elif cross_val == 2: | |
data_train = pd.concat([A, C, D], ignore_index=True) | |
elif cross_val == 3: | |
data_train = pd.concat([A, B, D], ignore_index=True) | |
else: | |
data_train = pd.concat([B, C, D], ignore_index=True) | |
if binary_task: | |
data_train = data_train[data_train.Target == target] | |
train = get_dataset(data_train, args=args, advs=advs) | |
print("Done.") | |
return train, data_train | |
def smiles_to_graph(smiles_string, is_prot=False, received_mol=False, saliency=False): | |
""" | |
Converts SMILES string to graph Data object | |
:input: SMILES string (str) | |
:return: graph object | |
""" | |
if not is_prot: | |
mol = Chem.MolFromSmiles(smiles_string) | |
else: | |
mol = Chem.MolFromFASTA(smiles_string) | |
# atoms | |
atom_features_list = [] | |
atom_feat_dims = get_atom_feature_dims() | |
for atom in mol.GetAtoms(): | |
ftrs = atom_to_feature_vector(atom) | |
if saliency: | |
ftrs_oh = one_hot_vector_am(ftrs, atom_feat_dims) | |
atom_features_list.append(torch.unsqueeze(ftrs_oh, 0)) | |
else: | |
atom_features_list.append(ftrs) | |
if saliency: | |
x = torch.cat(atom_features_list) | |
else: | |
x = np.array(atom_features_list, dtype=np.int64) | |
# bonds | |
num_bond_features = 3 # bond type, bond stereo, is_conjugated | |
if len(mol.GetBonds()) > 0: # mol has bonds | |
edges_list = [] | |
edge_features_list = [] | |
for bond in mol.GetBonds(): | |
i = bond.GetBeginAtomIdx() | |
j = bond.GetEndAtomIdx() | |
edge_feature = bond_to_feature_vector(bond) | |
# add edges in both directions | |
edges_list.append((i, j)) | |
edge_features_list.append(edge_feature) | |
edges_list.append((j, i)) | |
edge_features_list.append(edge_feature) | |
# data.edge_index: Graph connectivity in COO format with shape [2, num_edges] | |
edge_index = np.array(edges_list, dtype=np.int64).T | |
# data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features] | |
edge_attr = np.array(edge_features_list, dtype=np.int64) | |
else: # mol has no bonds | |
edge_index = np.empty((2, 0), dtype=np.int64) | |
edge_attr = np.empty((0, num_bond_features), dtype=np.int64) | |
return edge_attr, edge_index, x | |
def smiles_to_graph_advs( | |
smiles_string, args, advs=False, received_mol=False, saliency=False | |
): | |
""" | |
Converts SMILES string to graph Data object | |
:input: SMILES string (str) | |
:return: graph object | |
""" | |
if not received_mol: | |
mol = Chem.MolFromSmiles(smiles_string) | |
else: | |
mol = smiles_string | |
# atoms | |
atom_features_list = [] | |
atom_feat_dims = get_atom_feature_dims() | |
for atom in mol.GetAtoms(): | |
ftrs = atom_to_feature_vector(atom) | |
if saliency: | |
ftrs_oh = one_hot_vector_am(ftrs, atom_feat_dims) | |
atom_features_list.append(torch.unsqueeze(ftrs_oh, 0)) | |
else: | |
atom_features_list.append(ftrs) | |
if saliency: | |
x = torch.cat(atom_features_list) | |
else: | |
x = np.array(atom_features_list, dtype=np.int64) | |
if advs: | |
# bonds | |
mol_edge_dict = {} | |
num_bond_features = 3 # bond type, bond stereo, is_conjugated | |
features_dim1 = torch.eye(5) | |
features_dim2 = torch.eye(6) | |
features_dim3 = torch.eye(2) | |
if len(mol.GetBonds()) > 0: # mol has bonds | |
edges_list = [] | |
edge_features_list = [] | |
for bond in mol.GetBonds(): | |
i = bond.GetBeginAtomIdx() | |
j = bond.GetEndAtomIdx() | |
edge_feature = bond_to_feature_vector(bond) | |
# add edges in both directions | |
edges_list.append((i, j)) | |
edges_list.append((j, i)) | |
edge_feature_oh = one_hot_vector_sm( | |
edge_feature, features_dim1, features_dim2, features_dim3 | |
) | |
if advs: | |
mol_edge_dict[(i, j)] = Variable( | |
torch.tensor([1.0]), requires_grad=True | |
) | |
# add edges in both directions | |
edge_features_list.append( | |
torch.unsqueeze(mol_edge_dict[(i, j)] * edge_feature_oh, 0) | |
) | |
edge_features_list.append( | |
torch.unsqueeze(mol_edge_dict[(i, j)] * edge_feature_oh, 0) | |
) | |
else: | |
# add edges in both directions | |
edge_features_list.append(torch.unsqueeze(edge_feature_oh, 0)) | |
edge_features_list.append(torch.unsqueeze(edge_feature_oh, 0)) | |
if advs: | |
# Update edge dict | |
args.edge_dict[smiles_string] = mol_edge_dict | |
# data.edge_index: Graph connectivity in COO format with shape [2, num_edges] | |
edge_index = np.array(edges_list, dtype=np.int64).T | |
# data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features] | |
edge_attr = torch.cat(edge_features_list) | |
else: # mol has no bonds | |
edge_index = np.empty((2, 0), dtype=np.int64) | |
edge_attr = np.empty((0, num_bond_features), dtype=np.int64) | |
args.edge_dict[smiles_string] = {} | |
return edge_attr, edge_index, x | |
def get_dataset( | |
dataset, use_prot=False, target=None, args=None, advs=False, saliency=False | |
): | |
total_dataset = [] | |
if use_prot: | |
prot_graph = transform_molecule_pg( | |
target["Fasta"].item(), label=None, is_prot=use_prot | |
) | |
for mol, label in tqdm( | |
zip(dataset["Smiles"], dataset["Label"]), total=len(dataset["Smiles"]) | |
): | |
if use_prot: | |
total_dataset.append([transform_molecule_pg(mol,label,args, advs, saliency=saliency),prot_graph]) | |
else: | |
total_dataset.append( | |
transform_molecule_pg(mol, label, args, advs, saliency=saliency) | |
) | |
return total_dataset | |
def get_perturbed_dataset(mols, labels, args): | |
total_dataset = [] | |
for mol, label in zip(mols, labels): | |
total_dataset.append(transform_molecule_pg(mol, label, args, received_mol=True)) | |
return total_dataset | |
def transform_molecule_pg( | |
smiles, | |
label, | |
args=None, | |
advs=False, | |
received_mol=False, | |
saliency=False, | |
is_prot=False, | |
): | |
if is_prot: | |
edge_attr_p, edge_index_p, x_p = smiles_to_graph(smiles, is_prot) | |
x_p = torch.tensor(x_p) | |
edge_index_p = torch.tensor(edge_index_p) | |
edge_attr_p = torch.tensor(edge_attr_p) | |
return Data(edge_attr=edge_attr_p, edge_index=edge_index_p, x=x_p) | |
else: | |
if args.advs or received_mol: | |
if advs or received_mol: | |
edge_attr, edge_index, x = smiles_to_graph_advs( | |
smiles, | |
args, | |
advs=True, | |
received_mol=received_mol, | |
saliency=saliency, | |
) | |
else: | |
edge_attr, edge_index, x = smiles_to_graph_advs( | |
smiles, args, received_mol=received_mol, saliency=saliency | |
) | |
else: | |
edge_attr, edge_index, x = smiles_to_graph(smiles, saliency=saliency) | |
if not saliency: | |
x = torch.tensor(x) | |
y = torch.tensor([label]) | |
edge_index = torch.tensor(edge_index) | |
if not args.advs and not received_mol: | |
edge_attr = torch.tensor(edge_attr) | |
if received_mol: | |
mol = smiles | |
else: | |
mol = Chem.MolFromSmiles(smiles) | |
return Data( | |
edge_attr=edge_attr, edge_index=edge_index, x=x, y=y, mol=mol, smiles=smiles | |
) | |