--- imports: - $import glob - $import matplotlib.pyplot as plt dataset_dir: "../MedNIST/Hand" # inference with 10 images, modify the indices to run it with different image inputs datalist: "$list(sorted(glob.glob(@dataset_dir + '/*.jpeg')))[8500:8510]" bundle_root: "./" device: "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')" output_dir: "$@bundle_root + '/eval'" ckpt: "$@bundle_root + '/models/model.pt'" cross_subjects: false # whether the input images are from the same subject image_load: - _target_: LoadImage image_only: True ensure_channel_first: True - _target_: ScaleIntensityRange a_min: 0.0 a_max: 255.0 b_min: 0.0 b_max: 1.0 - _target_: EnsureType device: "@device" image_aug: - _target_: RandAffine spatial_size: [64, 64] translate_range: 5 scale_range: [-0.15, 0.15] prob: 1.0 rotate_range: $np.pi / 8 mode: bilinear padding_mode: border cache_grid: True device: "@device" preprocessing: _target_: Compose transforms: "$@image_load + @image_aug" datasets: - _target_: ShuffleBuffer data: _target_: Dataset data: "@datalist" transform: {_target_: Compose, transforms: "@image_load"} seed: "$int(3) if @cross_subjects else int(2)" - _target_: ShuffleBuffer data: _target_: Dataset data: "@datalist" transform: $@preprocessing.set_random_state(3) seed: 2 zip_dataset: _target_: IterableDataset data: "$map(lambda t: dict(image=monai.transforms.concatenate(t), m_img=t[0], label=t[1]), zip(*@datasets))" data_loader: _target_: ThreadDataLoader dataset: "@zip_dataset" batch_size: 1 num_workers: 0 # components for debugging first_pair: $monai.utils.misc.first(@data_loader) display: - $monai.utils.set_determinism(seed=23) - $print(@first_pair.keys()) - $plt.subplot(1,3,1) - $plt.imshow(@first_pair['image'][0, 0], cmap="gray") - $plt.subplot(1,3,2) - $plt.imshow(@first_pair['image'][0, 1], cmap="gray") - $plt.subplot(1,3,3) - $plt.imshow(np.abs(@first_pair['image'][0, 0] - @first_pair['image'][0, 1]), cmap="gray") - $plt.show() # network definition network_def: _target_: scripts.net.RegResNet image_size: [64, 64] spatial_dims: 2 mode: "bilinear" padding_mode: "border" # create the primary evaluator handlers: - _target_: CheckpointLoader load_path: "@ckpt" load_dict: {model: "@network_def"} - _target_: StatsHandler iteration_log: false inferer: {_target_: SimpleInferer} evaluator: _target_: SupervisedEvaluator device: "@device" val_data_loader: "@data_loader" network: "@network_def" epoch_length: $len(@datalist) // @data_loader#batch_size inferer: "@inferer" val_handlers: "@handlers" postprocessing: _target_: Compose transforms: - _target_: SaveImaged keys: [m_img] resample: False output_dir: "@output_dir" output_ext: "png" output_postfix: "moving" output_dtype: "$np.uint8" scale: 255 separate_folder: False writer: "PILWriter" output_name_formatter: "$lambda x, s: dict(idx=s._data_index, subject=x['filename_or_obj'])" - _target_: SaveImaged keys: [label] resample: False output_dir: "@output_dir" output_ext: "png" output_postfix: "fixed" output_dtype: "$np.uint8" scale: 255 separate_folder: False writer: "PILWriter" output_name_formatter: "$lambda x, s: dict(idx=s._data_index, subject=x['filename_or_obj'])" - _target_: SaveImaged keys: [pred] resample: False output_dir: "@output_dir" output_ext: "png" output_postfix: "pred" output_dtype: "$np.uint8" scale: 255 separate_folder: False writer: "PILWriter" output_name_formatter: "$lambda x, s: dict(idx=s._data_index, subject=x['filename_or_obj'])" eval: - $monai.utils.set_determinism(seed=123) - "$setattr(torch.backends.cudnn, 'benchmark', True)" - $@evaluator.run()