--- imports: - $import glob - $import matplotlib.pyplot as plt # workflow parameters bundle_root: "./" device: "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')" ckpt_dir: "$@bundle_root + '/models'" # folder to save new checkpoints ckpt: "" # path to load an existing checkpoint val_interval: 1 # every epoch max_epochs: 300 cross_subjects: false # whether the input images are from the same subject # construct the moving and fixed datasets dataset_dir: "../MedNIST/Hand" datalist: "$list(sorted(glob.glob(@dataset_dir + '/*.jpeg')))[:7000]" # training with 7000 images val_datalist: "$list(sorted(glob.glob(@dataset_dir + '/*.jpeg')))[7000:8500]" # validation with 1500 images image_load: - _target_: LoadImage image_only: True ensure_channel_first: True - _target_: ScaleIntensityRange a_min: 0.0 a_max: 255.0 b_min: 0.0 b_max: 1.0 - _target_: EnsureType device: "@device" image_aug: - _target_: RandAffine spatial_size: [64, 64] translate_range: 5 scale_range: [-0.15, 0.15] prob: 1.0 rotate_range: $np.pi / 8 mode: bilinear padding_mode: border cache_grid: True device: "@device" - _target_: RandGridDistortion prob: 0.2 num_cells: 8 device: "@device" distort_limit: 0.1 preprocessing: _target_: Compose transforms: "$@image_load + @image_aug" cache_datasets: - _target_: ShuffleBuffer data: _target_: CacheDataset data: "@datalist" transform: $@preprocessing.set_random_state(123) hash_as_key: true runtime_cache: threads epochs: "@max_epochs" seed: "$int(3) if @cross_subjects else int(2)" - _target_: ShuffleBuffer data: _target_: CacheDataset data: "@datalist" transform: $@preprocessing.set_random_state(234) hash_as_key: true runtime_cache: threads epochs: "@max_epochs" seed: 2 zip_dataset: _target_: IterableDataset data: "$map(lambda t: dict(image=monai.transforms.concatenate(t), label=t[1]), zip(*@cache_datasets))" data_loader: _target_: ThreadDataLoader dataset: "@zip_dataset" batch_size: 64 num_workers: 0 # components for debugging first_pair: $monai.utils.misc.first(@data_loader) display: - $monai.utils.set_determinism(seed=123) - $print(@first_pair.keys(), @first_pair['image'].meta['filename_or_obj']) - "$print(@trainer#loss_function(@first_pair['image'][:, 0:1], @first_pair['image'][:, 1:2]))" # print loss - $plt.subplot(1,2,1) - $plt.imshow(@first_pair['image'][0, 0], cmap="gray") - $plt.subplot(1,2,2) - $plt.imshow(@first_pair['image'][0, 1], cmap="gray") - $plt.show() # network definition net: _target_: scripts.net.RegResNet image_size: [64, 64] spatial_dims: 2 mode: "bilinear" padding_mode: "border" optimizer: _target_: torch.optim.Adam params: $@net.parameters() lr: 0.00001 # create a validation evaluator val: cache_datasets: - _target_: ShuffleBuffer data: _target_: CacheDataset data: "@val_datalist" transform: $@preprocessing.set_random_state(123) hash_as_key: true runtime_cache: threads epochs: -1 # infinite seed: "$int(3) if @cross_subjects else int(2)" - _target_: ShuffleBuffer data: _target_: CacheDataset data: "@val_datalist" transform: $@preprocessing.set_random_state(234) hash_as_key: true runtime_cache: threads epochs: -1 # infinite seed: 2 zip_dataset: _target_: IterableDataset data: "$map(lambda t: dict(image=monai.transforms.concatenate(t), label=t[1]), zip(*@val#cache_datasets))" data_loader: _target_: ThreadDataLoader dataset: "@val#zip_dataset" batch_size: 64 num_workers: 0 evaluator: _target_: SupervisedEvaluator device: "@device" val_data_loader: "@val#data_loader" network: "@net" epoch_length: $len(@val_datalist) // @val#data_loader#batch_size inferer: "$monai.inferers.SimpleInferer()" metric_cmp_fn: "$lambda x, y: x < y" key_val_metric: val_mse: _target_: MeanSquaredError output_transform: "$monai.handlers.from_engine(['pred', 'label'])" additional_metrics: {"mutual info loss": "@loss_metric#metric_handler"} val_handlers: - _target_: StatsHandler iteration_log: false - _target_: CheckpointSaver save_dir: "@ckpt_dir" save_dict: {model: "@net"} save_key_metric: true key_metric_negative_sign: true # key_metric_filename: "model.pt" # training handlers handlers: - _target_: StatsHandler tag_name: "train_loss" output_transform: "$monai.handlers.from_engine(['loss'], first=True)" - _target_: ValidationHandler validator: "@val#evaluator" epoch_level: true interval: "@val_interval" loss_metric: metric_handler: _target_: IgniteMetric output_transform: "$monai.handlers.from_engine(['pred', 'label'])" metric_fn: _target_: LossMetric loss_fn: "@mutual_info_loss" get_not_nans: true ckpt_loader: - _target_: CheckpointLoader load_path: "@ckpt" load_dict: {model: "@net"} lncc_loss: _target_: LocalNormalizedCrossCorrelationLoss spatial_dims: 2 kernel_size: 5 kernel_type: rectangular reduction: mean mutual_info_loss: _target_: GlobalMutualInformationLoss # create the primary trainer trainer: _target_: SupervisedTrainer device: "@device" train_data_loader: "@data_loader" network: "@net" max_epochs: "@max_epochs" epoch_length: $len(@datalist) // @data_loader#batch_size loss_function: "@lncc_loss" optimizer: "@optimizer" train_handlers: "$@handlers + @ckpt_loader if @ckpt else @handlers" training: - $monai.utils.set_determinism(seed=23) - "$setattr(torch.backends.cudnn, 'benchmark', True)" - $@trainer.run()