from argparse import ArgumentParser,FileType def parse_train_args(): # General arguments parser = ArgumentParser() parser.add_argument('--config', type=FileType(mode='r'), default=None) parser.add_argument('--log_dir', type=str, default='workdir/test_score', help='Folder in which to save model and logs') parser.add_argument('--restart_dir', type=str, help='Folder of previous training model from which to restart') parser.add_argument('--restart_ckpt', type=str, default='last_model', help='') parser.add_argument('--pretrain_dir', type=str, help='Folder of pretrained model from which to restart') parser.add_argument('--pretrain_ckpt', type=str, help='') parser.add_argument('--freeze_params', type=int, default=0, help='') parser.add_argument('--cache_path', type=str, default='data/cache', help='Folder from where to load/restore cached dataset') parser.add_argument('--moad_dir', type=str, default='data/BindingMOAD_2020_processed/', help='Folder containing original structures') parser.add_argument('--pdbbind_dir', type=str, default='data/PDBBind_processed/', help='Folder containing original structures') parser.add_argument('--dataset', type=str, default='pdbbind', help='Folder containing original structures') parser.add_argument('--split_train', type=str, default='data/splits/timesplit_no_lig_overlap_train', help='Path of file defining the split') parser.add_argument('--split_val', type=str, default='data/splits/timesplit_no_lig_overlap_val', help='Path of file defining the split') parser.add_argument('--split_test', type=str, default='data/splits/timesplit_test', help='Path of file defining the split') parser.add_argument('--test_sigma_intervals', action='store_true', default=False, help='Whether to log loss per noise interval') parser.add_argument('--val_inference_freq', type=int, default=5, help='Frequency of epochs for which to run expensive inference on val data') parser.add_argument('--save_model_freq', type=int, default=None, help='') parser.add_argument('--inference_samples', type=int, default=1, help='') parser.add_argument('--train_inference_freq', type=int, default=None, help='Frequency of epochs for which to run expensive inference on train data') parser.add_argument('--inference_steps', type=int, default=20, help='Number of denoising steps for inference on val') parser.add_argument('--num_inference_complexes', type=int, default=100, help='Number of complexes for which inference is run every val/train_inference_freq epochs (None will run it on all)') parser.add_argument('--inference_earlystop_metric', type=str, default='valinf_min_rmsds_lt2', help='This is the metric that is addionally used when val_inference_freq is not None') parser.add_argument('--inference_secondary_metric', type=str, default=None, help='') parser.add_argument('--inference_earlystop_goal', type=str, default='max', help='Whether to maximize or minimize metric') parser.add_argument('--wandb', action='store_true', default=False, help='') parser.add_argument('--project', type=str, default='diffdock', help='') parser.add_argument('--run_name', type=str, default='', help='') parser.add_argument('--cudnn_benchmark', action='store_true', default=False, help='CUDA optimization parameter for faster training') parser.add_argument('--num_dataloader_workers', type=int, default=0, help='Number of workers for dataloader') parser.add_argument('--pin_memory', action='store_true', default=False, help='pin_memory arg of dataloader') parser.add_argument('--dataloader_drop_last', action='store_true', default=False, help='drop_last arg of dataloader') parser.add_argument('--double_val', action='store_true', default=False, help='') parser.add_argument('--combined_training', action='store_true', default=False, help='') # Training arguments parser.add_argument('--n_epochs', type=int, default=400, help='Number of epochs for training') parser.add_argument('--batch_size', type=int, default=32, help='Batch size') parser.add_argument('--scheduler', type=str, default=None, help='LR scheduler') parser.add_argument('--scheduler_patience', type=int, default=20, help='Patience of the LR scheduler') parser.add_argument('--lr_start_factor', type=float, default=0.001, help='') parser.add_argument('--warmup_dur', type=int, default=4, help='') parser.add_argument('--lr', type=float, default=1e-3, help='Initial learning rate') parser.add_argument('--restart_lr', type=float, default=None, help='If this is not none, the lr of the optimizer will be overwritten with this value when restarting from a checkpoint.') parser.add_argument('--w_decay', type=float, default=0.0, help='Weight decay added to loss') parser.add_argument('--num_workers', type=int, default=1, help='Number of workers for preprocessing') parser.add_argument('--use_ema', action='store_true', default=False, help='Whether or not to use ema for the model weights') parser.add_argument('--ema_rate', type=float, default=0.999, help='decay rate for the exponential moving average model parameters ') # Dataset parser.add_argument('--limit_complexes', type=int, default=5, help='If positive, the number of training and validation complexes is capped') # TODO change parser.add_argument('--all_atoms', action='store_true', default=False, help='Whether to use the all atoms model') parser.add_argument('--chain_cutoff', type=float, default=None, help='Cutoff on whether to include non-interacting chains') parser.add_argument('--receptor_radius', type=float, default=30, help='Cutoff on distances for receptor edges') parser.add_argument('--c_alpha_max_neighbors', type=int, default=10, help='Maximum number of neighbors for each residue') parser.add_argument('--atom_radius', type=float, default=5, help='Cutoff on distances for atom connections') parser.add_argument('--atom_max_neighbors', type=int, default=8, help='Maximum number of atom neighbours for receptor') parser.add_argument('--matching_popsize', type=int, default=20, help='Differential evolution popsize parameter in matching') parser.add_argument('--matching_maxiter', type=int, default=20, help='Differential evolution maxiter parameter in matching') parser.add_argument('--matching_tries', type=int, default=1, help='') parser.add_argument('--max_lig_size', type=int, default=None, help='Maximum number of heavy atoms in ligand') parser.add_argument('--remove_hs', action='store_true', default=True, help='remove Hs') parser.add_argument('--num_conformers', type=int, default=1, help='Number of conformers to match to each ligand') parser.add_argument('--moad_esm_embeddings_path', type=str, default=None, help='If this is set then the LM embeddings at that path will be used for the receptor features') parser.add_argument('--pdbbind_esm_embeddings_path', type=str, default=None, help='If this is set then the LM embeddings at that path will be used for the receptor features') parser.add_argument('--moad_esm_embeddings_sequences_path', type=str, default=None, help='') parser.add_argument('--esm_embeddings_model', type=str, default=None, help='') parser.add_argument('--not_fixed_knn_radius_graph', action='store_true', default=False, help='Use knn graph and radius graph with closest neighbors instead of random ones as with radius_graph') parser.add_argument('--not_knn_only_graph', action='store_true', default=False, help='Use knn graph only and not restrict to a specific radius') parser.add_argument('--include_miscellaneous_atoms', action='store_true', default=False, help='include non amino acid atoms for the receptor') parser.add_argument('--train_multiplicity', type=int, default=1, help='') parser.add_argument('--val_multiplicity', type=int, default=1, help='') parser.add_argument('--max_receptor_size', type=int, default=None, help='') parser.add_argument('--remove_promiscuous_targets', type=int, default=None, help='') parser.add_argument('--min_ligand_size', type=int, default=0, help='') parser.add_argument('--unroll_clusters', action='store_true', default=False, help='') parser.add_argument('--enforce_timesplit', action='store_true', default=False, help='') parser.add_argument('--merge_clusters', type=int, default=1, help='') parser.add_argument('--triple_training', action='store_true', default=False, help='') parser.add_argument('--crop_beyond', type=float, default=20, help='') # Diffusion parser.add_argument('--tr_weight', type=float, default=0.33, help='Weight of translation loss') parser.add_argument('--rot_weight', type=float, default=0.33, help='Weight of rotation loss') parser.add_argument('--tor_weight', type=float, default=0.33, help='Weight of torsional loss') parser.add_argument('--confidence_weight', type=float, default=0.33, help='Weight of confidence loss') parser.add_argument('--rot_sigma_min', type=float, default=0.1, help='Minimum sigma for rotational component') parser.add_argument('--rot_sigma_max', type=float, default=1.65, help='Maximum sigma for rotational component') parser.add_argument('--tr_sigma_min', type=float, default=0.1, help='Minimum sigma for translational component') parser.add_argument('--tr_sigma_max', type=float, default=30, help='Maximum sigma for translational component') parser.add_argument('--tor_sigma_min', type=float, default=0.0314, help='Minimum sigma for torsional component') parser.add_argument('--tor_sigma_max', type=float, default=3.14, help='Maximum sigma for torsional component') parser.add_argument('--no_torsion', action='store_true', default=False, help='If set only rigid matching') parser.add_argument('--sampling_alpha', type=float, default=1, help='Alpha parameter of beta distribution for sampling t') parser.add_argument('--sampling_beta', type=float, default=1, help='Beta parameter of beta distribution for sampling t') parser.add_argument('--bootstrap_alpha', type=float, default=1, help='Alpha parameter of beta distribution for sampling t in bootstrapping') parser.add_argument('--bootstrap_beta', type=float, default=1, help='Beta parameter of beta distribution for sampling t in bootstrapping') parser.add_argument('--bootstrap_tmin', type=float, default=0, help='') # Model parser.add_argument('--num_conv_layers', type=int, default=2, help='Number of interaction layers') parser.add_argument('--max_radius', type=float, default=5.0, help='Radius cutoff for geometric graph') parser.add_argument('--scale_by_sigma', action='store_true', default=True, help='Whether to normalise the score') parser.add_argument('--norm_by_sigma', action='store_true', default=False, help='Whether to normalise the score') parser.add_argument('--ns', type=int, default=16, help='Number of hidden features per node of order 0') parser.add_argument('--nv', type=int, default=4, help='Number of hidden features per node of order >0') parser.add_argument('--distance_embed_dim', type=int, default=32, help='Embedding size for the distance') parser.add_argument('--cross_distance_embed_dim', type=int, default=32, help='Embeddings size for the cross distance') parser.add_argument('--no_batch_norm', action='store_true', default=False, help='If set, it removes the batch norm') parser.add_argument('--use_second_order_repr', action='store_true', default=False, help='Whether to use only up to first order representations or also second') parser.add_argument('--cross_max_distance', type=float, default=80, help='Maximum cross distance in case not dynamic') parser.add_argument('--dynamic_max_cross', action='store_true', default=False, help='Whether to use the dynamic distance cutoff') parser.add_argument('--dropout', type=float, default=0.0, help='MLP dropout') parser.add_argument('--smooth_edges', action='store_true', default=False, help='Whether to apply additional smoothing weight to edges') parser.add_argument('--odd_parity', action='store_true', default=False, help='Whether to impose odd parity in output') parser.add_argument('--embedding_type', type=str, default="sinusoidal", help='Type of diffusion time embedding') parser.add_argument('--sigma_embed_dim', type=int, default=32, help='Size of the embedding of the diffusion time') parser.add_argument('--embedding_scale', type=int, default=1000, help='Parameter of the diffusion time embedding') parser.add_argument('--use_old_atom_encoder', action='store_true', default=False, help='option to use old atom encoder for backward compatibility') parser.add_argument('--depthwise_convolution', action='store_true', default=False, help='') parser.add_argument('--protein_file', type=str, default='protein_processed', help='') parser.add_argument('--no_aminoacid_identities', action='store_true', default=False, help='') parser.add_argument('--sh_lmax', type=int, default=2, help='Size of the embedding of the diffusion time') parser.add_argument('--no_differentiate_convolutions', action='store_true', default=False, help='') parser.add_argument('--tp_weights_layers', type=int, default=2, help='') parser.add_argument('--num_prot_emb_layers', type=int, default=0, help='') parser.add_argument('--reduce_pseudoscalars', action='store_true', default=False, help='') parser.add_argument('--embed_also_ligand', action='store_true', default=True, help='') parser.add_argument('--sidechain_loss_weight', type=float, default=0, help='') parser.add_argument('--backbone_loss_weight', type=float, default=0, help='') # pdb sidechain training parser.add_argument('--pdbsidechain_dir', type=str, default='data/pdb_2021aug02_sample', help='') parser.add_argument('--pdbsidechain_esm_embeddings_path', type=str, default=None, help='') parser.add_argument('--pdbsidechain_esm_embeddings_sequences_path', type=str, default=None, help='') parser.add_argument('--vandermers_max_dist', type=int, default=5, help='') parser.add_argument('--vandermers_buffer_residue_num', type=int, default=7, help='') parser.add_argument('--vandermers_min_contacts', type=int, default=None, help='') parser.add_argument('--remove_second_segment', action='store_true', default=False, help='') args = parser.parse_args() assert (not args.dynamic_max_cross) or (args.tr_sigma_max * 3 + 20 < args.cross_max_distance) assert args.esm_embeddings_model is None or args.esm_embeddings_path is None return args